aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp16
-rw-r--r--mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp109
-rw-r--r--mlir/lib/Dialect/AMX/IR/AMXDialect.cpp4
-rw-r--r--mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp4
-rw-r--r--mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp21
-rw-r--r--mlir/lib/Dialect/Affine/Analysis/Utils.cpp78
-rw-r--r--mlir/lib/Dialect/Affine/IR/AffineOps.cpp32
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp199
-rw-r--r--mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp21
-rw-r--r--mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp93
-rw-r--r--mlir/lib/Dialect/Arith/IR/ArithOps.cpp5
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp12
-rw-r--r--mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp60
-rw-r--r--mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp2
-rw-r--r--mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp14
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp6
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp8
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp15
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp82
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp14
-rw-r--r--mlir/lib/Dialect/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/DLTI/Traits.cpp2
-rw-r--r--mlir/lib/Dialect/EmitC/IR/EmitC.cpp69
-rw-r--r--mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp150
-rw-r--r--mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp4
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp98
-rw-r--r--mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp13
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp57
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp12
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp48
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp1
-rw-r--r--mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp38
-rw-r--r--mlir/lib/Dialect/LLVMIR/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp487
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp120
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp71
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp6
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp4
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp31
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp448
-rw-r--r--mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp7
-rw-r--r--mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp8
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp508
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp43
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp12
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt4
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp256
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp4
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp275
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp22
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp65
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp14
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp62
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp98
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp28
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp (renamed from mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp)15
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp11
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp132
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp8
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp197
-rw-r--r--mlir/lib/Dialect/Math/Transforms/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp (renamed from mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp)139
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp1
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp6
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp2
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp70
-rw-r--r--mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp6
-rw-r--r--mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp24
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp78
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp153
-rw-r--r--mlir/lib/Dialect/Ptr/IR/CMakeLists.txt20
-rw-r--r--mlir/lib/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp15
-rw-r--r--mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp12
-rw-r--r--mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp122
-rw-r--r--mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp52
-rw-r--r--mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp39
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp3
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp9
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp5
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp4
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp86
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp4
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp15
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp63
-rw-r--r--mlir/lib/Dialect/SCF/Utils/Utils.cpp46
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp251
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp12
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp2
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp74
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp15
-rw-r--r--mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp5
-rw-r--r--mlir/lib/Dialect/Shard/Transforms/Partition.cpp10
-rw-r--r--mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp4
-rw-r--r--mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp7
-rw-r--r--mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp11
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp5
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp10
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorOps.cpp16
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp568
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp23
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp238
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp4
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp12
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformOps.cpp40
-rw-r--r--mlir/lib/Dialect/Transform/IR/Utils.cpp33
-rw-r--r--mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp4
-rw-r--r--mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp7
-rw-r--r--mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp86
-rw-r--r--mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp2
-rw-r--r--mlir/lib/Dialect/Utils/StaticValueUtils.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp217
-rw-r--r--mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp5
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp65
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp45
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp63
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp39
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp10
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp4
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp4
-rw-r--r--mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp51
-rw-r--r--mlir/lib/Dialect/WasmSSA/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/WasmSSA/IR/CMakeLists.txt24
-rw-r--r--mlir/lib/Dialect/WasmSSA/IR/WasmSSADialect.cpp38
-rw-r--r--mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp69
-rw-r--r--mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp494
-rw-r--r--mlir/lib/Dialect/WasmSSA/IR/WasmSSATypes.cpp18
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt5
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp355
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp388
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp39
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp4
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp82
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp630
-rw-r--r--mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp98
146 files changed, 7228 insertions, 2184 deletions
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 9a0a230..11a40d6 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -511,6 +511,18 @@ LogicalResult DPPOp::verify() {
}
//===----------------------------------------------------------------------===//
+// PermlaneSwapOp
+//===----------------------------------------------------------------------===//
+LogicalResult PermlaneSwapOp::verify() {
+ unsigned rowLength = getRowLength();
+
+ if (rowLength != 16 && rowLength != 32)
+ return emitOpError("row_length attribute must either be 16 or 32.");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// GatherToLDSOp
//===----------------------------------------------------------------------===//
@@ -518,8 +530,8 @@ LogicalResult GatherToLDSOp::verify() {
MemRefType srcType = cast<MemRefType>(getSrc().getType());
MemRefType dstType = cast<MemRefType>(getDst().getType());
- if (!dstType.areTrailingDimsContiguous(dstType.getRank()))
- return emitOpError("destination types must be contiguous");
+ if (!dstType.areTrailingDimsContiguous(1))
+ return emitOpError("destination type inner most dim must be contiguous");
auto elemType = srcType.getElementType();
// Check $src and $dst element types are the same.
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index 729e3da..d35853b 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -5,7 +5,7 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms
ResolveStridedMetadata.cpp
ADDITIONAL_HEADER_DIRS
- {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
DEPENDS
MLIRAMDGPUTransformsIncGen
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
index a3fdc7e..d547510 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
@@ -28,62 +28,79 @@ struct AmdgpuFoldMemRefOpsPass final
}
};
+static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc,
+ Value view, mlir::OperandRange indices,
+ SmallVectorImpl<Value> &resolvedIndices,
+ Value &memrefBase, StringRef role) {
+ Operation *defOp = view.getDefiningOp();
+ if (!defOp) {
+ return failure();
+ }
+ return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
+ .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
+ mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+ rewriter, loc, subviewOp.getMixedOffsets(),
+ subviewOp.getMixedStrides(), subviewOp.getDroppedDims(), indices,
+ resolvedIndices);
+ memrefBase = subviewOp.getSource();
+ return success();
+ })
+ .Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
+ if (failed(mlir::memref::resolveSourceIndicesExpandShape(
+ loc, rewriter, expandShapeOp, indices, resolvedIndices,
+ false))) {
+ return failure();
+ }
+ memrefBase = expandShapeOp.getViewSource();
+ return success();
+ })
+ .Case<memref::CollapseShapeOp>(
+ [&](memref::CollapseShapeOp collapseShapeOp) {
+ if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
+ loc, rewriter, collapseShapeOp, indices,
+ resolvedIndices))) {
+ return failure();
+ }
+ memrefBase = collapseShapeOp.getViewSource();
+ return success();
+ })
+ .Default([&](Operation *op) {
+ return rewriter.notifyMatchFailure(
+ op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or "
+ "CollapseShapeOp")
+ .str());
+ });
+}
+
struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(GatherToLDSOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- Value memrefSource;
- SmallVector<Value> sourceIndices;
- auto foldResult =
- llvm::TypeSwitch<Operation *, LogicalResult>(
- op.getSrc().getDefiningOp())
- .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
- // If the source is a SubViewOp, we can directly rewrite the
- // GatherToLDSOp.
- mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
- rewriter, loc, subviewOp.getMixedOffsets(),
- subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
- op.getSrcIndices(), sourceIndices);
- memrefSource = subviewOp.getSource();
- return success();
- })
- .Case<memref::ExpandShapeOp>(
- [&](memref::ExpandShapeOp expandShapeOp) {
- if (failed(mlir::memref::resolveSourceIndicesExpandShape(
- loc, rewriter, expandShapeOp, op.getSrcIndices(),
- sourceIndices, false))) {
- return failure();
- }
- memrefSource = expandShapeOp.getViewSource();
- return success();
- })
- .Case<memref::CollapseShapeOp>(
- [&](memref::CollapseShapeOp collapseShapeOp) {
- if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
- loc, rewriter, collapseShapeOp, op.getSrcIndices(),
- sourceIndices))) {
- return failure();
- }
- memrefSource = collapseShapeOp.getViewSource();
- return success();
- })
- .Default([&](Operation *op) {
- // If the source is not a SubViewOp, ExpandShapeOp, or
- // CollapseShapeOp, we cannot fold the GatherToLDSOp.
- return rewriter.notifyMatchFailure(
- op,
- "source producer is not one of SubViewOp, ExpandShapeOp, or "
- "CollapseShapeOp");
- });
+ SmallVector<Value> sourceIndices, destIndices;
+ Value memrefSource, memrefDest;
+
+ auto foldSrcResult =
+ foldMemrefViewOp(rewriter, loc, op.getSrc(), op.getSrcIndices(),
+ sourceIndices, memrefSource, "source");
+
+ if (failed(foldSrcResult)) {
+ memrefSource = op.getSrc();
+ sourceIndices = op.getSrcIndices();
+ }
+
+ auto foldDstResult =
+ foldMemrefViewOp(rewriter, loc, op.getDst(), op.getDstIndices(),
+ destIndices, memrefDest, "destination");
- if (failed(foldResult)) {
- return failure();
+ if (failed(foldDstResult)) {
+ memrefDest = op.getDst();
+ destIndices = op.getDstIndices();
}
rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
- op.getDst(), op.getDstIndices(),
+ memrefDest, destIndices,
op.getTransferType());
return success();
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index 6f3110c..68990ef 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -271,7 +271,9 @@ Type amx::TileType::parse(AsmParser &parser) {
if (parser.parseGreater())
return nullptr;
- return TileType::get(shape, elementType);
+ return TileType::getChecked(
+ [&] { return parser.emitError(parser.getNameLoc()); }, shape,
+ elementType);
}
void amx::TileType::print(AsmPrinter &os) const {
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
index 86edc2b..b405ec2 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
@@ -93,13 +93,13 @@ FlatAffineValueConstraints::addAffineForOpDomain(AffineForOp forOp) {
int64_t lb = forOp.getConstantLowerBound();
dividend[pos] = 1;
dividend.back() -= lb;
- addLocalFloorDiv(dividend, step);
+ unsigned qPos = addLocalFloorDiv(dividend, step);
// Second constraint: (iv - lb) - step * q = 0.
SmallVector<int64_t, 8> eq(getNumCols(), 0);
eq[pos] = 1;
eq.back() -= lb;
// For the local var just added above.
- eq[getNumCols() - 2] = -step;
+ eq[qPos] = -step;
addEquality(eq);
}
}
diff --git a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
index 2f85e0b..166d39e 100644
--- a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
@@ -21,6 +21,7 @@
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include <numeric>
#include <optional>
@@ -548,19 +549,19 @@ bool mlir::affine::isTilingValid(ArrayRef<AffineForOp> loops) {
// Check whether there is any negative direction vector in the
// dependence components found above, which means that dependence is
// violated by the default hyper-rect tiling method.
- LLVM_DEBUG(llvm::dbgs() << "Checking whether tiling legality violated "
- "for dependence at depth: "
- << Twine(d) << " between:\n";);
- LLVM_DEBUG(srcAccess.opInst->dump());
- LLVM_DEBUG(dstAccess.opInst->dump());
+ LDBG() << "Checking whether tiling legality violated "
+ << "for dependence at depth: " << Twine(d) << " between:"
+ << OpWithFlags(srcAccess.opInst, OpPrintingFlags().skipRegions())
+ << "\nand:\n"
+ << OpWithFlags(dstAccess.opInst,
+ OpPrintingFlags().skipRegions());
for (const DependenceComponent &depComp : depComps) {
if (depComp.lb.has_value() && depComp.ub.has_value() &&
*depComp.lb < *depComp.ub && *depComp.ub < 0) {
- LLVM_DEBUG(llvm::dbgs()
- << "Dependence component lb = " << Twine(*depComp.lb)
- << " ub = " << Twine(*depComp.ub)
- << " is negative at depth: " << Twine(d)
- << " and thus violates the legality rule.\n");
+ LDBG() << "Dependence component lb = " << Twine(*depComp.lb)
+ << " ub = " << Twine(*depComp.ub)
+ << " is negative at depth: " << Twine(d)
+ << " and thus violates the legality rule.";
return false;
}
}
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index a89c1ae..99ea20b 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/IntegerSet.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
@@ -241,7 +242,7 @@ addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
}
bool MemRefDependenceGraph::init() {
- LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
+ LDBG() << "--- Initializing MDG ---";
// Map from a memref to the set of ids of the nodes that have ops accessing
// the memref.
DenseMap<Value, SetVector<unsigned>> memrefAccesses;
@@ -288,8 +289,7 @@ bool MemRefDependenceGraph::init() {
// Return false if non-handled/unknown region-holding ops are found. We
// won't know what such ops do or what its regions mean; for e.g., it may
// not be an imperative op.
- LLVM_DEBUG(llvm::dbgs()
- << "MDG init failed; unknown region-holding op found!\n");
+ LDBG() << "MDG init failed; unknown region-holding op found!";
return false;
}
// We aren't creating nodes for memory-effect free ops either with no
@@ -297,7 +297,7 @@ bool MemRefDependenceGraph::init() {
// interface.
}
- LLVM_DEBUG(llvm::dbgs() << "Created " << nodes.size() << " nodes\n");
+ LDBG() << "Created " << nodes.size() << " nodes";
// Add dependence edges between nodes which produce SSA values and their
// users. Load ops can be considered as the ones producing SSA values.
@@ -556,9 +556,8 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
gatherDefiningNodes(dstId, definingNodes);
if (llvm::any_of(definingNodes,
[&](unsigned id) { return hasDependencePath(srcId, id); })) {
- LLVM_DEBUG(llvm::dbgs()
- << "Can't fuse: a defining op with a user in the dst "
- "loop has dependence from the src loop\n");
+ LDBG() << "Can't fuse: a defining op with a user in the dst "
+ << "loop has dependence from the src loop";
return nullptr;
}
@@ -957,20 +956,20 @@ std::optional<bool> ComputationSliceState::isSliceValid() const {
FlatAffineValueConstraints srcConstraints;
// TODO: Store the source's domain to avoid computation at each depth.
if (failed(getSourceAsConstraints(srcConstraints))) {
- LLVM_DEBUG(llvm::dbgs() << "Unable to compute source's domain\n");
+ LDBG() << "Unable to compute source's domain";
return std::nullopt;
}
// As the set difference utility currently cannot handle symbols in its
// operands, validity of the slice cannot be determined.
if (srcConstraints.getNumSymbolVars() > 0) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot handle symbols in source domain\n");
+ LDBG() << "Cannot handle symbols in source domain";
return std::nullopt;
}
// TODO: Handle local vars in the source domains while using the 'projectOut'
// utility below. Currently, aligning is not done assuming that there will be
// no local vars in the source domain.
if (srcConstraints.getNumLocalVars() != 0) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot handle locals in source domain\n");
+ LDBG() << "Cannot handle locals in source domain";
return std::nullopt;
}
@@ -978,7 +977,7 @@ std::optional<bool> ComputationSliceState::isSliceValid() const {
// fusion succeeds.
FlatAffineValueConstraints sliceConstraints;
if (failed(getAsConstraints(&sliceConstraints))) {
- LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice's domain\n");
+ LDBG() << "Unable to compute slice's domain";
return std::nullopt;
}
@@ -987,11 +986,11 @@ std::optional<bool> ComputationSliceState::isSliceValid() const {
sliceConstraints.projectOut(ivs.size(),
sliceConstraints.getNumVars() - ivs.size());
- LLVM_DEBUG(llvm::dbgs() << "Domain of the source of the slice:\n");
- LLVM_DEBUG(srcConstraints.dump());
- LLVM_DEBUG(llvm::dbgs() << "Domain of the slice if this fusion succeeds "
- "(expressed in terms of its source's IVs):\n");
- LLVM_DEBUG(sliceConstraints.dump());
+ LDBG() << "Domain of the source of the slice:\n"
+ << "Source constraints:" << srcConstraints
+ << "\nDomain of the slice if this fusion succeeds "
+ << "(expressed in terms of its source's IVs):\n"
+ << "Slice constraints:" << sliceConstraints;
// TODO: Store 'srcSet' to avoid recalculating for each depth.
PresburgerSet srcSet(srcConstraints);
@@ -999,7 +998,7 @@ std::optional<bool> ComputationSliceState::isSliceValid() const {
PresburgerSet diffSet = sliceSet.subtract(srcSet);
if (!diffSet.isIntegerEmpty()) {
- LLVM_DEBUG(llvm::dbgs() << "Incorrect slice\n");
+ LDBG() << "Incorrect slice";
return false;
}
return true;
@@ -1172,8 +1171,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
unsigned rank = access.getRank();
- LLVM_DEBUG(llvm::dbgs() << "MemRefRegion::compute: " << *op
- << "\ndepth: " << loopDepth << "\n";);
+ LDBG() << "MemRefRegion::compute: " << *op << " depth: " << loopDepth;
// 0-d memrefs.
if (rank == 0) {
@@ -1236,7 +1234,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
if (auto constVal = getConstantIntValue(symbol))
cst.addBound(BoundType::EQ, symbol, constVal.value());
} else {
- LLVM_DEBUG(llvm::dbgs() << "unknown affine dimensional value");
+ LDBG() << "unknown affine dimensional value";
return failure();
}
}
@@ -1260,7 +1258,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
// Add access function equalities to connect loop IVs to data dimensions.
if (failed(cst.composeMap(&accessValueMap))) {
op->emitError("getMemRefRegion: compose affine map failed");
- LLVM_DEBUG(accessValueMap.getAffineMap().dump());
+ LDBG() << "Access map: " << accessValueMap.getAffineMap();
return failure();
}
@@ -1317,8 +1315,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
}
cst.removeTrivialRedundancy();
- LLVM_DEBUG(llvm::dbgs() << "Memory region:\n");
- LLVM_DEBUG(cst.dump());
+ LDBG() << "Memory region: " << cst;
return success();
}
@@ -1346,14 +1343,14 @@ std::optional<int64_t> MemRefRegion::getRegionSize() {
auto memRefType = cast<MemRefType>(memref.getType());
if (!memRefType.getLayout().isIdentity()) {
- LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
+ LDBG() << "Non-identity layout map not yet supported";
return false;
}
// Compute the extents of the buffer.
std::optional<int64_t> numElements = getConstantBoundingSizeAndShape();
if (!numElements) {
- LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n");
+ LDBG() << "Dynamic shapes not yet supported";
return std::nullopt;
}
auto eltSize = getMemRefIntOrFloatEltSizeInBytes(memRefType);
@@ -1397,8 +1394,7 @@ LogicalResult mlir::affine::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,
/*addMemRefDimBounds=*/false)))
return success();
- LLVM_DEBUG(llvm::dbgs() << "Memory region");
- LLVM_DEBUG(region.getConstraints()->dump());
+ LDBG() << "Memory region: " << region.getConstraints();
bool outOfBounds = false;
unsigned rank = loadOrStoreOp.getMemRefType().getRank();
@@ -1558,7 +1554,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
// Check if 'loopDepth' exceeds nesting depth of src/dst ops.
if ((!isBackwardSlice && loopDepth > getNestingDepth(a)) ||
(isBackwardSlice && loopDepth > getNestingDepth(b))) {
- LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n");
+ LDBG() << "Invalid loop depth";
return SliceComputationResult::GenericFailure;
}
@@ -1571,7 +1567,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
&dependenceConstraints, /*dependenceComponents=*/nullptr,
/*allowRAR=*/readReadAccesses);
if (result.value == DependenceResult::Failure) {
- LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n");
+ LDBG() << "Dependence check failed";
return SliceComputationResult::GenericFailure;
}
if (result.value == DependenceResult::NoDependence)
@@ -1586,8 +1582,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
if (sliceUnionCst.getNumDimAndSymbolVars() == 0) {
// Initialize 'sliceUnionCst' with the bounds computed in previous step.
if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Unable to compute slice bound constraints\n");
+ LDBG() << "Unable to compute slice bound constraints";
return SliceComputationResult::GenericFailure;
}
assert(sliceUnionCst.getNumDimAndSymbolVars() > 0);
@@ -1597,8 +1592,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
// Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
FlatAffineValueConstraints tmpSliceCst;
if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Unable to compute slice bound constraints\n");
+ LDBG() << "Unable to compute slice bound constraints";
return SliceComputationResult::GenericFailure;
}
@@ -1630,8 +1624,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
if (sliceUnionCst.getNumLocalVars() > 0 ||
tmpSliceCst.getNumLocalVars() > 0 ||
failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Unable to compute union bounding box of slice bounds\n");
+ LDBG() << "Unable to compute union bounding box of slice bounds";
return SliceComputationResult::GenericFailure;
}
}
@@ -1639,7 +1632,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
// Empty union.
if (sliceUnionCst.getNumDimAndSymbolVars() == 0) {
- LLVM_DEBUG(llvm::dbgs() << "empty slice union - unexpected\n");
+ LDBG() << "empty slice union - unexpected";
return SliceComputationResult::GenericFailure;
}
@@ -1652,7 +1645,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
unsigned innermostCommonLoopDepth =
getInnermostCommonLoopDepth(ops, &surroundingLoops);
if (loopDepth > innermostCommonLoopDepth) {
- LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n");
+ LDBG() << "Exceeds max loop depth";
return SliceComputationResult::GenericFailure;
}
@@ -1696,7 +1689,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
// that the slice is valid, otherwise return appropriate failure status.
std::optional<bool> isSliceValid = sliceUnion->isSliceValid();
if (!isSliceValid) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot determine if the slice is valid\n");
+ LDBG() << "Cannot determine if the slice is valid";
return SliceComputationResult::GenericFailure;
}
if (!*isSliceValid)
@@ -2050,7 +2043,8 @@ static std::optional<int64_t> getMemoryFootprintBytes(Block &block,
if (failed(
region->compute(opInst,
/*loopDepth=*/getNestingDepth(&*block.begin())))) {
- LLVM_DEBUG(opInst->emitError("error obtaining memory region"));
+ LDBG() << "Error obtaining memory region";
+ opInst->emitError("error obtaining memory region");
return failure();
}
@@ -2058,9 +2052,11 @@ static std::optional<int64_t> getMemoryFootprintBytes(Block &block,
if (inserted) {
it->second = std::move(region);
} else if (failed(it->second->unionBoundingBox(*region))) {
- LLVM_DEBUG(opInst->emitWarning(
+ LDBG() << "getMemoryFootprintBytes: unable to perform a union on a "
+ "memory region";
+ opInst->emitWarning(
"getMemoryFootprintBytes: unable to perform a union on a memory "
- "region"));
+ "region");
return failure();
}
return WalkResult::advance();
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 22608a1..7e5ce26 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -427,6 +427,21 @@ bool mlir::affine::isValidSymbol(Value value) {
return false;
}
+/// A utility function to check if a value is defined at the top level of
+/// `region` or is an argument of `region` or is defined above the region.
+static bool isTopLevelValueOrAbove(Value value, Region *region) {
+ Region *parentRegion = value.getParentRegion();
+ do {
+ if (parentRegion == region)
+ return true;
+ Operation *regionOp = region->getParentOp();
+ if (regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
+ break;
+ region = region->getParentOp()->getParentRegion();
+ } while (region);
+ return false;
+}
+
/// A value can be used as a symbol for `region` iff it meets one of the
/// following conditions:
/// *) It is a constant.
@@ -445,19 +460,12 @@ bool mlir::affine::isValidSymbol(Value value, Region *region) {
return false;
// A top-level value is a valid symbol.
- if (region && ::isTopLevelValue(value, region))
+ if (region && isTopLevelValueOrAbove(value, region))
return true;
auto *defOp = value.getDefiningOp();
- if (!defOp) {
- // A block argument that is not a top-level value is a valid symbol if it
- // dominates region's parent op.
- Operation *regionOp = region ? region->getParentOp() : nullptr;
- if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
- if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
- return isValidSymbol(value, parentOpRegion);
+ if (!defOp)
return false;
- }
// Constant operation is ok.
Attribute operandCst;
@@ -475,12 +483,6 @@ bool mlir::affine::isValidSymbol(Value value, Region *region) {
if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
return isDimOpValidSymbol(dimOp, region);
- // Check for values dominating `region`'s parent op.
- Operation *regionOp = region ? region->getParentOp() : nullptr;
- if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
- if (auto *parentRegion = region->getParentOp()->getParentRegion())
- return isValidSymbol(value, parentRegion);
-
return false;
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 6c9adff..ff0157e 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -26,6 +26,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#include <iomanip>
#include <optional>
@@ -95,8 +96,8 @@ static bool canRemoveSrcNodeAfterFusion(
// Otherwise, the src loop can't be removed.
if (fusedLoopInsPoint != depNodeOp &&
!fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) {
- LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't "
- "dominate dependence\n");
+ LDBG() << "Src loop can't be removed: dst loop doesn't "
+ << "dominate dependence";
return false;
}
@@ -109,14 +110,13 @@ static bool canRemoveSrcNodeAfterFusion(
if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
std::optional<bool> isMaximal = fusionSlice.isMaximal();
if (!isMaximal) {
- LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine "
- "if fusion is maximal\n");
+ LDBG() << "Src loop can't be removed: can't determine "
+ << "if fusion is maximal";
return false;
}
if (!*isMaximal) {
- LLVM_DEBUG(llvm::dbgs()
- << "Src loop can't be removed: fusion is not maximal\n");
+ LDBG() << "Src loop can't be removed: fusion is not maximal";
return false;
}
}
@@ -190,7 +190,8 @@ static bool isEscapingMemref(Value memref, Block *block) {
// Check if this is defined to be an alias of another memref.
if (auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp))
- if (isEscapingMemref(viewOp.getViewSource(), block))
+ if (memref == viewOp.getViewDest() &&
+ isEscapingMemref(viewOp.getViewSource(), block))
return true;
// Any op besides allocating ops wouldn't guarantee alias freedom
@@ -279,19 +280,19 @@ static std::optional<double> getAdditionalComputeFraction(
AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth,
ArrayRef<ComputationSliceState> depthSliceUnions, int64_t &sliceCost,
int64_t &fusedLoopNestComputeCost) {
- LLVM_DEBUG(llvm::dbgs() << "Determining additional compute fraction...\n";);
+ LDBG() << "Determining additional compute fraction...";
// Compute cost of sliced and unsliced src loop nest.
// Walk src loop nest and collect stats.
LoopNestStats srcLoopNestStats;
if (!getLoopNestStats(srcForOp, &srcLoopNestStats)) {
- LLVM_DEBUG(llvm::dbgs() << "Failed to get source loop nest stats.\n");
+ LDBG() << "Failed to get source loop nest stats.";
return std::nullopt;
}
// Compute cost of dst loop nest.
LoopNestStats dstLoopNestStats;
if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) {
- LLVM_DEBUG(llvm::dbgs() << "Failed to get destination loop nest stats.\n");
+ LDBG() << "Failed to get destination loop nest stats.";
return std::nullopt;
}
@@ -304,14 +305,14 @@ static std::optional<double> getAdditionalComputeFraction(
const ComputationSliceState &slice = depthSliceUnions[depth - 1];
// Skip slice union if it wasn't computed for this depth.
if (slice.isEmpty()) {
- LLVM_DEBUG(llvm::dbgs() << "Slice wasn't computed.\n");
+ LDBG() << "Slice wasn't computed.";
return std::nullopt;
}
if (!getFusionComputeCost(srcForOp, srcLoopNestStats, dstForOp,
dstLoopNestStats, slice,
&fusedLoopNestComputeCost)) {
- LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n");
+ LDBG() << "Unable to compute fusion compute cost";
return std::nullopt;
}
@@ -348,9 +349,8 @@ static Value createPrivateMemRef(AffineForOp forOp,
MemRefAccess bM(cast<AffineWriteOpInterface>(b));
return aM == bM;
})) {
- LLVM_DEBUG(llvm::dbgs()
- << "Private memref creation unsupported for multiple producer "
- "stores with different access functions.\n");
+ LDBG() << "Private memref creation unsupported for multiple producer "
+ << "stores with different access functions.";
return nullptr;
}
@@ -455,8 +455,7 @@ static Value createPrivateMemRef(AffineForOp forOp,
assert(succeeded(res) &&
"replaceAllMemrefUsesWith should always succeed here");
(void)res;
- LLVM_DEBUG(llvm::dbgs() << "Created private memref of type: " << newMemRefType
- << '\n');
+ LDBG() << "Created private memref of type: " << newMemRefType;
return newMemRef;
}
@@ -505,15 +504,12 @@ static bool isFusionProfitable(AffineForOp srcForOp,
unsigned maxLegalFusionDepth,
unsigned *dstLoopDepth,
double computeToleranceThreshold) {
- LLVM_DEBUG({
- llvm::dbgs()
- << "Checking whether fusion is profitable between source nest:\n";
- llvm::dbgs() << ' ' << srcForOp << " and destination nest:\n";
- llvm::dbgs() << dstForOp << "\n";
- });
+ LDBG() << "Checking whether fusion is profitable between source nest:";
+ LDBG() << ' ' << srcForOp << " and destination nest:";
+ LDBG() << dstForOp;
if (maxLegalFusionDepth == 0) {
- LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth is 0\n");
+ LDBG() << "Can't fuse: maxLegalFusionDepth is 0";
return false;
}
@@ -537,8 +533,8 @@ static bool isFusionProfitable(AffineForOp srcForOp,
// TODO: Suppport multiple producer stores in profitability
// analysis.
if (producerStores.size() > 1) {
- LLVM_DEBUG(llvm::dbgs() << "Limited profitability analysis. Not "
- "supported for multiple producer store case.\n");
+ LDBG() << "Limited profitability analysis. Not "
+ << "supported for multiple producer store case.";
int64_t sliceCost;
int64_t fusedLoopNestComputeCost;
// We will still fuse if fusion obeys the specified compute
@@ -547,12 +543,11 @@ static bool isFusionProfitable(AffineForOp srcForOp,
srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost,
fusedLoopNestComputeCost);
if (!fraction || fraction > computeToleranceThreshold) {
- LLVM_DEBUG(llvm::dbgs() << "Additional computation exceeds "
- "compute tolerance. Not fusing.\n");
+ LDBG() << "Additional computation exceeds "
+ << "compute tolerance. Not fusing.";
return false;
}
- LLVM_DEBUG(llvm::dbgs()
- << "Considering fusion profitable at max legal depth.\n");
+ LDBG() << "Considering fusion profitable at max legal depth.";
return true;
}
@@ -574,8 +569,7 @@ static bool isFusionProfitable(AffineForOp srcForOp,
// Compute src loop nest write region size.
MemRefRegion srcWriteRegion(srcStoreOp->getLoc());
if (failed(srcWriteRegion.compute(srcStoreOp, /*loopDepth=*/0))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Unable to compute MemRefRegion for source operation\n");
+ LDBG() << "Unable to compute MemRefRegion for source operation";
return false;
}
@@ -609,8 +603,7 @@ static bool isFusionProfitable(AffineForOp srcForOp,
getAdditionalComputeFraction(srcForOp, dstForOp, i, depthSliceUnions,
sliceCost, fusedLoopNestComputeCost);
if (!mayAdditionalComputeFraction) {
- LLVM_DEBUG(llvm::dbgs()
- << "Can't determine additional compute fraction.\n");
+ LDBG() << "Can't determine additional compute fraction.";
continue;
}
double additionalComputeFraction = *mayAdditionalComputeFraction;
@@ -620,9 +613,7 @@ static bool isFusionProfitable(AffineForOp srcForOp,
// depth 'i'.
MemRefRegion sliceWriteRegion(srcStoreOp->getLoc());
if (failed(sliceWriteRegion.compute(srcStoreOp, /*loopDepth=*/0, &slice))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Failed to compute slice write region at loopDepth: " << i
- << "\n");
+ LDBG() << "Failed to compute slice write region at loopDepth: " << i;
continue;
}
@@ -630,9 +621,7 @@ static bool isFusionProfitable(AffineForOp srcForOp,
sliceWriteRegion.getRegionSize();
if (!maybeSliceWriteRegionSizeBytes.has_value() ||
*maybeSliceWriteRegionSizeBytes == 0) {
- LLVM_DEBUG(llvm::dbgs()
- << "Failed to get slice write region size at loopDepth: " << i
- << "\n");
+ LDBG() << "Failed to get slice write region size at loopDepth: " << i;
continue;
}
int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
@@ -649,9 +638,8 @@ static bool isFusionProfitable(AffineForOp srcForOp,
<< " storage reduction factor: " << storageReduction << "x\n"
<< " fused nest cost: " << fusedLoopNestComputeCost << "\n"
<< " src write region size: " << srcWriteRegionSizeBytes << "\n"
- << " slice write region size: " << sliceWriteRegionSizeBytes
- << "\n";
- llvm::dbgs() << msg.str();
+ << " slice write region size: " << sliceWriteRegionSizeBytes;
+ LDBG() << msg.str();
});
// TODO: This is a placeholder cost model.
@@ -670,28 +658,24 @@ static bool isFusionProfitable(AffineForOp srcForOp,
// A simple cost model: fuse if it reduces the memory footprint.
if (!bestDstLoopDepth) {
- LLVM_DEBUG(
- llvm::dbgs()
- << "All fusion choices involve more than the threshold amount of "
- "redundant computation; NOT fusing.\n");
+ LDBG() << "All fusion choices involve more than the threshold amount of "
+ << "redundant computation; NOT fusing.";
return false;
}
if (!bestDstLoopDepth) {
- LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n");
+ LDBG() << "no fusion depth could be evaluated.";
return false;
}
// Set dstLoopDepth based on best values from search.
*dstLoopDepth = *bestDstLoopDepth;
- LLVM_DEBUG(
- llvm::dbgs() << " LoopFusion fusion stats:"
- << "\n best loop depth: " << bestDstLoopDepth
- << "\n src loop nest compute cost: " << srcLoopNestCost
- << "\n dst loop nest compute cost: " << dstLoopNestCost
- << "\n fused loop nest compute cost: "
- << minFusedLoopNestComputeCost << "\n");
+ LDBG() << " LoopFusion fusion stats:";
+ LDBG() << " best loop depth: " << bestDstLoopDepth;
+ LDBG() << " src loop nest compute cost: " << srcLoopNestCost;
+ LDBG() << " dst loop nest compute cost: " << dstLoopNestCost;
+ LDBG() << " fused loop nest compute cost: " << minFusedLoopNestComputeCost;
auto dstMemSize = getMemoryFootprintBytes(dstForOp);
auto srcMemSize = getMemoryFootprintBytes(srcForOp);
@@ -699,8 +683,7 @@ static bool isFusionProfitable(AffineForOp srcForOp,
std::optional<double> storageReduction;
if (!dstMemSize || !srcMemSize) {
- LLVM_DEBUG(llvm::dbgs()
- << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
+ LDBG() << " fusion memory benefit cannot be evaluated; NOT fusing.";
return false;
}
@@ -710,13 +693,13 @@ static bool isFusionProfitable(AffineForOp srcForOp,
assert(sliceMemEstimate && "expected value");
auto fusedMem = dstMemSizeVal + *sliceMemEstimate;
- LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
- << " dst mem: " << dstMemSizeVal << "\n"
- << " fused mem: " << fusedMem << "\n"
- << " slice mem: " << sliceMemEstimate << "\n");
+ LDBG() << " src mem: " << srcMemSizeVal;
+ LDBG() << " dst mem: " << dstMemSizeVal;
+ LDBG() << " fused mem: " << fusedMem;
+ LDBG() << " slice mem: " << sliceMemEstimate;
if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
- LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
+ LDBG() << "Fusion is not profitable; NOT fusing.";
return false;
}
storageReduction =
@@ -734,8 +717,8 @@ static bool isFusionProfitable(AffineForOp srcForOp,
<< std::setprecision(2) << additionalComputeFraction
<< "% redundant computation and a ";
msg << (storageReduction ? std::to_string(*storageReduction) : "<unknown>");
- msg << "% storage reduction.\n";
- llvm::dbgs() << msg.str();
+ msg << "% storage reduction.";
+ LDBG() << msg.str();
});
return true;
@@ -895,7 +878,7 @@ public:
/// No fusion is performed when producers with a user count greater than
/// `maxSrcUserCount` for any of the memrefs involved.
void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) {
- LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
+ LDBG() << "Evaluating dst loop " << dstId;
// Skip if this node was removed (fused into another node).
if (mdg->nodes.count(dstId) == 0)
return;
@@ -909,7 +892,7 @@ public:
if (dstNode->op->getNumResults() > 0)
return;
- LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
+ LDBG() << "Evaluating dst loop " << dstId;
// Sink sequential loops in 'dstNode' (and thus raise parallel loops)
// while preserving relative order. This can increase the maximum loop
@@ -936,18 +919,14 @@ public:
auto *srcNode = mdg->getNode(srcId);
auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
- LLVM_DEBUG(llvm::dbgs()
- << "Trying to fuse producer loop nest " << srcId
- << " with consumer loop nest " << dstId << "\n");
- LLVM_DEBUG(llvm::dbgs() << "Compute tolerance threshold: "
- << computeToleranceThreshold << '\n');
- LLVM_DEBUG(llvm::dbgs()
- << "Producer loop nest:\n"
- << *srcNode->op << "\n and consumer loop nest:\n"
- << *dstNode->op << '\n');
+ LDBG() << "Trying to fuse producer loop nest " << srcId
+ << " with consumer loop nest " << dstId;
+ LDBG() << "Compute tolerance threshold: " << computeToleranceThreshold;
+ LDBG() << "Producer loop nest:";
+ LDBG() << *srcNode->op << " and consumer loop nest:";
+ LDBG() << *dstNode->op;
- LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId
- << " for dst loop " << dstId << "\n");
+ LDBG() << "Evaluating src loop " << srcId << " for dst loop " << dstId;
// Skip if 'srcNode' is a loop nest returning values.
// TODO: support loop nests that return values.
@@ -1018,19 +997,16 @@ public:
&depthSliceUnions[i - 1], strategy);
if (result.value == FusionResult::Success) {
maxLegalFusionDepth = i;
- LLVM_DEBUG(llvm::dbgs()
- << "Found valid slice for depth: " << i << '\n');
+ LDBG() << "Found valid slice for depth: " << i;
}
}
if (maxLegalFusionDepth == 0) {
- LLVM_DEBUG(llvm::dbgs()
- << "Can't fuse: fusion is not legal at any depth\n");
+ LDBG() << "Can't fuse: fusion is not legal at any depth";
continue;
}
- LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
- << maxLegalFusionDepth << '\n');
+ LDBG() << "Max legal depth for fusion: " << maxLegalFusionDepth;
double computeToleranceThresholdToUse = computeToleranceThreshold;
@@ -1040,7 +1016,7 @@ public:
// producer-consumer memref access for example). Check this and allow
// fusion accordingly.
if (hasCyclicDependence(srcAffineForOp)) {
- LLVM_DEBUG(llvm::dbgs() << "Source nest has a cyclic dependence.\n");
+ LDBG() << "Source nest has a cyclic dependence.";
// Maximal fusion does not check for compute tolerance threshold; so
// perform the maximal fusion only when the redundanation computation
// is zero.
@@ -1053,18 +1029,15 @@ public:
srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
sliceCost, fusedLoopNestComputeCost);
if (!fraction || fraction > 0) {
- LLVM_DEBUG(
- llvm::dbgs()
- << "Can't perform maximal fusion with a cyclic dependence "
- "and non-zero additional compute.\n");
+ LDBG() << "Can't perform maximal fusion with a cyclic dependence "
+ << "and non-zero additional compute.";
return;
}
} else {
// Set redundant computation tolerance to zero regardless of what
// the user specified. Without this, fusion would be invalid.
- LLVM_DEBUG(llvm::dbgs()
- << "Setting compute tolerance to zero since "
- "source has a cylic dependence.\n");
+ LDBG() << "Setting compute tolerance to zero since "
+ << "source has a cylic dependence.";
computeToleranceThresholdToUse = 0;
}
}
@@ -1107,8 +1080,7 @@ public:
if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
removeSrcNode)) {
// Create a private version of this memref.
- LLVM_DEBUG(llvm::dbgs()
- << "Creating private memref for " << memref << '\n');
+ LDBG() << "Creating private memref for " << memref;
// Create a private version of this memref.
privateMemrefs.insert(memref);
}
@@ -1118,10 +1090,9 @@ public:
fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
dstNodeChanged = true;
- LLVM_DEBUG(llvm::dbgs()
- << "Fused src loop " << srcId << " into dst loop " << dstId
- << " at depth " << bestDstLoopDepth << ":\n"
- << dstAffineForOp << "\n");
+ LDBG() << "Fused src loop " << srcId << " into dst loop " << dstId
+ << " at depth " << bestDstLoopDepth << ":";
+ LDBG() << dstAffineForOp;
// Move 'dstAffineForOp' before 'insertPointInst' if needed.
if (fusedLoopInsPoint != dstAffineForOp)
@@ -1179,8 +1150,7 @@ public:
dstLoopCollector.memrefFrees);
if (removeSrcNode) {
- LLVM_DEBUG(llvm::dbgs()
- << "Removing src loop " << srcId << " after fusion\n");
+ LDBG() << "Removing src loop " << srcId << " after fusion";
// srcNode is no longer valid after it is removed from mdg.
srcAffineForOp.erase();
mdg->removeNode(srcId);
@@ -1195,7 +1165,7 @@ public:
/// user count greater than `maxSrcUserCount` for any of the memrefs involved
/// are encountered.
void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
- LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
+ LDBG() << "--- Producer/Consumer Fusion ---";
init();
while (!worklist.empty()) {
unsigned dstId = worklist.back();
@@ -1207,7 +1177,7 @@ public:
// Visits each node in the graph, and for each node, attempts to fuse it with
// its sibling nodes (nodes which share a parent, but no dependence edges).
void fuseSiblingNodes() {
- LLVM_DEBUG(llvm::dbgs() << "--- Sibling Fusion ---\n");
+ LDBG() << "--- Sibling Fusion ---";
init();
while (!worklist.empty()) {
unsigned dstId = worklist.back();
@@ -1289,8 +1259,7 @@ public:
maxLegalFusionDepth = i;
}
- LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
- << maxLegalFusionDepth << '\n');
+ LDBG() << "Max legal depth for fusion: " << maxLegalFusionDepth;
// Skip if fusion is not feasible at any loop depths.
if (maxLegalFusionDepth == 0)
@@ -1304,7 +1273,7 @@ public:
// producer-consumer memref access for example). Check this and allow
// fusion accordingly.
if (hasCyclicDependence(sibAffineForOp)) {
- LLVM_DEBUG(llvm::dbgs() << "Source nest has a cyclic dependence.\n");
+ LDBG() << "Source nest has a cyclic dependence.";
// Maximal fusion does not check for compute tolerance threshold; so
// perform the maximal fusion only when the redundanation computation is
// zero.
@@ -1316,17 +1285,15 @@ public:
sibAffineForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
sliceCost, fusedLoopNestComputeCost);
if (!fraction || fraction > 0) {
- LLVM_DEBUG(
- llvm::dbgs()
- << "Can't perform maximal fusion with a cyclic dependence "
- "and non-zero additional compute.\n");
+ LDBG() << "Can't perform maximal fusion with a cyclic dependence "
+ << "and non-zero additional compute.";
return;
}
} else {
// Set redundant computation tolerance to zero regardless of what the
// user specified. Without this, fusion would be invalid.
- LLVM_DEBUG(llvm::dbgs() << "Setting compute tolerance to zero since "
- "source has a cyclic dependence.\n");
+ LDBG() << "Setting compute tolerance to zero since "
+ << "source has a cyclic dependence.";
computeToleranceThresholdToUse = 0.0;
}
}
@@ -1356,8 +1323,7 @@ public:
// slice is used in the destination.
auto isMaximal = bestSlice.isMaximal();
if (!isMaximal.value_or(false)) {
- LLVM_DEBUG(llvm::dbgs()
- << "Slice isn't maximal; not performing sibling fusion.\n");
+ LDBG() << "Slice isn't maximal; not performing sibling fusion.";
continue;
}
@@ -1374,10 +1340,9 @@ public:
if (insertPointInst != dstForInst)
dstForInst->moveBefore(insertPointInst);
- LLVM_DEBUG(llvm::dbgs()
- << "Fused sibling nest " << sibId << " into destination nest "
- << dstNode->id << " at depth " << bestDstLoopDepth << ":\n"
- << dstAffineForOp << "\n");
+ LDBG() << "Fused sibling nest " << sibId << " into destination nest "
+ << dstNode->id << " at depth " << bestDstLoopDepth << ":";
+ LDBG() << dstAffineForOp;
// Update data dependence graph state post fusion.
updateStateAfterSiblingFusion(sibNode, dstNode);
@@ -1555,7 +1520,7 @@ public:
void LoopFusion::runOnBlock(Block *block) {
MemRefDependenceGraph g(*block);
if (!g.init()) {
- LLVM_DEBUG(llvm::dbgs() << "MDG init failed\n");
+ LDBG() << "MDG init failed";
return;
}
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
index 41cd739..c6abb0d 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
@@ -251,20 +252,20 @@ FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp,
FusionStrategy fusionStrategy) {
// Return 'failure' if 'dstLoopDepth == 0'.
if (dstLoopDepth == 0) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n");
+ LDBG() << "Cannot fuse loop nests at depth 0";
return FusionResult::FailPrecondition;
}
// Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block.
auto *block = srcForOp->getBlock();
if (block != dstForOp->getBlock()) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n");
+ LDBG() << "Cannot fuse loop nests in different blocks";
return FusionResult::FailPrecondition;
}
// Return 'failure' if no valid insertion point for fused loop nest in 'block'
// exists which would preserve dependences.
if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) {
- LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n");
+ LDBG() << "Fusion would violate dependences in block";
return FusionResult::FailBlockDependence;
}
@@ -277,14 +278,14 @@ FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp,
// Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'.
SmallVector<Operation *, 4> opsA;
if (!gatherLoadsAndStores(forOpA, opsA)) {
- LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
+ LDBG() << "Fusing loops with affine.if unsupported";
return FusionResult::FailPrecondition;
}
// Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'.
SmallVector<Operation *, 4> opsB;
if (!gatherLoadsAndStores(forOpB, opsB)) {
- LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
+ LDBG() << "Fusing loops with affine.if unsupported";
return FusionResult::FailPrecondition;
}
@@ -296,7 +297,7 @@ FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp,
// TODO: 'getMaxLoopDepth' does not support forward slice fusion.
assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion");
if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) {
- LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n");
+ LDBG() << "Fusion would violate loop dependences";
return FusionResult::FailFusionDependence;
}
}
@@ -339,12 +340,12 @@ FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp,
strategyOpsA, opsB, dstLoopDepth, numCommonLoops,
isSrcForOpBeforeDstForOp, srcSlice);
if (sliceComputationResult.value == SliceComputationResult::GenericFailure) {
- LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
+ LDBG() << "computeSliceUnion failed";
return FusionResult::FailPrecondition;
}
if (sliceComputationResult.value ==
SliceComputationResult::IncorrectSliceFailure) {
- LLVM_DEBUG(llvm::dbgs() << "Incorrect slice computation\n");
+ LDBG() << "Incorrect slice computation";
return FusionResult::FailIncorrectSlice;
}
@@ -477,7 +478,7 @@ bool mlir::affine::getLoopNestStats(AffineForOp forOpRoot,
auto *parentForOp = forOp->getParentOp();
if (forOp != forOpRoot) {
if (!isa<AffineForOp>(parentForOp)) {
- LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n");
+ LDBG() << "Expected parent AffineForOp";
return WalkResult::interrupt();
}
// Add mapping to 'forOp' from its parent AffineForOp.
@@ -498,7 +499,7 @@ bool mlir::affine::getLoopNestStats(AffineForOp forOpRoot,
std::optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
if (!maybeConstTripCount) {
// Currently only constant trip count loop nests are supported.
- LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n");
+ LDBG() << "Non-constant trip count unsupported";
return WalkResult::interrupt();
}
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 2de057d..cd216ef 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -21,9 +21,11 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
@@ -365,12 +367,11 @@ checkIfHyperRectangular(MutableArrayRef<AffineForOp> input) {
if (input.size() <= 1)
return success();
if (failed(getIndexSet(ops, &cst))) {
- LLVM_DEBUG(llvm::dbgs() << "Index set computation failed!\n");
+ LDBG() << "Index set computation failed!";
return failure();
}
if (!cst.isHyperRectangular(0, input.size())) {
- LLVM_DEBUG(llvm::dbgs()
- << "Non-hyperrectangular nests not supported for tiling!\n");
+ LDBG() << "Non-hyperrectangular nests not supported for tiling!";
return failure();
}
return success();
@@ -385,14 +386,13 @@ static LogicalResult performPreTilingChecks(MutableArrayRef<AffineForOp> input,
if (llvm::any_of(input,
[](AffineForOp op) { return op.getNumResults() > 0; })) {
- LLVM_DEBUG(llvm::dbgs()
- << "Cannot tile nest where a loop has yield values\n");
+ LDBG() << "Cannot tile nest where a loop has yield values";
return failure();
}
// Check if the supplied `for` ops are all successively nested.
if (!isPerfectlyNested(input)) {
- LLVM_DEBUG(llvm::dbgs() << "input loops not perfectly nested");
+ LDBG() << "input loops not perfectly nested";
return failure();
}
@@ -1098,7 +1098,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
// If the trip count is lower than the unroll jam factor, no unroll jam.
if (mayBeConstantTripCount && *mayBeConstantTripCount < unrollJamFactor) {
- LLVM_DEBUG(llvm::dbgs() << "[failed] trip count < unroll-jam factor\n");
+ LDBG() << "[failed] trip count < unroll-jam factor";
return failure();
}
@@ -1339,6 +1339,15 @@ bool mlir::affine::isValidLoopInterchangePermutation(
unsigned maxLoopDepth = loops.size();
if (maxLoopDepth == 1)
return true;
+
+ // We cannot guarantee the validity of the interchange if the loops have
+ // iter_args, since the dependence analysis does not take them into account.
+ // Conservatively return false in such cases.
+ if (llvm::any_of(loops, [](AffineForOp loop) {
+ return loop.getNumIterOperands() > 0;
+ }))
+ return false;
+
// Gather dependence components for dependences between all ops in loop nest
// rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth].
std::vector<SmallVector<DependenceComponent, 2>> depCompsVec;
@@ -1766,9 +1775,7 @@ findHighestBlockForPlacement(const MemRefRegion &region, Block &block,
// We can't hoist past the definition of the memref being copied.
Value memref = region.memref;
if (!memref.getParentRegion()->isAncestor(enclosingOp->getParentRegion())) {
- LLVM_DEBUG(
- llvm::dbgs()
- << "memref definition will end up not dominating hoist location\n");
+ LDBG() << "memref definition will end up not dominating hoist location";
break;
}
@@ -1977,7 +1984,7 @@ static LogicalResult generateCopy(
auto memRefType = cast<MemRefType>(memref.getType());
if (!memRefType.getLayout().isIdentity()) {
- LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
+ LDBG() << "Non-identity layout map not yet supported";
return failure();
}
@@ -1989,7 +1996,7 @@ static LogicalResult generateCopy(
unsigned rank = memRefType.getRank();
if (rank == 0) {
- LLVM_DEBUG(llvm::dbgs() << "Non-zero ranked memrefs supported\n");
+ LDBG() << "Non-zero ranked memrefs supported";
return failure();
}
@@ -2001,19 +2008,18 @@ static LogicalResult generateCopy(
std::optional<int64_t> numElements =
region.getConstantBoundingSizeAndShape(&fastBufferShape, &lbs);
if (!numElements) {
- LLVM_DEBUG(llvm::dbgs() << "Non-constant region size not supported\n");
+ LDBG() << "Non-constant region size not supported";
return failure();
}
if (llvm::any_of(lbs, [](AffineMap lb) { return lb.getNumResults() > 1; })) {
// This can be supported in the future if needed.
- LLVM_DEBUG(llvm::dbgs()
- << "Max lower bound for memref region start not supported\n");
+ LDBG() << "Max lower bound for memref region start not supported";
return failure();
}
if (*numElements == 0) {
- LLVM_DEBUG(llvm::dbgs() << "Nothing to copy\n");
+ LDBG() << "Nothing to copy";
return success();
}
@@ -2021,9 +2027,8 @@ static LogicalResult generateCopy(
for (unsigned i = 0; i < rank; ++i) {
region.getLowerAndUpperBound(i, lbMaps[i], ubMaps[i]);
if (lbMaps[i].getNumResults() == 0 || ubMaps[i].getNumResults() == 0) {
- LLVM_DEBUG(llvm::dbgs()
- << "Missing lower or upper bound for region along dimension: "
- << i << '\n');
+ LDBG() << "Missing lower or upper bound for region along dimension: "
+ << i;
return failure();
}
}
@@ -2122,7 +2127,7 @@ static LogicalResult generateCopy(
// TODO: use all stride levels once DmaStartOp is extended for
// multi-level strides.
if (dmaStrideInfos.size() > 1) {
- LLVM_DEBUG(llvm::dbgs() << "Only up to one level of stride supported\n");
+ LDBG() << "Only up to one level of stride supported";
return failure();
}
@@ -2309,10 +2314,11 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
// surrounding the this block range.
unsigned copyDepth = getNestingDepth(&*begin);
- LLVM_DEBUG(llvm::dbgs() << "Generating copies at depth " << copyDepth
- << "\n");
- LLVM_DEBUG(llvm::dbgs() << "from begin: " << *begin << "\n");
- LLVM_DEBUG(llvm::dbgs() << "to inclusive end: " << *std::prev(end) << "\n");
+ LDBG() << "Generating copies at depth " << copyDepth;
+ LDBG() << "from begin: "
+ << OpWithFlags(&*begin, OpPrintingFlags().skipRegions());
+ LDBG() << "to inclusive end: "
+ << OpWithFlags(&*std::prev(end), OpPrintingFlags().skipRegions());
// List of memory regions to copy for. We need a map vector to have a
// guaranteed iteration order to write test cases. CHECK-DAG doesn't help here
@@ -2349,8 +2355,8 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
return;
if (!memref.getParentRegion()->isAncestor(block->getParent())) {
- LLVM_DEBUG(llvm::dbgs() << "memref definition is inside of the depth at "
- "which copy-in/copy-out would happen\n");
+ LDBG() << "memref definition is inside of the depth at "
+ << "which copy-in/copy-out would happen";
return;
}
@@ -2358,12 +2364,10 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
auto region = std::make_unique<MemRefRegion>(opInst->getLoc());
if (failed(region->compute(opInst, copyDepth, /*sliceState=*/nullptr,
/*addMemRefDimBounds=*/false))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Error obtaining memory region: semi-affine maps?\n");
- LLVM_DEBUG(llvm::dbgs() << "over-approximating to the entire memref\n");
+ LDBG() << "Error obtaining memory region: semi-affine maps?";
+ LDBG() << "over-approximating to the entire memref";
if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) {
- LLVM_DEBUG(
- opInst->emitError("non-constant memref sizes not yet supported"));
+ LDBG() << "non-constant memref sizes not yet supported";
error = true;
return;
}
@@ -2392,13 +2396,11 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
// Perform a union with the existing region.
if (failed(it->second->unionBoundingBox(*region))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Memory region bounding box failed; "
- "over-approximating to the entire memref\n");
+ LDBG() << "Memory region bounding box failed; "
+ << "over-approximating to the entire memref";
// If the union fails, we will overapproximate.
if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) {
- LLVM_DEBUG(opInst->emitError(
- "non-constant memref sizes not yet supported"));
+ LDBG() << "non-constant memref sizes not yet supported";
error = true;
return true;
}
@@ -2428,8 +2430,7 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
});
if (error) {
- LLVM_DEBUG(begin->emitError(
- "copy generation failed for one or more memref's in this block\n"));
+ LDBG() << "copy generation failed for one or more memref's in this block";
return failure();
}
@@ -2466,8 +2467,7 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
processRegions(writeRegions);
if (!ret) {
- LLVM_DEBUG(begin->emitError(
- "copy generation failed for one or more memref's in this block\n"));
+ LDBG() << "copy generation failed for one or more memref's in this block";
return failure();
}
@@ -2608,7 +2608,7 @@ static AffineIfOp createSeparationCondition(MutableArrayRef<AffineForOp> loops,
/*boundFloorDivisor=*/nullptr,
/*ub=*/nullptr, &fullTileLbPos,
&fullTileUbPos)) {
- LLVM_DEBUG(llvm::dbgs() << "Can't get constant diff pair for a loop\n");
+ LDBG() << "Can't get constant diff pair for a loop";
return nullptr;
}
@@ -2667,8 +2667,7 @@ createFullTiles(MutableArrayRef<AffineForOp> inputNest,
for (auto loop : inputNest) {
// TODO: straightforward to generalize to a non-unit stride.
if (loop.getStepAsInt() != 1) {
- LLVM_DEBUG(llvm::dbgs()
- << "[tile separation] non-unit stride not implemented\n");
+ LDBG() << "[tile separation] non-unit stride not implemented";
return failure();
}
SmallVector<Operation *, 1> loopOp{loop.getOperation()};
@@ -2682,8 +2681,8 @@ createFullTiles(MutableArrayRef<AffineForOp> inputNest,
/*boundFloorDivisor=*/nullptr,
/*ub=*/nullptr, &lbPos, &ubPos) ||
lbPos == ubPos) {
- LLVM_DEBUG(llvm::dbgs() << "[tile separation] Can't get constant diff / "
- "equalities not yet handled\n");
+ LDBG() << "[tile separation] Can't get constant diff / "
+ << "equalities not yet handled";
return failure();
}
@@ -2741,8 +2740,8 @@ mlir::affine::separateFullTiles(MutableArrayRef<AffineForOp> inputNest,
AffineIfOp ifOp = createSeparationCondition(inputNest, b);
if (!ifOp) {
fullTileLoops.front().erase();
- LLVM_DEBUG(llvm::dbgs() << "All tiles are full tiles, or failure creating "
- "separation condition\n");
+ LDBG() << "All tiles are full tiles, or failure creating "
+ << "separation condition";
return failure();
}
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 488c3c3..7d4d818 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2678,6 +2678,7 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
case AtomicRMWKind::addi:
case AtomicRMWKind::maxu:
case AtomicRMWKind::ori:
+ case AtomicRMWKind::xori:
return builder.getZeroAttr(resultType);
case AtomicRMWKind::andi:
return builder.getIntegerAttr(
@@ -2736,7 +2737,7 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
// Integer operations.
.Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
.Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
- .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
+ .Case([](arith::XOrIOp op) { return AtomicRMWKind::xori; })
.Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
.Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
.Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
@@ -2806,6 +2807,8 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
return arith::OrIOp::create(builder, loc, lhs, rhs);
case AtomicRMWKind::andi:
return arith::AndIOp::create(builder, loc, lhs, rhs);
+ case AtomicRMWKind::xori:
+ return arith::XOrIOp::create(builder, loc, lhs, rhs);
// TODO: Add remaining reduction operations.
default:
(void)emitOptionalError(loc, "Reduction operation type not supported");
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 93682a9..4780dbb 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -12,7 +12,7 @@ add_mlir_dialect_library(MLIRArithTransforms
UnsignedWhenEquivalent.cpp
ADDITIONAL_HEADER_DIRS
- {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arith/Transforms
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arith/Transforms
DEPENDS
MLIRArithTransformsIncGen
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
index 1aa8064..35365f2 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
@@ -158,13 +158,11 @@ protected:
PatternRewriter &rewriter) {
// Check iterator types for matrix multiplication.
SmallVector<vector::IteratorType> itTypes = op.getIteratorTypesArray();
- if (!((itTypes.size() == 3 &&
- (itTypes[0] == vector::IteratorType::parallel &&
- itTypes[1] == vector::IteratorType::parallel &&
- itTypes[2] == vector::IteratorType::reduction)) ||
- (itTypes.size() == 2 &&
- (itTypes[0] == vector::IteratorType::parallel &&
- itTypes[1] == vector::IteratorType::reduction))))
+ if ((itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+ itTypes[1] != vector::IteratorType::parallel ||
+ itTypes[2] != vector::IteratorType::reduction) &&
+ (itTypes.size() != 2 || itTypes[0] != vector::IteratorType::parallel ||
+ itTypes[1] != vector::IteratorType::reduction))
return rewriter.notifyMatchFailure(
op, "iterator types do not correspond to matrix multiplication");
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
index 35b0bd1..6cb2a56 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
@@ -183,9 +183,9 @@ protected:
Value acc;
// Conventional names for matrix dimensions.
- int64_t M = 0;
- int64_t N = 0;
- int64_t K = 0;
+ int64_t m = 0;
+ int64_t n = 0;
+ int64_t k = 0;
// Create the matrix mulitply and accumulate operation according to
// `mmlaOp`.
@@ -286,41 +286,41 @@ Value VectorContractRewriter::lower(vector::ContractionOp op,
// Single-dimension vector type for the entire RHS tile.
- auto flatRhsTileType = VectorType::get(/*shape=*/K * N, operandEltType,
+ auto flatRhsTileType = VectorType::get(/*shape=*/k * n, operandEltType,
/*scalableDims=*/{true});
// Vector type having the same number of elements as a row in the
// accumulator/output tile and the same element type.
- auto accRowTy = VectorType::get(/*shape=*/N, resultEltType,
+ auto accRowTy = VectorType::get(/*shape=*/n, resultEltType,
/*scalableDims=*/{true});
// Vector type having twice the number of elements as a row in the
// accumulator/output tile the same element type.
- auto accRowX2Ty = VectorType::get(/*shape=*/2 * N, resultEltType,
+ auto accRowX2Ty = VectorType::get(/*shape=*/2 * n, resultEltType,
/*scalableDims=*/{true});
// Vector type having half the number of elements as a row in the
// accumulator/output tile and an integer element type with twice the bit
// width.
- auto accRow64Ty = VectorType::get(/*shape=*/N / 2, rewriter.getI64Type(),
+ auto accRow64Ty = VectorType::get(/*shape=*/n / 2, rewriter.getI64Type(),
/*scalableDims=*/{true});
// Vector type having the same the number of elements as a row in the
// accumulator/output tile and an integer element type with twice the bit
// width.
- auto accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(),
+ auto accRowX264Ty = VectorType::get(/*shape=*/n, rewriter.getI64Type(),
/*scalableDims=*/{true});
Location loc = op.getLoc();
// Extract LHS sub-tiles with logical shape <2xK>.
SmallVector<Value> lhsTile;
- for (int64_t i = 0; i < M; i += 2) {
+ for (int64_t i = 0; i < m; i += 2) {
// Extract two consecutive rows of the LHS tile.
auto r0 =
vector::ExtractOp::create(rewriter, loc, lhs, ArrayRef<int64_t>{i});
auto r1 =
vector::ExtractOp::create(rewriter, loc, lhs, ArrayRef<int64_t>{i + 1});
// Concatenate to obtain a 2 x K x <input-type> flattened sub-tile.
- SmallVector<int64_t> shuffleIdx(2 * K);
+ SmallVector<int64_t> shuffleIdx(2 * k);
std::iota(shuffleIdx.begin(), shuffleIdx.end(), 0);
auto t = vector::ShuffleOp::create(rewriter, loc, r0, r1, shuffleIdx);
// Turn it into a scalable vector.
@@ -337,13 +337,13 @@ Value VectorContractRewriter::lower(vector::ContractionOp op,
// Extract the RHS sub-tiles with logical shape <Kx[2]>.
SmallVector<Value> rhsTile;
- for (int64_t j = 0; j < N; j += 2)
+ for (int64_t j = 0; j < n; j += 2)
rhsTile.push_back(vector::ScalableExtractOp::create(
- rewriter, loc, flatRhsType, rhs, j * K));
+ rewriter, loc, flatRhsType, rhs, j * k));
// Extract and pack the ACC sub-tiles.
SmallVector<Value> accTile;
- for (int64_t i = 0; i < M; i += 2) {
+ for (int64_t i = 0; i < m; i += 2) {
// Extract two consecutive rows of the accumulator tile.
auto r0 = vector::ExtractOp::create(rewriter, loc, op.getAcc(),
ArrayRef<int64_t>{i});
@@ -370,28 +370,28 @@ Value VectorContractRewriter::lower(vector::ContractionOp op,
vector::BitCastOp::create(rewriter, loc, accRowX2Ty, intrI64);
}
// Extract ACC sub-tiles.
- for (int64_t j = 0; j < N; j += 2)
+ for (int64_t j = 0; j < n; j += 2)
accTile.push_back(vector::ScalableExtractOp::create(
rewriter, loc, flatAccType, accTileVec, j * 2));
}
// Emit sub-tile matrix multiplications.
SmallVector<Value> outTile;
- for (int64_t i = 0; i < M / 2; ++i)
- for (int64_t j = 0; j < N / 2; ++j) {
- Value mmla = createMMLA(rewriter, loc, accTile[i * N / 2 + j], lhsTile[i],
+ for (int64_t i = 0; i < m / 2; ++i)
+ for (int64_t j = 0; j < n / 2; ++j) {
+ Value mmla = createMMLA(rewriter, loc, accTile[i * n / 2 + j], lhsTile[i],
rhsTile[j]);
outTile.push_back(mmla);
}
// Unpack the OUT sub-tiles and insert into the result.
Value result = ub::PoisonOp::create(rewriter, loc, op.getResultType());
- for (int64_t i = 0; i < M / 2; ++i) {
+ for (int64_t i = 0; i < m / 2; ++i) {
// Collect a number of sub-tiles in a row.
Value row = ub::PoisonOp::create(rewriter, loc, accRowX2Ty);
- for (int64_t j = 0; j < N / 2; ++j)
+ for (int64_t j = 0; j < n / 2; ++j)
row = vector::ScalableInsertOp::create(
- rewriter, loc, outTile[i * N / 2 + j], row, j * 4);
+ rewriter, loc, outTile[i * n / 2 + j], row, j * 4);
// Unpack the row to obtain two rows of the output. If we have the out
// sub-tiles transposed we obtain two consecutive output rows by
@@ -432,9 +432,9 @@ public:
VectorType lhsType = op.getLhsType();
VectorType rhsType = op.getRhsType();
- M = lhsType.getDimSize(0);
- N = rhsType.getDimSize(0);
- K = rhsType.getDimSize(1);
+ m = lhsType.getDimSize(0);
+ n = rhsType.getDimSize(0);
+ k = rhsType.getDimSize(1);
// Check the operands have the expected shape:
// * for LHS: fixed vector MxK
@@ -442,8 +442,8 @@ public:
// * K == 8
// * M and N even and at least 2
if (lhsType.isScalable() || !rhsType.getScalableDims()[0] ||
- rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != K || K != 8 ||
- M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 ||
+ rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != k || k != 8 ||
+ m < 2 || m % 2 != 0 || n < 2 || n % 2 != 0 ||
!rhsType.getScalableDims()[0])
return rewriter.notifyMatchFailure(op, "non-matching operand shape");
@@ -504,9 +504,9 @@ public:
VectorType lhsType = op.getLhsType();
VectorType rhsType = op.getRhsType();
- M = lhsType.getDimSize(0);
- N = rhsType.getDimSize(0);
- K = rhsType.getDimSize(1);
+ m = lhsType.getDimSize(0);
+ n = rhsType.getDimSize(0);
+ k = rhsType.getDimSize(1);
// Check the operands have the expected shape:
// * for LHS: fixed vector MxK
@@ -514,8 +514,8 @@ public:
// * K == 4
// * M and N even and at least 2
if (lhsType.isScalable() || !rhsType.getScalableDims()[0] ||
- rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != K || K != 4 ||
- M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 ||
+ rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != k || k != 4 ||
+ m < 2 || m % 2 != 0 || n < 2 || n % 2 != 0 ||
!rhsType.getScalableDims()[0])
return rewriter.notifyMatchFailure(op, "non-matching operand shape");
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
index ddc64ea..91e37dd 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
@@ -248,7 +248,7 @@ LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) {
Region *definingRegion = value.getParentRegion();
// Last users of the `value` inside all blocks where the value dies.
- llvm::SmallSet<Operation *, 4> lastUsers;
+ llvm::SmallPtrSet<Operation *, 4> lastUsers;
// Find blocks in the `definingRegion` that have users of the `value` (if
// there are multiple users in the block, which one will be selected is
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 7eb729f..56ff212 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -463,8 +463,12 @@ struct SimplifyClones : public OpRewritePattern<CloneOp> {
// which otherwise could prevent removal of unnecessary allocs.
Value canonicalSource = source;
while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
- canonicalSource.getDefiningOp()))
+ canonicalSource.getDefiningOp())) {
+ if (canonicalSource != iface.getViewDest()) {
+ break;
+ }
canonicalSource = iface.getViewSource();
+ }
std::optional<Operation *> maybeCloneDeallocOp =
memref::findDealloc(cloneOp.getOutput());
@@ -806,14 +810,12 @@ struct ToBufferOfCast : public OpRewritePattern<ToBufferOp> {
if (!srcTensorType)
return failure();
auto currentOutputMemRefType =
- dyn_cast<MemRefType>(toBuffer.getResult().getType());
+ dyn_cast<BaseMemRefType>(toBuffer.getResult().getType());
if (!currentOutputMemRefType)
return failure();
- auto memrefType = MemRefType::get(srcTensorType.getShape(),
- srcTensorType.getElementType(),
- currentOutputMemRefType.getLayout(),
- currentOutputMemRefType.getMemorySpace());
+ auto memrefType = currentOutputMemRefType.cloneWith(
+ srcTensorType.getShape(), srcTensorType.getElementType());
Value memref = ToBufferOp::create(rewriter, toBuffer.getLoc(), memrefType,
tensorCastOperand.getOperand(),
toBuffer.getReadOnly());
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 8916526..a465c95 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -37,8 +37,12 @@ using namespace mlir::bufferization;
/// Given a memref value, return the "base" value by skipping over all
/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
static Value getViewBase(Value value) {
- while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
+ while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>()) {
+ if (value != viewLikeOp.getViewDest()) {
+ break;
+ }
value = viewLikeOp.getViewSource();
+ }
return value;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 8f983ab..0b2e080 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -121,7 +121,7 @@ void BufferViewFlowAnalysis::build(Operation *op) {
// Add additional dependencies created by view changes to the alias list.
if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
registerDependencies(viewInterface.getViewSource(),
- viewInterface->getResult(0));
+ viewInterface.getViewDest());
return WalkResult::advance();
}
@@ -231,8 +231,12 @@ static bool isFunctionArgument(Value v) {
/// Given a memref value, return the "base" value by skipping over all
/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
static Value getViewBase(Value value) {
- while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
+ while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>()) {
+ if (value != viewLikeOp.getViewDest()) {
+ break;
+ }
value = viewLikeOp.getViewSource();
+ }
return value;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 91f6f25..68ef519 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -20,6 +20,7 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/PassManager.h"
+#include "llvm/Support/DebugLog.h"
#include <optional>
namespace mlir {
@@ -328,20 +329,16 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
"blocks");
// Bufferize the op.
- LLVM_DEBUG(llvm::dbgs()
- << "//===-------------------------------------------===//\n"
- << "IR after bufferizing: " << nextOp->getName() << "\n");
+ LDBG(3) << "//===-------------------------------------------===//\n"
+ << "IR after bufferizing: " << nextOp->getName();
rewriter.setInsertionPoint(nextOp);
if (failed(
bufferizableOp.bufferize(rewriter, options, bufferizationState))) {
- LLVM_DEBUG(llvm::dbgs()
- << "failed to bufferize\n"
- << "//===-------------------------------------------===//\n");
+ LDBG(2) << "failed to bufferize\n"
+ << "//===-------------------------------------------===//";
return nextOp->emitError("failed to bufferize op");
}
- LLVM_DEBUG(llvm::dbgs()
- << *op
- << "\n//===-------------------------------------------===//\n");
+ LDBG(3) << *op << "\n//===-------------------------------------------===//";
}
// Return early if the top-level op is entirely gone.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index a8e8353..fb7f2bb 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -56,6 +56,7 @@
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/DebugLog.h"
MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::OneShotAnalysisState)
@@ -616,13 +617,10 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
if (getParallelRegion(def.getParentRegion(), options) !=
getParallelRegion(uConflictingWrite->getOwner()->getParentRegion(),
options)) {
- LLVM_DEBUG(
- llvm::dbgs()
- << "\n- bufferizes out-of-place due to parallel region:\n");
- LLVM_DEBUG(llvm::dbgs()
- << " unConflictingWrite = operand "
- << uConflictingWrite->getOperandNumber() << " of "
- << *uConflictingWrite->getOwner() << "\n");
+ LDBG() << "\n- bufferizes out-of-place due to parallel region:\n"
+ << " unConflictingWrite = operand "
+ << uConflictingWrite->getOperandNumber() << " of "
+ << *uConflictingWrite->getOwner();
return true;
}
}
@@ -631,9 +629,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
for (OpOperand *uRead : usesRead) {
Operation *readingOp = uRead->getOwner();
- LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n");
- LLVM_DEBUG(llvm::dbgs() << " uRead = operand " << uRead->getOperandNumber()
- << " of " << *readingOp << "\n");
+ LDBG() << "\n- check conflict:\n"
+ << " uRead = operand " << uRead->getOperandNumber() << " of "
+ << *readingOp;
// Find the definition of uRead by following the SSA use-def chain.
// E.g.:
@@ -648,23 +646,22 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
const SetVector<Value> &definitions = state.findDefinitionsCached(uRead);
if (definitions.empty()) {
// Fast path: No conflict if there are no definitions.
- LLVM_DEBUG(llvm::dbgs()
- << " no conflict: read value has no definitions\n");
+ LDBG() << " no conflict: read value has no definitions";
continue;
}
// Look for conflicting memory writes. Potential conflicts are writes to an
// alias that have been decided to bufferize inplace.
for (OpOperand *uConflictingWrite : usesWrite) {
- LLVM_DEBUG(llvm::dbgs() << " unConflictingWrite = operand "
- << uConflictingWrite->getOperandNumber() << " of "
- << *uConflictingWrite->getOwner() << "\n");
+ LDBG() << " unConflictingWrite = operand "
+ << uConflictingWrite->getOperandNumber() << " of "
+ << *uConflictingWrite->getOwner();
// Check if op dominance can be used to rule out read-after-write
// conflicts.
bool useDominance =
canUseOpDominance(uRead, uConflictingWrite, definitions, state);
- LLVM_DEBUG(llvm::dbgs() << "\n- useDominance = " << useDominance << "\n");
+ LDBG() << "\n- useDominance = " << useDominance;
// Throughout this loop, check for multiple requirements that have to be
// met for uConflictingWrite to be an actual conflict.
@@ -680,8 +677,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// inside a loop), there may be no meaningful `happensBefore`
// relationship.
if (happensBefore(readingOp, conflictingWritingOp, domInfo)) {
- LLVM_DEBUG(llvm::dbgs()
- << " no conflict: read happens before write\n");
+ LDBG() << " no conflict: read happens before write";
continue;
}
@@ -693,8 +689,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// Note: If the op is executed multiple times (e.g., because it is
// inside a loop), it may be conflicting with itself.
if (uConflictingWrite == uRead) {
- LLVM_DEBUG(llvm::dbgs()
- << " no conflict: read and write are same use\n");
+ LDBG() << " no conflict: read and write are same use";
continue;
}
@@ -705,8 +700,8 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// multiple times.
if (state.insideMutuallyExclusiveRegions(readingOp,
conflictingWritingOp)) {
- LLVM_DEBUG(llvm::dbgs() << " no conflict: read and write are in "
- "mutually exclusive regions\n");
+ LDBG() << " no conflict: read and write are in "
+ "mutually exclusive regions";
continue;
}
@@ -721,9 +716,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
state, uRead, uConflictingWrite->get()) ||
hasEquivalentValueInReverseUseDefChain(
state, uConflictingWrite, uRead->get())) {
- LLVM_DEBUG(
- llvm::dbgs()
- << " no conflict: op bufferizes to element-wise access\n");
+ LDBG() << " no conflict: op bufferizes to element-wise access";
continue;
}
}
@@ -733,15 +726,14 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// No conflict if the operands are non-conflicting subsets.
if (areNonConflictingSubsets(uRead, uConflictingWrite, state)) {
- LLVM_DEBUG(llvm::dbgs() << " no conflict: non-conflicting subsets\n");
+ LDBG() << " no conflict: non-conflicting subsets";
continue;
}
// No conflict if the op interface says so.
if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) {
- LLVM_DEBUG(llvm::dbgs()
- << " no conflict: op interace of reading op says 'no'\n");
+ LDBG() << " no conflict: op interace of reading op says 'no'";
continue;
}
}
@@ -751,9 +743,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
options.dynCastBufferizableOp(conflictingWritingOp)) {
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
state)) {
- LLVM_DEBUG(
- llvm::dbgs()
- << " no conflict: op interace of writing op says 'no'\n");
+ LDBG() << " no conflict: op interace of writing op says 'no'";
continue;
}
}
@@ -761,29 +751,26 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// Check all possible definitions.
for (Value definition : definitions) {
- LLVM_DEBUG(llvm::dbgs() << " * definition = " << definition << "\n");
+ LDBG() << " * definition = " << definition;
// No conflict if the conflicting write happens before the definition.
if (Operation *defOp = definition.getDefiningOp()) {
if (happensBefore(conflictingWritingOp, defOp, domInfo)) {
// conflictingWritingOp happens before defOp. No conflict.
- LLVM_DEBUG(llvm::dbgs()
- << " no conflict: write happens before definition\n");
+ LDBG() << " no conflict: write happens before definition";
continue;
}
// No conflict if conflictingWritingOp is contained in defOp.
if (defOp->isProperAncestor(conflictingWritingOp)) {
- LLVM_DEBUG(
- llvm::dbgs()
- << " no conflict: write is contained in definition\n");
+ LDBG() << " no conflict: write is contained in definition";
continue;
}
} else {
auto bbArg = cast<BlockArgument>(definition);
Block *block = bbArg.getOwner();
if (!block->findAncestorOpInBlock(*conflictingWritingOp)) {
- LLVM_DEBUG(llvm::dbgs() << " no conflict: definition is bbArg "
- "and write happens outside of block\n");
+ LDBG() << " no conflict: definition is bbArg "
+ "and write happens outside of block";
// conflictingWritingOp happens outside of the block. No
// conflict.
continue;
@@ -795,8 +782,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
AliasingValueList aliases = state.getAliasingValues(*uConflictingWrite);
if (aliases.getNumAliases() == 1 &&
aliases.getAliases()[0].value == definition) {
- LLVM_DEBUG(llvm::dbgs()
- << " no conflict: definition and write are same\n");
+ LDBG() << " no conflict: definition and write are same";
continue;
}
@@ -804,7 +790,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
if (options.printConflicts)
annotateConflict(uRead, uConflictingWrite, definition);
- LLVM_DEBUG(llvm::dbgs() << " => RaW CONFLICT FOUND\n");
+ LDBG() << " => RaW CONFLICT FOUND";
return true;
}
}
@@ -958,7 +944,7 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
for (AliasingValue alias : state.getAliasingValues(operand))
state.applyOnAliases(alias.value, checkReadOnly);
if (foundReadOnly) {
- LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n");
+ LDBG() << "=> NOT WRITABLE";
return true;
}
@@ -987,10 +973,9 @@ void OneShotAnalysisState::resetCache() {
static LogicalResult
bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state,
const DominanceInfo &domInfo) {
- LLVM_DEBUG(
- llvm::dbgs() << "//===-------------------------------------------===//\n"
- << "Analyzing operand #" << operand.getOperandNumber()
- << " of " << *operand.getOwner() << "\n");
+ LDBG() << "//===-------------------------------------------===//\n"
+ << "Analyzing operand #" << operand.getOperandNumber() << " of "
+ << *operand.getOwner();
bool foundInterference =
wouldCreateWriteToNonWritableBuffer(operand, state) ||
@@ -1001,8 +986,7 @@ bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state,
else
state.bufferizeInPlace(operand);
- LLVM_DEBUG(llvm::dbgs()
- << "//===-------------------------------------------===//\n");
+ LDBG() << "//===-------------------------------------------===//";
return success();
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index 725fa24..b593cca 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -51,14 +51,8 @@ static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
/// Return "true" if the given op is guaranteed to have neither "Allocate" nor
/// "Free" side effects.
static bool hasNeitherAllocateNorFreeSideEffect(Operation *op) {
- if (isa<MemoryEffectOpInterface>(op))
- return !hasEffect<MemoryEffects::Allocate>(op) &&
- !hasEffect<MemoryEffects::Free>(op);
- // If the op does not implement the MemoryEffectOpInterface but has has
- // recursive memory effects, then this op in isolation (without its body) does
- // not have any side effects. All the ops inside the regions of this op will
- // be processed separately.
- return op->hasTrait<OpTrait::HasRecursiveMemoryEffects>();
+ return !mightHaveEffect<MemoryEffects::Allocate>(op) &&
+ !mightHaveEffect<MemoryEffects::Free>(op);
}
/// Return "true" if the given op has buffer semantics. I.e., it has buffer
@@ -517,9 +511,7 @@ LogicalResult BufferDeallocation::verifyOperationPreconditions(Operation *op) {
// MemoryEffectOpInterface. They usually do not have side effects apart
// from the callee, which will be analyzed separately. (This is similar to
// "recursive memory effects".)
- if (!isa<MemoryEffectOpInterface>(op) &&
- !op->hasTrait<OpTrait::HasRecursiveMemoryEffects>() &&
- !isa<CallOpInterface>(op))
+ if (hasUnknownEffects(op) && !isa<CallOpInterface>(op))
return op->emitError(
"ops with unknown memory side effects are not supported");
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index 053ee95..0acb4b1 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -41,6 +41,7 @@ add_subdirectory(Transform)
add_subdirectory(UB)
add_subdirectory(Utils)
add_subdirectory(Vector)
+add_subdirectory(WasmSSA)
add_subdirectory(X86Vector)
add_subdirectory(XeGPU)
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
index 37b4cfc..47740d3 100644
--- a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
@@ -3,7 +3,7 @@ add_mlir_dialect_library(MLIRControlFlowTransforms
BufferizableOpInterfaceImpl.cpp
ADDITIONAL_HEADER_DIRS
- {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms
LINK_LIBS PUBLIC
MLIRBufferizationDialect
diff --git a/mlir/lib/Dialect/DLTI/Traits.cpp b/mlir/lib/Dialect/DLTI/Traits.cpp
index 34f2dd5..3f6dd29 100644
--- a/mlir/lib/Dialect/DLTI/Traits.cpp
+++ b/mlir/lib/Dialect/DLTI/Traits.cpp
@@ -24,7 +24,7 @@ LogicalResult mlir::impl::verifyHasDefaultDLTIDataLayoutTrait(Operation *op) {
}
DataLayoutSpecInterface mlir::impl::getDataLayoutSpec(Operation *op) {
- return op->getAttrOfType<DataLayoutSpecAttr>(
+ return op->getAttrOfType<DataLayoutSpecInterface>(
DLTIDialect::kDataLayoutAttrName);
}
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index e6a3154..00ce3b5 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -114,11 +114,8 @@ bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) {
bool mlir::emitc::isSupportedFloatType(Type type) {
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
switch (floatType.getWidth()) {
- case 16: {
- if (llvm::isa<Float16Type, BFloat16Type>(type))
- return true;
- return false;
- }
+ case 16:
+ return llvm::isa<Float16Type, BFloat16Type>(type);
case 32:
case 64:
return true;
@@ -134,6 +131,12 @@ bool mlir::emitc::isPointerWideType(Type type) {
type);
}
+bool mlir::emitc::isFundamentalType(Type type) {
+ return llvm::isa<IndexType>(type) || isPointerWideType(type) ||
+ isSupportedIntegerType(type) || isSupportedFloatType(type) ||
+ isa<emitc::PointerType>(type);
+}
+
/// Check that the type of the initial value is compatible with the operations
/// result type.
static LogicalResult verifyInitializationAttribute(Operation *op,
@@ -378,6 +381,52 @@ OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
// ExpressionOp
//===----------------------------------------------------------------------===//
+ParseResult ExpressionOp::parse(OpAsmParser &parser, OperationState &result) {
+ SmallVector<OpAsmParser::UnresolvedOperand> operands;
+ if (parser.parseOperandList(operands))
+ return parser.emitError(parser.getCurrentLocation()) << "expected operands";
+ if (succeeded(parser.parseOptionalKeyword("noinline")))
+ result.addAttribute(ExpressionOp::getDoNotInlineAttrName(result.name),
+ parser.getBuilder().getUnitAttr());
+ Type type;
+ if (parser.parseColonType(type))
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected function type");
+ auto fnType = llvm::dyn_cast<FunctionType>(type);
+ if (!fnType)
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected function type");
+ if (parser.resolveOperands(operands, fnType.getInputs(),
+ parser.getCurrentLocation(), result.operands))
+ return failure();
+ if (fnType.getNumResults() != 1)
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected single return type");
+ result.addTypes(fnType.getResults());
+ Region *body = result.addRegion();
+ SmallVector<OpAsmParser::Argument> argsInfo;
+ for (auto [unresolvedOperand, operandType] :
+ llvm::zip(operands, fnType.getInputs())) {
+ OpAsmParser::Argument argInfo;
+ argInfo.ssaName = unresolvedOperand;
+ argInfo.type = operandType;
+ argsInfo.push_back(argInfo);
+ }
+ if (parser.parseRegion(*body, argsInfo, /*enableNameShadowing=*/true))
+ return failure();
+ return success();
+}
+
+void emitc::ExpressionOp::print(OpAsmPrinter &p) {
+ p << ' ';
+ p.printOperands(getDefs());
+ p << " : ";
+ p.printFunctionalType(getOperation());
+ p.shadowRegionArgs(getRegion(), getDefs());
+ p << ' ';
+ p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
+}
+
Operation *ExpressionOp::getRootOp() {
auto yieldOp = cast<YieldOp>(getBody()->getTerminator());
Value yieldedValue = yieldOp.getResult();
@@ -1398,6 +1447,7 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
//===----------------------------------------------------------------------===//
// FieldOp
//===----------------------------------------------------------------------===//
+
static void printEmitCFieldOpTypeAndInitialValue(OpAsmPrinter &p, FieldOp op,
TypeAttr type,
Attribute initialValue) {
@@ -1455,6 +1505,15 @@ LogicalResult FieldOp::verify() {
//===----------------------------------------------------------------------===//
// GetFieldOp
//===----------------------------------------------------------------------===//
+
+LogicalResult GetFieldOp::verify() {
+ auto parentClassOp = getOperation()->getParentOfType<emitc::ClassOp>();
+ if (!parentClassOp.getOperation())
+ return emitOpError(" must be nested within an emitc.class operation");
+
+ return success();
+}
+
LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
mlir::FlatSymbolRefAttr fieldNameAttr = getFieldNameAttr();
FieldOp fieldOp =
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index 3f0690c..f8469b8 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -9,7 +9,9 @@
#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/STLExtras.h"
namespace mlir {
namespace emitc {
@@ -24,20 +26,24 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
Location loc = op->getLoc();
builder.setInsertionPointAfter(op);
- auto expressionOp = emitc::ExpressionOp::create(builder, loc, resultType);
+ auto expressionOp =
+ emitc::ExpressionOp::create(builder, loc, resultType, op->getOperands());
// Replace all op's uses with the new expression's result.
result.replaceAllUsesWith(expressionOp.getResult());
- // Create an op to yield op's value.
- Region &region = expressionOp.getRegion();
- Block &block = region.emplaceBlock();
+ Block &block = expressionOp.createBody();
+ IRMapping mapper;
+ for (auto [operand, arg] :
+ llvm::zip(expressionOp.getOperands(), block.getArguments()))
+ mapper.map(operand, arg);
builder.setInsertionPointToEnd(&block);
- auto yieldOp = emitc::YieldOp::create(builder, loc, result);
- // Move op into the new expression.
- op->moveBefore(yieldOp);
+ Operation *rootOp = builder.clone(*op, mapper);
+ op->erase();
+ // Create an op to yield op's value.
+ emitc::YieldOp::create(builder, loc, rootOp->getResults()[0]);
return expressionOp;
}
@@ -53,51 +59,93 @@ struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
using OpRewritePattern<ExpressionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExpressionOp expressionOp,
PatternRewriter &rewriter) const override {
- bool anythingFolded = false;
- for (Operation &op : llvm::make_early_inc_range(
- expressionOp.getBody()->without_terminator())) {
- // Don't fold expressions whose result value has its address taken.
- auto applyOp = dyn_cast<emitc::ApplyOp>(op);
- if (applyOp && applyOp.getApplicableOperator() == "&")
- continue;
-
- for (Value operand : op.getOperands()) {
- auto usedExpression = operand.getDefiningOp<ExpressionOp>();
- if (!usedExpression)
- continue;
-
- // Don't fold expressions with multiple users: assume any
- // re-materialization was done separately.
- if (!usedExpression.getResult().hasOneUse())
- continue;
-
- // Don't fold expressions with side effects.
- if (usedExpression.hasSideEffects())
- continue;
-
- // Fold the used expression into this expression by cloning all
- // instructions in the used expression just before the operation using
- // its value.
- rewriter.setInsertionPoint(&op);
- IRMapping mapper;
- for (Operation &opToClone :
- usedExpression.getBody()->without_terminator()) {
- Operation *clone = rewriter.clone(opToClone, mapper);
- mapper.map(&opToClone, clone);
- }
-
- Operation *expressionRoot = usedExpression.getRootOp();
- Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot);
- assert(clonedExpressionRootOp &&
- "Expected cloned expression root to be in mapper");
- assert(clonedExpressionRootOp->getNumResults() == 1 &&
- "Expected cloned root to have a single result");
-
- rewriter.replaceOp(usedExpression, clonedExpressionRootOp);
- anythingFolded = true;
- }
+ Block *expressionBody = expressionOp.getBody();
+ ExpressionOp usedExpression;
+ SetVector<Value> foldedOperands;
+
+ auto takesItsOperandsAddress = [](Operation *user) {
+ auto applyOp = dyn_cast<emitc::ApplyOp>(user);
+ return applyOp && applyOp.getApplicableOperator() == "&";
+ };
+
+ // Select as expression to fold the first operand expression that
+ // - doesn't have its result value's address taken,
+ // - has a single user: assume any re-materialization was done separately,
+ // - has no side effects,
+ // and save all other operands to be used later as operands in the folded
+ // expression.
+ for (auto [operand, arg] : llvm::zip(expressionOp.getOperands(),
+ expressionBody->getArguments())) {
+ ExpressionOp operandExpression = operand.getDefiningOp<ExpressionOp>();
+ if (usedExpression || !operandExpression ||
+ llvm::any_of(arg.getUsers(), takesItsOperandsAddress) ||
+ !operandExpression.getResult().hasOneUse() ||
+ operandExpression.hasSideEffects())
+ foldedOperands.insert(operand);
+ else
+ usedExpression = operandExpression;
}
- return anythingFolded ? success() : failure();
+
+ // If no operand expression was selected, bail out.
+ if (!usedExpression)
+ return failure();
+
+ // Collect additional operands from the folded expression.
+ for (Value operand : usedExpression.getOperands())
+ foldedOperands.insert(operand);
+
+ // Create a new expression to hold the folding result.
+ rewriter.setInsertionPointAfter(expressionOp);
+ auto foldedExpression = emitc::ExpressionOp::create(
+ rewriter, expressionOp.getLoc(), expressionOp.getResult().getType(),
+ foldedOperands.getArrayRef(), expressionOp.getDoNotInline());
+ Block &foldedExpressionBody = foldedExpression.createBody();
+
+ // Map each operand of the new expression to its matching block argument.
+ IRMapping mapper;
+ for (auto [operand, arg] : llvm::zip(foldedExpression.getOperands(),
+ foldedExpressionBody.getArguments()))
+ mapper.map(operand, arg);
+
+ // Prepare to fold the used expression and the matched expression into the
+ // newly created folded expression.
+ auto foldExpression = [&rewriter, &mapper](ExpressionOp expressionToFold,
+ bool withTerminator) {
+ Block *expressionToFoldBody = expressionToFold.getBody();
+ for (auto [operand, arg] :
+ llvm::zip(expressionToFold.getOperands(),
+ expressionToFoldBody->getArguments())) {
+ mapper.map(arg, mapper.lookup(operand));
+ }
+
+ for (Operation &opToClone : expressionToFoldBody->without_terminator())
+ rewriter.clone(opToClone, mapper);
+
+ if (withTerminator)
+ rewriter.clone(*expressionToFoldBody->getTerminator(), mapper);
+ };
+ rewriter.setInsertionPointToStart(&foldedExpressionBody);
+
+ // First, fold the used expression into the new expression and map its
+ // result to the clone of its root operation within the new expression.
+ foldExpression(usedExpression, /*withTerminator=*/false);
+ Operation *expressionRoot = usedExpression.getRootOp();
+ Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot);
+ assert(clonedExpressionRootOp &&
+ "Expected cloned expression root to be in mapper");
+ assert(clonedExpressionRootOp->getNumResults() == 1 &&
+ "Expected cloned root to have a single result");
+ mapper.map(usedExpression.getResult(),
+ clonedExpressionRootOp->getResults()[0]);
+
+ // Now fold the matched expression into the new expression.
+ foldExpression(expressionOp, /*withTerminator=*/true);
+
+ // Complete the rewrite.
+ rewriter.replaceOp(expressionOp, foldedExpression);
+ rewriter.eraseOp(usedExpression);
+
+ return success();
}
};
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index c55e26e..06d7e07 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -64,8 +64,8 @@ public:
TypeAttr typeAttr = TypeAttr::get(val.getType());
fields.push_back({fieldName, typeAttr});
- FieldOp fieldop = rewriter.create<emitc::FieldOp>(
- funcOp->getLoc(), fieldName, typeAttr, nullptr);
+ FieldOp fieldop = emitc::FieldOp::create(rewriter, funcOp->getLoc(),
+ fieldName, typeAttr, nullptr);
if (argAttrs && idx < argAttrs->size()) {
fieldop->setDiscardableAttrs(funcOp.getArgAttrDict(idx));
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 5a72ef1..b87b4f4 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -756,7 +756,8 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
Type asyncTokenType, ValueRange asyncDependencies,
TypeRange workgroupAttributions,
TypeRange privateAttributions, Value clusterSizeX,
- Value clusterSizeY, Value clusterSizeZ) {
+ Value clusterSizeY, Value clusterSizeZ,
+ FlatSymbolRefAttr module, FlatSymbolRefAttr function) {
OpBuilder::InsertionGuard g(builder);
// Add a WorkGroup attribution attribute. This attribute is required to
@@ -781,6 +782,12 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
if (dynamicSharedMemorySize)
result.addOperands(dynamicSharedMemorySize);
+ // Add optional module and function attributes.
+ if (module)
+ result.addAttribute(getModuleAttrName(result.name), module);
+ if (function)
+ result.addAttribute(getFunctionAttrName(result.name), function);
+
// Create a kernel body region with kNumConfigRegionAttributes + N memory
// attributions, where the first kNumConfigRegionAttributes arguments have
// `index` type and the rest have the same types as the data operands.
@@ -944,6 +951,21 @@ void LaunchOp::print(OpAsmPrinter &p) {
p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
<< getDynamicSharedMemorySize();
+ // Print optional module attribute.
+ StringRef moduleAttrName = getModuleAttrName();
+ if (auto module = getModule()) {
+ p << ' ' << moduleAttrName << '(';
+ p.printSymbolName(*module);
+ p << ')';
+ }
+ // Print optional function attribute.
+ StringRef functionAttrName = getFunctionAttrName();
+ if (auto function = getFunction()) {
+ p << ' ' << functionAttrName << '(';
+ p.printSymbolName(*function);
+ p << ')';
+ }
+
printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
@@ -952,7 +974,8 @@ void LaunchOp::print(OpAsmPrinter &p) {
p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
LaunchOp::getOperandSegmentSizeAttr(),
- getNumWorkgroupAttributionsAttrName()});
+ getNumWorkgroupAttributionsAttrName(),
+ moduleAttrName, functionAttrName});
}
// Parse the size assignment blocks for blocks and threads. These have the form
@@ -990,6 +1013,9 @@ parseSizeAssignment(OpAsmParser &parser,
/// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
/// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
/// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
+/// (`dynamic_shared_memory_size` ssa-use)?
+/// (`module(` symbol-ref-id `)`)?
+/// (`function(` symbol-ref-id `)`)?
/// memory-attribution
/// region attr-dict?
/// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
@@ -1060,6 +1086,27 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
}
+ // Parse optional module attribute.
+ StringRef moduleAttrName = getModuleAttrName(result.name);
+ if (succeeded(parser.parseOptionalKeyword(moduleAttrName))) {
+ FlatSymbolRefAttr moduleSymbol;
+ if (parser.parseLParen() ||
+ parser.parseAttribute(moduleSymbol, Type(), moduleAttrName,
+ result.attributes) ||
+ parser.parseRParen())
+ return failure();
+ }
+ // Parse optional function attribute.
+ StringRef functionAttrName = getFunctionAttrName(result.name);
+ if (succeeded(parser.parseOptionalKeyword(functionAttrName))) {
+ FlatSymbolRefAttr funcSymbol;
+ if (parser.parseLParen() ||
+ parser.parseAttribute(funcSymbol, Type(), functionAttrName,
+ result.attributes) ||
+ parser.parseRParen())
+ return failure();
+ }
+
// Create the region arguments, it has kNumConfigRegionAttributes arguments
// that correspond to block/thread identifiers and grid/block sizes, all
// having `index` type, a variadic number of WorkGroup Attributions and
@@ -2439,8 +2486,7 @@ LogicalResult WarpExecuteOnLane0Op::verify() {
if (getArgs().size() != getWarpRegion().getNumArguments())
return emitOpError(
"expected same number op arguments and block arguments.");
- auto yield =
- cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = getTerminator();
if (yield.getNumOperands() != getNumResults())
return emitOpError(
"expected same number of yield operands and return values.");
@@ -2464,6 +2510,50 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
}
+gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
+ return cast<gpu::YieldOp>(getBody()->getTerminator());
+}
+
+//===----------------------------------------------------------------------===//
+// GPU_SubgroupBroadcastOp
+//===----------------------------------------------------------------------===//
+
+void gpu::SubgroupBroadcastOp::inferResultRanges(
+ ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
+ setResultRange(getResult(), argRanges.front());
+}
+
+Speculation::Speculatability gpu::SubgroupBroadcastOp::getSpeculatability() {
+ switch (getBroadcastType()) {
+ case BroadcastType::first_active_lane:
+ // Cannot speculate first_lane broadcast, because speculating it across
+ // control flow can change the active lanes.
+ return Speculation::NotSpeculatable;
+ case BroadcastType::any_lane:
+ LLVM_FALLTHROUGH;
+ case BroadcastType::specific_lane:
+ // Speculation should be safe as long as we inside structured control flow.
+ return Speculation::Speculatable;
+ }
+}
+
+LogicalResult gpu::SubgroupBroadcastOp::verify() {
+ switch (getBroadcastType()) {
+ case BroadcastType::first_active_lane:
+ LLVM_FALLTHROUGH;
+ case BroadcastType::any_lane:
+ if (getLane())
+ return emitOpError()
+ << "lane can only be specified for `specific_lane` broadcast";
+ return success();
+ case BroadcastType::specific_lane:
+ if (!getLane())
+ return emitOpError()
+ << "lane must be specified for `specific_lane` broadcast";
+ return success();
+ }
+}
+
//===----------------------------------------------------------------------===//
// GPU KernelMetadataAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 21cb2f6..c766539 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -13,6 +13,7 @@
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/TransformOps/Utils.h"
@@ -43,6 +44,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/InterleavedRange.h"
#include "llvm/Support/LogicalResult.h"
+#include <optional>
#include <type_traits>
using namespace mlir;
@@ -170,7 +172,16 @@ void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) {
void transform::ApplyGPUPromoteShuffleToAMDGPUPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- populateGpuPromoteShuffleToAMDGPUPatterns(patterns);
+ std::optional<StringRef> chipsetName = getChipset();
+ std::optional<amdgpu::Chipset> maybeChipset;
+ if (chipsetName) {
+ FailureOr<amdgpu::Chipset> parsedChipset =
+ amdgpu::Chipset::parse(*chipsetName);
+ assert(llvm::succeeded(parsedChipset) && "expected valid chipset");
+ maybeChipset = parsedChipset;
+ }
+
+ populateGpuPromoteShuffleToAMDGPUPatterns(patterns, maybeChipset);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
index 9bf11c7..d2c2138 100644
--- a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
@@ -25,6 +25,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
namespace mlir {
#define GEN_PASS_DEF_GPUELIMINATEBARRIERS
@@ -37,9 +38,6 @@ using namespace mlir::gpu;
#define DEBUG_TYPE "gpu-erase-barriers"
#define DEBUG_TYPE_ALIAS "gpu-erase-barries-alias"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
-
// The functions below provide interface-like verification, but are too specific
// to barrier elimination to become interfaces.
@@ -424,27 +422,18 @@ static bool maybeCaptured(Value v) {
/// everything. This seems sufficient to achieve barrier removal in structured
/// control flow, more complex cases would require a proper dataflow analysis.
static bool mayAlias(Value first, Value second) {
- DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, {
- DBGS_ALIAS() << "checking aliasing between ";
- DBGS_ALIAS() << first << "\n";
- DBGS_ALIAS() << " and ";
- DBGS_ALIAS() << second << "\n";
- });
+ LDBG(DEBUG_TYPE_ALIAS, 1)
+ << "checking aliasing between " << first << " and " << second;
first = getBase(first);
second = getBase(second);
- DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, {
- DBGS_ALIAS() << "base ";
- DBGS_ALIAS() << first << "\n";
- DBGS_ALIAS() << " and ";
- DBGS_ALIAS() << second << "\n";
- });
+ LDBG(DEBUG_TYPE_ALIAS, 1) << "base " << first << " and " << second;
// Values derived from the same base memref do alias (unless we do a more
// advanced analysis to prove non-overlapping accesses).
if (first == second) {
- DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, DBGS_ALIAS() << "-> do alias!\n");
+ LDBG(DEBUG_TYPE_ALIAS, 1) << "-> do alias!";
return true;
}
@@ -493,7 +482,7 @@ static bool mayAlias(Value first, Value second) {
return false;
// Otherwise, conservatively assume aliasing.
- DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, DBGS_ALIAS() << "-> may alias!\n");
+ LDBG(DEBUG_TYPE_ALIAS, 1) << "-> may alias!";
return true;
}
@@ -567,20 +556,16 @@ haveConflictingEffects(ArrayRef<MemoryEffects::EffectInstance> beforeEffects,
continue;
// Other kinds of effects create a conflict, e.g. read-after-write.
- LLVM_DEBUG(
- DBGS() << "found a conflict between (before): " << before.getValue()
- << " read:" << isa<MemoryEffects::Read>(before.getEffect())
- << " write:" << isa<MemoryEffects::Write>(before.getEffect())
- << " alloc:"
- << isa<MemoryEffects::Allocate>(before.getEffect()) << " free:"
- << isa<MemoryEffects::Free>(before.getEffect()) << "\n");
- LLVM_DEBUG(
- DBGS() << "and (after): " << after.getValue()
- << " read:" << isa<MemoryEffects::Read>(after.getEffect())
- << " write:" << isa<MemoryEffects::Write>(after.getEffect())
- << " alloc:" << isa<MemoryEffects::Allocate>(after.getEffect())
- << " free:" << isa<MemoryEffects::Free>(after.getEffect())
- << "\n");
+ LDBG() << "found a conflict between (before): " << before.getValue()
+ << " read:" << isa<MemoryEffects::Read>(before.getEffect())
+ << " write:" << isa<MemoryEffects::Write>(before.getEffect())
+ << " alloc:" << isa<MemoryEffects::Allocate>(before.getEffect())
+ << " free:" << isa<MemoryEffects::Free>(before.getEffect());
+ LDBG() << "and (after): " << after.getValue()
+ << " read:" << isa<MemoryEffects::Read>(after.getEffect())
+ << " write:" << isa<MemoryEffects::Write>(after.getEffect())
+ << " alloc:" << isa<MemoryEffects::Allocate>(after.getEffect())
+ << " free:" << isa<MemoryEffects::Free>(after.getEffect());
return true;
}
}
@@ -595,8 +580,8 @@ public:
LogicalResult matchAndRewrite(BarrierOp barrier,
PatternRewriter &rewriter) const override {
- LLVM_DEBUG(DBGS() << "checking the necessity of: " << barrier << " "
- << barrier.getLoc() << "\n");
+ LDBG() << "checking the necessity of: " << barrier << " "
+ << barrier.getLoc();
SmallVector<MemoryEffects::EffectInstance> beforeEffects;
getEffectsBefore(barrier, beforeEffects, /*stopAtBarrier=*/true);
@@ -605,14 +590,12 @@ public:
getEffectsAfter(barrier, afterEffects, /*stopAtBarrier=*/true);
if (!haveConflictingEffects(beforeEffects, afterEffects)) {
- LLVM_DEBUG(DBGS() << "the surrounding barriers are sufficient, removing "
- << barrier << "\n");
+ LDBG() << "the surrounding barriers are sufficient, removing " << barrier;
rewriter.eraseOp(barrier);
return success();
}
- LLVM_DEBUG(DBGS() << "barrier is necessary: " << barrier << " "
- << barrier.getLoc() << "\n");
+ LDBG() << "barrier is necessary: " << barrier << " " << barrier.getLoc();
return failure();
}
};
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index 99f5c5b..97adad6 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -356,8 +356,8 @@ public:
auto funcWalkResult = func.walk([&](gpu::LaunchOp op) {
SetVector<Value> operands;
std::string kernelFnName;
- if (op.getKernelFunc()) {
- kernelFnName = op.getKernelFunc()->getRootReference().str();
+ if (op.getFunction()) {
+ kernelFnName = op.getFunction()->str();
} else {
kernelFnName =
Twine(op->getParentOfType<SymbolOpInterface>().getName(),
@@ -403,9 +403,8 @@ private:
OpBuilder builder(context);
std::string kernelModuleName;
gpu::GPUModuleOp kernelModule;
- if (gpuLaunchOp.getKernelModule()) {
- kernelModuleName =
- gpuLaunchOp.getKernelModule()->getRootReference().str();
+ if (gpuLaunchOp.getModule()) {
+ kernelModuleName = gpuLaunchOp.getModule()->str();
kernelModule =
parentSymbolTable.lookup<gpu::GPUModuleOp>(kernelModuleName);
} else {
@@ -432,8 +431,7 @@ private:
if (std::optional<SymbolTable::UseRange> symbolUses =
SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) {
for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
- StringRef symbolName =
- cast<FlatSymbolRefAttr>(symbolUse.getSymbolRef()).getValue();
+ StringAttr symbolName = symbolUse.getSymbolRef().getLeafReference();
if (symbolTable.lookup(symbolName))
continue;
diff --git a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
index 18c69f5..67cef8a 100644
--- a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
@@ -11,16 +11,21 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/PatternMatch.h"
+#include <optional>
using namespace mlir;
namespace {
+
+constexpr amdgpu::Chipset kGfx950 = amdgpu::Chipset(9, 5, 0);
+
/// Try to promote `gpu.shuffle` to `amdgpu.swizzle_bitmode`, width must be 64
/// and offset must be a constant integer in the range [0, 31].
struct PromoteShuffleToSwizzlePattern
@@ -56,9 +61,48 @@ struct PromoteShuffleToSwizzlePattern
return success();
}
};
+
+/// Try to promote `gpu.shuffle` to `amdgpu.permlane_swap`, width must be 64
+/// and offset must be a constant integer in the set {16, 32}.
+struct PromoteShuffleToPermlanePattern
+ : public OpRewritePattern<gpu::ShuffleOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(gpu::ShuffleOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getMode() != gpu::ShuffleMode::XOR)
+ return rewriter.notifyMatchFailure(op,
+ "only xor shuffle mode is supported");
+
+ if (!isConstantIntValue(op.getWidth(), 64))
+ return rewriter.notifyMatchFailure(op,
+ "only 64 width shuffle is supported");
+
+ std::optional<int64_t> offset = getConstantIntValue(op.getOffset());
+ if (!offset)
+ return rewriter.notifyMatchFailure(op,
+ "offset must be a constant integer");
+
+ int64_t offsetValue = *offset;
+ if (offsetValue != 16 && offsetValue != 32)
+ return rewriter.notifyMatchFailure(op, "offset must be either 15 or 31");
+
+ Location loc = op.getLoc();
+ Value res = amdgpu::PermlaneSwapOp::create(
+ rewriter, loc, op.getResult(0).getType(), op.getValue(), offsetValue);
+ Value valid = arith::ConstantIntOp::create(rewriter, loc, 1, /*width*/ 1);
+ rewriter.replaceOp(op, {res, valid});
+ return success();
+ }
+};
+
} // namespace
void mlir::populateGpuPromoteShuffleToAMDGPUPatterns(
- RewritePatternSet &patterns) {
- patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext());
+ RewritePatternSet &patterns, std::optional<amdgpu::Chipset> maybeChipset) {
+ patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext(),
+ /*benefit*/ 1);
+ if (maybeChipset && *maybeChipset >= kGfx950)
+ patterns.add<PromoteShuffleToPermlanePattern>(patterns.getContext(),
+ /*benefit*/ 2);
}
diff --git a/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp b/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp
index e9cf493..6da76e9 100644
--- a/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Target/LLVM/XeVM/Target.h"
#include "llvm/Support/Regex.h"
namespace mlir {
diff --git a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
index 384d1a0..88f531f 100644
--- a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
+++ b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Value.h"
+#include "llvm/ADT/DenseMap.h"
#include <numeric>
@@ -55,28 +56,30 @@ WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns(
SmallVector<size_t> &indices) const {
SmallVector<Type> types(warpOp.getResultTypes().begin(),
warpOp.getResultTypes().end());
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
- llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
- yield.getOperands().end());
+ gpu::YieldOp yield = warpOp.getTerminator();
+ SmallVector<Value> yieldValues(yield.getOperands().begin(),
+ yield.getOperands().end());
+ llvm::SmallDenseMap<Value, unsigned> indexLookup;
+ // Record the value -> first index mapping for faster lookup.
+ for (auto [i, v] : llvm::enumerate(yieldValues)) {
+ if (!indexLookup.count(v))
+ indexLookup[v] = i;
+ }
+
for (auto [value, type] : llvm::zip_equal(newYieldedValues, newReturnTypes)) {
- if (yieldValues.insert(value)) {
+ // If the value already exists in the yield, don't create a new output.
+ if (indexLookup.count(value)) {
+ indices.push_back(indexLookup[value]);
+ } else {
+ // If the value is new, add it to the yield and to the types.
+ yieldValues.push_back(value);
types.push_back(type);
indices.push_back(yieldValues.size() - 1);
- } else {
- // If the value already exit the region don't create a new output.
- for (auto [idx, yieldOperand] :
- llvm::enumerate(yieldValues.getArrayRef())) {
- if (yieldOperand == value) {
- indices.push_back(idx);
- break;
- }
- }
}
}
- yieldValues.insert_range(newYieldedValues);
+
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
- rewriter, warpOp, yieldValues.getArrayRef(), types);
+ rewriter, warpOp, yieldValues, types);
rewriter.replaceOp(warpOp,
newWarpOp.getResults().take_front(warpOp.getNumResults()));
return newWarpOp;
@@ -85,8 +88,7 @@ WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns(
OpOperand *WarpDistributionPattern::getWarpResult(
WarpExecuteOnLane0Op warpOp,
llvm::function_ref<bool(Operation *)> fn) const {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
for (OpOperand &yieldOperand : yield->getOpOperands()) {
Value yieldValues = yieldOperand.get();
Operation *definedOp = yieldValues.getDefiningOp();
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index ff55f17..ec581ac 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -32,6 +32,7 @@ add_mlir_dialect_library(MLIRLLVMDialect
MLIRInferTypeOpInterface
MLIRIR
MLIRMemorySlotInterfaces
+ MLIRPtrMemorySpaceInterfaces
MLIRSideEffectInterfaces
MLIRSupport
)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index 894de44..7220e10 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -12,10 +12,20 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/LogicalResult.h"
+#include "llvm/Support/Regex.h"
#define DEBUG_TYPE "ptx-builder"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
//===----------------------------------------------------------------------===//
// BasicPtxBuilderInterface
@@ -28,50 +38,122 @@ using namespace NVVM;
static constexpr int64_t kSharedMemorySpace = 3;
-static char getRegisterType(Type type) {
- if (type.isInteger(1))
- return 'b';
- if (type.isInteger(16))
- return 'h';
- if (type.isInteger(32))
- return 'r';
- if (type.isInteger(64))
- return 'l';
- if (type.isF32())
- return 'f';
- if (type.isF64())
- return 'd';
- if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
- // Shared address spaces is addressed with 32-bit pointers.
- if (ptr.getAddressSpace() == kSharedMemorySpace) {
+static FailureOr<char> getRegisterType(Type type, Location loc) {
+ MLIRContext *ctx = type.getContext();
+ auto i16 = IntegerType::get(ctx, 16);
+ auto i32 = IntegerType::get(ctx, 32);
+ auto f32 = Float32Type::get(ctx);
+
+ auto getRegisterTypeForScalar = [&](Type type) -> FailureOr<char> {
+ if (type.isInteger(1))
+ return 'b';
+ if (type.isInteger(16))
+ return 'h';
+ if (type.isInteger(32))
return 'r';
+ if (type.isInteger(64))
+ return 'l';
+ if (type.isF32())
+ return 'f';
+ if (type.isF64())
+ return 'd';
+ if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
+ // Shared address spaces is addressed with 32-bit pointers.
+ if (ptr.getAddressSpace() == kSharedMemorySpace) {
+ return 'r';
+ }
+ return 'l';
}
- return 'l';
+ // register type for struct is not supported.
+ mlir::emitError(
+ loc, "The register type could not be deduced from MLIR type. The ")
+ << type
+ << " is not supported. Supported types are:"
+ "i1, i16, i32, i64, f32, f64,"
+ "pointers.\nPlease use llvm.bitcast if you have different type. "
+ "\nSee the constraints from here: "
+ "https://docs.nvidia.com/cuda/inline-ptx-assembly/"
+ "index.html#constraints";
+ return failure();
+ };
+
+ // Packed registers
+ if (auto v = dyn_cast<VectorType>(type)) {
+ assert(v.getNumDynamicDims() == 0 && "Dynamic vectors are not supported");
+
+ int64_t lanes = v.getNumElements();
+ Type elem = v.getElementType();
+
+ // Case 1. Single vector
+ if (lanes <= 1)
+ return getRegisterTypeForScalar(elem);
+
+ // Case 2. Packed registers
+ Type widened = elem;
+ switch (lanes) {
+
+ case 2:
+ if (elem.isF16() || elem.isBF16()) // vector<2xf16>
+ widened = f32;
+ else if (elem.isFloat(8)) // vector<2xf8>
+ widened = i16;
+ break;
+ case 4:
+ if (elem.isInteger(8)) // vector<i8x4>
+ widened = i32;
+ else if (elem.isFloat(8)) // vector<f8x4>
+ widened = f32;
+ else if (elem.isFloat(4)) // vector<f4x4>
+ widened = i16;
+ break;
+ // Other packing is not supported
+ default:
+ break;
+ }
+ return getRegisterTypeForScalar(widened);
}
- // register type for struct is not supported.
- llvm_unreachable("The register type could not deduced from MLIR type");
- return '?';
+
+ return getRegisterTypeForScalar(type);
}
-static char getRegisterType(Value v) {
+static FailureOr<char> getRegisterType(Value v, Location loc) {
if (v.getDefiningOp<LLVM::ConstantOp>())
return 'n';
- return getRegisterType(v.getType());
+ return getRegisterType(v.getType(), loc);
}
-void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
- LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n");
+/// Extract every element of a struct value.
+static SmallVector<Value> extractStructElements(PatternRewriter &rewriter,
+ Location loc, Value structVal) {
+ auto structTy = dyn_cast<LLVM::LLVMStructType>(structVal.getType());
+ assert(structTy && "expected LLVM struct");
+
+ SmallVector<Value> elems;
+ for (unsigned i : llvm::seq<unsigned>(0, structTy.getBody().size()))
+ elems.push_back(rewriter.create<LLVM::ExtractValueOp>(loc, structVal, i));
+
+ return elems;
+}
+
+LogicalResult PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
+ LDBG() << v << "\t Modifier : " << itype << "\n";
+ registerModifiers.push_back(itype);
+
+ Location loc = interfaceOp->getLoc();
auto getModifier = [&]() -> const char * {
- if (itype == PTXRegisterMod::ReadWrite) {
- assert(false && "Read-Write modifier is not supported. Try setting the "
- "same value as Write and Read separately.");
- return "+";
- }
- if (itype == PTXRegisterMod::Write) {
+ switch (itype) {
+ case PTXRegisterMod::Read:
+ return "";
+ case PTXRegisterMod::Write:
return "=";
+ case PTXRegisterMod::ReadWrite:
+ // "Read-Write modifier is not actually supported
+ // Interface will change it to "=" later and add integer mapping
+ return "+";
}
- return "";
+ llvm_unreachable("Unknown PTX register modifier");
};
+
auto addValue = [&](Value v) {
if (itype == PTXRegisterMod::Read) {
ptxOperands.push_back(v);
@@ -90,35 +172,273 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
}
for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
if (itype != PTXRegisterMod::Write) {
- Value extractValue = LLVM::ExtractValueOp::create(
- rewriter, interfaceOp->getLoc(), v, idx);
+ Value extractValue =
+ LLVM::ExtractValueOp::create(rewriter, loc, v, idx);
addValue(extractValue);
}
if (itype == PTXRegisterMod::ReadWrite) {
ss << idx << ",";
} else {
- ss << getModifier() << getRegisterType(t) << ",";
+ FailureOr<char> regType = getRegisterType(t, loc);
+ if (failed(regType))
+ return rewriter.notifyMatchFailure(loc,
+ "failed to get register type");
+ ss << getModifier() << regType.value() << ",";
}
}
- return;
+ return success();
}
// Handle Scalars
addValue(v);
- ss << getModifier() << getRegisterType(v) << ",";
+ FailureOr<char> regType = getRegisterType(v, loc);
+ if (failed(regType))
+ return rewriter.notifyMatchFailure(loc, "failed to get register type");
+ ss << getModifier() << regType.value() << ",";
+ return success();
+}
+
+/// Check if the operation needs to pack and unpack results.
+static bool
+needsPackUnpack(BasicPtxBuilderInterface interfaceOp,
+ bool needsManualRegisterMapping,
+ SmallVectorImpl<PTXRegisterMod> &registerModifiers) {
+ if (needsManualRegisterMapping)
+ return false;
+ const unsigned writeOnlyVals = interfaceOp->getNumResults();
+ const unsigned readWriteVals =
+ llvm::count_if(registerModifiers, [](PTXRegisterMod m) {
+ return m == PTXRegisterMod::ReadWrite;
+ });
+ return (writeOnlyVals + readWriteVals) > 1;
+}
+
+/// Pack the result types of the interface operation.
+/// If the operation has multiple results, it packs them into a struct
+/// type. Otherwise, it returns the original result types.
+static SmallVector<Type>
+packResultTypes(BasicPtxBuilderInterface interfaceOp,
+ bool needsManualRegisterMapping,
+ SmallVectorImpl<PTXRegisterMod> &registerModifiers,
+ SmallVectorImpl<Value> &ptxOperands) {
+ MLIRContext *ctx = interfaceOp->getContext();
+ TypeRange resultRange = interfaceOp->getResultTypes();
+
+ if (!needsPackUnpack(interfaceOp, needsManualRegisterMapping,
+ registerModifiers)) {
+ // Single value path:
+ if (interfaceOp->getResults().size() == 1)
+ return SmallVector<Type>{resultRange.front()};
+
+ // No declared results: if there is an RW, forward its type.
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+ if (m == PTXRegisterMod::ReadWrite)
+ return SmallVector<Type>{v.getType()};
+ }
+
+ SmallVector<Type> packed;
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+ if (m == PTXRegisterMod::ReadWrite)
+ packed.push_back(v.getType());
+ for (Type t : resultRange)
+ packed.push_back(t);
+
+ if (packed.empty())
+ return {};
+
+ auto sTy = LLVM::LLVMStructType::getLiteral(ctx, packed, /*isPacked=*/false);
+ return SmallVector<Type>{sTy};
+}
+
+/// Canonicalize the register constraints:
+/// - Turn every "+X" into "=X"
+/// - Append (at the very end) the 0-based indices of tokens that were "+X"
+/// Examples:
+/// "+f,+f,+r,=r,=r,r,r" -> "=f,=f,=r,=r,=r,r,r,0,1,2"
+/// "+f,+f,+r,=r,=r" -> "=f,=f,=r,=r,=r,0,1,2"
+static std::string canonicalizeRegisterConstraints(llvm::StringRef csv) {
+ SmallVector<llvm::StringRef> toks;
+ SmallVector<std::string> out;
+ SmallVector<unsigned> plusIdx;
+
+ csv.split(toks, ',');
+ out.reserve(toks.size() + 8);
+
+ for (unsigned i = 0, e = toks.size(); i < e; ++i) {
+ StringRef t = toks[i].trim();
+ if (t.consume_front("+")) {
+ plusIdx.push_back(i);
+ out.push_back(("=" + t).str());
+ } else {
+ out.push_back(t.str());
+ }
+ }
+
+ // Append indices of original "+X" tokens.
+ for (unsigned idx : plusIdx)
+ out.push_back(std::to_string(idx));
+
+ // Join back to CSV.
+ std::string result;
+ result.reserve(csv.size() + plusIdx.size() * 2);
+ llvm::raw_string_ostream os(result);
+ for (size_t i = 0; i < out.size(); ++i) {
+ if (i)
+ os << ',';
+ os << out[i];
+ }
+ return os.str();
+}
+
+constexpr llvm::StringLiteral kReadWritePrefix{"rw"};
+constexpr llvm::StringLiteral kWriteOnlyPrefix{"w"};
+constexpr llvm::StringLiteral kReadOnlyPrefix{"r"};
+
+/// Returns a regex that matches {$rwN}, {$wN}, {$rN}
+static llvm::Regex getPredicateMappingRegex() {
+ llvm::Regex rx(llvm::formatv(R"(\{\$({0}|{1}|{2})([0-9]+)\})",
+ kReadWritePrefix, kWriteOnlyPrefix,
+ kReadOnlyPrefix)
+ .str());
+ return rx;
+}
+
+void mlir::NVVM::countPlaceholderNumbers(
+ StringRef ptxCode, llvm::SmallDenseSet<unsigned int> &seenRW,
+ llvm::SmallDenseSet<unsigned int> &seenW,
+ llvm::SmallDenseSet<unsigned int> &seenR,
+ llvm::SmallVectorImpl<unsigned int> &rwNums,
+ llvm::SmallVectorImpl<unsigned int> &wNums,
+ llvm::SmallVectorImpl<unsigned int> &rNums) {
+
+ llvm::Regex rx = getPredicateMappingRegex();
+ StringRef rest = ptxCode;
+
+ SmallVector<StringRef, 3> m; // 0: full, 1: kind, 2: number
+ while (!rest.empty() && rx.match(rest, &m)) {
+ unsigned num = 0;
+ (void)m[2].getAsInteger(10, num);
+ // Insert it into the vector only the first time we see this number
+ if (m[1].equals_insensitive(kReadWritePrefix)) {
+ if (seenRW.insert(num).second)
+ rwNums.push_back(num);
+ } else if (m[1].equals_insensitive(kWriteOnlyPrefix)) {
+ if (seenW.insert(num).second)
+ wNums.push_back(num);
+ } else {
+ if (seenR.insert(num).second)
+ rNums.push_back(num);
+ }
+
+ const size_t advance = (size_t)(m[0].data() - rest.data()) + m[0].size();
+ rest = rest.drop_front(advance);
+ }
+}
+
+/// Rewrites `{$rwN}`, `{$wN}`, and `{$rN}` placeholders in `ptxCode` into
+/// compact `$K` indices:
+/// - All `rw*` first (sorted by N),
+/// - Then `w*`,
+/// - Then `r*`.
+/// If there a predicate, it comes always in the end.
+/// Each number is assigned once; duplicates are ignored.
+///
+/// Example Input:
+/// "{
+/// reg .pred p;
+/// setp.ge.s32 p, {$r0}, {$r1};"
+/// selp.s32 {$rw0}, {$r0}, {$r1}, p;
+/// selp.s32 {$rw1}, {$r0}, {$r1}, p;
+/// selp.s32 {$w0}, {$r0}, {$r1}, p;
+/// selp.s32 {$w1}, {$r0}, {$r1}, p;
+/// }\n"
+/// Example Output:
+/// "{
+/// reg .pred p;
+/// setp.ge.s32 p, $4, $5;"
+/// selp.s32 $0, $4, $5, p;
+/// selp.s32 $1, $4, $5, p;
+/// selp.s32 $2, $4, $5, p;
+/// selp.s32 $3, $4, $5, p;
+/// }\n"
+static std::string rewriteAsmPlaceholders(llvm::StringRef ptxCode) {
+ llvm::SmallDenseSet<unsigned> seenRW, seenW, seenR;
+ llvm::SmallVector<unsigned> rwNums, wNums, rNums;
+
+ // Step 1. Count Register Placeholder numbers
+ countPlaceholderNumbers(ptxCode, seenRW, seenW, seenR, rwNums, wNums, rNums);
+
+ // Step 2. Sort the Register Placeholder numbers
+ llvm::sort(rwNums);
+ llvm::sort(wNums);
+ llvm::sort(rNums);
+
+ // Step 3. Create mapping from original to new IDs
+ llvm::DenseMap<unsigned, unsigned> rwMap, wMap, rMap;
+ unsigned nextId = 0;
+ for (unsigned n : rwNums)
+ rwMap[n] = nextId++;
+ for (unsigned n : wNums)
+ wMap[n] = nextId++;
+ for (unsigned n : rNums)
+ rMap[n] = nextId++;
+
+ // Step 4. Rewrite the PTX code with new IDs
+ std::string out;
+ out.reserve(ptxCode.size());
+ size_t prev = 0;
+ StringRef rest = ptxCode;
+ SmallVector<StringRef, 3> matches;
+ llvm::Regex rx = getPredicateMappingRegex();
+ while (!rest.empty() && rx.match(rest, &matches)) {
+ // Compute absolute match bounds in the original buffer.
+ size_t absStart = (size_t)(matches[0].data() - ptxCode.data());
+ size_t absEnd = absStart + matches[0].size();
+
+ // Emit text before the match.
+ out.append(ptxCode.data() + prev, ptxCode.data() + absStart);
+
+ // Emit compact $K
+ unsigned num = 0;
+ (void)matches[2].getAsInteger(10, num);
+ unsigned id = 0;
+ if (matches[1].equals_insensitive(kReadWritePrefix))
+ id = rwMap.lookup(num);
+ else if (matches[1].equals_insensitive(kWriteOnlyPrefix))
+ id = wMap.lookup(num);
+ else
+ id = rMap.lookup(num);
+
+ out.push_back('$');
+ out += std::to_string(id);
+
+ prev = absEnd;
+
+ const size_t advance =
+ (size_t)(matches[0].data() - rest.data()) + matches[0].size();
+ rest = rest.drop_front(advance);
+ }
+
+ // Step 5. Tail.
+ out.append(ptxCode.data() + prev, ptxCode.data() + ptxCode.size());
+ return out;
}
LLVM::InlineAsmOp PtxBuilder::build() {
auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
LLVM::AsmDialect::AD_ATT);
- auto resultTypes = interfaceOp->getResultTypes();
+ SmallVector<Type> resultTypes = packResultTypes(
+ interfaceOp, needsManualRegisterMapping, registerModifiers, ptxOperands);
// Remove the last comma from the constraints string.
if (!registerConstraints.empty() &&
registerConstraints[registerConstraints.size() - 1] == ',')
registerConstraints.pop_back();
+ registerConstraints = canonicalizeRegisterConstraints(registerConstraints);
std::string ptxInstruction = interfaceOp.getPtx();
+ if (!needsManualRegisterMapping)
+ ptxInstruction = rewriteAsmPlaceholders(ptxInstruction);
// Add the predicate to the asm string.
if (interfaceOp.getPredicate().has_value() &&
@@ -136,7 +456,7 @@ LLVM::InlineAsmOp PtxBuilder::build() {
rewriter, interfaceOp->getLoc(),
/*result types=*/resultTypes,
/*operands=*/ptxOperands,
- /*asm_string=*/llvm::StringRef(ptxInstruction),
+ /*asm_string=*/ptxInstruction,
/*constraints=*/registerConstraints.data(),
/*has_side_effects=*/interfaceOp.hasSideEffect(),
/*is_align_stack=*/false, LLVM::TailCallKind::None,
@@ -146,10 +466,89 @@ LLVM::InlineAsmOp PtxBuilder::build() {
void PtxBuilder::buildAndReplaceOp() {
LLVM::InlineAsmOp inlineAsmOp = build();
- LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
- if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) {
- rewriter.replaceOp(interfaceOp, inlineAsmOp);
- } else {
+ LDBG() << "\n Generated PTX \n\t" << inlineAsmOp;
+
+ // Case 0: no result at all → just erase wrapper op.
+ if (!hasResult) {
rewriter.eraseOp(interfaceOp);
+ return;
+ }
+
+ if (needsManualRegisterMapping) {
+ rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
+ return;
+ }
+
+ // Case 1: Simple path, return single scalar
+ if (!needsPackUnpack(interfaceOp, needsManualRegisterMapping,
+ registerModifiers)) {
+ if (inlineAsmOp->getNumResults() > 0) {
+ rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
+ } else {
+ // RW-only case with no declared results: forward the RW value.
+ SmallVector<Value> results;
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+ if (m == PTXRegisterMod::ReadWrite) {
+ results.push_back(v);
+ break;
+ }
+ rewriter.replaceOp(interfaceOp, results);
+ }
+ return;
+ }
+
+ const bool hasRW = llvm::any_of(registerModifiers, [](PTXRegisterMod m) {
+ return m == PTXRegisterMod::ReadWrite;
+ });
+
+ // All multi-value paths produce a single struct result we need to unpack.
+ assert(LLVM::LLVMStructType::classof(inlineAsmOp.getResultTypes().front()) &&
+ "expected struct return for multi-result inline asm");
+ Value structVal = inlineAsmOp.getResult(0);
+ SmallVector<Value> unpacked =
+ extractStructElements(rewriter, interfaceOp->getLoc(), structVal);
+
+ // Case 2: only declared results (no RW): replace the op with all unpacked.
+ if (!hasRW && interfaceOp->getResults().size() > 0) {
+ rewriter.replaceOp(interfaceOp, unpacked);
+ return;
+ }
+
+ // Case 3: RW-only (no declared results): update RW uses and erase wrapper.
+ if (hasRW && interfaceOp->getResults().size() == 0) {
+ unsigned idx = 0;
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) {
+ if (m != PTXRegisterMod::ReadWrite)
+ continue;
+ Value repl = unpacked[idx++];
+ v.replaceUsesWithIf(repl, [&](OpOperand &use) {
+ Operation *owner = use.getOwner();
+ return owner != interfaceOp && owner != inlineAsmOp;
+ });
+ }
+ rewriter.eraseOp(interfaceOp);
+ return;
+ }
+
+ // Case 4: mixed (RW + declared results).
+ {
+ // First rewrite RW operands in place.
+ unsigned idx = 0;
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) {
+ if (m != PTXRegisterMod::ReadWrite)
+ continue;
+ Value repl = unpacked[idx++];
+ v.replaceUsesWithIf(repl, [&](OpOperand &use) {
+ Operation *owner = use.getOwner();
+ return owner != interfaceOp && owner != inlineAsmOp;
+ });
+ }
+ // The remaining unpacked values correspond to the declared results.
+ SmallVector<Value> tail;
+ tail.reserve(unpacked.size() - idx);
+ for (unsigned i = idx, e = unpacked.size(); i < e; ++i)
+ tail.push_back(unpacked[i]);
+
+ rewriter.replaceOp(interfaceOp, tail);
}
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index 1e02bfe..e268e8f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -12,6 +12,8 @@
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/Ptr/IR/PtrEnums.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
@@ -51,6 +53,87 @@ void LLVMDialect::registerAttributes() {
}
//===----------------------------------------------------------------------===//
+// AddressSpaceAttr
+//===----------------------------------------------------------------------===//
+
+/// Checks whether the given type is an LLVM type that can be loaded or stored.
+static bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering,
+ std::optional<int64_t> alignment,
+ const ::mlir::DataLayout *dataLayout,
+ function_ref<InFlightDiagnostic()> emitError) {
+ if (!isLoadableType(type)) {
+ if (emitError)
+ emitError() << "type must be LLVM type with size, but got " << type;
+ return false;
+ }
+ if (ordering == ptr::AtomicOrdering::not_atomic)
+ return true;
+
+ // To check atomic validity we need a datalayout.
+ if (!dataLayout) {
+ if (emitError)
+ emitError() << "expected a valid data layout";
+ return false;
+ }
+ if (!isTypeCompatibleWithAtomicOp(type, *dataLayout)) {
+ if (emitError)
+ emitError() << "unsupported type " << type << " for atomic access";
+ return false;
+ }
+ return true;
+}
+
+bool AddressSpaceAttr::isValidLoad(
+ Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
+ const ::mlir::DataLayout *dataLayout,
+ function_ref<InFlightDiagnostic()> emitError) const {
+ return isValidLoadStoreImpl(type, ordering, alignment, dataLayout, emitError);
+}
+
+bool AddressSpaceAttr::isValidStore(
+ Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
+ const ::mlir::DataLayout *dataLayout,
+ function_ref<InFlightDiagnostic()> emitError) const {
+ return isValidLoadStoreImpl(type, ordering, alignment, dataLayout, emitError);
+}
+
+bool AddressSpaceAttr::isValidAtomicOp(
+ ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering,
+ std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
+ function_ref<InFlightDiagnostic()> emitError) const {
+ // TODO: update this method once `ptr.atomic_rmw` is implemented.
+ assert(false && "unimplemented, see TODO in the source.");
+ return false;
+}
+
+bool AddressSpaceAttr::isValidAtomicXchg(
+ Type type, ptr::AtomicOrdering successOrdering,
+ ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
+ const ::mlir::DataLayout *dataLayout,
+ function_ref<InFlightDiagnostic()> emitError) const {
+ // TODO: update this method once `ptr.atomic_cmpxchg` is implemented.
+ assert(false && "unimplemented, see TODO in the source.");
+ return false;
+}
+
+bool AddressSpaceAttr::isValidAddrSpaceCast(
+ Type tgt, Type src, function_ref<InFlightDiagnostic()> emitError) const {
+ // TODO: update this method once the `ptr.addrspace_cast` op is added to the
+ // dialect.
+ assert(false && "unimplemented, see TODO in the source.");
+ return false;
+}
+
+bool AddressSpaceAttr::isValidPtrIntCast(
+ Type intLikeTy, Type ptrLikeTy,
+ function_ref<InFlightDiagnostic()> emitError) const {
+ // TODO: update this method once the int-cast ops are added to the `ptr`
+ // dialect.
+ assert(false && "unimplemented, see TODO in the source.");
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
// AliasScopeAttr
//===----------------------------------------------------------------------===//
@@ -374,6 +457,43 @@ TargetFeaturesAttr TargetFeaturesAttr::featuresAt(Operation *op) {
getAttributeName());
}
+FailureOr<Attribute> TargetFeaturesAttr::query(DataLayoutEntryKey key) {
+ auto stringKey = dyn_cast<StringAttr>(key);
+ if (!stringKey)
+ return failure();
+
+ if (contains(stringKey))
+ return UnitAttr::get(getContext());
+
+ if (contains((std::string("+") + stringKey.strref()).str()))
+ return BoolAttr::get(getContext(), true);
+
+ if (contains((std::string("-") + stringKey.strref()).str()))
+ return BoolAttr::get(getContext(), false);
+
+ return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// TargetAttr
+//===----------------------------------------------------------------------===//
+
+FailureOr<::mlir::Attribute> TargetAttr::query(DataLayoutEntryKey key) {
+ if (auto stringAttrKey = dyn_cast<StringAttr>(key)) {
+ if (stringAttrKey.getValue() == "triple")
+ return getTriple();
+ if (stringAttrKey.getValue() == "chip")
+ return getChip();
+ if (stringAttrKey.getValue() == "features" && getFeatures())
+ return getFeatures();
+ }
+ return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// ModuleFlagAttr
+//===----------------------------------------------------------------------===//
+
LogicalResult
ModuleFlagAttr::verify(function_ref<InFlightDiagnostic()> emitError,
LLVM::ModFlagBehavior flagBehavior, StringAttr key,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 422039f..ef27070 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -141,6 +141,38 @@ static ParseResult parseLLVMLinkage(OpAsmParser &p, LinkageAttr &val) {
return success();
}
+static ArrayAttr getLLVMAlignParamForCompressExpand(OpBuilder &builder,
+ bool isExpandLoad,
+ uint64_t alignment = 1) {
+ // From
+ // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
+ // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
+ //
+ // The pointer alignment defaults to 1.
+ if (alignment == 1) {
+ return nullptr;
+ }
+
+ auto emptyDictAttr = builder.getDictionaryAttr({});
+ auto alignmentAttr = builder.getI64IntegerAttr(alignment);
+ auto namedAttr =
+ builder.getNamedAttr(LLVMDialect::getAlignAttrName(), alignmentAttr);
+ SmallVector<mlir::NamedAttribute> attrs = {namedAttr};
+ auto alignDictAttr = builder.getDictionaryAttr(attrs);
+ // From
+ // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
+ // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
+ //
+ // The align parameter attribute can be provided for [expandload]'s first
+ // argument. The align parameter attribute can be provided for
+ // [compressstore]'s second argument.
+ int pos = isExpandLoad ? 0 : 1;
+ return pos == 0 ? builder.getArrayAttr(
+ {alignDictAttr, emptyDictAttr, emptyDictAttr})
+ : builder.getArrayAttr(
+ {emptyDictAttr, alignDictAttr, emptyDictAttr});
+}
+
//===----------------------------------------------------------------------===//
// Operand bundle helpers.
//===----------------------------------------------------------------------===//
@@ -821,8 +853,8 @@ void LoadOp::getEffects(
/// Returns true if the given type is supported by atomic operations. All
/// integer, float, and pointer types with a power-of-two bitsize and a minimal
/// size of 8 bits are supported.
-static bool isTypeCompatibleWithAtomicOp(Type type,
- const DataLayout &dataLayout) {
+bool LLVM::isTypeCompatibleWithAtomicOp(Type type,
+ const DataLayout &dataLayout) {
if (!isa<IntegerType, LLVMPointerType>(type))
if (!isCompatibleFloatingPointType(type))
return false;
@@ -836,8 +868,9 @@ static bool isTypeCompatibleWithAtomicOp(Type type,
/// Verifies the attributes and the type of atomic memory access operations.
template <typename OpTy>
-LogicalResult verifyAtomicMemOp(OpTy memOp, Type valueType,
- ArrayRef<AtomicOrdering> unsupportedOrderings) {
+static LogicalResult
+verifyAtomicMemOp(OpTy memOp, Type valueType,
+ ArrayRef<AtomicOrdering> unsupportedOrderings) {
if (memOp.getOrdering() != AtomicOrdering::not_atomic) {
DataLayout dataLayout = DataLayout::closest(memOp);
if (!isTypeCompatibleWithAtomicOp(valueType, dataLayout))
@@ -1087,7 +1120,7 @@ static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee) {
/// Verify that the parameter and return types of the variadic callee type match
/// the `callOp` argument and result types.
template <typename OpTy>
-LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
+static LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
std::optional<LLVMFunctionType> varCalleeType = callOp.getVarCalleeType();
if (!varCalleeType)
return success();
@@ -2500,7 +2533,7 @@ LogicalResult GlobalOp::verifyRegions() {
// LLVM::GlobalCtorsOp
//===----------------------------------------------------------------------===//
-LogicalResult checkGlobalXtorData(Operation *op, ArrayAttr data) {
+static LogicalResult checkGlobalXtorData(Operation *op, ArrayAttr data) {
if (data.empty())
return success();
@@ -4117,6 +4150,32 @@ LogicalResult LLVM::masked_scatter::verify() {
}
//===----------------------------------------------------------------------===//
+// masked_expandload (intrinsic)
+//===----------------------------------------------------------------------===//
+
+void LLVM::masked_expandload::build(OpBuilder &builder, OperationState &state,
+ mlir::TypeRange resTys, Value ptr,
+ Value mask, Value passthru,
+ uint64_t align) {
+ ArrayAttr argAttrs = getLLVMAlignParamForCompressExpand(builder, true, align);
+ build(builder, state, resTys, ptr, mask, passthru, /*arg_attrs=*/argAttrs,
+ /*res_attrs=*/nullptr);
+}
+
+//===----------------------------------------------------------------------===//
+// masked_compressstore (intrinsic)
+//===----------------------------------------------------------------------===//
+
+void LLVM::masked_compressstore::build(OpBuilder &builder,
+ OperationState &state, Value value,
+ Value ptr, Value mask, uint64_t align) {
+ ArrayAttr argAttrs =
+ getLLVMAlignParamForCompressExpand(builder, false, align);
+ build(builder, state, value, ptr, mask, /*arg_attrs=*/argAttrs,
+ /*res_attrs=*/nullptr);
+}
+
+//===----------------------------------------------------------------------===//
// InlineAsmOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index e7d5dad..ef38027 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -19,6 +19,7 @@
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "sroa"
@@ -734,9 +735,8 @@ static std::optional<uint64_t> gepToByteOffset(const DataLayout &dataLayout,
return false;
})
.Default([&](Type type) {
- LLVM_DEBUG(llvm::dbgs()
- << "[sroa] Unsupported type for offset computations"
- << type << "\n");
+ LDBG() << "[sroa] Unsupported type for offset computations"
+ << type;
return true;
});
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index 78b4411..297640c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -24,7 +24,9 @@ using namespace mlir::LLVM;
/// prints it as usual.
static void dispatchPrint(AsmPrinter &printer, Type type) {
if (isCompatibleType(type) &&
- !llvm::isa<IntegerType, FloatType, VectorType>(type))
+ !(llvm::isa<IntegerType, FloatType, VectorType>(type) ||
+ (llvm::isa<PtrLikeTypeInterface>(type) &&
+ !llvm::isa<LLVMPointerType>(type))))
return mlir::LLVM::detail::printType(type, printer);
printer.printType(type);
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index fee2d3e..2dd0132 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -13,6 +13,7 @@
#include "TypeDetail.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -701,6 +702,17 @@ const llvm::fltSemantics &LLVMPPCFP128Type::getFloatSemantics() const {
// Utility functions.
//===----------------------------------------------------------------------===//
+/// Check whether type is a compatible ptr type. These are pointer-like types
+/// with no element type, no metadata, and using the LLVM AddressSpaceAttr
+/// memory space.
+static bool isCompatiblePtrType(Type type) {
+ auto ptrTy = dyn_cast<PtrLikeTypeInterface>(type);
+ if (!ptrTy)
+ return false;
+ return !ptrTy.hasPtrMetadata() && ptrTy.getElementType() == nullptr &&
+ isa<AddressSpaceAttr>(ptrTy.getMemorySpace());
+}
+
bool mlir::LLVM::isCompatibleOuterType(Type type) {
// clang-format off
if (llvm::isa<
@@ -734,7 +746,7 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) {
if (auto vecType = llvm::dyn_cast<VectorType>(type))
return vecType.getRank() == 1;
- return false;
+ return isCompatiblePtrType(type);
}
static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
@@ -784,6 +796,8 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
LLVMX86AMXType
>([](Type) { return true; })
// clang-format on
+ .Case<PtrLikeTypeInterface>(
+ [](Type type) { return isCompatiblePtrType(type); })
.Default([](Type) { return false; });
if (!result)
@@ -805,6 +819,18 @@ bool mlir::LLVM::isCompatibleType(Type type) {
return LLVMDialect::isCompatibleType(type);
}
+bool mlir::LLVM::isLoadableType(Type type) {
+ return /*LLVM_PrimitiveType*/ (
+ LLVM::isCompatibleOuterType(type) &&
+ !isa<LLVM::LLVMVoidType, LLVM::LLVMFunctionType>(type)) &&
+ /*LLVM_OpaqueStruct*/
+ !(isa<LLVM::LLVMStructType>(type) &&
+ cast<LLVM::LLVMStructType>(type).isOpaque()) &&
+ /*LLVM_AnyTargetExt*/
+ !(isa<LLVM::LLVMTargetExtType>(type) &&
+ !cast<LLVM::LLVMTargetExtType>(type).supportsMemOps());
+}
+
bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {
return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
Float80Type, Float128Type, LLVMPPCFP128Type>(type);
@@ -818,7 +844,8 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) {
if (auto intType = llvm::dyn_cast<IntegerType>(elementType))
return intType.isSignless();
return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
- Float80Type, Float128Type, LLVMPointerType>(elementType);
+ Float80Type, Float128Type, LLVMPointerType>(elementType) ||
+ isCompatiblePtrType(elementType);
}
return false;
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index e0977f5..376e3c3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -33,6 +33,7 @@
#include "llvm/IR/IRBuilder.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/NVPTXAddrSpace.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <optional>
@@ -50,7 +51,6 @@ using namespace NVVM;
// This verifier is shared among the following Ops:
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
-// CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
bool isIm2Col,
@@ -82,8 +82,27 @@ LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
}
LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
- if (getCoordinates().size() > 5)
- return emitError("Maximum 5 coordinates and dimension is supported.");
+ TMAStoreMode mode = getMode();
+ // We lower through inline-ptx when getPredicate() is true.
+ // a) Only TILE mode is supported
+ // b) Cache-hint is not supported
+ if (getPredicate()) {
+ if (mode != TMAStoreMode::TILE)
+ return emitError("Inline-ptx lowering supported only for Tile mode.");
+ if (getL2CacheHint())
+ return emitError("Inline-ptx lowering unsupported with L2 cache-hint.");
+ }
+
+ size_t dims = getCoordinates().size();
+ switch (mode) {
+ case TMAStoreMode::TILE:
+ return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
+ case TMAStoreMode::IM2COL:
+ return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
+ case TMAStoreMode::TILE_SCATTER4:
+ if (dims != 5)
+ return emitError("Scatter4 mode expects 5 coordinates");
+ }
return success();
}
@@ -98,17 +117,59 @@ LogicalResult CpAsyncOp::verify() {
return success();
}
+// This verify params can be shared across TMA Load and Prefetch Ops.
+static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff,
+ TMALoadMode mode, Location loc) {
+ if (tensorDims < 1 || tensorDims > 5)
+ return emitError(loc, "expects coordinates between 1 to 5 dimension");
+
+ auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col,
+ size_t expectedIm2colOff) -> LogicalResult {
+ if (isIm2col && (tensorDims < 3))
+ return emitError(loc)
+ << "to use " << stringifyEnum(mode)
+ << " mode, the tensor has to be at least 3-dimensional";
+
+ if (numIm2colOff != expectedIm2colOff)
+ return emitError(loc) << " im2col offsets expected " << expectedIm2colOff
+ << " (provided " << numIm2colOff << ")";
+
+ return success();
+ };
+
+ switch (mode) {
+ case TMALoadMode::TILE:
+ return checkTMALoadParams(mode, false, 0);
+ case TMALoadMode::IM2COL:
+ return checkTMALoadParams(mode, true, tensorDims - 2);
+ case TMALoadMode::IM2COL_W:
+ case TMALoadMode::IM2COL_W_128:
+ return checkTMALoadParams(mode, true, 2);
+ case TMALoadMode::TILE_GATHER4:
+ return (tensorDims == 5)
+ ? checkTMALoadParams(mode, false, 0)
+ : emitError(loc, "Gather4 mode expects 5 coordinates");
+ }
+ return success();
+}
+
LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
- size_t numIm2ColOffsets = getIm2colOffsets().size();
- bool isIm2Col = numIm2ColOffsets > 0;
- return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
- numIm2ColOffsets, getLoc());
+ return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
+ getMode(), getLoc());
}
LogicalResult CpAsyncBulkTensorReduceOp::verify() {
- bool isIm2Col = (getMode() == TMAStoreMode::IM2COL);
- return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 0,
- getLoc());
+ TMAStoreMode mode = getMode();
+ size_t dims = getCoordinates().size();
+ switch (mode) {
+ case TMAStoreMode::TILE:
+ return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
+ case TMAStoreMode::IM2COL:
+ return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
+ case TMAStoreMode::TILE_SCATTER4:
+ return emitError("Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
+ }
+ return success();
}
LogicalResult ConvertFloatToTF32Op::verify() {
@@ -189,6 +250,26 @@ LogicalResult BulkStoreOp::verify() {
return success();
}
+LogicalResult PMEventOp::verify() {
+ auto eventId = getEventId();
+ auto maskedEventId = getMaskedEventId();
+ if (!maskedEventId && !eventId) {
+ return emitOpError() << "either `id` or `mask` must be set";
+ }
+
+ if (maskedEventId && eventId) {
+ return emitOpError() << "`id` and `mask` cannot be set at the same time";
+ }
+
+ if (eventId) {
+ if (eventId < 0 || eventId > 15) {
+ return emitOpError() << "`id` must be between 0 and 15";
+ }
+ }
+
+ return llvm::success();
+}
+
// Given the element type of an operand and whether or not it is an accumulator,
// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
// operand's element type.
@@ -791,24 +872,58 @@ LogicalResult NVVM::WMMAMmaOp::verify() {
}
LogicalResult NVVM::LdMatrixOp::verify() {
- unsigned addressSpace =
- llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
- if (addressSpace != NVVM::kSharedMemorySpace)
- return emitOpError("expected source pointer in memory space 3");
-
- if (getNum() != 1 && getNum() != 2 && getNum() != 4)
- return emitOpError("expected num attribute to be 1, 2 or 4");
+ uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN();
+ if (m == 8 && n == 8) {
+ if (num != 1 && num != 2 && num != 4) {
+ return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 "
+ "matrix");
+ }
+ if (getEltType() != LdStMatrixEltType::B16) {
+ return emitOpError("expected element type to be b16 for 8x8 matrix");
+ }
+ } else if (m == 8 && n == 16) {
+ if (num != 1 && num != 2 && num != 4) {
+ return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 "
+ "matrix");
+ }
+ if (getLayout() != MMALayout::row) {
+ return emitOpError("expected layout to be row for 8x16 matrix");
+ }
+ if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
+ getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
+ return emitOpError("expected element type to be b8x16.b4x16_p64 or "
+ "b8x16.b6x16_p32 for 8x16 matrix");
+ }
+ } else if (m == 16 && n == 16) {
+ if (num != 1 && num != 2) {
+ return emitOpError("expected num attribute to be 1 or 2 for 16x16 "
+ "matrix");
+ }
+ if (getLayout() != MMALayout::col) {
+ return emitOpError("expected layout to be col for 16x16 matrix");
+ }
+ if (getEltType() != LdStMatrixEltType::B8 &&
+ getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
+ getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
+ return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or "
+ "b8x16.b6x16_p32 for 16x16 matrix");
+ }
+ } else {
+ return emitOpError("expected shape to be 8x8, 8x16 or 16x16");
+ }
Type i32 = IntegerType::get(getContext(), 32);
- if (getNum() == 1 && getType() != i32)
+ uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
+ if (numElements == 1 && getType() != i32)
return emitOpError("expected destination type is i32");
- if (getNum() == 2 || getNum() == 4) {
+ if (numElements == 2 || numElements == 4) {
Type dstType = LLVM::LLVMStructType::getLiteral(
- getContext(), SmallVector<Type>(getNum(), i32));
+ getContext(), SmallVector<Type>(numElements, i32));
if (getType() != dstType)
return emitOpError("expected destination type is a structure of ")
- << getNum() << " elements of type i32";
+ << numElements << " elements of type i32";
}
+
return success();
}
@@ -1069,7 +1184,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
return ptx;
}
-void NVVM::WgmmaMmaAsyncOp::getAsmValues(
+bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
RewriterBase &rewriter,
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
&asmValues) {
@@ -1100,7 +1215,9 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
{makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
mlir::NVVM::PTXRegisterMod::Read});
}
+ return true; // Has manual mapping
}
+
LogicalResult NVVM::FenceProxyOp::verify() {
if (getKind() == NVVM::ProxyKind::TENSORMAP)
return emitOpError() << "tensormap proxy is not a supported proxy kind";
@@ -1216,30 +1333,70 @@ LogicalResult NVVM::PrefetchOp::verify() {
unsigned addressSpace =
llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
+ std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
- if (getUniform()) {
- if (getCacheLevel() != CacheLevel::L1)
- return emitOpError("unsupported cache level, the only supported uniform "
- "cache level is L1");
+ if (getTensormap() && cacheLevel)
+ return emitOpError("cannot specify both tensormap and cache level");
- if (addressSpace != MemSpace::kGenericMemorySpace)
+ if (getTensormap()) {
+ if (addressSpace != MemSpace::kGenericMemorySpace &&
+ addressSpace != MemSpace::kConstantMemorySpace) {
return emitOpError(
- "prefetch to uniform cache requires a generic pointer");
- }
+ "prefetch tensormap requires a generic or constant pointer");
+ }
- if (evictPriority) {
- if (getCacheLevel() != CacheLevel::L2)
+ if (evictPriority) {
return emitOpError(
- "cache eviction priority supported only for cache level L2");
-
- if (addressSpace != MemSpace::kGlobalMemorySpace)
- return emitOpError("cache eviction priority requires a global pointer");
+ "prefetch tensormap does not support eviction priority");
+ }
- if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
- *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
+ if (getInParamSpace() && addressSpace != MemSpace::kGenericMemorySpace) {
return emitOpError(
- "unsupported cache eviction priority, only evict_last and "
- "evict_normal are supported");
+ "in_param_space can only be specified for a generic pointer");
+ }
+
+ } else if (cacheLevel) {
+ if (addressSpace != MemSpace::kGenericMemorySpace &&
+ addressSpace != MemSpace::kGlobalMemorySpace &&
+ addressSpace != MemSpace::kLocalMemorySpace) {
+ return emitOpError("prefetch to cache level requires a generic, global, "
+ "or local pointer");
+ }
+
+ if (getUniform()) {
+ if (*cacheLevel != CacheLevel::L1) {
+ return emitOpError(
+ "unsupported cache level, the only supported uniform "
+ "cache level is L1");
+ }
+
+ if (addressSpace != MemSpace::kGenericMemorySpace) {
+ return emitOpError(
+ "prefetch to uniform cache requires a generic pointer");
+ }
+ }
+
+ if (evictPriority) {
+ if (*cacheLevel != CacheLevel::L2)
+ return emitOpError(
+ "cache eviction priority supported only for cache level L2");
+
+ if (addressSpace != MemSpace::kGlobalMemorySpace)
+ return emitOpError("cache eviction priority requires a global pointer");
+
+ if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
+ *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
+ return emitOpError(
+ "unsupported cache eviction priority, only evict_last and "
+ "evict_normal are supported");
+ }
+
+ if (getPredicate())
+ return emitOpError("predicate supported only on prefetch tensormap");
+
+ } else {
+ return emitOpError(
+ "requires specification of either cache level or tensormap");
}
return success();
@@ -1379,28 +1536,102 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
return {id, std::move(args)};
}
-llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
- bool isIm2Col) {
- switch (tensorDims) {
- case 1:
- return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
- case 2:
- return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
- case 3:
- return isIm2Col
- ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
- : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
- case 4:
- return isIm2Col
- ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
- : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
- case 5:
- return isIm2Col
- ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
- : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
- default:
- llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
- }
+mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ // Fill the Intrinsic Args
+ args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
+
+ for (auto v : thisOp.getCoordinates())
+ args.push_back(mt.lookupValue(v));
+ for (auto v : thisOp.getIm2colOffsets())
+ args.push_back(mt.lookupValue(v));
+
+ mlir::Value cacheHint = thisOp.getL2CacheHint();
+ const bool hasCacheHint = static_cast<bool>(cacheHint);
+ llvm::Value *i64Unused =
+ llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
+ args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
+ args.push_back(builder.getInt1(hasCacheHint));
+
+ const unsigned NI = llvm::Intrinsic::not_intrinsic;
+ static constexpr llvm::Intrinsic::ID IDTable[][6] = {
+ {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
+ {NI, NI, NI,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
+ {NI, NI, NI,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
+ {NI, NI, NI,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
+ {NI, NI, NI, NI, NI,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
+
+ static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
+ "TMALoadModes must match number of rows in IDTable");
+ size_t mode = static_cast<size_t>(thisOp.getMode());
+ size_t dim = thisOp.getCoordinates().size();
+ llvm::Intrinsic::ID id = IDTable[mode][dim];
+ if (id == llvm::Intrinsic::not_intrinsic)
+ llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
+
+ return {id, std::move(args)};
+}
+
+mlir::NVVM::IDArgPair
+CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ // Fill the Intrinsic Args
+ args.push_back(mt.lookupValue(thisOp.getSrcMem()));
+ args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
+
+ for (auto v : thisOp.getCoordinates())
+ args.push_back(mt.lookupValue(v));
+
+ mlir::Value cacheHint = thisOp.getL2CacheHint();
+ const bool hasCacheHint = static_cast<bool>(cacheHint);
+ llvm::Value *i64Unused =
+ llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
+ args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
+ args.push_back(builder.getInt1(hasCacheHint));
+
+ const unsigned NI = llvm::Intrinsic::not_intrinsic;
+ static constexpr llvm::Intrinsic::ID IDTable[][6] = {
+ {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
+ {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
+ {NI, NI, NI, NI, NI,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
+
+ static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
+ "TMAStoreModes must match number of rows in IDTable");
+ size_t mode = static_cast<size_t>(thisOp.getMode());
+ size_t dim = thisOp.getCoordinates().size();
+ llvm::Intrinsic::ID id = IDTable[mode][dim];
+ if (id == llvm::Intrinsic::not_intrinsic)
+ llvm_unreachable(
+ "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
+
+ return {id, std::move(args)};
}
#define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
@@ -1566,7 +1797,7 @@ Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
.getAddressSpace();
bool isShared = as == NVVMMemorySpace::kSharedMemorySpace;
- bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
+ bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
llvm::Intrinsic::ID id;
if (isShared) {
@@ -1588,7 +1819,7 @@ llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt,
llvm::SmallVector<llvm::Value *> &args) {
auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
- auto id = (curOp.getGroup() == Tcgen05GroupKind::CTA_1)
+ auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
: llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
@@ -1616,7 +1847,7 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
.getAddressSpace();
bool isShared = as == NVVMMemorySpace::kSharedMemorySpace;
bool hasMulticast = static_cast<bool>(curOp.getMulticastMask());
- bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
+ bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
llvm::Intrinsic::ID id =
is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
@@ -1648,7 +1879,7 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
auto curOp = cast<NVVM::Tcgen05CpOp>(op);
- bool is2CTA = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
+ bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
auto srcFmt = curOp.getSrcFormat();
auto mc = curOp.getMulticast();
@@ -1774,26 +2005,47 @@ NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
return {ids[type], args};
}
-llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) {
+static llvm::Value *getParamCastedAddr(llvm::Value *addr,
+ llvm::IRBuilderBase &builder) {
+ return builder.CreateAddrSpaceCast(
+ addr,
+ llvm::PointerType::get(builder.getContext(),
+ llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM));
+}
+
+NVVM::IDArgPair
+PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
using MemSpace = NVVM::NVVMMemorySpace;
using CacheLevel = NVVM::PrefetchCacheLevel;
- NVVM::PrefetchCacheLevel cacheLevel = op.getCacheLevel();
+ std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
std::optional<NVVM::CacheEvictionPriority> evictPriority =
op.getEvictPriority();
unsigned addressSpace =
llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
.getAddressSpace();
- if (op.getUniform() && cacheLevel == CacheLevel::L1)
- return llvm::Intrinsic::nvvm_prefetchu_L1;
+ llvm::SmallVector<llvm::Value *> args;
+ llvm::Value *addr = mt.lookupValue(op.getAddr());
+ args.push_back(op.getInParamSpace() ? getParamCastedAddr(addr, builder)
+ : addr);
+
+ if (op.getTensormap())
+ return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
+
+ assert(cacheLevel && "expected cache level for non-tensormap prefetch");
+
+ if (op.getUniform() && *cacheLevel == CacheLevel::L1)
+ return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
- if (evictPriority && cacheLevel == CacheLevel::L2) {
+ if (evictPriority && *cacheLevel == CacheLevel::L2) {
switch (*evictPriority) {
case NVVM::CacheEvictionPriority::EvictLast:
- return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last;
+ return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
case NVVM::CacheEvictionPriority::EvictNormal:
- return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal;
+ return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
default:
llvm_unreachable("Invalid cache eviction priority");
}
@@ -1801,21 +2053,41 @@ llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) {
switch (addressSpace) {
case MemSpace::kGenericMemorySpace:
- return cacheLevel == CacheLevel::L1 ? llvm::Intrinsic::nvvm_prefetch_L1
- : llvm::Intrinsic::nvvm_prefetch_L2;
+ return *cacheLevel == CacheLevel::L1
+ ? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args})
+ : NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
case MemSpace::kGlobalMemorySpace:
- return cacheLevel == CacheLevel::L1
- ? llvm::Intrinsic::nvvm_prefetch_global_L1
- : llvm::Intrinsic::nvvm_prefetch_global_L2;
+ return *cacheLevel == CacheLevel::L1
+ ? NVVM::IDArgPair(
+ {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
+ : NVVM::IDArgPair(
+ {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
case MemSpace::kLocalMemorySpace:
- return cacheLevel == CacheLevel::L1
- ? llvm::Intrinsic::nvvm_prefetch_local_L1
- : llvm::Intrinsic::nvvm_prefetch_local_L2;
+ return *cacheLevel == CacheLevel::L1
+ ? NVVM::IDArgPair(
+ {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
+ : NVVM::IDArgPair(
+ {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
default:
llvm_unreachable("Invalid pointer address space");
}
}
+bool NVVM::InlinePtxOp::getAsmValues(
+ RewriterBase &rewriter,
+ llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
+ &asmValues) {
+ for (auto arg : getReadWriteArgs())
+ asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::ReadWrite});
+ for (auto arg : getResults())
+ asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Write});
+ for (auto arg : getReadOnlyArgs())
+ asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Read});
+ if (getPredicate())
+ asmValues.push_back({getPredicate(), mlir::NVVM::PTXRegisterMod::Read});
+ return false; // No manual mapping needed
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
@@ -1854,19 +2126,31 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
attrName == NVVMDialect::getReqntidAttrName() ||
attrName == NVVMDialect::getClusterDimAttrName()) {
auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
- if (!values || values.empty() || values.size() > 3)
+ if (!values || values.empty() || values.size() > 3) {
return op->emitError()
<< "'" << attrName
<< "' attribute must be integer array with maximum 3 index";
+ }
}
// If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
// attribute
if (attrName == NVVMDialect::getMinctasmAttrName() ||
attrName == NVVMDialect::getMaxnregAttrName() ||
attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
- if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
+ if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) {
return op->emitError()
<< "'" << attrName << "' attribute must be integer constant";
+ }
+ }
+ // blocksareclusters must be used along with reqntid and cluster_dim
+ if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
+ if (!op->hasAttr(NVVMDialect::getReqntidAttrName()) ||
+ !op->hasAttr(NVVMDialect::getClusterDimAttrName())) {
+ return op->emitError()
+ << "'" << attrName << "' attribute must be used along with "
+ << "'" << NVVMDialect::getReqntidAttrName() << "' and "
+ << "'" << NVVMDialect::getClusterDimAttrName() << "'";
+ }
}
return success();
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
index 8317b67..23b4130 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
using namespace mlir;
using namespace LLVM;
@@ -63,9 +63,8 @@ DIExpressionRewriter::simplify(DIExpressionAttr expr,
}
if (maxNumRewrites && numRewrites >= *maxNumRewrites) {
- LLVM_DEBUG(llvm::dbgs()
- << "LLVMDIExpressionSimplifier exceeded max num rewrites ("
- << maxNumRewrites << ")\n");
+ LDBG() << "LLVMDIExpressionSimplifier exceeded max num rewrites ("
+ << maxNumRewrites << ")";
// Skip rewriting the rest.
result.append(inputs.begin(), inputs.end());
}
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
index b951df8..4ea2ac9 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
@@ -129,7 +129,6 @@ handleInlinedAllocas(Operation *call,
OpBuilder::InsertionGuard insertionGuard(builder);
builder.setInsertionPoint(allocaOp);
LLVM::LifetimeStartOp::create(builder, allocaOp.getLoc(),
- arraySize.getValue().getLimitedValue(),
allocaOp.getResult());
}
allocaOp->moveAfter(newConstant);
@@ -147,7 +146,6 @@ handleInlinedAllocas(Operation *call,
for (auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) {
if (shouldInsertLifetime)
LLVM::LifetimeEndOp::create(builder, allocaOp.getLoc(),
- arraySize.getValue().getLimitedValue(),
allocaOp.getResult());
}
}
@@ -237,8 +235,10 @@ getUnderlyingObjectSet(Value pointerValue) {
WalkContinuation walkResult = walkSlice(pointerValue, [&](Value val) {
// Attempt to advance to the source of the underlying view-like operation.
// Examples of view-like operations include GEPOp and AddrSpaceCastOp.
- if (auto viewOp = val.getDefiningOp<ViewLikeOpInterface>())
- return WalkContinuation::advanceTo(viewOp.getViewSource());
+ if (auto viewOp = val.getDefiningOp<ViewLikeOpInterface>()) {
+ if (val == viewOp.getViewDest())
+ return WalkContinuation::advanceTo(viewOp.getViewSource());
+ }
// Attempt to advance to control flow predecessors.
std::optional<SmallVector<Value>> controlFlowPredecessors =
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 34c63d3..578931e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -194,9 +194,10 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
ArrayRef<AffineMap> indexingMaps) {
// Initialize indexingMaps attribute, for MatmulOp.
SmallVector<Attribute, 3> indexingMapsAttrVal;
- indexingMapsAttrVal = llvm::map_to_vector(
- MatmulOp::getDefaultIndexingMaps(b.getContext()),
- [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ indexingMapsAttrVal =
+ llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
+ return AffineMapAttr::get(map);
+ });
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
attributes, regionBuilder);
@@ -1569,40 +1570,50 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}
-// Retrieve the operation from the body, if it is the only one (except
-// yield) and if it gets the same amount of arguments as the body does.
-// If initFirst flag is enabled, we check that init takes the first position in
-// operands of payload.
-static Operation *findPayloadOp(Block *body, bool initFirst = false) {
+static bool canUseShortForm(Block *body, bool initFirst = false) {
+ // Check if the body can be printed in short form. The following 4 conditions
+ // must be satisfied:
+
+ // 1) The body must contain exactly 2 operations: the payload op and a yield.
if (body->getOperations().size() != 2)
- return nullptr;
+ return false;
Operation &payload = body->getOperations().front();
- assert(isa<YieldOp>(body->getOperations().back()));
+ // 2) The payload op must have the same number of operands as the number of
+ // block arguments.
if (payload.getNumOperands() == 0 ||
payload.getNumOperands() != body->getNumArguments())
- return nullptr;
+ return false;
+
+ // 3) If `initFirst` is true (e.g., for reduction ops), the init block
+ // must be the first operand of the payload op, otherwise, the operands
+ // must match the block arguments in order.
if (initFirst) {
// check init
if (payload.getOperands().back() != body->getArgument(0))
- return nullptr;
+ return false;
// check rest
for (const auto &[operand, bbArg] :
llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
if (bbArg != operand)
- return nullptr;
+ return false;
}
} else {
for (const auto &[operand, bbArg] :
llvm::zip(payload.getOperands(), body->getArguments())) {
if (bbArg != operand)
- return nullptr;
+ return false;
}
}
- return &payload;
+
+ // 4) The `yield` operand must be the result of the payload op.
+ auto yieldOp = cast<YieldOp>(body->getTerminator());
+ return yieldOp.getNumOperands() == 1 &&
+ yieldOp.getOperand(0).getDefiningOp() &&
+ yieldOp.getOperand(0).getDefiningOp() == &payload;
}
-void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
+static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
SmallVector<StringRef> elidedAttrs;
std::string attrToElide;
p << " { " << payloadOp->getName().getStringRef();
@@ -1621,15 +1632,15 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
void MapOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
- Operation *payloadOp = findPayloadOp(mapper);
- if (payloadOp) {
- printShortForm(p, payloadOp);
+ bool useShortForm = canUseShortForm(mapper);
+ if (useShortForm) {
+ printShortForm(p, &mapper->getOperations().front());
}
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
p.printOptionalAttrDict((*this)->getAttrs());
- if (!payloadOp) {
+ if (!useShortForm) {
// Print region if the payload op was not detected.
p.increaseIndent();
p.printNewline();
@@ -1828,15 +1839,15 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
void ReduceOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
- Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
- if (payloadOp) {
- printShortForm(p, payloadOp);
+ bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
+ if (useShortForm) {
+ printShortForm(p, &mapper->getOperations().front());
}
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
- if (!payloadOp) {
+ if (!useShortForm) {
// Print region if the payload op was not detected.
p.increaseIndent();
p.printNewline();
@@ -3749,6 +3760,25 @@ std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
// MatMulOp
//===----------------------------------------------------------------------===//
+static FailureOr<SmallVector<SmallVector<int64_t>>>
+getAffineResultPositions(ArrayAttr maps) {
+ SmallVector<SmallVector<int64_t>> positions;
+ for (auto map : maps) {
+ AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
+ if (!attr)
+ return failure();
+ SmallVector<int64_t> pos;
+ for (auto result : attr.getAffineMap().getResults()) {
+ auto dim = dyn_cast<AffineDimExpr>(result);
+ if (!dim)
+ return failure();
+ pos.push_back(dim.getPosition());
+ }
+ positions.push_back(pos);
+ }
+ return positions;
+}
+
/// Returns a list of AffineMap with the typical matmul indexing charactristic.
SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
AffineExpr d0, d1, d2;
@@ -3760,6 +3790,20 @@ SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
return indexingMaps;
}
+bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
+ (*positions)[1] == SmallVector<int64_t>{2, 1} &&
+ (*positions)[2] == SmallVector<int64_t>{0, 1};
+}
+
SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
utils::IteratorType::parallel,
@@ -3836,7 +3880,7 @@ bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
return expr.isFunctionOfDim(bcastMap.getNumDims() - 1);
}
-FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
+static FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
if (parser.parseOptionalKeyword("indexing_maps"))
return ArrayAttr{
nullptr}; // Success in case indexing_maps was not provided.
@@ -3912,6 +3956,380 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+SmallVector<AffineMap>
+MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
+ AffineExpr d0, d1, d2;
+ MLIRContext *context = builder.getContext();
+ bindDims(context, d0, d1, d2);
+ AffineMap mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
+ AffineMap mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
+ AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
+ return {mapLHS, mapRHS, mapOut};
+}
+
+bool MatmulTransposeAOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{2, 0} &&
+ (*positions)[1] == SmallVector<int64_t>{2, 1} &&
+ (*positions)[2] == SmallVector<int64_t>{0, 1};
+}
+
+void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
+ OperationState &result,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
+}
+
+MatmulTransposeAOp
+MatmulTransposeAOp::create(OpBuilder &builder, Location location,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, inputs, outputs, attributes);
+ auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
+ OperationState &result,
+ TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
+}
+
+MatmulTransposeAOp
+MatmulTransposeAOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, attributes);
+ auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
+ OperationState &result,
+ TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ result.addAttribute("cast", cast);
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
+}
+
+MatmulTransposeAOp
+MatmulTransposeAOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
+ auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+bool MatmulTransposeAOp::classof(Operation *op) {
+ return dyn_cast_or_null<linalg::MatmulOp>(op) &&
+ MatmulTransposeAOp::isDefaultIndexingMaps(
+ op->getAttr("indexing_maps"));
+}
+
+SmallVector<AffineMap>
+MatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
+ AffineExpr d0, d1, d2;
+ MLIRContext *context = builder.getContext();
+ bindDims(context, d0, d1, d2);
+ AffineMap mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
+ AffineMap mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
+ AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
+ return {mapLHS, mapRHS, mapOut};
+}
+
+bool MatmulTransposeBOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
+ (*positions)[1] == SmallVector<int64_t>{1, 2} &&
+ (*positions)[2] == SmallVector<int64_t>{0, 1};
+}
+
+void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
+ OperationState &result,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
+}
+
+MatmulTransposeBOp
+MatmulTransposeBOp::create(OpBuilder &builder, Location location,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, inputs, outputs, attributes);
+ auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
+ OperationState &result,
+ TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
+}
+
+MatmulTransposeBOp
+MatmulTransposeBOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, attributes);
+ auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
+ OperationState &result,
+ TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ result.addAttribute("cast", cast);
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
+}
+
+MatmulTransposeBOp
+MatmulTransposeBOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
+ auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+bool MatmulTransposeBOp::classof(Operation *op) {
+ return dyn_cast_or_null<linalg::MatmulOp>(op) &&
+ MatmulTransposeBOp::isDefaultIndexingMaps(
+ op->getAttr("indexing_maps"));
+}
+
+SmallVector<AffineMap>
+BatchMatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
+ AffineExpr d0, d1, d2, d3;
+ MLIRContext *context = builder.getContext();
+ bindDims(context, d0, d1, d2, d3);
+ AffineMap mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
+ AffineMap mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
+ AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
+ return {mapLHS, mapRHS, mapOut};
+}
+
+bool BatchMatmulTransposeAOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{0, 3, 1} &&
+ (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
+ (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
+}
+
+void linalg::BatchMatmulTransposeAOp::build(
+ OpBuilder &builder, OperationState &result, ValueRange inputs,
+ ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(),
+ getDefaultIndexingMaps(builder));
+}
+
+BatchMatmulTransposeAOp
+BatchMatmulTransposeAOp::create(OpBuilder &builder, Location location,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, inputs, outputs, attributes);
+ auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::BatchMatmulTransposeAOp::build(
+ OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(),
+ getDefaultIndexingMaps(builder));
+}
+
+BatchMatmulTransposeAOp
+BatchMatmulTransposeAOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, attributes);
+ auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::BatchMatmulTransposeAOp::build(
+ OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ result.addAttribute("cast", cast);
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(),
+ getDefaultIndexingMaps(builder));
+}
+
+BatchMatmulTransposeAOp
+BatchMatmulTransposeAOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
+ auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+bool BatchMatmulTransposeAOp::classof(Operation *op) {
+ return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
+ BatchMatmulTransposeAOp::isDefaultIndexingMaps(
+ op->getAttr("indexing_maps"));
+}
+
+SmallVector<AffineMap>
+BatchMatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
+ AffineExpr d0, d1, d2, d3;
+ MLIRContext *context = builder.getContext();
+ bindDims(context, d0, d1, d2, d3);
+ AffineMap mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
+ AffineMap mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
+ AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
+ return {mapLHS, mapRHS, mapOut};
+}
+
+bool BatchMatmulTransposeBOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
+ (*positions)[1] == SmallVector<int64_t>{0, 2, 3} &&
+ (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
+}
+
+void linalg::BatchMatmulTransposeBOp::build(
+ OpBuilder &builder, OperationState &result, ValueRange inputs,
+ ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(),
+ getDefaultIndexingMaps(builder));
+}
+
+BatchMatmulTransposeBOp
+BatchMatmulTransposeBOp::create(OpBuilder &builder, Location location,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, inputs, outputs, attributes);
+ auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::BatchMatmulTransposeBOp::build(
+ OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(),
+ getDefaultIndexingMaps(builder));
+}
+
+BatchMatmulTransposeBOp
+BatchMatmulTransposeBOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, attributes);
+ auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::BatchMatmulTransposeBOp::build(
+ OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ result.addAttribute("cast", cast);
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(),
+ getDefaultIndexingMaps(builder));
+}
+
+BatchMatmulTransposeBOp
+BatchMatmulTransposeBOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
+ auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+bool BatchMatmulTransposeBOp::classof(Operation *op) {
+ return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
+ BatchMatmulTransposeBOp::isDefaultIndexingMaps(
+ op->getAttr("indexing_maps"));
+}
+
//===----------------------------------------------------------------------===//
// ContractOp
//===----------------------------------------------------------------------===//
@@ -4120,6 +4538,20 @@ BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
return indexingMaps;
}
+bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
+ (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
+ (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
+}
+
SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
return SmallVector<utils::IteratorType>{
utils::IteratorType::parallel, utils::IteratorType::parallel,
@@ -5042,7 +5474,7 @@ PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
/// Returns true if the tiles and the tiled dims are constant.
template <typename OpTy>
-bool areTilesAndTiledDimsAllConstant(OpTy op) {
+static bool areTilesAndTiledDimsAllConstant(OpTy op) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
"applies to only pack or unpack operations");
ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
@@ -5345,11 +5777,18 @@ ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
auto innerDimsPos = getInnerDimsPos();
- auto packedShape = getSourceType().getShape();
+ SmallVector<int64_t> outerDims(getAllOuterDims());
SmallVector<int64_t> res;
+ // Recover the original order of the outer dims.
+ SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
+ invertPermutationVector(outerDimPermInv);
+ if (!outerDimPermInv.empty())
+ applyPermutationToVector(outerDims, outerDimPermInv);
+
+ // Collect the outer dims corresponding to the tilled inner dims.
for (auto index : innerDimsPos)
- res.push_back(packedShape[index]);
+ res.push_back(outerDims[index]);
return res;
}
@@ -5646,6 +6085,19 @@ BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
return indexingMaps;
}
+bool BatchReduceMatmulOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
+ (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
+ (*positions)[2] == SmallVector<int64_t>{1, 2};
+}
unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
std::string BatchReduceMatmulOp::getLibraryCallName() {
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index bdfc8d0..f0c1f44 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h"
@@ -27,6 +28,7 @@
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Utils/Utils.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
@@ -68,12 +70,7 @@ static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
// We want to discourage direct use of PatternRewriter in APIs but In this
// very specific case, an IRRewriter is not enough.
- struct TrivialPatternRewriter : public PatternRewriter {
- public:
- explicit TrivialPatternRewriter(MLIRContext *context)
- : PatternRewriter(context) {}
- };
- TrivialPatternRewriter rewriter(operation->getContext());
+ PatternRewriter rewriter(operation->getContext());
rewriter.setInsertionPoint(operation);
auto result = pattern.returningMatchAndRewrite(op, rewriter);
if (failed(result))
@@ -1985,14 +1982,19 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
// Convert the padding values to attributes.
SmallVector<Attribute> paddingValues;
- for (auto const &it :
+ for (auto const &[untypedAttr, elementOrTensorType] :
llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
- auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
+
+ if (isa<ub::PoisonAttr>(untypedAttr)) {
+ paddingValues.push_back(untypedAttr);
+ continue;
+ }
+ auto attr = dyn_cast<TypedAttr>(untypedAttr);
if (!attr) {
- emitOpError("expects padding values to be typed attributes");
+ emitOpError("expects padding values to be typed attributes or poison");
return DiagnosedSilenceableFailure::definiteFailure();
}
- Type elementType = getElementTypeOrSelf(std::get<1>(it));
+ Type elementType = getElementTypeOrSelf(elementOrTensorType);
// Try to parse string attributes to obtain an attribute of element type.
if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
@@ -2000,7 +2002,7 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
/*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
if (!parsedAttr || parsedAttr.getType() != elementType) {
auto diag = this->emitOpError("expects a padding that parses to ")
- << elementType << ", got " << std::get<0>(it);
+ << elementType << ", got " << untypedAttr;
diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
return DiagnosedSilenceableFailure::definiteFailure();
}
@@ -2235,8 +2237,13 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
auto attr = dyn_cast<TypedAttr>(untypedAttr);
Type elementType = getElementTypeOrSelf(elementOrTensorType);
+
+ if (isa<ub::PoisonAttr>(untypedAttr)) {
+ paddingValues.push_back(untypedAttr);
+ continue;
+ }
if (!attr) {
- emitOpError("expects padding values to be typed attributes");
+ emitOpError("expects padding values to be typed attributes or poison");
return DiagnosedSilenceableFailure::definiteFailure();
}
// Try to parse string attributes to obtain an attribute of element type.
@@ -3783,8 +3790,15 @@ LogicalResult TileUsingForallOp::verify() {
void transform::VectorizeChildrenAndApplyPatternsOp::build(
OpBuilder &builder, OperationState &result, Value target,
- bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
+ bool foldTypeExtensionsIntoContract, bool vectorizePadding,
+ bool vectorizeExtract, bool flatten1DDepthwiseConv) {
result.addOperands(target);
+ if (foldTypeExtensionsIntoContract) {
+ result.addAttribute(
+ VectorizeChildrenAndApplyPatternsOp::
+ getFoldTypeExtensionsIntoContractAttrName(result.name),
+ builder.getUnitAttr());
+ }
if (vectorizePadding) {
result.addAttribute(
VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
@@ -3875,6 +3889,9 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
patterns.add<CopyVectorizationPattern>(ctx);
+ if (getFoldTypeExtensionsIntoContract())
+ vector::populateFoldArithExtensionPatterns(patterns);
+
if (getVectorizePadding()) {
linalg::populatePadOpVectorizationPatterns(patterns);
// This creates an alternative path for lowering tensor.pad - by
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 3908d73..6912da3f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -55,8 +55,8 @@ static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp,
// Skip the batch dimension if present.
// Offset all dimensions accordingly.
SmallVector<int64_t, 3> offsetDims(dims);
- for (size_t i = 0; i < offsetDims.size(); i++)
- offsetDims[i] += batchDimsOffset;
+ for (int64_t &offsetDim : offsetDims)
+ offsetDim += batchDimsOffset;
auto tileOp = cast<TilingInterface>(linalgOp.getOperation());
OpBuilder builder(tileOp);
@@ -320,10 +320,6 @@ void linalg::populateBlockPackMatmulPatterns(
RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
patterns.add<BlockPackMatmul<linalg::GenericOp>,
BlockPackMatmul<linalg::MatmulOp>,
- BlockPackMatmul<linalg::BatchMatmulOp>,
- BlockPackMatmul<linalg::MatmulTransposeAOp>,
- BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
- BlockPackMatmul<linalg::MatmulTransposeBOp>,
- BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>(
- patterns.getContext(), controlFn);
+ BlockPackMatmul<linalg::BatchMatmulOp>>(patterns.getContext(),
+ controlFn);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 70f846e..fb39e186 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -23,9 +23,11 @@ add_mlir_dialect_library(MLIRLinalgTransforms
InlineScalarOperands.cpp
Interchange.cpp
Loops.cpp
+ MorphOps.cpp
TransposeMatmul.cpp
ShardingInterfaceImpl.cpp
- NamedOpConversions.cpp
+ SimplifyDepthwiseConv.cpp
+ NamedToElementwise.cpp
BlockPackMatmul.cpp
PackAndUnpackPatterns.cpp
Padding.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index d1eb270..108abe8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
@@ -50,28 +51,71 @@ static Value createMul(Location loc, Value x, Value y, Type accType,
return arith::MulFOp::create(builder, loc, xConvert, yConvert);
}
-// Delinearizes the given composite `index` by the basis specified in `factors`.
-static SmallVector<Value> unrollIndex(OpBuilder &b, Location loc, Value index,
- ArrayRef<int64_t> factors) {
- assert(!factors.empty() && "empty factor list");
- SmallVector<Value> basis;
- for (int64_t f : factors)
- basis.push_back(arith::ConstantOp::create(b, loc, b.getIndexAttr(f)));
- FailureOr<SmallVector<Value>> multiIndex =
- affine::delinearizeIndex(b, loc, index, basis);
- assert(!failed(multiIndex) && "Failed to linearize img2col index");
- return *multiIndex;
+// Generate the affine expression to compute the convolved index
+// for the input as `oIndex * stride + fIndex`,
+// where oIndex: output iterator; fIndex: filter iterator.
+static AffineExpr getConvolvedExpr(OpBuilder &b, int64_t stride,
+ bool useSymbols = true) {
+ AffineExpr oExpr, fExpr;
+ if (useSymbols)
+ bindSymbols(b.getContext(), oExpr, fExpr);
+ else
+ bindDims(b.getContext(), oExpr, fExpr);
+ return AffineExpr(stride * oExpr + fExpr);
}
-// Given indices corresponding to iterators in the output (oIndex) and filter
-// (fIndex) for a convolution, compute the convolved index for the
-// input as `oIndex * stride + fIndex`.
-static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex,
- Value fIndex, int64_t stride) {
- AffineExpr oExpr, fExpr;
- bindSymbols(b.getContext(), oExpr, fExpr);
- AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr);
- return affine::makeComposedAffineApply(b, loc, convMap, {oIndex, fIndex});
+// Stores the affine expressions to map the iteration space of the im2col matrix
+// to the corresponding indices of the output and filter matrices
+struct Im2ColToOperandsExprs {
+ AffineExpr fhIndex;
+ AffineExpr fwIndex;
+ AffineExpr icIndex;
+ AffineExpr ohIndex;
+ AffineExpr owIndex;
+};
+
+// Stores the affine expressions to map the iteration space of the im2col matrix
+// to the input matrix indices
+struct Im2ColToInputDimsExprs {
+ AffineExpr bIndex;
+ AffineExpr hIndex;
+ AffineExpr wIndex;
+ AffineExpr cIndex;
+};
+
+/// Construct the affine expressions that map the indices of the im2col matrix
+/// to the corresponding input tensor indices for a 2D convolution with the the
+/// provided strides.
+///
+/// @param exprs Affine expressions for output and filter indices.
+/// @param strides [height, width] stride values for the convolution.
+/// @param rewriter Pattern rewriter.
+/// @return Affine expressions mapping im2col matrix indices to input
+/// offsets.
+static Im2ColToInputDimsExprs
+getIm2ColInputExpressions(Im2ColToOperandsExprs exprs,
+ ArrayRef<int64_t> strides, RewriterBase &rewriter) {
+ // maps the iteration space of the im2col matrix to (output_y, filter_y)
+ auto hIndicesMap = AffineMap::inferFromExprList(
+ {ArrayRef{exprs.ohIndex, exprs.fhIndex}}, rewriter.getContext())[0];
+ // maps the iteration space of the im2col matrix to (output_x, filter_x)
+ auto wIndicesMap = AffineMap::inferFromExprList(
+ {ArrayRef{exprs.owIndex, exprs.fwIndex}}, rewriter.getContext())[0];
+ // Compute the input indexing map, to map the indices of the im2col matrix to
+ // the original input offsets. Each element of the im2col matrix corresponds
+ // to a pair of (out_element, filter_element). First, we build the expressions
+ // to compute the input (ix, iy) indices from [out_x/y, filter_x/y] pairs;
+ // then we compose them with the maps that map the im2col matrix elements to
+ // the (out_element, filter_element) pairs.
+ auto bIndexExpr = rewriter.getAffineDimExpr(0U);
+ auto hIndexExpr = getConvolvedExpr(rewriter, strides[0],
+ /*useSymbols*/ false);
+ hIndexExpr = hIndexExpr.compose(hIndicesMap);
+ auto wIndexExpr = getConvolvedExpr(rewriter, strides[1],
+ /*useSymbols*/ false);
+ wIndexExpr = wIndexExpr.compose(wIndicesMap);
+ auto cIndexExpr = exprs.icIndex;
+ return {bIndexExpr, hIndexExpr, wIndexExpr, cIndexExpr};
}
FailureOr<std::pair<Operation *, Operation *>>
@@ -135,44 +179,37 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
auto reduction = utils::IteratorType::reduction;
SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
+ // Given an index of the im2col matrix, retrieve the corresponding indices of
+ // the output and filter matrices
+ auto mIndicesExprs =
+ delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
+ auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
+ ArrayRef<int64_t>{fw * ic, ic, 1});
+ Im2ColToOperandsExprs i2cToOperExprs;
+ i2cToOperExprs.fhIndex = kIndicesExprs[0];
+ i2cToOperExprs.fwIndex = kIndicesExprs[1];
+ i2cToOperExprs.icIndex = kIndicesExprs[2];
+ i2cToOperExprs.ohIndex = mIndicesExprs[0];
+ i2cToOperExprs.owIndex = mIndicesExprs[1];
+
+ // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
+ Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions(
+ i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
+ rewriter);
+ auto inMap =
+ AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.hIndex,
+ inExprs.wIndex, inExprs.cIndex}},
+ rewriter.getContext())[0];
+
SmallVector<AffineMap> img2colIndexingMaps = {
- AffineMap::getMultiDimIdentityMap(nloops, context)};
+ inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
auto img2ColTensor = linalg::GenericOp::create(
rewriter, loc, colTensor.getType(),
- /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
+ /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
img2colIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- // Get the iterators named based on the matmul (batch, m, k).
- Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0);
- Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1);
- Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2);
-
- // Recover the original iteration indices from the problem/input sizes.
- SmallVector<Value> mIndices = unrollIndex(
- nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
- auto ohIndex = mIndices[0];
- auto owIndex = mIndices[1];
-
- SmallVector<Value> kIndices = unrollIndex(
- nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
- auto fhIndex = kIndices[0];
- auto fwIndex = kIndices[1];
- auto icIndex = kIndices[2];
-
- // Extract the input element corresponding to the expanded indices.
- Value hIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
- convOp.getStrides().getValues<int64_t>()[0]);
- Value wIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
- convOp.getStrides().getValues<int64_t>()[1]);
-
- // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
- SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
- Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input,
- extractionIndices);
- linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
});
// Because the filter does not share the same batch dimension,
@@ -421,44 +458,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
auto reduction = utils::IteratorType::reduction;
SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);
- SmallVector<AffineMap, 4> img2colIndexingMaps = {
- AffineMap::getMultiDimIdentityMap(nloops, context)};
+ // Recover the original iteration indices from the problem/input sizes:
+ // given an index of the im2col matrix, retrieve the corresponding indices of
+ // the output and filter matrices
+ auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(1U),
+ ArrayRef<int64_t>{fh * fw, fw, 1});
+ auto mIndicesExprs =
+ delinearize(rewriter.getAffineDimExpr(2U), ArrayRef<int64_t>{ow, 1});
+ Im2ColToOperandsExprs i2cToOperExprs;
+ i2cToOperExprs.icIndex = kIndicesExprs[0];
+ i2cToOperExprs.fhIndex = kIndicesExprs[1];
+ i2cToOperExprs.fwIndex = kIndicesExprs[2];
+ i2cToOperExprs.ohIndex = mIndicesExprs[0];
+ i2cToOperExprs.owIndex = mIndicesExprs[1];
+ Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions(
+ i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
+ rewriter);
+ auto inMap =
+ AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.cIndex,
+ inExprs.hIndex, inExprs.wIndex}},
+ rewriter.getContext())[0];
+ // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
+ SmallVector<AffineMap> img2colIndexingMaps = {
+ inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
auto img2ColTensor = linalg::GenericOp::create(
rewriter, loc, colTensor.getType(),
- /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
+ /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
img2colIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- // Get the iterators named based on the matmul (batch, m, k).
- Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0);
- Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 1);
- Value nIndex = linalg::IndexOp::create(nestedBuilder, loc, 2);
-
- // Recover the original iteration indices from the problem/input sizes.
- SmallVector<Value> kIndices = unrollIndex(
- nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw});
- auto icIndex = kIndices[0];
- auto fhIndex = kIndices[1];
- auto fwIndex = kIndices[2];
-
- SmallVector<Value> nIndices = unrollIndex(
- nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t>{oh, ow});
- auto ohIndex = nIndices[0];
- auto owIndex = nIndices[1];
-
- // Extract the input element corresponding to the expanded indices.
- Value hIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
- convOp.getStrides().getValues<int64_t>()[0]);
- Value wIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
- convOp.getStrides().getValues<int64_t>()[1]);
-
- // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
- SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex};
- Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input,
- extractionIndices);
- linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
});
// Because the filter does not share the same batch dimension,
@@ -545,6 +574,7 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
Value reshapedOutput = tensor::CollapseShapeOp::create(
rewriter, loc, reshapedOutputType, output, outputReassocIndices);
+ // Shape of the Toeplitz matrix produced by Im2col.
SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
inputType.getElementType());
@@ -556,44 +586,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
auto reduction = utils::IteratorType::reduction;
SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
+ // Given an index of the im2col matrix, retrieve the corresponding indices of
+ // the output and filter matrices
+ auto mIndicesExprs =
+ delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
+ auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
+ ArrayRef<int64_t>{fw * ic, ic, 1});
+ Im2ColToOperandsExprs i2cToOperExprs;
+ i2cToOperExprs.fhIndex = kIndicesExprs[0];
+ i2cToOperExprs.fwIndex = kIndicesExprs[1];
+ i2cToOperExprs.icIndex = kIndicesExprs[2];
+ i2cToOperExprs.ohIndex = mIndicesExprs[0];
+ i2cToOperExprs.owIndex = mIndicesExprs[1];
+
+ // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
+ Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions(
+ i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
+ rewriter);
+ auto inMap =
+ AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.hIndex,
+ inExprs.wIndex, inExprs.cIndex}},
+ rewriter.getContext())[0];
SmallVector<AffineMap> img2colIndexingMaps = {
- AffineMap::getMultiDimIdentityMap(nloops, context)};
+ inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
auto img2ColTensor = linalg::GenericOp::create(
rewriter, loc, colTensor.getType(),
- /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
+ /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
img2colIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- // Get the iterators named based on the matmul (batch, m, k).
- Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0);
- Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1);
- Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2);
-
- // Recover the original iteration indices from the problem/input sizes.
- SmallVector<Value> mIndices = unrollIndex(
- nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
- auto ohIndex = mIndices[0];
- auto owIndex = mIndices[1];
-
- SmallVector<Value> kIndices = unrollIndex(
- nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
- auto fhIndex = kIndices[0];
- auto fwIndex = kIndices[1];
- auto icIndex = kIndices[2];
-
- // Extract the input element corresponding to the expanded indices.
- Value hIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
- convOp.getStrides().getValues<int64_t>()[0]);
- Value wIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
- convOp.getStrides().getValues<int64_t>()[1]);
-
- // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
- SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
- Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input,
- extractionIndices);
- linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
});
// Because we didn't transpose the filters we don't actually have a batched
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 76ddee4..2ff7f46 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -75,7 +75,7 @@ static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource,
// layout for best compatibility.
Value toBuffer = bufferization::ToBufferOp::create(
b, loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType),
- tensorSource, /*readOnly=*/true);
+ tensorSource, /*read_only=*/true);
memref::CopyOp::create(b, loc, toBuffer, memrefDest);
} break;
case linalg::BufferizeToAllocationOptions::MemcpyOp::LinalgCopy: {
@@ -84,7 +84,7 @@ static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource,
// layout for best compatibility.
Value toBuffer = bufferization::ToBufferOp::create(
b, loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType),
- tensorSource, /*readOnly=*/true);
+ tensorSource, /*read_only=*/true);
linalg::CopyOp::create(b, loc, toBuffer, memrefDest);
} break;
};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 0a9c176..40085a2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -6,10 +6,12 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Dominance.h"
#include "llvm/ADT/SetOperations.h"
@@ -1236,6 +1238,272 @@ private:
ControlPropagationFn controlFn;
};
+// This struct contains infomation about extract_slice dims.
+struct SliceDimInfo {
+ OpFoldResult offset;
+ OpFoldResult sliceSize;
+ OpFoldResult outputSize;
+};
+
+/// Return the first input extract slice operand, if present, for the current
+/// generic op.
+static FailureOr<OpOperand *> getSliceOperand(GenericOp genericOp) {
+ OpOperand *sliceOperand = nullptr;
+ for (auto operand : genericOp.getDpsInputOperands()) {
+ auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
+ if (!extractOp)
+ continue;
+ sliceOperand = operand;
+ break;
+ }
+ if (!sliceOperand) {
+ return failure();
+ }
+ return sliceOperand;
+}
+
+// Return a map of dims that have partial slices on them so that other operands
+// can use this information. Also return a bool mentioning if a reduction dim
+// has a non full slice as that can be used to fold the original extract slice.
+static FailureOr<llvm::DenseMap<int64_t, SliceDimInfo>>
+getPartialSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand) {
+ tensor::ExtractSliceOp producerSliceOp =
+ sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
+ assert(producerSliceOp && "expect a valid ExtractSliceOp");
+ llvm::DenseMap<int64_t, SliceDimInfo> partialSliceDimMap;
+ SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+
+ SmallVector<OpFoldResult> shape = getAsIndexOpFoldResult(
+ genericOp.getContext(), producerSliceOp.getSourceType().getShape());
+
+ for (auto [idx, expr] : llvm::enumerate(
+ genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
+ // If we have a full slice in a dimension then we dont need to add it to
+ // the partial slice map.
+ if (isConstantIntValue(offsets[idx], 0) &&
+ isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
+ continue;
+ }
+ // We only support partial slices of AffineDimExprs so bail-out if thats not
+ // the case.
+ if (!isa<AffineDimExpr>(expr)) {
+ return failure();
+ }
+ SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
+ int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
+ partialSliceDimMap[dimPos] = sliceDimInfo;
+ }
+ // Next check if the dims with partial slice info are used in non
+ // AffineDimExpr in other operands and if they are then bail-out.
+ for (OpOperand &operand : genericOp->getOpOperands()) {
+ if (operand == *sliceOperand) {
+ continue;
+ }
+ AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
+ if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
+ if (isa<AffineDimExpr>(expr)) {
+ return false;
+ }
+ WalkResult status = expr.walk([&](AffineExpr expr) {
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+ if (partialSliceDimMap.contains(dimExpr.getPosition())) {
+ return WalkResult::interrupt();
+ }
+ }
+ return WalkResult::advance();
+ });
+ if (status.wasInterrupted()) {
+ return true;
+ }
+ return false;
+ })) {
+ return failure();
+ }
+ }
+ return partialSliceDimMap;
+}
+
+static FailureOr<std::tuple<GenericOp, Value>>
+pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
+ GenericOp genericOp,
+ ControlPropagationFn controlFn) {
+ if (genericOp.getNumResults() != 1)
+ return rewriter.notifyMatchFailure(
+ genericOp, "propagation through multi-result generic is unsupported.");
+ if (hasGatherSemantics(genericOp))
+ return rewriter.notifyMatchFailure(
+ genericOp,
+ "propagation through generic with gather semantics is unsupported.");
+ // Collect the sliced operand, if present.
+ auto maybeSliceOperand = getSliceOperand(genericOp);
+ if (failed(maybeSliceOperand))
+ return failure();
+ OpOperand *sliceOperand = *maybeSliceOperand;
+ unsigned OperandIndex = sliceOperand->getOperandNumber();
+
+ if (!controlFn(sliceOperand))
+ return failure();
+
+ tensor::ExtractSliceOp producerSliceOp =
+ sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
+ assert(producerSliceOp && "expect a valid ExtractSliceOp");
+
+ if (producerSliceOp.getSource().getType().getRank() !=
+ producerSliceOp.getResult().getType().getRank()) {
+ return rewriter.notifyMatchFailure(
+ genericOp,
+ "propagation of rank-reducing extract slice is unsupported.");
+ }
+
+ SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
+ if (!areAllConstantIntValue(strides, 1))
+ return rewriter.notifyMatchFailure(
+ genericOp, "propagation of strided extract slice is unsupported.");
+
+ // check if we can support the propagation of this extractSlice
+ // through the generic op and if so return the dimensions that
+
+ auto maybePartialSliceDimMap =
+ getPartialSliceDimInfo(genericOp, sliceOperand);
+
+ if (failed(maybePartialSliceDimMap)) {
+ return failure();
+ }
+
+ auto partialSliceDimMap = *maybePartialSliceDimMap;
+
+ SmallVector<utils::IteratorType> iterators =
+ genericOp.getIteratorTypesArray();
+ bool hasPartialReductionDimSlice =
+ llvm::any_of(partialSliceDimMap, [&](const auto &slice) {
+ int64_t sliceDim = slice.first;
+ return iterators[sliceDim] == utils::IteratorType::reduction;
+ });
+
+ // Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
+ Location loc = genericOp->getLoc();
+ AffineExpr dim0, dim1;
+ bindDims(rewriter.getContext(), dim0, dim1);
+ auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
+ auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
+ return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap,
+ {v1, v2});
+ };
+
+ MLIRContext *ctx = genericOp.getContext();
+ SmallVector<Value> paddedInputs;
+ for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+ if (idx == OperandIndex && !hasPartialReductionDimSlice) {
+ paddedInputs.push_back(producerSliceOp.getSource());
+ continue;
+ }
+ AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
+ SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 0));
+ SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 0));
+ for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) {
+ if (!isa<AffineDimExpr>(expr)) {
+ continue;
+ }
+ AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+ if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
+ continue;
+ }
+ SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
+ operandLowPads[idx] = sliceDimInfo.offset;
+ operandHighPads[idx] =
+ sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+ sliceDimInfo.sliceSize);
+ }
+ auto paddingValue = ub::PoisonOp::create(
+ rewriter, loc, getElementTypeOrSelf(operand->get().getType()));
+ auto paddedOperand = tensor::PadOp::create(
+ rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads,
+ paddingValue, /*nofold=*/false);
+ paddedInputs.push_back(paddedOperand);
+ }
+ AffineMap outputIndexingMap =
+ genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
+
+ auto outputShapeType =
+ llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
+ SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector(
+ outputShapeType.getShape(),
+ [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); });
+ SmallVector<OpFoldResult> newSizes = OutputShape;
+ SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 0));
+ SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 0));
+ SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 1));
+ for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) {
+ if (!isa<AffineDimExpr>(expr)) {
+ continue;
+ }
+ AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+ if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
+ continue;
+ }
+ SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
+ outputLowPads[idx] = sliceDimInfo.offset;
+ outputHighPads[idx] = sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+ sliceDimInfo.sliceSize);
+ OutputShape[idx] = sliceDimInfo.outputSize;
+ newSizes[idx] = sliceDimInfo.sliceSize;
+ }
+ Value newPadOutput;
+ auto outputElType =
+ getElementTypeOrSelf(genericOp.getDpsInits()[0].getType());
+ if (isGenericOutsNotUsed(genericOp)) {
+ newPadOutput =
+ tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
+ } else {
+ auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
+ newPadOutput = tensor::PadOp::create(
+ rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads,
+ outputHighPads, paddingValue, /*nofold=*/false);
+ }
+
+ auto newGenericOp = linalg::GenericOp::create(
+ rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput},
+ genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
+ /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
+ rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
+ newGenericOp.getRegion().begin());
+
+ auto extractOp = tensor::ExtractSliceOp::create(
+ rewriter, loc,
+ newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
+ outputLowPads, newSizes, newStrides);
+ Value extractRes = extractOp.getResult();
+
+ return std::make_tuple(newGenericOp, extractRes);
+}
+
+class PushDownExtractSliceOpThroughGenericOp final
+ : public OpRewritePattern<GenericOp> {
+public:
+ PushDownExtractSliceOpThroughGenericOp(MLIRContext *context,
+ ControlPropagationFn fun)
+ : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
+
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ auto genericAndRepl =
+ pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
+ if (failed(genericAndRepl))
+ return failure();
+ rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
+ return success();
+ }
+
+private:
+ ControlPropagationFn controlFn;
+};
+
} // namespace
void mlir::linalg::populateDataLayoutPropagationPatterns(
@@ -1247,3 +1515,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
patterns.getContext(), controlPackUnPackPropagation);
}
+
+void mlir::linalg::populateExtractSliceSinkingPatterns(
+ RewritePatternSet &patterns,
+ const ControlPropagationFn &controlPackUnPackPropagation) {
+ patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
+ patterns.getContext(), controlPackUnPackPropagation);
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index bf66ed0..22690da 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -691,9 +691,9 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
auto newResultType = RankedTensorType::get(
newResultShape, padOp.getResultType().getElementType());
- auto newPadOp = rewriter.create<tensor::PadOp>(
- padOp.getLoc(), /*result=*/newResultType, collapsedSource, newLowPad,
- newHighPad, paddingVal, padOp.getNofold());
+ auto newPadOp = tensor::PadOp::create(
+ rewriter, padOp.getLoc(), /*result=*/newResultType, collapsedSource,
+ newLowPad, newHighPad, paddingVal, padOp.getNofold());
Value dest = padOp.getResult();
if (options.rankReductionStrategy ==
@@ -1052,12 +1052,8 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
static bool constexpr reduceLeft =
(std::is_same_v<FromOpTy, BatchMatmulOp> &&
std::is_same_v<ToOpTy, BatchVecmatOp>) ||
- (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
- std::is_same_v<ToOpTy, BatchVecmatOp>) ||
(std::is_same_v<FromOpTy, MatmulOp> &&
std::is_same_v<ToOpTy, VecmatOp>) ||
- (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
- std::is_same_v<ToOpTy, VecmatOp>) ||
(std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
/// Look for non-batch spatial dims to collapse.
@@ -1113,27 +1109,15 @@ void mlir::linalg::populateContractionOpRankReducingPatterns(
MLIRContext *context = patterns.getContext();
// Unbatching patterns for unit batch size
patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
- patterns
- .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
- context);
- patterns
- .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
- context);
patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
// Non-batch rank 1 reducing patterns
patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
- patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
- patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
// Batch rank 1 reducing patterns
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
- patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
- context);
- patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
- context);
// Non-batch rank 0 reducing patterns
patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index c523153..baf4083 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -20,13 +20,26 @@ namespace mlir {
using namespace mlir;
+static inline bool isScalarLike(Type t) {
+ return isa<IntegerType, FloatType, IndexType, ComplexType>(t);
+}
+
static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
if (!OpTrait::hasElementwiseMappableTraits(op))
return false;
- // TODO: The conversion pattern can be made to work for `any_of` here, but
- // it's more complex as it requires tracking which operands are scalars.
- return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
+ auto types = op->getOperandTypes();
+
+ // We want at least one ranked tensor.
+ bool anyRankedTensor = llvm::any_of(types, llvm::IsaPred<RankedTensorType>);
+
+ // No invalid operands (i.e., every operand is a ranked tensor or
+ // scalar-like).
+ bool noneInvalid = llvm::none_of(types, [](Type t) {
+ return !(isa<RankedTensorType>(t) || isScalarLike(t));
+ });
+
+ return anyRankedTensor && noneInvalid;
}
/// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
@@ -81,13 +94,41 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
return rewriter.notifyMatchFailure(
op, "requires elementwise op on ranked tensors");
- auto rank = cast<RankedTensorType>(op->getResult(0).getType()).getRank();
- SmallVector<AffineMap, 3> indexingMaps(
- op->getNumResults() + op->getNumOperands(),
- rewriter.getMultiDimIdentityMap(rank));
- SmallVector<utils::IteratorType, 6> iteratorTypes(
+ auto resTy = cast<RankedTensorType>(op->getResult(0).getType());
+ auto rank = resTy.getRank();
+
+ // Maps: identity for tensors (rank > 0), scalar map for scalars.
+ AffineMap scalarMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0,
+ /*results=*/{}, rewriter.getContext());
+ AffineMap idMap = rewriter.getMultiDimIdentityMap(rank);
+
+ // Match phase.
+ SmallVector<bool> isScalarOperand;
+ isScalarOperand.reserve(op->getNumOperands());
+ for (Type ty : op->getOperandTypes()) {
+ if (isScalarLike(ty))
+ isScalarOperand.push_back(true);
+ else if (auto rt = dyn_cast<RankedTensorType>(ty))
+ isScalarOperand.push_back(false);
+ else
+ return rewriter.notifyMatchFailure(
+ op,
+ "unsupported operand type (expected scalar-like or ranked tensor)");
+ }
+
+ // Create indexing maps.
+ SmallVector<AffineMap> indexingMaps;
+ indexingMaps.reserve(op->getNumOperands() + op->getNumResults());
+
+ for (bool isScalar : isScalarOperand)
+ indexingMaps.push_back(isScalar ? scalarMap : idMap);
+
+ indexingMaps.append(op->getNumResults(), idMap);
+
+ SmallVector<utils::IteratorType> iteratorTypes(
rank, utils::IteratorType::parallel);
- auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op);
+ SmallVector<Value> outputs =
+ getOrCreateOperandsMatchingResultTypes(rewriter, op);
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, /*resultTensorTypes=*/op->getResultTypes(),
/*inputs=*/op->getOperands(),
@@ -96,14 +137,14 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
/*iteratorTypes=*/iteratorTypes,
/*bodyBuilder=*/
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
- auto resultTypes = llvm::to_vector<6>(
+ SmallVector<Type> resultEltTys = llvm::to_vector<6>(
llvm::map_range(op->getResultTypes(), [](Type type) {
return cast<TensorType>(type).getElementType();
}));
- auto *scalarOp =
+ Operation *scalarOp =
builder.create(loc, op->getName().getIdentifier(),
regionArgs.take_front(op->getNumOperands()),
- resultTypes, op->getAttrs());
+ resultEltTys, op->getAttrs());
linalg::YieldOp::create(builder, loc, scalarOp->getResults());
});
return success();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index fd530f2..9436f1c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -594,7 +594,8 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl(
auto clonedForOp = scf::ForOp::create(
rewriter, loc, bvm.lookupOrDefault(forOp.getLowerBound()),
bvm.lookupOrDefault(forOp.getUpperBound()),
- bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor);
+ bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor,
+ /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
// Map the induction var, region args and results to the `clonedForOp`.
bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 58986a6..36434cf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -55,7 +55,8 @@ static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter,
scf::ForOp newLoop = scf::ForOp::create(
rewriter, loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(),
- loop.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {});
+ loop.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {},
+ loop.getUnsignedCmp());
// Generate the new yield with the replaced operand.
auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
@@ -165,8 +166,12 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
Value source = transferRead.getBase();
// Skip view-like Ops and retrive the actual soruce Operation
- while (auto srcOp = source.getDefiningOp<ViewLikeOpInterface>())
- source = srcOp.getViewSource();
+ while (auto viewLike = source.getDefiningOp<ViewLikeOpInterface>()) {
+ if (viewLike.getViewDest() != source) {
+ break;
+ }
+ source = viewLike.getViewSource();
+ }
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
source.getUsers().end());
@@ -177,7 +182,8 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
if (!processed.insert(user).second)
continue;
if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
- users.append(viewLike->getUsers().begin(), viewLike->getUsers().end());
+ Value viewDest = viewLike.getViewDest();
+ users.append(viewDest.getUsers().begin(), viewDest.getUsers().end());
continue;
}
if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp
new file mode 100644
index 0000000..f261ccb
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp
@@ -0,0 +1,62 @@
+//===- MorphOps.cpp - conversion between named,category and generic ops ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements conversions between linalg ops:
+// named <--> category (elementwise, contraction, ..) <--> generic.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGMORPHOPSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "linalg-morphism"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+struct LinalgMorphOpsPass
+ : public impl::LinalgMorphOpsPassBase<LinalgMorphOpsPass> {
+
+ using impl::LinalgMorphOpsPassBase<
+ LinalgMorphOpsPass>::LinalgMorphOpsPassBase;
+
+ void runOnOperation() override;
+};
+
+void LinalgMorphOpsPass::runOnOperation() {
+
+ RewritePatternSet patterns(&getContext());
+
+ // Lowering paths (named -> category -> generic)
+ if (namedToCategory) {
+ populateLinalgNamedToElementwisePatterns(patterns);
+ }
+ if (namedToGeneric || categoryToGeneric) {
+ populateLinalgNamedOpsGeneralizationPatterns(patterns);
+ }
+
+ // Lifting paths (named <- category <- generic)
+ if (genericToNamed) {
+ populateLinalgGenericOpsSpecializationPatterns(patterns);
+ }
+
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ signalPassFailure();
+}
+} // namespace
diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
new file mode 100644
index 0000000..00a076b
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
@@ -0,0 +1,98 @@
+//===- NamedToElementwise.cpp - convert linalg named op into elementwise --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewriting those linalg named ops that are essentially
+// elementwise e.g. `linalg.exp`, to `linalg.elementwise`. This allows further
+// optimization on `linalg.elementwise` such as folding transpose, broadcast.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+#define DEBUG_TYPE "linalg-named-to-elementwise"
+
+namespace {
+ElementwiseKind getKind(Operation *op) {
+ return llvm::TypeSwitch<Operation *, ElementwiseKind>(op)
+ .Case([](SelectOp) { return ElementwiseKind::select; })
+ .Case([](AddOp) { return ElementwiseKind::add; })
+ .Case([](SubOp) { return ElementwiseKind::sub; })
+ .Case([](MulOp) { return ElementwiseKind::mul; })
+ .Case([](DivOp) { return ElementwiseKind::div; })
+ .Case([](DivUnsignedOp) { return ElementwiseKind::div_unsigned; })
+ .Case([](PowFOp) { return ElementwiseKind::powf; })
+ .Case([](ExpOp) { return ElementwiseKind::exp; })
+ .Case([](LogOp) { return ElementwiseKind::log; })
+ .Case([](AbsOp) { return ElementwiseKind::abs; })
+ .Case([](CeilOp) { return ElementwiseKind::ceil; })
+ .Case([](FloorOp) { return ElementwiseKind::floor; })
+ .Case([](NegFOp) { return ElementwiseKind::negf; })
+ .Case([](ReciprocalOp) { return ElementwiseKind::reciprocal; })
+ .Case([](RoundOp) { return ElementwiseKind::round; })
+ .Case([](SqrtOp) { return ElementwiseKind::sqrt; })
+ .Case([](RsqrtOp) { return ElementwiseKind::rsqrt; })
+ .Case([](SquareOp) { return ElementwiseKind::square; })
+ .Case([](TanhOp) { return ElementwiseKind::tanh; })
+ .Case([](ErfOp) { return ElementwiseKind::erf; })
+ .Default([&](Operation *op) {
+ llvm_unreachable("unhandled case in named to elementwise");
+ return ElementwiseKind::sub;
+ });
+}
+
+template <typename NamedOpTy>
+struct NamedToElementwisePattern : public OpRewritePattern<NamedOpTy> {
+ using OpRewritePattern<NamedOpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(NamedOpTy op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<NamedAttribute> attrs;
+ auto kindAttr = ElementwiseKindAttr::get(op.getContext(), getKind(op));
+ attrs.push_back(rewriter.getNamedAttr("kind", kindAttr));
+ attrs.push_back(
+ rewriter.getNamedAttr("indexing_maps", op.getIndexingMaps()));
+
+ rewriter.replaceOpWithNewOp<ElementwiseOp>(op, op.getDpsInputs(),
+ op.getDpsInits(), attrs);
+ return success();
+ }
+};
+} // namespace
+
+void mlir::linalg::populateLinalgNamedToElementwisePatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<NamedToElementwisePattern<SelectOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<AddOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<SubOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<MulOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<DivOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<DivUnsignedOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<PowFOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<ExpOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<LogOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<AbsOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<CeilOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<FloorOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<NegFOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<ReciprocalOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<RoundOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<SqrtOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<RsqrtOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<SquareOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<TanhOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<ErfOp>>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 2e62523..8942670 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -230,13 +231,18 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
Value paddingValue;
if (auto complexTy =
dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) {
- auto complexAttr = cast<ArrayAttr>(paddingValueAttr);
- paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
- complexTy, complexAttr);
- } else {
- paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(),
- cast<TypedAttr>(paddingValueAttr));
+ if (auto complexAttr = dyn_cast<ArrayAttr>(paddingValueAttr)) {
+ paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
+ complexTy, complexAttr);
+ }
+ } else if (isa<ub::PoisonAttr>(paddingValueAttr)) {
+ paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(),
+ getElementTypeOrSelf(v.getType()));
+ } else if (auto typedAttr = dyn_cast<TypedAttr>(paddingValueAttr)) {
+ paddingValue =
+ arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr);
}
+ assert(paddingValue && "failed to create value from padding attribute");
// Pad the operand to the bounding box defined by `paddedShape`.
SmallVector<int64_t> tensorShape;
@@ -257,11 +263,11 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
paddingValue, /*nofold=*/false, dynDims);
}
-FailureOr<TilingInterface>
-linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
- const PadTilingInterfaceOptions &constOptions,
- SmallVector<tensor::PadOp> &padOps,
- PadSizeComputationFunction computePaddingSizeFun) {
+FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
+ RewriterBase &rewriter, TilingInterface opToPad,
+ const PadTilingInterfaceOptions &constOptions,
+ SmallVector<tensor::PadOp> &padOps,
+ const PadSizeComputationFunction &computePaddingSizeFun) {
LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n");
Location loc = opToPad.getLoc();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp
index a2bd9d9..27ccf3c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp
@@ -21,7 +21,7 @@
#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
-#define GEN_PASS_DEF_LINALGNAMEDOPCONVERSIONPASS
+#define GEN_PASS_DEF_SIMPLIFYDEPTHWISECONVPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir
@@ -143,23 +143,22 @@ struct SimplifyDepthwiseConvQOp
}
};
-struct LinalgNamedOpConversionPass
- : public impl::LinalgNamedOpConversionPassBase<
- LinalgNamedOpConversionPass> {
- using impl::LinalgNamedOpConversionPassBase<
- LinalgNamedOpConversionPass>::LinalgNamedOpConversionPassBase;
+struct SimplifyDepthwiseConvPass
+ : public impl::SimplifyDepthwiseConvPassBase<SimplifyDepthwiseConvPass> {
+ using impl::SimplifyDepthwiseConvPassBase<
+ SimplifyDepthwiseConvPass>::SimplifyDepthwiseConvPassBase;
void runOnOperation() override {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
- populateLinalgNamedOpConversionPatterns(patterns);
+ populateSimplifyDepthwiseConvPatterns(patterns);
if (failed(applyPatternsGreedily(op, std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
-void mlir::linalg::populateLinalgNamedOpConversionPatterns(
+void mlir::linalg::populateSimplifyDepthwiseConvPatterns(
RewritePatternSet &patterns) {
patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>(
patterns.getContext());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 455e1a6..35ba4f15 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -234,19 +234,8 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
/// Codegen the different matmul variants.
if (numOfBatchDims) {
- if (a == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
- genericOp);
- if (b == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
- genericOp);
return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
}
-
- if (a == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
- if (b == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index bb725f2..e9a8b25 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -29,6 +29,7 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/InterleavedRange.h"
#include "llvm/Support/raw_ostream.h"
#include <utility>
@@ -38,9 +39,6 @@
using namespace mlir;
using namespace mlir::linalg;
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
-
//===----------------------------------------------------------------------===//
// Transformations exposed as functional-style API calls.
//===----------------------------------------------------------------------===//
@@ -91,11 +89,11 @@ static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) {
}
return true;
}
+#endif // NDEBUG
static std::string stringifyReassocIndices(ReassociationIndicesRef ri) {
return llvm::interleaved(ri, ", ", /*Prefix=*/"|", /*Suffix=*/"");
}
-#endif // NDEBUG
/// Return the index of the first result of `map` that is a function of
/// AffineDimExpr(dim), std::nullopt otherwise.
@@ -276,23 +274,18 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
tensor::PadOp::create(rewriter, loc, collapsed, packOp.getSource(), lows,
highs, paddingValue, /*nofold=*/false);
- LLVM_DEBUG(
- DBGSNL(); DBGSNL();
- DBGS() << "insertPositions: "
- << llvm::interleaved(packingMetadata.insertPositions);
- DBGSNL(); DBGS() << "outerPositions: "
- << llvm::interleaved(packingMetadata.outerPositions);
- DBGSNL(); DBGS() << "packedShape: "
- << llvm::interleaved(packedTensorType.getShape());
- DBGSNL(); DBGS() << "packedToStripMinedShapePerm: "
- << llvm::interleaved(packedToStripMinedShapePerm);
- DBGSNL();
- DBGS() << "reassociations: "
- << llvm::interleaved(llvm::map_range(
- packingMetadata.reassociations, stringifyReassocIndices));
- DBGSNL();
- DBGS() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
- DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
+ LDBG() << "insertPositions: "
+ << llvm::interleaved(packingMetadata.insertPositions);
+ LDBG() << "outerPositions: "
+ << llvm::interleaved(packingMetadata.outerPositions);
+ LDBG() << "packedShape: " << llvm::interleaved(packedTensorType.getShape());
+ LDBG() << "packedToStripMinedShapePerm: "
+ << llvm::interleaved(packedToStripMinedShapePerm);
+ LDBG() << "reassociations: "
+ << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
+ stringifyReassocIndices));
+ LDBG() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
+ LDBG() << "collapsed type: " << collapsed;
if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
// Pack ops which operate as simple pads may not produce legal
@@ -317,7 +310,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
rewriter, loc, /*source=*/padOp, /*dest=*/packOp.getDest(),
/*offsets=*/zeros, sizes, /*strides=*/ones);
- LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
+ LDBG() << "insert_slice op: " << insertSliceOp;
rewriter.replaceOp(packOp, insertSliceOp->getResults());
@@ -339,10 +332,9 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
auto transposeOp = linalg::TransposeOp::create(
rewriter, loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
- LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
- DBGS() << "reshape op: " << reshapeOp; DBGSNL();
- DBGS() << "transpPerm: " << llvm::interleaved(transpPerm);
- DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
+ LDBG() << "reshape op: " << reshapeOp;
+ LDBG() << "transpPerm: " << llvm::interleaved(transpPerm);
+ LDBG() << "transpose op: " << transposeOp;
// 7. Replace packOp by transposeOp.
rewriter.replaceOp(packOp, transposeOp->getResults());
@@ -410,21 +402,16 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
linalg::TransposeOp::create(rewriter, loc, unPackOp.getSource(), emptyOp,
packedToStripMinedShapePerm);
- LLVM_DEBUG(
- DBGSNL(); DBGSNL();
- DBGS() << "insertPositions: "
- << llvm::interleaved(packingMetadata.insertPositions);
- DBGSNL(); DBGS() << "packedShape: "
- << llvm::interleaved(packedTensorType.getShape());
- DBGSNL(); DBGS() << "packedToStripMinedShapePerm: "
- << llvm::interleaved(packedToStripMinedShapePerm);
- DBGSNL();
- DBGS() << "reassociations: "
- << llvm::interleaved(llvm::map_range(
- packingMetadata.reassociations, stringifyReassocIndices));
- DBGSNL();
- DBGS() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
- DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););
+ LDBG() << "insertPositions: "
+ << llvm::interleaved(packingMetadata.insertPositions);
+ LDBG() << "packedShape: " << llvm::interleaved(packedTensorType.getShape());
+ LDBG() << "packedToStripMinedShapePerm: "
+ << llvm::interleaved(packedToStripMinedShapePerm);
+ LDBG() << "reassociations: "
+ << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
+ stringifyReassocIndices));
+ LDBG() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
+ LDBG() << "collapsed type: " << collapsedType;
// 4. Collapse from the stripMinedShape to the padded result.
auto reshapeOp = tensor::CollapseShapeOp::create(
@@ -486,10 +473,9 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
SmallVector<utils::IteratorType> iteratorTypes =
linalgOp.getIteratorTypesArray();
- LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n"
- << "maps: " << llvm::interleaved(indexingMaps) << "\n"
- << "iterators: " << llvm::interleaved(iteratorTypes)
- << "\n");
+ LDBG() << "Start packing: " << linalgOp;
+ LDBG() << "maps: " << llvm::interleaved(indexingMaps);
+ LDBG() << "iterators: " << llvm::interleaved(iteratorTypes);
SmallVector<linalg::PackOp> packOps;
SmallVector<linalg::UnPackOp> unPackOps;
@@ -511,14 +497,11 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
- LLVM_DEBUG(
- DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i]
- << "\n"
- << "maps: " << llvm::interleaved(indexingMaps) << "\n"
- << "iterators: " << llvm::interleaved(iteratorTypes) << "\n"
- << "packedDimForEachOperand: "
- << llvm::interleaved(packedOperandsDims.packedDimForEachOperand)
- << "\n");
+ LDBG() << "++++ After pack size #" << i << ": " << packedSizes[i];
+ LDBG() << "maps: " << llvm::interleaved(indexingMaps);
+ LDBG() << "iterators: " << llvm::interleaved(iteratorTypes);
+ LDBG() << "packedDimForEachOperand: "
+ << llvm::interleaved(packedOperandsDims.packedDimForEachOperand);
}
// Step 2. Propagate packing to all LinalgOp operands.
@@ -534,10 +517,9 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
SmallVector<OpFoldResult> innerPackSizes =
listOfPackedOperandsDim.extractPackSizesForOperand(pos);
- LLVM_DEBUG(DBGS() << "operand: " << operand << "\n"
- << "innerPos: " << llvm::interleaved(innerPos) << "\n"
- << "innerPackSizes: "
- << llvm::interleaved(innerPackSizes) << "\n");
+ LDBG() << "operand: " << operand;
+ LDBG() << "innerPos: " << llvm::interleaved(innerPos);
+ LDBG() << "innerPackSizes: " << llvm::interleaved(innerPackSizes);
if (innerPackSizes.empty()) {
inputsAndInits.push_back(operand);
continue;
@@ -776,8 +758,8 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
int64_t numLoops = linalgOp.getNumLoops();
if (numLoops <= 2) {
- LLVM_DEBUG(DBGS() << "need 3+ loops to find a matmul to pack, got "
- << numLoops << "\nin: " << linalgOp << "\n");
+ LDBG() << "need 3+ loops to find a matmul to pack, got " << numLoops
+ << " in: " << linalgOp;
return rewriter.notifyMatchFailure(
linalgOp, "need 3+ loops to find a matmul to pack");
}
@@ -801,8 +783,7 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
FailureOr<ContractionDimensions> maybeDimensions =
inferContractionDims(linalgOp);
if (failed(maybeDimensions)) {
- LLVM_DEBUG(DBGS() << "couldn't infer matmul iterators in: " << linalgOp
- << "\n");
+ LDBG() << "couldn't infer matmul iterators in: " << linalgOp;
return rewriter.notifyMatchFailure(linalgOp,
"couldn't infer matmul iterators");
}
@@ -814,10 +795,8 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
// to plug a heuristic.
int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
kPos = maybeDimensions->k.back();
- LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
- DBGS() << "Start packing generic op greedily with (m@" << mPos
- << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp
- << "\n";);
+ LDBG() << "Start packing generic op greedily with (m@" << mPos << ", n@"
+ << nPos << ", k@" << kPos << "): " << linalgOp;
// 2.a. Rewrite as a generic.
auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
@@ -833,14 +812,14 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
// not change the indexings of any operand.
SmallVector<int64_t> permutation =
computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos);
- LLVM_DEBUG(DBGS() << "perm: " << llvm::interleaved(permutation) << "\n");
+ LDBG() << "perm: " << llvm::interleaved(permutation);
// Sign .. unsigned pollution.
SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end());
FailureOr<GenericOp> interchangeResult =
interchangeGenericOp(rewriter, genericOp, unsignedPerm);
assert(succeeded(interchangeResult) && "unexpected failure interchanging op");
genericOp = *interchangeResult;
- LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n";);
+ LDBG() << "Generalized Op to pack: " << genericOp;
// At this point, the op iterators are normalized to {leading, k, m, n}.
// The layouts induced by packing will always be:
@@ -862,12 +841,11 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
// Add leading zeros to match numLoops, we only pack the last 3 dimensions
// post interchange.
- LLVM_DEBUG(DBGS() << "paddedSizesNextMultipleOf: "
- << llvm::interleaved(paddedSizesNextMultipleOf) << "\n"
- << "loopRanges: "
- << llvm::interleaved(llvm::map_range(
- loopRanges, [](Range r) { return r.size; }))
- << "\n");
+ LDBG() << "paddedSizesNextMultipleOf: "
+ << llvm::interleaved(paddedSizesNextMultipleOf);
+ LDBG() << "loopRanges: "
+ << llvm::interleaved(
+ llvm::map_range(loopRanges, [](Range r) { return r.size; }));
SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(),
rewriter.getIndexAttr(0));
for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
@@ -883,8 +861,7 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
{loopRanges[adjustedPackedSizes.size()].size,
rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
}
- LLVM_DEBUG(DBGS() << "adjustedPackedSizes: "
- << llvm::interleaved(adjustedPackedSizes) << "\n");
+ LDBG() << "adjustedPackedSizes: " << llvm::interleaved(adjustedPackedSizes);
// TODO: If we wanted to give the genericOp a name after packing, after
// calling `pack` would be a good time. One would still need to check that
@@ -1214,9 +1191,8 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
}
srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end());
- LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"
- << "perm: " << llvm::interleaved(srcPermForTranspose)
- << "\n");
+ LDBG() << "Pack permutation: " << packOp;
+ LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose);
// 2.1 Create tensor.empty (init value for TransposeOp)
SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index a2a4335..2650488 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -59,12 +59,12 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
ArrayRef<int64_t>{1, 0});
Operation *newMatmulOp;
if (transposeLHS) {
- newMatmulOp = linalg::MatmulTransposeAOp::create(
+ newMatmulOp = MatmulTransposeAOp::create(
rewriter, loc, matmulOp.getResultTypes(),
ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
matmulOp.getOutputs());
} else {
- newMatmulOp = linalg::MatmulTransposeBOp::create(
+ newMatmulOp = MatmulTransposeBOp::create(
rewriter, loc, matmulOp.getResultTypes(),
ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
matmulOp.getOutputs());
@@ -116,12 +116,12 @@ mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
ArrayRef<int64_t>{0, 2, 1});
Operation *newMatmulOp;
if (transposeLHS) {
- newMatmulOp = linalg::BatchMatmulTransposeAOp::create(
+ newMatmulOp = BatchMatmulTransposeAOp::create(
rewriter, loc, batchMatmulOp.getResultTypes(),
ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
batchMatmulOp.getOutputs());
} else {
- newMatmulOp = linalg::BatchMatmulTransposeBOp::create(
+ newMatmulOp = BatchMatmulTransposeBOp::create(
rewriter, loc, batchMatmulOp.getResultTypes(),
ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
batchMatmulOp.getOutputs());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0860cea..406f05c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1805,7 +1805,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
inputShape[innerDimsPos[idx]] *= size;
auto maskedRead = vector::createReadOrMaskedRead(
rewriter, loc, packOp.getSource(), inputShape, padValue,
- useInBoundsInsteadOfMasking);
+ useInBoundsInsteadOfMasking,
+ /*inputScalableVecSizes=*/{});
// Create ShapeCastOp.
SmallVector<int64_t> destShape(inputVectorSizes);
@@ -1878,19 +1879,46 @@ static VectorType getCollapsedVecType(VectorType type,
return VectorType::get(newShape, type.getElementType(), newScalableFlags);
}
-/// Vectorize a `linalg::UnPackOp` to these 4 Ops:
-/// Vector::TransferReadOp - Reads a vector from the source tensor
-/// vector::TransposeOp - Transpose the Source tensor
-/// ShapeCastOp - Reshape the data based on the target.
-/// vector::TransferWriteOp. - Write the result vector back to the destination
-/// tensor.
-/// If the vector sizes are not provided:
-/// * the vector sizes are determined by the input operand and attributes,
-/// * update the inBounds attribute instead of masking.
+/// Vectorize `linalg.unpack` as:
+/// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
+///
+/// The input-vector-sizes specify the read vector sizes (i.e. the vector sizes
+/// for the xfer_read operation). This is sufficient to infer the other vector
+/// sizes required here.
+///
+/// If the vector sizes are not provided:
+/// * the vector sizes are determined from the input tensor static shape.
+/// * the inBounds attribute is used instead of masking.
+///
+/// EXAMPLE (no vector sizes):
+/// ```
+/// %unpack = linalg.unpack %src
+/// inner_dims_pos = [0, 1]
+/// inner_tiles = [8, 8]
+/// into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32>
+/// ```
+/// is vectorized as:
+/// ```
+/// %read = vector.transfer_read %src
+/// : tensor<1x1x8x8xf32>, vector<1x1x8x8xf32>
+/// %tr = vector.transpose %read, [0, 2, 1, 3]
+/// : vector<1x1x8x8xf32> to vector<1x8x1x8xf32>
+/// %sc = vector.shape_cast %tr
+/// : vector<1x8x1x8xf32> to vector<8x8xf32>
+/// %vector = vector.transfer_write %sc into %dest
+/// : vector<8x8xf32>, tensor<8x8xf32>
+/// ```
static LogicalResult
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
ArrayRef<int64_t> inputVectorSizes,
+ ArrayRef<bool> inputScalableVecDims,
SmallVectorImpl<Value> &newResults) {
+ if (!inputVectorSizes.empty()) {
+ assert(inputVectorSizes.size() == unpackOp.getSourceRank() &&
+ "Invalid number of input vector sizes!");
+ assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
+ "Incompatible number of vector sizes and vector scalable flags!");
+ }
// TODO: Introduce a parent class that will handle the insertion point update.
OpBuilder::InsertionGuard g(rewriter);
@@ -1898,88 +1926,40 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
RankedTensorType unpackTensorType = unpackOp.getSourceType();
- ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
- ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
bool useInBoundsInsteadOfMasking = false;
- ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
-
- auto destSize = unpackOp.getDestRank();
-
- if (!inputVectorSizes.empty())
- assert(inputVectorSizes.size() == destSize &&
- "Incorrect number of input vector sizes");
-
- // vectorSizes is the shape of the vector that will be used to do final
- // write on the destination tensor. It is set like this: Let's say the
- // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
- // Thus:
- // 1. vectorSizes = sourceShape.take_front(N)
- // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
- // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
- // innerTiles attribute value.
- SmallVector<int64_t> vectorSizes(inputVectorSizes);
- if (vectorSizes.empty()) {
- llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
- if (!outerDimsPerm.empty())
- applyPermutationToVector(vectorSizes, outerDimsPerm);
- for (auto [i, pos] : llvm::enumerate(innerDimPos))
- vectorSizes[pos] *= innerTiles[i];
- useInBoundsInsteadOfMasking = true;
- }
+ Location loc = unpackOp->getLoc();
- // readVectorSizes is the size of tensor used to read and apply mask. It is
- // set like this: Let's say the vectorSize (VS) array is size 'N' and
- // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
- // size M-N
- // Thus:
- // - initially: readVectorSizes = vectorInputSizes
- // - Divide all the readMaskShape locations pointed by innerDimPos
- // by the innerTileSize attribute value.
- // - if outer_dims_perms is present: do that permutation on readVectorSizes.
- // - Append the remaining shape from SS
- // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
- // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
- // 128] and outer_dims_perm is [1, 0] then read shape is:
- // ReadVectorSizes(initial): [512, 128]
- // Final Value(after innerDim Adjustment): [512/32, 128/16]
- // = [16, 8]
- // After applying outer_dims_perm: [8, 16]
- // After appending the rest of the sourceShape: [8, 16, 32, 16]
-
- SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
-
- for (auto [index, size] : enumerate(innerTiles)) {
- readVectorSizes[innerDimPos[index]] =
- llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
- }
- if (!outerDimsPerm.empty()) {
- applyPermutationToVector(readVectorSizes, outerDimsPerm);
- }
- readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
- sourceShape.end());
+ // Obtain vector sizes for the read operation.
+ SmallVector<int64_t> readVectorSizes(inputVectorSizes);
+ SmallVector<bool> readScalableVectorFlags(inputScalableVecDims);
- Location loc = unpackOp->getLoc();
+ // In the absence of input-vector-sizes, use the _static_ input tensor shape.
+ if (inputVectorSizes.empty()) {
+ if (ShapedType::isDynamicShape(sourceShape))
+ return failure();
+ readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
+ useInBoundsInsteadOfMasking = true;
+ }
+
+ // -- Generate the read operation --
auto padValue = arith::ConstantOp::create(
rewriter, loc,
rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
-
- // Read result, mask if necessary. If transferReadOp shape is not equal
- // to shape of source, then a mask is necessary.
Value readResult = vector::createReadOrMaskedRead(
rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
- /*useInBoundsInsteadOfMasking=*/false);
+ useInBoundsInsteadOfMasking, readScalableVectorFlags);
+ // -- Generate the transpose operation --
PackingMetadata packMetadata;
SmallVector<int64_t> lastDimToInsertPosPerm =
getUnPackInverseSrcPerm(unpackOp, packMetadata);
- // Transpose the appropriate rows to match output.
vector::TransposeOp transposeOp = vector::TransposeOp::create(
rewriter, loc, readResult, lastDimToInsertPosPerm);
- // Collapse the vector to the size required by result.
+ // -- Generate the shape_cast operation --
VectorType collapsedVecType = getCollapsedVecType(
transposeOp.getType(),
getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
@@ -1987,9 +1967,11 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
rewriter, loc, collapsedVecType, transposeOp->getResult(0));
+ // -- Generate the write operation --
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
+
newResults.push_back(write->getResult(0));
return success();
}
@@ -2016,7 +1998,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
assert(succeeded(status) && "failed to reify result shapes");
auto maskedRead = vector::createReadOrMaskedRead(
rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
- /*useInBoundsInsteadOfMasking=*/false);
+ /*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{});
// Create Xfer write Op
Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
@@ -2095,24 +2077,34 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
return success();
}
-/// Need to check if the inner-tiles are static/constant.
+//// This hook considers two cases:
+/// (1) If the input-vector-sizes are empty, then the vector sizes will be
+/// infered. This is only possible when all shapes are static.
+/// (2) If the input-vector-sizes are non-empty (i.e. user provided), then
+/// carry out basic sanity-checking.
static LogicalResult
vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
ArrayRef<int64_t> inputVectorSizes) {
+ // If there are no input vector sizes and all shapes are static, there is
+ // nothing left to check.
+ if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
+ unpackOp.getSourceType().hasStaticShape())
+ return success();
- if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
- return !getConstantIntValue(res).has_value();
- })) {
- LDBG() << "Inner-tiles must be constant: " << unpackOp;
+ // The number of input vector sizes must be equal to:
+ // * read-vector-rank
+ if (!inputVectorSizes.empty() &&
+ (inputVectorSizes.size() != unpackOp.getSourceRank())) {
+ LDBG() << "Incorrect number of input vector sizes";
return failure();
}
- ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
- bool satisfyEmptyCond = inputVectorSizes.empty() &&
- unpackOp.getDestType().hasStaticShape() &&
- unpackOp.getSourceType().hasStaticShape();
- if (!satisfyEmptyCond &&
- failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
+
+ // Check the vector sizes for the read operation.
+ if (failed(vector::isValidMaskedInputVector(
+ unpackOp.getSourceType().getShape(), inputVectorSizes))) {
+ LDBG() << "Invalid vector sizes for the read operation";
return failure();
+ }
return success();
}
@@ -2436,6 +2428,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
LDBG() << "pad value is not constant: " << packOp;
return failure();
}
+
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
bool satisfyEmptyCond = true;
if (inputVectorSizes.empty()) {
@@ -2499,8 +2492,12 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
return success();
}
-/// Preconditions for scalable vectors. This is quite restrictive - it models
-/// the fact that in practice we would only make selected dimensions scalable.
+/// Preconditions for scalable vectors.
+///
+/// For Ops implementing the LinalgOp interface, this is quite restrictive - it
+/// models the fact that in practice we would only make selected dimensions
+/// scalable. For other Ops (e.g. `linalg.unpack`), this will succeed
+/// unconditionally - we are yet to identify meaningful conditions.
static LogicalResult
vectorizeScalableVectorPrecondition(Operation *op,
ArrayRef<int64_t> inputVectorSizes,
@@ -2516,10 +2513,11 @@ vectorizeScalableVectorPrecondition(Operation *op,
auto linalgOp = dyn_cast<LinalgOp>(op);
- // Cond 1: There's been no need for scalable vectorisation of
- // non-linalg Ops so far
- if (!linalgOp)
- return failure();
+ // Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
+ // exception of UnpackOp for which there is a dedicated hook.
+ if (!linalgOp) {
+ return success(isa<linalg::UnPackOp>(op));
+ }
// Cond 2: There's been no need for more than 2 scalable dims so far
if (numOfScalableDims > 2)
@@ -2565,7 +2563,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
"vectorization";
return failure();
}
- if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
+ if (isa<linalg::MatmulOp>(op)) {
LDBG()
<< "Scalable vectorization of the reduction dim in Matmul-like ops "
"is not supported";
@@ -2606,17 +2604,12 @@ vectorizeScalableVectorPrecondition(Operation *op,
return failure();
}
- // Check to not let go the matmul with extended semantic, through this
- // transform.
- if (linalgOp.hasUserDefinedMaps())
- return failure();
-
// Cond 4: Only the following ops are supported in the
// presence of scalable vectors
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
- isa<linalg::MatmulTransposeAOp>(op) ||
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
+ isa<linalg::BatchMmt4DOp>(op) ||
hasReductionIterator(linalgOp));
}
@@ -2750,7 +2743,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
})
.Case<linalg::UnPackOp>([&](auto unpackOp) {
return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
- inputVectorSizes, results);
+ inputVectorSizes,
+ inputScalableVecDims, results);
})
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
@@ -3142,7 +3136,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
Value read = mlir::vector::createReadOrMaskedRead(
rewriter, loc, source, vecType.getShape(), padValue,
- /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
+ /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(),
+ /*inputScalableVecSizes=*/{});
// Create write
auto writeIndices =
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index e1c0c24..d37a056 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -1,6 +1,6 @@
add_mlir_dialect_library(MLIRMathTransforms
AlgebraicSimplification.cpp
- ExpandPatterns.cpp
+ ExpandOps.cpp
ExtendToSupportedTypes.cpp
PolynomialApproximation.cpp
UpliftToFMA.cpp
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index 4a40a30..cd68039 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -13,14 +13,18 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHEXPANDOPSPASS
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
/// Create a float constant.
static Value createFloatConst(Location loc, Type type, APFloat value,
OpBuilder &b) {
@@ -661,66 +665,77 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op,
return success();
}
-void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
- patterns.add(convertCtlzOp);
-}
-
-void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) {
- patterns.add(convertSinhOp);
-}
-
-void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) {
- patterns.add(convertCoshOp);
-}
-
-void mlir::populateExpandTanPattern(RewritePatternSet &patterns) {
- patterns.add(convertTanOp);
-}
-
-void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
- patterns.add(convertTanhOp);
-}
-
-void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) {
- patterns.add(convertAsinhOp);
-}
-
-void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) {
- patterns.add(convertAcoshOp);
-}
-
-void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) {
- patterns.add(convertAtanhOp);
-}
-
-void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
- patterns.add(convertFmaFOp);
-}
-
-void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) {
- patterns.add(convertCeilOp);
-}
-
-void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
- patterns.add(convertExp2fOp);
-}
-
-void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
- patterns.add(convertPowfOp);
-}
-
-void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
- patterns.add(convertFPowIOp);
-}
-
-void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
- patterns.add(convertRoundOp);
+// Convert `math.clampf` into `arith.minimumf` + `arith.maximumf`
+static LogicalResult convertClampfOp(math::ClampFOp op,
+ PatternRewriter &rewriter) {
+ auto minOp = arith::MinimumFOp::create(rewriter, op.getLoc(), op.getValue(),
+ op.getMin(), op.getFastmath());
+ rewriter.replaceOpWithNewOp<arith::MaximumFOp>(op, minOp, op.getMax(),
+ op.getFastmath());
+ return success();
}
-void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) {
- patterns.add(convertRoundEvenOp);
+void mlir::math::populateExpansionPatterns(RewritePatternSet &patterns,
+ ArrayRef<StringRef> opMnemonics) {
+ auto filter = [&](StringRef name) {
+ // This should be a static assert and `consume_front` take a twine, but none
+ // is currently possible. TODO: augment `StringRef::consume_front` and make
+ // `getDialectNamespace` use `std::string_view`.
+ assert("math" == MathDialect::getDialectNamespace());
+ name.consume_front("math.");
+ return opMnemonics.empty() || (llvm::count(opMnemonics, name) > 0);
+ };
+ if (filter(CountLeadingZerosOp::getOperationName()))
+ patterns.add(convertCtlzOp);
+ if (filter(SinhOp::getOperationName()))
+ patterns.add(convertSinhOp);
+ if (filter(CoshOp::getOperationName()))
+ patterns.add(convertCoshOp);
+ if (filter(TanOp::getOperationName()))
+ patterns.add(convertTanOp);
+ if (filter(TanhOp::getOperationName()))
+ patterns.add(convertTanhOp);
+ if (filter(AsinhOp::getOperationName()))
+ patterns.add(convertAsinhOp);
+ if (filter(AcoshOp::getOperationName()))
+ patterns.add(convertAcoshOp);
+ if (filter(AtanhOp::getOperationName()))
+ patterns.add(convertAtanhOp);
+ if (filter(FmaOp::getOperationName()))
+ patterns.add(convertFmaFOp);
+ if (filter(CeilOp::getOperationName()))
+ patterns.add(convertCeilOp);
+ if (filter(Exp2Op::getOperationName()))
+ patterns.add(convertExp2fOp);
+ if (filter(PowFOp::getOperationName()))
+ patterns.add(convertPowfOp);
+ if (filter(FPowIOp::getOperationName()))
+ patterns.add(convertFPowIOp);
+ if (filter(RoundOp::getOperationName()))
+ patterns.add(convertRoundOp);
+ if (filter(RoundEvenOp::getOperationName()))
+ patterns.add(convertRoundEvenOp);
+ if (filter(RsqrtOp::getOperationName()))
+ patterns.add(convertRsqrtOp);
+ if (filter(ClampFOp::getOperationName()))
+ patterns.add(convertClampfOp);
}
-void mlir::populateExpandRsqrtPattern(RewritePatternSet &patterns) {
- patterns.add(convertRsqrtOp);
-}
+//===----------------------------------------------------------------------===//
+// MathExpandOpsPass pass
+//===----------------------------------------------------------------------===//
+namespace {
+struct MathExpandOpsPass final
+ : math::impl::MathExpandOpsPassBase<MathExpandOpsPass> {
+ using MathExpandOpsPassBase::MathExpandOpsPassBase;
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ SmallVector<StringRef> mnemonics =
+ llvm::to_vector_of<StringRef>(opMnemonics);
+ math::populateExpansionPatterns(patterns, mnemonics);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 74b968c..b59d73d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3558,6 +3558,7 @@ LogicalResult AtomicRMWOp::verify() {
case arith::AtomicRMWKind::minu:
case arith::AtomicRMWKind::muli:
case arith::AtomicRMWKind::ori:
+ case arith::AtomicRMWKind::xori:
case arith::AtomicRMWKind::andi:
if (!llvm::isa<IntegerType>(getValue().getType()))
return emitOpError() << "with kind '"
diff --git a/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp
index bbb269b..1939195 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp
@@ -21,9 +21,9 @@ namespace {
struct ReallocOpInterface
: public BufferViewFlowOpInterface::ExternalModel<ReallocOpInterface,
ReallocOp> {
- void
- populateDependencies(Operation *op,
- RegisterDependenciesFn registerDependenciesFn) const {
+ void populateDependencies(
+ Operation *op,
+ const RegisterDependenciesFn &registerDependenciesFn) const {
auto reallocOp = cast<ReallocOp>(op);
// memref.realloc may return the source operand.
registerDependenciesFn(reallocOp.getSource(), reallocOp.getResult());
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 9771bd2..d35566a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -959,7 +959,7 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
PatternRewriter &rewriter) const override {
auto viewLikeOp =
extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
- if (!viewLikeOp)
+ if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest())
return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source");
rewriter.modifyOpInPlace(extractOp, [&]() {
extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index 5d3cec4..860384f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -43,50 +43,34 @@ static bool overrideBuffer(Operation *op, Value buffer) {
/// propagate the type change and erase old subview ops.
static void replaceUsesAndPropagateType(RewriterBase &rewriter,
Operation *oldOp, Value val) {
- SmallVector<Operation *> opsToDelete;
- SmallVector<OpOperand *> operandsToReplace;
-
- // Save the operand to replace / delete later (avoid iterator invalidation).
- // TODO: can we use an early_inc iterator?
- for (OpOperand &use : oldOp->getUses()) {
- // Non-subview ops will be replaced by `val`.
- auto subviewUse = dyn_cast<memref::SubViewOp>(use.getOwner());
- if (!subviewUse) {
- operandsToReplace.push_back(&use);
+ // Iterate with early_inc to erase current user inside the loop.
+ for (OpOperand &use : llvm::make_early_inc_range(oldOp->getUses())) {
+ Operation *user = use.getOwner();
+ if (auto subviewUse = dyn_cast<memref::SubViewOp>(user)) {
+ // `subview(old_op)` is replaced by a new `subview(val)`.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(subviewUse);
+ MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
+ subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
+ subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
+ subviewUse.getStaticStrides());
+ Value newSubview = memref::SubViewOp::create(
+ rewriter, subviewUse->getLoc(), newType, val,
+ subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
+ subviewUse.getMixedStrides());
+
+ // Ouch recursion ... is this really necessary?
+ replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
+
+ // Safe to erase.
+ rewriter.eraseOp(subviewUse);
continue;
}
-
- // `subview(old_op)` is replaced by a new `subview(val)`.
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(subviewUse);
- MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
- subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
- subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
- subviewUse.getStaticStrides());
- Value newSubview = memref::SubViewOp::create(
- rewriter, subviewUse->getLoc(), newType, val,
- subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
- subviewUse.getMixedStrides());
-
- // Ouch recursion ... is this really necessary?
- replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
-
- opsToDelete.push_back(use.getOwner());
+ // Non-subview: replace with new value.
+ rewriter.startOpModification(user);
+ use.set(val);
+ rewriter.finalizeOpModification(user);
}
-
- // Perform late replacement.
- // TODO: can we use an early_inc iterator?
- for (OpOperand *operand : operandsToReplace) {
- Operation *op = operand->getOwner();
- rewriter.startOpModification(op);
- operand->set(val);
- rewriter.finalizeOpModification(op);
- }
-
- // Perform late op erasure.
- // TODO: can we use an early_inc iterator?
- for (Operation *op : opsToDelete)
- rewriter.eraseOp(op);
}
// Transformation to do multi-buffering/array expansion to remove dependencies
@@ -216,8 +200,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
offsets, sizes, strides);
LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");
- // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to
- // handle dealloc uses separately..
+ // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need
+ // to handle dealloc uses separately..
for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) {
auto deallocOp = dyn_cast<memref::DeallocOp>(use.getOwner());
if (!deallocOp)
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index 5af46a4..3de9c38 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -210,8 +210,10 @@ MemrefValue skipFullyAliasingOperations(MemrefValue source) {
MemrefValue skipViewLikeOps(MemrefValue source) {
while (auto op = source.getDefiningOp()) {
if (auto viewLike = dyn_cast<ViewLikeOpInterface>(op)) {
- source = cast<MemrefValue>(viewLike.getViewSource());
- continue;
+ if (source == viewLike.getViewDest()) {
+ source = cast<MemrefValue>(viewLike.getViewSource());
+ continue;
+ }
}
return source;
}
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index cc03974..8474244 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -345,6 +345,19 @@ LogicalResult LdMatrixOp::verify() {
// NVGPU_TmaAsyncLoadOp
//===----------------------------------------------------------------------===//
+unsigned getSwizzleBytes(TensorMapSwizzleKind kind) {
+ switch (kind) {
+ case TensorMapSwizzleKind::SWIZZLE_32B:
+ return 32;
+ case TensorMapSwizzleKind::SWIZZLE_64B:
+ return 64;
+ case TensorMapSwizzleKind::SWIZZLE_128B:
+ return 128;
+ default:
+ return 0;
+ }
+}
+
std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
Operation *op, nvgpu::TensorMapDescriptorType descType,
std::optional<MemRefType> memrefType = std::nullopt) {
@@ -373,10 +386,11 @@ std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
descType.getSwizzle() != TensorMapSwizzleKind::SWIZZLE_NONE) {
unsigned lastDimensionByte =
descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
- if (lastDimensionByte != kMaxTMALastdimByte)
+ unsigned expectByte = getSwizzleBytes(descType.getSwizzle());
+ if (lastDimensionByte != expectByte)
return op->emitError() << "the tensormap descriptor must have last "
"dimension of "
- << kMaxTMALastdimByte << " bytes but it is "
+ << expectByte << " bytes but it is "
<< lastDimensionByte << " bytes";
}
@@ -408,6 +422,12 @@ std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
<< descMemref << " != " << dstMemref;
}
+ int lastDimBytes =
+ descMemref.getShape().back() * descMemref.getElementTypeBitWidth() / 8;
+ if (lastDimBytes % 16 != 0) {
+ return op->emitError() << "the bytes in the last dimension of the tensor "
+ "map must be a multiple of 16";
+ }
return std::nullopt;
}
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 485bb73..ded4c7a 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -173,9 +173,7 @@ void OpenACCDialect::initialize() {
//===----------------------------------------------------------------------===//
static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
- if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
- return true;
- return false;
+ return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
}
static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
@@ -1390,6 +1388,36 @@ void acc::ParallelOp::addPrivatization(MLIRContext *context,
setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
+void acc::ParallelOp::addFirstPrivatization(
+ MLIRContext *context, mlir::acc::FirstprivateOp op,
+ mlir::acc::FirstprivateRecipeOp recipe) {
+ getFirstprivateOperandsMutable().append(op.getResult());
+
+ llvm::SmallVector<mlir::Attribute> recipes;
+
+ if (getFirstprivatizationRecipesAttr())
+ llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
+
+ recipes.push_back(
+ mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
+ setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
+}
+
+void acc::ParallelOp::addReduction(MLIRContext *context,
+ mlir::acc::ReductionOp op,
+ mlir::acc::ReductionRecipeOp recipe) {
+ getReductionOperandsMutable().append(op.getResult());
+
+ llvm::SmallVector<mlir::Attribute> recipes;
+
+ if (getReductionRecipesAttr())
+ llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
+
+ recipes.push_back(
+ mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
+ setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
+}
+
static ParseResult parseNumGangs(
mlir::OpAsmParser &parser,
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
@@ -2041,6 +2069,36 @@ void acc::SerialOp::addPrivatization(MLIRContext *context,
setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
+void acc::SerialOp::addFirstPrivatization(
+ MLIRContext *context, mlir::acc::FirstprivateOp op,
+ mlir::acc::FirstprivateRecipeOp recipe) {
+ getFirstprivateOperandsMutable().append(op.getResult());
+
+ llvm::SmallVector<mlir::Attribute> recipes;
+
+ if (getFirstprivatizationRecipesAttr())
+ llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
+
+ recipes.push_back(
+ mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
+ setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
+}
+
+void acc::SerialOp::addReduction(MLIRContext *context,
+ mlir::acc::ReductionOp op,
+ mlir::acc::ReductionRecipeOp recipe) {
+ getReductionOperandsMutable().append(op.getResult());
+
+ llvm::SmallVector<mlir::Attribute> recipes;
+
+ if (getReductionRecipesAttr())
+ llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
+
+ recipes.push_back(
+ mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
+ setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
+}
+
//===----------------------------------------------------------------------===//
// KernelsOp
//===----------------------------------------------------------------------===//
@@ -3059,6 +3117,20 @@ void acc::LoopOp::addPrivatization(MLIRContext *context,
setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
+void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op,
+ mlir::acc::ReductionRecipeOp recipe) {
+ getReductionOperandsMutable().append(op.getResult());
+
+ llvm::SmallVector<mlir::Attribute> recipes;
+
+ if (getReductionRecipesAttr())
+ llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
+
+ recipes.push_back(
+ mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
+ setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
+}
+
//===----------------------------------------------------------------------===//
// DataOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index c1c1767..6e43f28 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -3874,6 +3874,159 @@ LogicalResult AllocateDirOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// TargetAllocMemOp
+//===----------------------------------------------------------------------===//
+
+mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
+ return getInTypeAttr().getValue();
+}
+
+/// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype,
+/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
+/// attr-dict-without-keyword
+static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser,
+ mlir::OperationState &result) {
+ auto &builder = parser.getBuilder();
+ bool hasOperands = false;
+ std::int32_t typeparamsSize = 0;
+
+ // Parse device number as a new operand
+ mlir::OpAsmParser::UnresolvedOperand deviceOperand;
+ mlir::Type deviceType;
+ if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType))
+ return mlir::failure();
+ if (parser.resolveOperand(deviceOperand, deviceType, result.operands))
+ return mlir::failure();
+ if (parser.parseComma())
+ return mlir::failure();
+
+ mlir::Type intype;
+ if (parser.parseType(intype))
+ return mlir::failure();
+ result.addAttribute("in_type", mlir::TypeAttr::get(intype));
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands;
+ llvm::SmallVector<mlir::Type> typeVec;
+ if (!parser.parseOptionalLParen()) {
+ // parse the LEN params of the derived type. (<params> : <types>)
+ if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) ||
+ parser.parseColonTypeList(typeVec) || parser.parseRParen())
+ return mlir::failure();
+ typeparamsSize = operands.size();
+ hasOperands = true;
+ }
+ std::int32_t shapeSize = 0;
+ if (!parser.parseOptionalComma()) {
+ // parse size to scale by, vector of n dimensions of type index
+ if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None))
+ return mlir::failure();
+ shapeSize = operands.size() - typeparamsSize;
+ auto idxTy = builder.getIndexType();
+ for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
+ typeVec.push_back(idxTy);
+ hasOperands = true;
+ }
+ if (hasOperands &&
+ parser.resolveOperands(operands, typeVec, parser.getNameLoc(),
+ result.operands))
+ return mlir::failure();
+
+ mlir::Type restype = builder.getIntegerType(64);
+ if (!restype) {
+ parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype;
+ return mlir::failure();
+ }
+ llvm::SmallVector<std::int32_t> segmentSizes{1, typeparamsSize, shapeSize};
+ result.addAttribute("operandSegmentSizes",
+ builder.getDenseI32ArrayAttr(segmentSizes));
+ if (parser.parseOptionalAttrDict(result.attributes) ||
+ parser.addTypeToList(restype, result.types))
+ return mlir::failure();
+ return mlir::success();
+}
+
+mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser,
+ mlir::OperationState &result) {
+ return parseTargetAllocMemOp(parser, result);
+}
+
+void omp::TargetAllocMemOp::print(mlir::OpAsmPrinter &p) {
+ p << " ";
+ p.printOperand(getDevice());
+ p << " : ";
+ p << getDevice().getType();
+ p << ", ";
+ p << getInType();
+ if (!getTypeparams().empty()) {
+ p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')';
+ }
+ for (auto sh : getShape()) {
+ p << ", ";
+ p.printOperand(sh);
+ }
+ p.printOptionalAttrDict((*this)->getAttrs(),
+ {"in_type", "operandSegmentSizes"});
+}
+
+llvm::LogicalResult omp::TargetAllocMemOp::verify() {
+ mlir::Type outType = getType();
+ if (!mlir::dyn_cast<IntegerType>(outType))
+ return emitOpError("must be a integer type");
+ return mlir::success();
+}
+
+//===----------------------------------------------------------------------===//
+// WorkdistributeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WorkdistributeOp::verify() {
+ // Check that region exists and is not empty
+ Region &region = getRegion();
+ if (region.empty())
+ return emitOpError("region cannot be empty");
+ // Verify single entry point.
+ Block &entryBlock = region.front();
+ if (entryBlock.empty())
+ return emitOpError("region must contain a structured block");
+ // Verify single exit point.
+ bool hasTerminator = false;
+ for (Block &block : region) {
+ if (isa<TerminatorOp>(block.back())) {
+ if (hasTerminator) {
+ return emitOpError("region must have exactly one terminator");
+ }
+ hasTerminator = true;
+ }
+ }
+ if (!hasTerminator) {
+ return emitOpError("region must be terminated with omp.terminator");
+ }
+ auto walkResult = region.walk([&](Operation *op) -> WalkResult {
+ // No implicit barrier at end
+ if (isa<BarrierOp>(op)) {
+ return emitOpError(
+ "explicit barriers are not allowed in workdistribute region");
+ }
+ // Check for invalid nested constructs
+ if (isa<ParallelOp>(op)) {
+ return emitOpError(
+ "nested parallel constructs not allowed in workdistribute");
+ }
+ if (isa<TeamsOp>(op)) {
+ return emitOpError(
+ "nested teams constructs not allowed in workdistribute");
+ }
+ return WalkResult::advance();
+ });
+ if (walkResult.wasInterrupted())
+ return failure();
+
+ Operation *parentOp = (*this)->getParentOp();
+ if (!llvm::dyn_cast<TeamsOp>(parentOp))
+ return emitOpError("workdistribute must be nested under teams");
+ return success();
+}
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
index 497468b..bd1e655 100644
--- a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
@@ -1,3 +1,22 @@
+set(LLVM_OPTIONAL_SOURCES
+ MemorySpaceInterfaces.cpp
+ PtrAttrs.cpp
+ PtrTypes.cpp
+ PtrDialect.cpp
+)
+
+add_mlir_dialect_library(
+ MLIRPtrMemorySpaceInterfaces
+ MemorySpaceInterfaces.cpp
+
+ DEPENDS
+ MLIRPtrOpsEnumsGen
+ MLIRPtrMemorySpaceInterfacesIncGen
+ LINK_LIBS
+ PUBLIC
+ MLIRIR
+)
+
add_mlir_dialect_library(
MLIRPtrDialect
PtrAttrs.cpp
@@ -15,4 +34,5 @@ add_mlir_dialect_library(
MLIRDataLayoutInterfaces
MLIRMemorySlotInterfaces
MLIRViewLikeInterface
+ MLIRPtrMemorySpaceInterfaces
)
diff --git a/mlir/lib/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp b/mlir/lib/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp
new file mode 100644
index 0000000..059e67f
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp
@@ -0,0 +1,15 @@
+//===-- MemorySpaceInterfaces.cpp - ptr memory space interfaces -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the ptr dialect memory space interfaces.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
+
+#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.cpp.inc"
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp
index 772d25d..ac3bcd6 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp
@@ -22,26 +22,30 @@ constexpr const static unsigned kBitsInByte = 8;
//===----------------------------------------------------------------------===//
bool GenericSpaceAttr::isValidLoad(
- Type type, ptr::AtomicOrdering ordering, IntegerAttr alignment,
+ Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
+ const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError) const {
return true;
}
bool GenericSpaceAttr::isValidStore(
- Type type, ptr::AtomicOrdering ordering, IntegerAttr alignment,
+ Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
+ const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError) const {
return true;
}
bool GenericSpaceAttr::isValidAtomicOp(
ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering,
- IntegerAttr alignment, function_ref<InFlightDiagnostic()> emitError) const {
+ std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
+ function_ref<InFlightDiagnostic()> emitError) const {
return true;
}
bool GenericSpaceAttr::isValidAtomicXchg(
Type type, ptr::AtomicOrdering successOrdering,
- ptr::AtomicOrdering failureOrdering, IntegerAttr alignment,
+ ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
+ const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError) const {
return true;
}
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index c5ec0ca..d5976b9 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -85,6 +85,124 @@ LogicalResult FromPtrOp::verify() {
}
//===----------------------------------------------------------------------===//
+// LoadOp
+//===----------------------------------------------------------------------===//
+
+/// Verifies the attributes and the type of atomic memory access operations.
+template <typename OpTy>
+static LogicalResult
+verifyAtomicMemOp(OpTy memOp, ArrayRef<AtomicOrdering> unsupportedOrderings) {
+ if (memOp.getOrdering() != AtomicOrdering::not_atomic) {
+ if (llvm::is_contained(unsupportedOrderings, memOp.getOrdering()))
+ return memOp.emitOpError("unsupported ordering '")
+ << stringifyAtomicOrdering(memOp.getOrdering()) << "'";
+ if (!memOp.getAlignment())
+ return memOp.emitOpError("expected alignment for atomic access");
+ return success();
+ }
+ if (memOp.getSyncscope()) {
+ return memOp.emitOpError(
+ "expected syncscope to be null for non-atomic access");
+ }
+ return success();
+}
+
+/// Verifies that the alignment attribute is a power of 2 if present.
+static LogicalResult
+verifyAlignment(std::optional<int64_t> alignment,
+ function_ref<InFlightDiagnostic()> emitError) {
+ if (!alignment)
+ return success();
+ if (alignment.value() <= 0)
+ return emitError() << "alignment must be positive";
+ if (!llvm::isPowerOf2_64(alignment.value()))
+ return emitError() << "alignment must be a power of 2";
+ return success();
+}
+
+void LoadOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable());
+ // Volatile operations can have target-specific read-write effects on
+ // memory besides the one referred to by the pointer operand.
+ // Similarly, atomic operations that are monotonic or stricter cause
+ // synchronization that from a language point-of-view, are arbitrary
+ // read-writes into memory.
+ if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
+ getOrdering() != AtomicOrdering::unordered)) {
+ effects.emplace_back(MemoryEffects::Write::get());
+ effects.emplace_back(MemoryEffects::Read::get());
+ }
+}
+
+LogicalResult LoadOp::verify() {
+ auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
+ MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
+ DataLayout dataLayout = DataLayout::closest(*this);
+ if (!ms.isValidLoad(getResult().getType(), getOrdering(), getAlignment(),
+ &dataLayout, emitDiag))
+ return failure();
+ if (failed(verifyAlignment(getAlignment(), emitDiag)))
+ return failure();
+ return verifyAtomicMemOp(*this,
+ {AtomicOrdering::release, AtomicOrdering::acq_rel});
+}
+
+void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
+ Value addr, unsigned alignment, bool isVolatile,
+ bool isNonTemporal, bool isInvariant, bool isInvariantGroup,
+ AtomicOrdering ordering, StringRef syncscope) {
+ build(builder, state, type, addr,
+ alignment ? std::optional<int64_t>(alignment) : std::nullopt,
+ isVolatile, isNonTemporal, isInvariant, isInvariantGroup, ordering,
+ syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
+}
+
+//===----------------------------------------------------------------------===//
+// StoreOp
+//===----------------------------------------------------------------------===//
+
+void StoreOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Write::get(), &getPtrMutable());
+ // Volatile operations can have target-specific read-write effects on
+ // memory besides the one referred to by the pointer operand.
+ // Similarly, atomic operations that are monotonic or stricter cause
+ // synchronization that from a language point-of-view, are arbitrary
+ // read-writes into memory.
+ if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
+ getOrdering() != AtomicOrdering::unordered)) {
+ effects.emplace_back(MemoryEffects::Write::get());
+ effects.emplace_back(MemoryEffects::Read::get());
+ }
+}
+
+LogicalResult StoreOp::verify() {
+ auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
+ MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
+ DataLayout dataLayout = DataLayout::closest(*this);
+ if (!ms.isValidStore(getValue().getType(), getOrdering(), getAlignment(),
+ &dataLayout, emitDiag))
+ return failure();
+ if (failed(verifyAlignment(getAlignment(), emitDiag)))
+ return failure();
+ return verifyAtomicMemOp(*this,
+ {AtomicOrdering::acquire, AtomicOrdering::acq_rel});
+}
+
+void StoreOp::build(OpBuilder &builder, OperationState &state, Value value,
+ Value addr, unsigned alignment, bool isVolatile,
+ bool isNonTemporal, bool isInvariantGroup,
+ AtomicOrdering ordering, StringRef syncscope) {
+ build(builder, state, value, addr,
+ alignment ? std::optional<int64_t>(alignment) : std::nullopt,
+ isVolatile, isNonTemporal, isInvariantGroup, ordering,
+ syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
+}
+
+//===----------------------------------------------------------------------===//
// PtrAddOp
//===----------------------------------------------------------------------===//
@@ -152,10 +270,6 @@ llvm::TypeSize TypeOffsetOp::getTypeSize(std::optional<DataLayout> layout) {
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc"
-#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp.inc"
-
-#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.cpp.inc"
-
#include "mlir/Dialect/Ptr/IR/PtrOpsEnums.cpp.inc"
#define GET_TYPEDEF_CLASSES
diff --git a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
index 825d119..deb7109 100644
--- a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
@@ -4,7 +4,7 @@ add_mlir_dialect_library(MLIRQuantTransforms
StripFuncQuantTypes.cpp
ADDITIONAL_HEADER_DIRS
- {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Quant/Transforms
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Quant/Transforms
DEPENDS
MLIRQuantTransformsIncGen
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 0262a1b..84f9777 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -157,8 +157,7 @@ void ExecuteRegionOp::print(OpAsmPrinter &p) {
p.printRegion(getRegion(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
-
- p.printOptionalAttrDict((*this)->getAttrs());
+ p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"no_inline"});
}
LogicalResult ExecuteRegionOp::verify() {
@@ -318,9 +317,12 @@ void ConditionOp::getSuccessorRegions(
void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
Value ub, Value step, ValueRange initArgs,
- BodyBuilderFn bodyBuilder) {
+ BodyBuilderFn bodyBuilder, bool unsignedCmp) {
OpBuilder::InsertionGuard guard(builder);
+ if (unsignedCmp)
+ result.addAttribute(getUnsignedCmpAttrName(result.name),
+ builder.getUnitAttr());
result.addOperands({lb, ub, step});
result.addOperands(initArgs);
for (Value v : initArgs)
@@ -450,6 +452,9 @@ static void printInitializationList(OpAsmPrinter &p,
}
void ForOp::print(OpAsmPrinter &p) {
+ if (getUnsignedCmp())
+ p << " unsigned";
+
p << " " << getInductionVar() << " = " << getLowerBound() << " to "
<< getUpperBound() << " step " << getStep();
@@ -462,7 +467,8 @@ void ForOp::print(OpAsmPrinter &p) {
p.printRegion(getRegion(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/!getInitArgs().empty());
- p.printOptionalAttrDict((*this)->getAttrs());
+ p.printOptionalAttrDict((*this)->getAttrs(),
+ /*elidedAttrs=*/getUnsignedCmpAttrName().strref());
}
ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -472,6 +478,10 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::Argument inductionVariable;
OpAsmParser::UnresolvedOperand lb, ub, step;
+ if (succeeded(parser.parseOptionalKeyword("unsigned")))
+ result.addAttribute(getUnsignedCmpAttrName(result.name),
+ builder.getUnitAttr());
+
// Parse the induction variable followed by '='.
if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
// Parse loop bounds.
@@ -562,7 +572,7 @@ ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
inits.append(newInitOperands.begin(), newInitOperands.end());
scf::ForOp newLoop = scf::ForOp::create(
rewriter, getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
- [](OpBuilder &, Location, Value, ValueRange) {});
+ [](OpBuilder &, Location, Value, ValueRange) {}, getUnsignedCmp());
newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
// Generate the new yield values and append them to the scf.yield operation.
@@ -806,7 +816,8 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
// 2. Create the new forOp shell.
scf::ForOp newForOp = scf::ForOp::create(
rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), newIterOperands);
+ forOp.getStep(), newIterOperands, /*bodyBuilder=*/nullptr,
+ forOp.getUnsignedCmp());
newForOp->setAttrs(forOp->getAttrs());
Block &newBlock = newForOp.getRegion().front();
SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
@@ -931,7 +942,8 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
scf::ForOp newForOp =
scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
- forOp.getUpperBound(), forOp.getStep(), newIterArgs);
+ forOp.getUpperBound(), forOp.getStep(), newIterArgs,
+ /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
newForOp->setAttrs(forOp->getAttrs());
Block &newBlock = newForOp.getRegion().front();
@@ -989,12 +1001,12 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
/// Util function that tries to compute a constant diff between u and l.
/// Returns std::nullopt when the difference between two AffineValueMap is
/// dynamic.
-static std::optional<int64_t> computeConstDiff(Value l, Value u) {
+static std::optional<APInt> computeConstDiff(Value l, Value u) {
IntegerAttr clb, cub;
if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
llvm::APInt lbValue = clb.getValue();
llvm::APInt ubValue = cub.getValue();
- return (ubValue - lbValue).getSExtValue();
+ return ubValue - lbValue;
}
// Else a simple pattern match for x + c or c + x
@@ -1003,7 +1015,7 @@ static std::optional<int64_t> computeConstDiff(Value l, Value u) {
u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
matchPattern(
u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
- return diff.getSExtValue();
+ return diff;
return std::nullopt;
}
@@ -1022,13 +1034,15 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
return success();
}
- std::optional<int64_t> diff =
+ std::optional<APInt> diff =
computeConstDiff(op.getLowerBound(), op.getUpperBound());
if (!diff)
return failure();
// If the loop is known to have 0 iterations, remove it.
- if (*diff <= 0) {
+ bool zeroOrLessIterations =
+ diff->isZero() || (!op.getUnsignedCmp() && diff->isNegative());
+ if (zeroOrLessIterations) {
rewriter.replaceOp(op, op.getInitArgs());
return success();
}
@@ -3384,9 +3398,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
if (functionType.getNumInputs() != operands.size()) {
return parser.emitError(typeLoc)
- << "expected as many input types as operands "
- << "(expected " << operands.size() << " got "
- << functionType.getNumInputs() << ")";
+ << "expected as many input types as operands " << "(expected "
+ << operands.size() << " got " << functionType.getNumInputs() << ")";
}
// Resolve input operands.
@@ -4222,14 +4235,15 @@ LogicalResult scf::IndexSwitchOp::verify() {
<< "see yield operation here";
}
for (auto [idx, result, operand] :
- llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
- yield.getOperandTypes())) {
- if (result == operand)
+ llvm::enumerate(getResultTypes(), yield.getOperands())) {
+ if (!operand)
+ return yield.emitOpError() << "operand " << idx << " is null\n";
+ if (result == operand.getType())
continue;
return (emitOpError("expected result #")
<< idx << " of each region to be " << result)
.attachNote(yield.getLoc())
- << name << " returns " << operand << " here";
+ << name << " returns " << operand.getType() << " here";
}
return success();
};
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index aea842d..71fe987 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -147,6 +147,45 @@ transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter,
}
//===----------------------------------------------------------------------===//
+// ParallelForToNestedForOps
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::ParallelForToNestedForOps::apply(
+ transform::TransformRewriter &rewriter,
+ transform::TransformResults &results, transform::TransformState &state) {
+ auto payload = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(payload))
+ return emitSilenceableError() << "expected a single payload op";
+
+ auto target = dyn_cast<scf::ParallelOp>(*payload.begin());
+ if (!target) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "expected the payload to be scf.parallel";
+ diag.attachNote((*payload.begin())->getLoc()) << "payload op";
+ return diag;
+ }
+
+ if (getNumResults() != 1) {
+ DiagnosedSilenceableFailure diag = emitSilenceableError()
+ << "op expects one result, given "
+ << getNumResults();
+ diag.attachNote(target.getLoc()) << "payload op";
+ return diag;
+ }
+
+ FailureOr<scf::LoopNest> loopNest =
+ scf::parallelForToNestedFors(rewriter, target);
+ if (failed(loopNest)) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "failed to convert parallel into nested fors";
+ return diag;
+ }
+
+ results.set(cast<OpResult>(getTransformed()[0]), {loopNest->loops.front()});
+ return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
// LoopOutlineOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index f8799c5..fb179e6 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -769,7 +769,8 @@ struct ForOpInterface
// Construct a new scf.for op with memref instead of tensor values.
auto newForOp = scf::ForOp::create(
rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), castedInitArgs);
+ forOp.getStep(), castedInitArgs, /*bodyBuilder=*/nullptr,
+ forOp.getUnsignedCmp());
newForOp->setAttrs(forOp->getAttrs());
Block *loopBody = newForOp.getBody();
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 6d3bafb..a07d9d4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
LoopPipelining.cpp
LoopRangeFolding.cpp
LoopSpecialization.cpp
+ ParallelForToNestedFors.cpp
ParallelLoopCollapsing.cpp
ParallelLoopFusion.cpp
ParallelLoopTiling.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index bee7780..ae52af5 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -58,9 +58,12 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
auto *beforeBlock = rewriter.createBlock(
&whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs);
rewriter.setInsertionPointToStart(whileOp.getBeforeBody());
- auto cmpOp = arith::CmpIOp::create(
- rewriter, whileOp.getLoc(), arith::CmpIPredicate::slt,
- beforeBlock->getArgument(0), forOp.getUpperBound());
+ arith::CmpIPredicate predicate = forOp.getUnsignedCmp()
+ ? arith::CmpIPredicate::ult
+ : arith::CmpIPredicate::slt;
+ auto cmpOp = arith::CmpIOp::create(rewriter, whileOp.getLoc(), predicate,
+ beforeBlock->getArgument(0),
+ forOp.getUpperBound());
scf::ConditionOp::create(rewriter, whileOp.getLoc(), cmpOp.getResult(),
beforeBlock->getArguments());
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 1130538..7e7fba4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -791,6 +791,11 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
bool *modifiedIR) {
if (modifiedIR)
*modifiedIR = false;
+
+ // TODO: Add support for unsigned loops.
+ if (forOp.getUnsignedCmp())
+ return failure();
+
LoopPipelinerInternal pipeliner;
if (!pipeliner.initializeLoopInfo(forOp, options))
return failure();
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index 4752c08..f1203b2 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -256,6 +256,10 @@ struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
LogicalResult matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const override {
+ if (forOp.getUnsignedCmp())
+ return rewriter.notifyMatchFailure(forOp,
+ "unsigned loops are not supported");
+
// Do not peel already peeled loops.
if (forOp->hasAttr(kPeeledLoopLabel))
return failure();
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp
new file mode 100644
index 0000000..8f7d5e3
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp
@@ -0,0 +1,86 @@
+//===- ParallelForToNestedFors.cpp - scf.parallel to nested scf.for ops --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms SCF.ParallelOp to nested scf.for ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFPARALLELFORTONESTEDFORS
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "parallel-for-to-nested-fors"
+using namespace mlir;
+
+FailureOr<scf::LoopNest>
+mlir::scf::parallelForToNestedFors(RewriterBase &rewriter,
+ scf::ParallelOp parallelOp) {
+
+ if (!parallelOp.getResults().empty())
+ return rewriter.notifyMatchFailure(
+ parallelOp, "Currently scf.parallel to scf.for conversion doesn't "
+ "support scf.parallel with results.");
+
+ rewriter.setInsertionPoint(parallelOp);
+
+ Location loc = parallelOp.getLoc();
+ SmallVector<Value> lowerBounds = parallelOp.getLowerBound();
+ SmallVector<Value> upperBounds = parallelOp.getUpperBound();
+ SmallVector<Value> steps = parallelOp.getStep();
+
+ assert(lowerBounds.size() == upperBounds.size() &&
+ lowerBounds.size() == steps.size() &&
+ "Mismatched parallel loop bounds");
+
+ SmallVector<Value> ivs;
+ scf::LoopNest loopNest =
+ scf::buildLoopNest(rewriter, loc, lowerBounds, upperBounds, steps);
+
+ SmallVector<Value> newInductionVars = llvm::map_to_vector(
+ loopNest.loops, [](scf::ForOp forOp) { return forOp.getInductionVar(); });
+ Block *linearizedBody = loopNest.loops.back().getBody();
+ Block *parallelBody = parallelOp.getBody();
+ rewriter.eraseOp(parallelBody->getTerminator());
+ rewriter.inlineBlockBefore(parallelBody, linearizedBody->getTerminator(),
+ newInductionVars);
+ rewriter.eraseOp(parallelOp);
+ return loopNest;
+}
+
+namespace {
+struct ParallelForToNestedFors final
+ : public impl::SCFParallelForToNestedForsBase<ParallelForToNestedFors> {
+ void runOnOperation() override {
+ Operation *parentOp = getOperation();
+ IRRewriter rewriter(parentOp->getContext());
+
+ parentOp->walk(
+ [&](scf::ParallelOp parallelOp) {
+ if (failed(scf::parallelForToNestedFors(rewriter, parallelOp))) {
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << "Failed to convert scf.parallel to nested scf.for ops for:\n"
+ << parallelOp << "\n");
+ return WalkResult::advance();
+ }
+ return WalkResult::advance();
+ });
+ }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createParallelForToNestedForsPass() {
+ return std::make_unique<ParallelForToNestedFors>();
+}
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index 694cd85..4ea8321 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -269,10 +269,10 @@ namespace {
struct ParallelLoopFusion
: public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
void runOnOperation() override {
- auto &AA = getAnalysis<AliasAnalysis>();
+ auto &aa = getAnalysis<AliasAnalysis>();
auto mayAlias = [&](Value val1, Value val2) -> bool {
- return !AA.alias(val1, val2).isNo();
+ return !aa.alias(val1, val2).isNo();
};
getOperation()->walk([&](Operation *child) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 1b07b77..072bc50 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -52,8 +52,8 @@ public:
SmallVector<unsigned> offsets;
offsets.push_back(0);
// Do the type conversion and record the offsets.
- for (Type type : op.getResultTypes()) {
- if (failed(typeConverter->convertTypes(type, dstTypes)))
+ for (Value v : op.getResults()) {
+ if (failed(typeConverter->convertType(v, dstTypes)))
return rewriter.notifyMatchFailure(op, "could not convert result type");
offsets.push_back(dstTypes.size());
}
@@ -116,7 +116,8 @@ public:
llvm::getSingleElement(adaptor.getLowerBound()),
llvm::getSingleElement(adaptor.getUpperBound()),
llvm::getSingleElement(adaptor.getStep()),
- flattenValues(adaptor.getInitArgs()));
+ flattenValues(adaptor.getInitArgs()),
+ /*bodyBuilder=*/nullptr, op.getUnsignedCmp());
// Reserve whatever attributes in the original op.
newOp->setAttrs(op->getAttrs());
@@ -126,7 +127,6 @@ public:
// Inline the type converted region from the original operation.
rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
newOp.getRegion().end());
-
return newOp;
}
};
@@ -225,15 +225,14 @@ void mlir::scf::populateSCFStructuralTypeConversions(
void mlir::scf::populateSCFStructuralTypeConversionTarget(
const TypeConverter &typeConverter, ConversionTarget &target) {
- target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
- return typeConverter.isLegal(op->getResultTypes());
- });
+ target.addDynamicallyLegalOp<ForOp, IfOp>(
+ [&](Operation *op) { return typeConverter.isLegal(op->getResults()); });
target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
// We only have conversions for a subset of ops that use scf.yield
// terminators.
if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp()))
return true;
- return typeConverter.isLegal(op.getOperandTypes());
+ return typeConverter.isLegal(op.getOperands());
});
target.addDynamicallyLegalOp<WhileOp, ConditionOp>(
[&](Operation *op) { return typeConverter.isLegal(op); });
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index c0e47ee..834c021 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -797,7 +797,8 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
inits.append(newInitOperands.begin(), newInitOperands.end());
auto newLoop = scf::ForOp::create(
rewriter, loc, loopOp.getLowerBound(), loopOp.getUpperBound(),
- loopOp.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {});
+ loopOp.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {},
+ loopOp.getUnsignedCmp());
// Move the loop body to the new op.
Block *loopBody = loopOp.getBody();
@@ -935,7 +936,8 @@ static LogicalResult addInitOperandsToLoopNest(
auto newLoop = scf::ForOp::create(
rewriter, forLoop.getLoc(), forLoop.getLowerBound(),
forLoop.getUpperBound(), forLoop.getStep(), newInits,
- [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {},
+ forLoop.getUnsignedCmp());
// Merge the body of the new loop with the body of the old loops.
SmallVector<Value> sourceBlockArgs;
@@ -1914,63 +1916,6 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
return failure();
}
-/// Check that the loop is perfectly nested.
-/// The loops are expected to be ordered from outer most to inner most.
-/// For example:
-/// ```
-/// %0 = scf.for()
-/// %1 = scf.for()
-/// %2 = scf.for()
-/// %3 = ...
-/// yield %3
-/// yield %2
-/// yield %1
-/// ```
-/// Here loops should be [%0, %1].
-static bool
-isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) {
- assert(!loops.empty() && "unexpected empty loop nest");
- if (loops.size() == 1) {
- return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
- }
- for (auto [outerLoop, innerLoop] :
- llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
- auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
- auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
- if (!outerFor || !innerFor) {
- return false;
- }
- auto outerBBArgs = outerFor.getRegionIterArgs();
- auto innerIterArgs = innerFor.getInitArgs();
- if (outerBBArgs.size() != innerIterArgs.size()) {
- return false;
- }
-
- for (auto [outerBBArg, innerIterArg] :
- llvm::zip_equal(outerBBArgs, innerIterArgs)) {
- if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
- innerIterArg != outerBBArg) {
- return false;
- }
- }
-
- ValueRange outerYields =
- cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
- ValueRange innerResults = innerFor.getResults();
- if (outerYields.size() != innerResults.size()) {
- return false;
- }
- for (auto [outerYield, innerResult] :
- llvm::zip_equal(outerYields, innerResults)) {
- if (!llvm::hasSingleElement(innerResult.getUses()) ||
- outerYield != innerResult) {
- return false;
- }
- }
- }
- return true;
-}
-
/// Fetch the untiled consumer of the outermost scf.for's result which is
/// yielded by a tensor.insert_slice from the innermost scf.for. This function
/// makes the following assumptions :
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 5731795..684dff8 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1233,6 +1233,7 @@ static void getPerfectlyNestedLoopsImpl(
static Loops stripmineSink(scf::ForOp forOp, Value factor,
ArrayRef<scf::ForOp> targets) {
+ assert(!forOp.getUnsignedCmp() && "unsigned loops are not supported");
auto originalStep = forOp.getStep();
auto iv = forOp.getInductionVar();
@@ -1241,6 +1242,8 @@ static Loops stripmineSink(scf::ForOp forOp, Value factor,
Loops innerLoops;
for (auto t : targets) {
+ assert(!t.getUnsignedCmp() && "unsigned loops are not supported");
+
// Save information for splicing ops out of t when done
auto begin = t.getBody()->begin();
auto nOps = t.getBody()->getOperations().size();
@@ -1415,6 +1418,8 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
scf::ForOp source,
RewriterBase &rewriter) {
+ assert(source.getUnsignedCmp() == target.getUnsignedCmp() &&
+ "incompatible signedness");
unsigned numTargetOuts = target.getNumResults();
unsigned numSourceOuts = source.getNumResults();
@@ -1428,7 +1433,8 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
rewriter.setInsertionPointAfter(source);
scf::ForOp fusedLoop = scf::ForOp::create(
rewriter, source.getLoc(), source.getLowerBound(), source.getUpperBound(),
- source.getStep(), fusedInitArgs);
+ source.getStep(), fusedInitArgs, /*bodyBuilder=*/nullptr,
+ source.getUnsignedCmp());
// Map original induction variables and operands to those of the fused loop.
IRMapping mapping;
@@ -1506,3 +1512,41 @@ FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
rewriter.replaceOp(forallOp, normalizedForallOp);
return normalizedForallOp;
}
+
+bool mlir::isPerfectlyNestedForLoops(
+ MutableArrayRef<LoopLikeOpInterface> loops) {
+ assert(!loops.empty() && "unexpected empty loop nest");
+ if (loops.size() == 1)
+ return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
+ for (auto [outerLoop, innerLoop] :
+ llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
+ auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
+ auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
+ if (!outerFor || !innerFor)
+ return false;
+ auto outerBBArgs = outerFor.getRegionIterArgs();
+ auto innerIterArgs = innerFor.getInitArgs();
+ if (outerBBArgs.size() != innerIterArgs.size())
+ return false;
+
+ for (auto [outerBBArg, innerIterArg] :
+ llvm::zip_equal(outerBBArgs, innerIterArgs)) {
+ if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
+ innerIterArg != outerBBArg)
+ return false;
+ }
+
+ ValueRange outerYields =
+ cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
+ ValueRange innerResults = innerFor.getResults();
+ if (outerYields.size() != innerResults.size())
+ return false;
+ for (auto [outerYield, innerResult] :
+ llvm::zip_equal(outerYields, innerResults)) {
+ if (!llvm::hasSingleElement(innerResult.getUses()) ||
+ outerYield != innerResult)
+ return false;
+ }
+ }
+ return true;
+}
diff --git a/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp
new file mode 100644
index 0000000..47fe4d9
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp
@@ -0,0 +1,251 @@
+//===- ArmGraphOps.cpp - MLIR SPIR-V SPV_ARM_graph operations -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the SPV_ARM_graph operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+#include "SPIRVParsingUtils.h"
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/FunctionImplementation.h"
+#include "llvm/Support/InterleavedRange.h"
+
+using namespace mlir;
+using namespace mlir::spirv::AttrNames;
+
+//===----------------------------------------------------------------------===//
+// spirv.GraphARM
+//===----------------------------------------------------------------------===//
+
+ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ Builder &builder = parser.getBuilder();
+
+ // Parse the name as a symbol.
+ StringAttr nameAttr;
+ if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
+ result.attributes))
+ return failure();
+
+ // Parse the function signature.
+ bool isVariadic = false;
+ SmallVector<OpAsmParser::Argument> entryArgs;
+ SmallVector<Type> resultTypes;
+ SmallVector<DictionaryAttr> resultAttrs;
+ if (function_interface_impl::parseFunctionSignatureWithArguments(
+ parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
+ resultAttrs))
+ return failure();
+
+ SmallVector<Type> argTypes = llvm::map_to_vector(
+ entryArgs, [](const OpAsmParser::Argument &arg) { return arg.type; });
+ GraphType grType = builder.getGraphType(argTypes, resultTypes);
+ result.addAttribute(getFunctionTypeAttrName(result.name),
+ TypeAttr::get(grType));
+
+ // If additional attributes are present, parse them.
+ if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
+ return failure();
+
+ // Add the attributes to the function arguments.
+ assert(resultAttrs.size() == resultTypes.size());
+ call_interface_impl::addArgAndResultAttrs(
+ builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
+ getResAttrsAttrName(result.name));
+
+ // Parse the optional function body.
+ Region *body = result.addRegion();
+ OptionalParseResult parseResult =
+ parser.parseOptionalRegion(*body, entryArgs);
+ return failure(parseResult.has_value() && failed(*parseResult));
+}
+
+void spirv::GraphARMOp::print(OpAsmPrinter &printer) {
+ // Print graph name, signature, and control.
+ printer << " ";
+ printer.printSymbolName(getSymName());
+ GraphType grType = getFunctionType();
+ function_interface_impl::printFunctionSignature(
+ printer, *this, grType.getInputs(),
+ /*isVariadic=*/false, grType.getResults());
+ function_interface_impl::printFunctionAttributes(printer, *this,
+ {getFunctionTypeAttrName(),
+ getArgAttrsAttrName(),
+ getResAttrsAttrName()});
+
+ // Print the body.
+ Region &body = this->getBody();
+ if (!body.empty()) {
+ printer << ' ';
+ printer.printRegion(body, /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/true);
+ }
+}
+
+LogicalResult spirv::GraphARMOp::verifyType() {
+ if (getFunctionType().getNumResults() < 1)
+ return emitOpError("there should be at least one result");
+ return success();
+}
+
+LogicalResult spirv::GraphARMOp::verifyBody() {
+ for (auto [index, graphArgType] : llvm::enumerate(getArgumentTypes())) {
+ if (!isa<spirv::TensorArmType>(graphArgType)) {
+ return emitOpError("type of argument #")
+ << index << " must be a TensorArmType, but got " << graphArgType;
+ }
+ }
+ for (auto [index, graphResType] : llvm::enumerate(getResultTypes())) {
+ if (!isa<spirv::TensorArmType>(graphResType)) {
+ return emitOpError("type of result #")
+ << index << " must be a TensorArmType, but got " << graphResType;
+ }
+ }
+
+ if (!isExternal()) {
+ Block &entryBlock = front();
+
+ unsigned numArguments = this->getNumArguments();
+ if (entryBlock.getNumArguments() != numArguments)
+ return emitOpError("entry block must have ")
+ << numArguments << " arguments to match graph signature";
+
+ for (auto [index, grArgType, blockArgType] :
+ llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
+ if (blockArgType != grArgType) {
+ return emitOpError("type of entry block argument #")
+ << index << '(' << blockArgType
+ << ") must match the type of the corresponding argument in "
+ << "graph signature(" << grArgType << ')';
+ }
+ }
+ }
+
+ GraphType grType = getFunctionType();
+ auto walkResult = walk([grType](spirv::GraphOutputsARMOp op) -> WalkResult {
+ if (grType.getNumResults() != op.getNumOperands())
+ return op.emitOpError("is returning ")
+ << op.getNumOperands()
+ << " value(s) but enclosing spirv.ARM.Graph requires "
+ << grType.getNumResults() << " result(s)";
+
+ ValueTypeRange<OperandRange> graphOutputOperandTypes =
+ op.getValue().getType();
+ for (auto [index, type] : llvm::enumerate(graphOutputOperandTypes)) {
+ if (type != grType.getResult(index))
+ return op.emitError("type of return operand ")
+ << index << " (" << type << ") doesn't match graph result type ("
+ << grType.getResult(index) << ")";
+ }
+ return WalkResult::advance();
+ });
+
+ return failure(walkResult.wasInterrupted());
+}
+
+void spirv::GraphARMOp::build(OpBuilder &builder, OperationState &state,
+ StringRef name, GraphType type,
+ ArrayRef<NamedAttribute> attrs, bool entryPoint) {
+ state.addAttribute(SymbolTable::getSymbolAttrName(),
+ builder.getStringAttr(name));
+ state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
+ state.attributes.append(attrs);
+ state.addAttribute(getEntryPointAttrName(state.name),
+ builder.getBoolAttr(entryPoint));
+ state.addRegion();
+}
+
+ArrayRef<Type> spirv::GraphARMOp::getArgumentTypes() {
+ return getFunctionType().getInputs();
+}
+
+ArrayRef<Type> spirv::GraphARMOp::getResultTypes() {
+ return getFunctionType().getResults();
+}
+
+Region *spirv::GraphARMOp::getCallableRegion() {
+ return isExternal() ? nullptr : &getBody();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GraphOutputsARM
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::GraphOutputsARMOp::verify() {
+ auto graph = cast<GraphARMOp>((*this)->getParentOp());
+
+ // The operand number and types must match the graph signature.
+ const ArrayRef<Type> &results = graph.getFunctionType().getResults();
+ if (getNumOperands() != results.size())
+ return emitOpError("has ")
+ << getNumOperands() << " operands, but enclosing spirv.ARM.Graph (@"
+ << graph.getName() << ") returns " << results.size();
+
+ for (auto [index, result] : llvm::enumerate(results))
+ if (getOperand(index).getType() != result)
+ return emitError() << "type of return operand " << index << " ("
+ << getOperand(index).getType()
+ << ") doesn't match spirv.ARM.Graph result type ("
+ << result << ")"
+ << " in graph @" << graph.getName();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GraphEntryPointARM
+//===----------------------------------------------------------------------===//
+
+void spirv::GraphEntryPointARMOp::build(OpBuilder &builder,
+ OperationState &state,
+ spirv::GraphARMOp graph,
+ ArrayRef<Attribute> interfaceVars) {
+ build(builder, state, SymbolRefAttr::get(graph),
+ builder.getArrayAttr(interfaceVars));
+}
+
+ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ FlatSymbolRefAttr fn;
+ if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes))
+ return failure();
+
+ SmallVector<Attribute, 4> interfaceVars;
+ if (!parser.parseOptionalComma()) {
+ // Parse the interface variables.
+ if (parser.parseCommaSeparatedList([&]() -> ParseResult {
+ // The name of the interface variable attribute is not important.
+ FlatSymbolRefAttr var;
+ NamedAttrList attrs;
+ if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
+ return failure();
+ interfaceVars.push_back(var);
+ return success();
+ }))
+ return failure();
+ }
+ result.addAttribute("interface",
+ parser.getBuilder().getArrayAttr(interfaceVars));
+ return success();
+}
+
+void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) {
+ printer << " ";
+ printer.printSymbolName(getFn());
+ ArrayRef<Attribute> interfaceVars = getInterface().getValue();
+ if (!interfaceVars.empty()) {
+ printer << ", " << llvm::interleaved(interfaceVars);
+ }
+}
diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index b9aa7b7..60d705d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -3,6 +3,7 @@ mlir_tablegen(SPIRVCanonicalization.inc -gen-rewriters)
add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen)
add_mlir_dialect_library(MLIRSPIRVDialect
+ ArmGraphOps.cpp
AtomicOps.cpp
CastOps.cpp
ControlFlowOps.cpp
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
index d8dfe16..2f3a28f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
@@ -31,6 +31,18 @@ static bool isNestedInFunctionOpInterface(Operation *op) {
return isNestedInFunctionOpInterface(op->getParentOp());
}
+/// Returns true if the given op is a GraphARM op or nested in a
+/// GraphARM op without a module-like op in the middle.
+static bool isNestedInGraphARMOpInterface(Operation *op) {
+ if (!op)
+ return false;
+ if (op->hasTrait<OpTrait::SymbolTable>())
+ return false;
+ if (isa<spirv::GraphARMOp>(op))
+ return true;
+ return isNestedInGraphARMOpInterface(op->getParentOp());
+}
+
/// Returns true if the given op is an module-like op that maintains a symbol
/// table.
static bool isDirectInModuleLikeOp(Operation *op) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index ddb3426..369b953 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -1322,7 +1322,7 @@ struct spirv::detail::TensorArmTypeStorage final : TypeStorage {
}
TensorArmTypeStorage(ArrayRef<int64_t> shape, Type elementType)
- : shape(std::move(shape)), elementType(std::move(elementType)) {}
+ : shape(shape), elementType(elementType) {}
ArrayRef<int64_t> shape;
Type elementType;
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 8f4c4cc..49f4ce8 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -608,6 +608,45 @@ static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
return wrapInStructAndGetPointer(arrayType, storageClass);
}
+static spirv::Dim convertRank(int64_t rank) {
+ switch (rank) {
+ case 1:
+ return spirv::Dim::Dim1D;
+ case 2:
+ return spirv::Dim::Dim2D;
+ case 3:
+ return spirv::Dim::Dim3D;
+ default:
+ llvm_unreachable("Invalid memref rank!");
+ }
+}
+
+static spirv::ImageFormat getImageFormat(Type elementType) {
+ return llvm::TypeSwitch<Type, spirv::ImageFormat>(elementType)
+ .Case<Float16Type>([](Float16Type) { return spirv::ImageFormat::R16f; })
+ .Case<Float32Type>([](Float32Type) { return spirv::ImageFormat::R32f; })
+ .Case<IntegerType>([](IntegerType intType) {
+ auto const isSigned = intType.isSigned() || intType.isSignless();
+#define BIT_WIDTH_CASE(BIT_WIDTH) \
+ case BIT_WIDTH: \
+ return isSigned ? spirv::ImageFormat::R##BIT_WIDTH##i \
+ : spirv::ImageFormat::R##BIT_WIDTH##ui
+
+ switch (intType.getWidth()) {
+ BIT_WIDTH_CASE(16);
+ BIT_WIDTH_CASE(32);
+ default:
+ llvm_unreachable("Unhandled integer type!");
+ }
+ })
+ .Default([](Type) {
+ llvm_unreachable("Unhandled element type!");
+ // We need to return something here to satisfy the type switch.
+ return spirv::ImageFormat::R32f;
+ });
+#undef BIT_WIDTH_CASE
+}
+
static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
const SPIRVConversionOptions &options,
MemRefType type) {
@@ -623,6 +662,41 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
}
spirv::StorageClass storageClass = attr.getValue();
+ // Images are a special case since they are an opaque type from which elements
+ // may be accessed via image specific ops or directly through a texture
+ // pointer.
+ if (storageClass == spirv::StorageClass::Image) {
+ const int64_t rank = type.getRank();
+ if (rank < 1 || rank > 3) {
+ LLVM_DEBUG(llvm::dbgs()
+ << type << " illegal: cannot lower memref of rank " << rank
+ << " to a SPIR-V Image\n");
+ return nullptr;
+ }
+
+ // Note that we currently only support lowering to single element texels
+ // e.g. R32f.
+ auto elementType = type.getElementType();
+ if (!isa<spirv::ScalarType>(elementType)) {
+ LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot lower memref of "
+ << elementType << " to a SPIR-V Image\n");
+ return nullptr;
+ }
+
+ // Currently every memref in the image storage class is converted to a
+ // sampled image so we can hardcode the NeedSampler field. Future work
+ // will generalize this to support regular non-sampled images.
+ auto spvImageType = spirv::ImageType::get(
+ elementType, convertRank(rank), spirv::ImageDepthInfo::DepthUnknown,
+ spirv::ImageArrayedInfo::NonArrayed,
+ spirv::ImageSamplingInfo::SingleSampled,
+ spirv::ImageSamplerUseInfo::NeedSampler, getImageFormat(elementType));
+ auto spvSampledImageType = spirv::SampledImageType::get(spvImageType);
+ auto imagePtrType = spirv::PointerType::get(
+ spvSampledImageType, spirv::StorageClass::UniformConstant);
+ return imagePtrType;
+ }
+
if (isa<IntegerType>(type.getElementType())) {
if (type.getElementTypeBitWidth() == 1)
return convertBoolMemrefType(targetEnv, options, type, storageClass);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index a53d0a7..670eabf 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -95,6 +95,13 @@ static LogicalResult checkAndUpdateCapabilityRequirements(
return success();
}
+static void addAllImpliedCapabilities(SetVector<spirv::Capability> &caps) {
+ SetVector<spirv::Capability> tmp;
+ for (spirv::Capability cap : caps)
+ tmp.insert_range(getRecursiveImpliedCapabilities(cap));
+ caps.insert_range(std::move(tmp));
+}
+
void UpdateVCEPass::runOnOperation() {
spirv::ModuleOp module = getOperation();
@@ -151,6 +158,12 @@ void UpdateVCEPass::runOnOperation() {
if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
valueTypes.push_back(globalVar.getType());
+ // If the op is FunctionLike make sure to process input and result types.
+ if (auto funcOpInterface = dyn_cast<FunctionOpInterface>(op)) {
+ llvm::append_range(valueTypes, funcOpInterface.getArgumentTypes());
+ llvm::append_range(valueTypes, funcOpInterface.getResultTypes());
+ }
+
// Requirements from values' types
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
@@ -174,6 +187,8 @@ void UpdateVCEPass::runOnOperation() {
if (walkResult.wasInterrupted())
return signalPassFailure();
+ addAllImpliedCapabilities(deducedCapabilities);
+
// Update min version requirement for capabilities after deducing them.
for (spirv::Capability cap : deducedCapabilities) {
if (std::optional<spirv::Version> minVersion = spirv::getMinVersion(cap)) {
diff --git a/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp
index d4e7618..7a05dfe 100644
--- a/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp
@@ -513,8 +513,9 @@ LogicalResult shard::detail::defaultAddShardingAnnotations(
}
#ifndef NDEBUG
-static bool isValueCompatibleWithFullReplicationSharding(Value value,
- Sharding sharding) {
+static bool
+isValueCompatibleWithFullReplicationSharding(Value value,
+ const Sharding &sharding) {
if (isa<RankedTensorType>(value.getType())) {
return isFullReplication(sharding);
}
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 3e3d476..5dc61a2 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -477,10 +477,10 @@ reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid,
return targetShard;
}
-TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, GridOp grid,
- Sharding sourceSharding, Sharding targetSharding,
- TypedValue<ShapedType> sourceUnshardedValue,
- TypedValue<ShapedType> sourceShard) {
+static TypedValue<ShapedType>
+reshard(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding,
+ Sharding targetSharding, TypedValue<ShapedType> sourceUnshardedValue,
+ TypedValue<ShapedType> sourceShard) {
// If source and destination sharding are the same, no need to do anything.
if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) &&
isFullReplication(targetSharding))) {
@@ -535,7 +535,7 @@ using UnshardedToShardedValueMap = DenseMap<Value, Value>;
// Get the types of block arguments for an partitioned block.
// Reads the sharding annotations of the arguments to deduce the sharded types.
// Types that are not ranked tensors are left unchanged.
-SmallVector<Type>
+static SmallVector<Type>
shardedBlockArgumentTypes(Block &block,
SymbolTableCollection &symbolTableCollection) {
SmallVector<Type> res;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
index 56b435c..9694a40 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
@@ -231,7 +231,9 @@ ParseResult DimLvlMapParser::parseLvlSpecList() {
const auto loc = parser.getCurrentLocation();
const auto res = parser.parseCommaSeparatedList(
mlir::OpAsmParser::Delimiter::Paren,
- [=]() -> ParseResult { return parseLvlSpec(requireLvlVarBinding); },
+ [this, requireLvlVarBinding]() -> ParseResult {
+ return parseLvlSpec(requireLvlVarBinding);
+ },
" in level-specifier list");
FAILURE_IF_FAILED(res)
const auto specLvlRank = lvlSpecs.size();
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index 9e2e6ab..a1711a6 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -156,13 +156,14 @@ minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) {
return pair1 <= pair2 ? sm1 : sm2;
}
-bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, StringRef name) {
+static bool isInternalConsistent(VarEnv const &env, VarInfo::ID id,
+ StringRef name) {
const auto &var = env.access(id);
return (var.getName() == name && var.getID() == id);
}
-bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, llvm::SMLoc loc,
- VarKind vk) {
+static bool isUsageConsistent(VarEnv const &env, VarInfo::ID id,
+ llvm::SMLoc loc, VarKind vk) {
const auto &var = env.access(id);
return var.getKind() == vk;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
index 3b97786..dabbea1 100644
--- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
@@ -71,7 +71,6 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm,
pm.addPass(createLowerAffinePass());
pm.addPass(
createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
- pm.addPass(createFinalizeMemRefToLLVMConversionPass());
pm.addNestedPass<func::FuncOp>(createConvertComplexToStandardPass());
pm.addNestedPass<func::FuncOp>(arith::createArithExpandOpsPass());
pm.addNestedPass<func::FuncOp>(createConvertMathToLLVMPass());
@@ -79,12 +78,6 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm,
pm.addPass(createConvertComplexToLibm());
pm.addPass(
createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
- pm.addPass(createConvertComplexToLLVMPass());
- pm.addPass(
- createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
- pm.addPass(createConvertFuncToLLVMPass());
- pm.addPass(createArithToLLVMConversionPass());
- pm.addPass(createConvertControlFlowToLLVMPass());
// Finalize GPU code generation.
if (gpuCodegen) {
@@ -99,8 +92,8 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm,
pm.addPass(createGpuModuleToBinaryPass(gpuModuleToBinaryPassOptions));
}
- // Convert poison values.
- pm.addPass(createUBToLLVMConversionPass());
+ // Convert to LLVM.
+ pm.addPass(createConvertToLLVMPass());
// Ensure all casts are realized.
pm.addPass(createReconcileUnrealizedCastsPass());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 3b4140e..ae7eef2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -1219,8 +1219,9 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
/// Implements the rewriting for operator sort and sort_coo.
template <typename OpTy>
-LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm,
- uint64_t ny, PatternRewriter &rewriter) {
+static LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys,
+ AffineMap xPerm, uint64_t ny,
+ PatternRewriter &rewriter) {
Location loc = op.getLoc();
SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 134aef3..0e88d31d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -730,9 +730,9 @@ public:
{tensor, lvlCoords, values, filled, added, count},
EmitCInterface::On);
Operation *parent = getTop(op);
+ rewriter.setInsertionPointAfter(parent);
rewriter.replaceOp(op, adaptor.getTensor());
// Deallocate the buffers on exit of the loop nest.
- rewriter.setInsertionPointAfter(parent);
memref::DeallocOp::create(rewriter, loc, values);
memref::DeallocOp::create(rewriter, loc, filled);
memref::DeallocOp::create(rewriter, loc, added);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index 4464450..febec6d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -533,8 +533,10 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
VectorType vtp = vectorType(vl, init.getType());
Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
forOp.getRegionIterArg(0), init, vtp);
- forOpNew = scf::ForOp::create(rewriter, loc, forOp.getLowerBound(),
- forOp.getUpperBound(), step, vinit);
+ forOpNew =
+ scf::ForOp::create(rewriter, loc, forOp.getLowerBound(),
+ forOp.getUpperBound(), step, vinit,
+ /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
forOpNew->setAttr(
LoopEmitter::getLoopEmitterLoopAttrName(),
forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
@@ -605,8 +607,8 @@ public:
ForOpRewriter(MLIRContext *context, unsigned vectorLength,
bool enableVLAVectorization, bool enableSIMDIndex32)
- : OpRewritePattern(context), vl{vectorLength, enableVLAVectorization,
- enableSIMDIndex32} {}
+ : OpRewritePattern(context),
+ vl{vectorLength, enableVLAVectorization, enableSIMDIndex32} {}
LogicalResult matchAndRewrite(scf::ForOp op,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 7d4b112..68584ec 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3200,20 +3200,6 @@ void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "padded");
}
-// TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
-// supports optional types.
-void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
- Type typeToInfer, Type typeToInferFrom) {}
-
-ParseResult
-parseInferType(OpAsmParser &parser,
- std::optional<OpAsmParser::UnresolvedOperand> optOperand,
- Type &typeToInfer, Type typeToInferFrom) {
- if (optOperand)
- typeToInfer = typeToInferFrom;
- return success();
-}
-
LogicalResult PadOp::verify() {
auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
@@ -4059,7 +4045,7 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
// Common Canonicalizers and Folders.
//===----------------------------------------------------------------------===//
-bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
+static bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
// 1. InsertSliceOp has its own logic about folding tensor.cast ops.
// 2. Exclude DPS ops that are also LoopLike from this interface as they
// might need special handling of attached regions.
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index 2ec23e1..dfce835 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -327,172 +327,31 @@ struct BubbleUpExpandShapeThroughExtractSlice
PatternRewriter &rewriter) const override {
auto expandShapeOp =
sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
+ if (!expandShapeOp) {
+ return rewriter.notifyMatchFailure(
+ sliceOp, "tensor.extract_slice source not produced by expand_shape");
+ }
+ SmallVector<ReassociationIndices> reassociation =
+ expandShapeOp.getReassociationIndices();
- if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
- rewriter)
- .failed())
+ SmallVector<OpFoldResult> offsets, sizes, strides;
+ if (failed(getCollapsedExtractSliceInfo(rewriter, sliceOp, reassociation,
+ offsets, sizes, strides)))
return failure();
- // The tensor.extract_slice before applying the pattern works on the result
- // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
- // referring to the state before applying the pattern are named with the
- // prefix "expanded", and ones referring to the state after applying the
- // pattern are named with the prefix "collapsed".
- SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
- SmallVector<OpFoldResult> expandedShape =
- getMixedValues(expandShapeOp.getStaticOutputShape(),
- expandShapeOp.getOutputShape(), rewriter);
-
- // Helper variables and function for accumulating the size values.
- Location loc = expandShapeOp->getLoc();
- AffineExpr d0, d1, d2;
- bindDims(rewriter.getContext(), d0, d1, d2);
- // Multiply two integers.
- auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
- auto mulMap = AffineMap::get(2, 0, {d0 * d1});
- return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
- {v1, v2});
- };
-
- // Compute new offsets, sizes, and strides for tensor.extract_slice.
- // The new tensor.extract_slice will work on a tensor that has has a rank of
- // ReassociationIndices.size(). In the loop a single offset, size, and
- // stride value is computed per reassociation group.
- SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes,
- collapsedStrides;
- for (const ReassociationIndices &indices :
- expandShapeOp.getReassociationIndices()) {
- // collapsedSize will hold the size of the single dim that represents the
- // reassociation group in the non expanded tensor.
- OpFoldResult collapsedSize = rewriter.getIndexAttr(1);
- // The reassocGroupSizes and reassocGroupOffsets are used to create an
- // affine.linearize_index op to linearize the single offset value required
- // for this reassociation group.
- SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
-
- for (long expandedDim : indices) {
- // reassocGroupSizes and reassocGroupOffsets can be obtained directly
- // from the expanded state, but the collapsed size requires calculation
- // as it did not previously exist.
- reassocGroupSizes.push_back(expandedShape[expandedDim]);
- reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
- collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
- }
-
- SmallVector<Value> offsetVals =
- llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
- return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
- });
- OpFoldResult collapsedOffset =
- affine::AffineLinearizeIndexOp::create(rewriter, loc, offsetVals,
- reassocGroupSizes,
- /*disjoint=*/true)
- .getResult();
- collapsedOffsets.push_back(collapsedOffset);
- collapsedSizes.push_back(collapsedSize);
-
- // Only unit stride is supported.
- collapsedStrides.push_back(rewriter.getIndexAttr(1));
- }
-
// The shape of the result can be obtained from the sizes passed in.
- SmallVector<Value> dynDims;
- SmallVector<int64_t> shape;
- dispatchIndexOpFoldResults(expandedSizes, dynDims, shape);
- RankedTensorType resultType = RankedTensorType::get(
- shape, expandShapeOp.getResultType().getElementType());
+ SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
+ RankedTensorType resultType = sliceOp.getResultType();
// Create a new ExtractSliceOp and ExpandShapeOp.
+ Location loc = sliceOp.getLoc();
Value newSliceOp = tensor::ExtractSliceOp::create(
- rewriter, loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes,
- collapsedStrides);
+ rewriter, loc, expandShapeOp.getSrc(), offsets, sizes, strides);
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
sliceOp, resultType, newSliceOp,
expandShapeOp.getReassociationIndices(), expandedSizes);
return success();
}
-
- // Helper function to check if all the required conditions for the
- // tensor.extract_slice to be bubbled up through the tensor.expand_shape are
- // met.
- LogicalResult
- checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp,
- tensor::ExpandShapeOp expandShapeOp,
- PatternRewriter &rewriter) const {
-
- if (!expandShapeOp) {
- return rewriter.notifyMatchFailure(
- sliceOp, "tensor.extract_slice source not produced by expand_shape");
- }
-
- if (!sliceOp.hasUnitStride()) {
- return rewriter.notifyMatchFailure(
- sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
- "be supported in this transformation.");
- }
-
- SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
-
- if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
- sizes.size()) {
- return rewriter.notifyMatchFailure(sliceOp,
- "unimplemented: rank reducing slice");
- }
-
- SmallVector<OpFoldResult> outputShape =
- getMixedValues(expandShapeOp.getStaticOutputShape(),
- expandShapeOp.getOutputShape(), rewriter);
-
- std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
- isZeroOffsetAndFullSize =
- [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
- if (!isZeroInteger(offset))
- return false;
- FailureOr<bool> maybeEqual =
- ValueBoundsConstraintSet::areEqual(sliceSize, size);
- return llvm::succeeded(maybeEqual) && maybeEqual.value();
- };
-
- // Check that the slice is contiguous within each reassociation group.
- // The slice is contiguous only if after the first dimension where a non
- // unit slice is taken, the slice size on all subsequent dimensions of the
- // group is equal to the entire size of the dimension.
- // Examples of contiguous slices:
- // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
- // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
- // Examples of non contiguous slices:
- // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
- // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
- for (const ReassociationIndices &indices :
- expandShapeOp.getReassociationIndices()) {
- int64_t i = 0;
- int64_t e = indices.size();
- // Find the first expanded dim after the first dim with non-unit extracted
- // size.
- for (; i < e; ++i) {
- if (!isOneInteger(sizes[indices[i]])) {
- // +1 to skip the first non-unit size dim.
- i++;
- break;
- }
- }
-
- // Verify that all subsequent dimensions extract the full size of the
- // source tensor.
- for (; i < e; ++i) {
- int64_t expandedDim = indices[i];
- if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
- outputShape[expandedDim])) {
- return rewriter.notifyMatchFailure(
- sliceOp, "Not a contiguous slice of the expanded tensor.");
- }
- }
- }
-
- return success();
- }
};
/// Converts `tensor.extract_slice(tensor.collapse_shape)` to
@@ -582,170 +441,281 @@ struct BubbleUpCollapseShapeThroughExtractSlice
"tensor.extract_slice source not produced by tensor.collapse_shape");
}
- if (!sliceOp.hasUnitStride()) {
- return rewriter.notifyMatchFailure(
- sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
- "be supported in this transformation.");
- }
+ SmallVector<OpFoldResult> offsets, sizes, strides;
+ if (failed(getExpandedExtractSliceInfo(
+ rewriter, sliceOp, collapseShapeOp.getReassociationIndices(),
+ collapseShapeOp.getSrcType().getShape(), offsets, sizes, strides)))
+ return failure();
- // The tensor.extract_slice before applying the pattern works on the result
- // of the tensor.collapse_shape, so variables (i.e. inputs for
- // ExtractSliceOp) referring to the state before applying the pattern are
- // named with the prefix "collapsed", and ones referring to the state after
- // applying the pattern are named with the prefix "expanded".
- SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
-
- if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
- collapsedSizes.size()) {
- return rewriter.notifyMatchFailure(sliceOp,
- "unimplemented: rank reducing slice");
- }
+ Value newSliceOp = tensor::ExtractSliceOp::create(
+ rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), offsets,
+ sizes, strides);
+ rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+ sliceOp, sliceOp.getResultType(), newSliceOp,
+ collapseShapeOp.getReassociationIndices());
- ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape();
- SmallVector<ReassociationIndices, 4> reassociationIndices =
- collapseShapeOp.getReassociationIndices();
-
- // Compute new offsets, sizes, and strides for tensor.extract_slice.
- // The new tensor.extract_slice will work on a tensor that has has a rank
- // equal to the rank of the src of the collapse_shape. In each iteration of
- // the loop, the offsets and sizes will be computed per reassociation group.
- SmallVector<OpFoldResult> expandedOffsets, expandedSizes;
- SmallVector<OpFoldResult> expandedStrides(srcShape.size(),
- rewriter.getIndexAttr(1));
-
- for (auto [collapsedSize, collapsedOffset, reassocIndices] :
- llvm::zip_equal(collapsedSizes, collapsedOffsets,
- collapseShapeOp.getReassociationIndices())) {
- // CASE #1 - size and/or offset are dynamic.
- // In this case, the slice can be represented as a contiguous slice only
- // if there is a single dimension in the reassociation group that has a
- // size not equal to 1.
- if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
- int nonUnitSizeCount = 0;
- for (int64_t expandedShapeIdx : reassocIndices) {
- if (srcShape[expandedShapeIdx] != 1) {
- nonUnitSizeCount++;
- expandedSizes.push_back(collapsedSize);
- expandedOffsets.push_back(collapsedOffset);
- continue;
- }
-
- expandedSizes.push_back(rewriter.getIndexAttr(1));
- expandedOffsets.push_back(rewriter.getIndexAttr(0));
- }
+ return success();
+ }
+};
- if (nonUnitSizeCount != 1) {
- return rewriter.notifyMatchFailure(
- sliceOp,
- "unsupported: slice cannot be verified to be contiguous");
- }
- continue;
- }
+} // namespace
- // CASE #2 = size and offset are static.
- // Verify that the slice can be represented as a contiguous slice of the
- // src of the collapse_shape.
- // Checking this is done on order of most internal dimensions first,
- // so traversal is done in reverse order of the reassociation group.
- // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
- // ...,An] then we first find the size and offset for n...k+1 then for k
- // and then for k-1...0.
-
- // currentCollapsedsize and currentCollapsedOffset are initialized with
- // the original collapsed size and offset and divided by the expanded
- // shape size in each dimension as we go along the reassociation group.
- // In essence we are spreading the original collapsed size and offset over
- // the various expanded slice dimensions.
- // The variables are used both to check the validity of the slice and to
- // compute the expanded sizes and offsets.
- int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
- int64_t currentCollapsedOffset =
- getConstantIntValue(collapsedOffset).value();
-
- SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
-
- ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
- reassocIndices.rend());
- int64_t idx = 0;
- int64_t reassocGroupSize = reassocIndices.size();
-
- // First handle the trailing dimensions where the slice size should be
- // equal to the tensor shape and the offset should be 0 (n...k+1).
- for (; idx < reassocGroupSize; ++idx) {
- int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
-
- if (currentCollapsedsize < expandedShapeSize)
- break;
-
- // We need to make sure that the slice size can be set to the shape size
- // and the offset to 0.
- if ((currentCollapsedsize % expandedShapeSize) != 0 ||
- (currentCollapsedOffset % expandedShapeSize) != 0) {
- return rewriter.notifyMatchFailure(
- sliceOp, "unsupported: cannot be extracted as a contiguous slice "
- "of the src of the collapse_shape");
- }
+LogicalResult mlir::tensor::getCollapsedExtractSliceInfo(
+ OpBuilder &b, tensor::ExtractSliceOp sliceOp,
+ ArrayRef<ReassociationIndices> reassociation,
+ SmallVectorImpl<OpFoldResult> &collapsedOffsets,
+ SmallVectorImpl<OpFoldResult> &collapsedSizes,
+ SmallVectorImpl<OpFoldResult> &collapsedStrides) {
+ if (!sliceOp.hasUnitStride()) {
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
- groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize));
- groupExpandedOffsets.push_back(rewriter.getIndexAttr(0));
+ if (static_cast<size_t>(sliceOp.getResultType().getRank()) != sizes.size()) {
+ return failure();
+ }
- currentCollapsedsize /= expandedShapeSize;
- currentCollapsedOffset /= expandedShapeSize;
+ auto isZeroOffsetAndFullSize = [&](OpFoldResult offset,
+ OpFoldResult sliceSize, int64_t inputDim) {
+ if (!isZeroInteger(offset))
+ return false;
+ ValueBoundsConstraintSet::Variable inputSize(sliceOp.getSource(), inputDim);
+ FailureOr<bool> maybeEqual =
+ ValueBoundsConstraintSet::areEqual(sliceSize, inputSize);
+ return llvm::succeeded(maybeEqual) && maybeEqual.value();
+ };
+
+ // Check that the slice is contiguous within each reassociation group.
+ // The slice is contiguous only if after the first dimension where a non
+ // unit slice is taken, the slice size on all subsequent dimensions of the
+ // group is equal to the entire size of the dimension.
+ // Examples of contiguous slices:
+ // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
+ // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
+ // Examples of non contiguous slices:
+ // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
+ // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
+ for (const ReassociationIndices &indices : reassociation) {
+ int64_t i = 0;
+ int64_t e = indices.size();
+ // Find the first expanded dim after the first dim with non-unit extracted
+ // size.
+ for (; i < e; ++i) {
+ if (!isOneInteger(sizes[indices[i]])) {
+ // +1 to skip the first non-unit size dim.
+ i++;
+ break;
}
+ }
+
+ // Verify that all subsequent dimensions extract the full size of the
+ // source tensor.
+ for (; i < e; ++i) {
+ int64_t expandedDim = indices[i];
+ if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
+ expandedDim)) {
+ return failure();
+ }
+ }
+ }
+
+ // The tensor.extract_slice before applying the pattern works on the result
+ // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
+ // referring to the state before applying the pattern are named with the
+ // prefix "expanded", and ones referring to the state after applying the
+ // pattern are named with the prefix "collapsed".
+ Location loc = sliceOp.getLoc();
+ SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> expandedShape =
+ getMixedSizes(b, loc, sliceOp.getSource());
+
+ // Helper variables and function for accumulating the size values.
+ AffineExpr d0, d1, d2;
+ bindDims(b.getContext(), d0, d1, d2);
+ // Multiply two integers.
+ auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
+ auto mulMap = AffineMap::get(2, 0, {d0 * d1});
+ return affine::makeComposedFoldedAffineApply(b, loc, mulMap, {v1, v2});
+ };
+
+ // Compute new offsets, sizes, and strides for tensor.extract_slice.
+ // The new tensor.extract_slice will work on a tensor that has has a rank of
+ // ReassociationIndices.size(). In the loop a single offset, size, and
+ // stride value is computed per reassociation group.
+ for (const ReassociationIndices &indices : reassociation) {
+ // collapsedSize will hold the size of the single dim that represents the
+ // reassociation group in the non expanded tensor.
+ OpFoldResult collapsedSize = b.getIndexAttr(1);
+ // The reassocGroupSizes and reassocGroupOffsets are used to create an
+ // affine.linearize_index op to linearize the single offset value required
+ // for this reassociation group.
+ SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
+
+ for (long expandedDim : indices) {
+ // reassocGroupSizes and reassocGroupOffsets can be obtained directly
+ // from the expanded state, but the collapsed size requires calculation
+ // as it did not previously exist.
+ reassocGroupSizes.push_back(expandedShape[expandedDim]);
+ reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
+ collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
+ }
+
+ SmallVector<Value> offsetVals =
+ llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
+ return getValueOrCreateConstantIndexOp(b, loc, ofr);
+ });
+ OpFoldResult collapsedOffset = affine::AffineLinearizeIndexOp::create(
+ b, loc, offsetVals, reassocGroupSizes,
+ /*disjoint=*/true)
+ .getResult();
+ collapsedOffsets.push_back(collapsedOffset);
+ collapsedSizes.push_back(collapsedSize);
+
+ // Only unit stride is supported.
+ collapsedStrides.push_back(b.getIndexAttr(1));
+ }
+ return success();
+}
+
+LogicalResult mlir::tensor::getExpandedExtractSliceInfo(
+ OpBuilder &b, tensor::ExtractSliceOp sliceOp,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<int64_t> expandedShape,
+ SmallVectorImpl<OpFoldResult> &expandedOffsets,
+ SmallVectorImpl<OpFoldResult> &expandedSizes,
+ SmallVectorImpl<OpFoldResult> &expandedStrides) {
+ if (!sliceOp.hasUnitStride()) {
+ return failure();
+ }
+
+ // The tensor.extract_slice before applying the pattern works on the result
+ // of the tensor.collapse_shape, so variables (i.e. inputs for
+ // ExtractSliceOp) referring to the state before applying the pattern are
+ // named with the prefix "collapsed", and ones referring to the state after
+ // applying the pattern are named with the prefix "expanded".
+ SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
+ if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
+ collapsedSizes.size()) {
+ return failure();
+ }
- // Now handle the first dim where slicing occurs on (k).
- if (idx < reassocGroupSize) {
- int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
- int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
- // We need to make sure that the slice size in this dim + offset will
- // not exceed the shape size.
- if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
- return rewriter.notifyMatchFailure(
- sliceOp, "unsupported: slice cannot be extracted as a contiguous "
- "slice of the src of the collapse_shape");
+ // Compute new offsets, sizes, and strides for tensor.extract_slice.
+ // The new tensor.extract_slice will work on a tensor that has has a rank
+ // equal to the rank of the src of the collapse_shape. In each iteration of
+ // the loop, the offsets and sizes will be computed per reassociation group.
+ expandedStrides.resize(expandedShape.size(), b.getIndexAttr(1));
+ for (auto [collapsedSize, collapsedOffset, reassocIndices] :
+ llvm::zip_equal(collapsedSizes, collapsedOffsets, reassociation)) {
+ // CASE #1 - size and/or offset are dynamic.
+ // In this case, the slice can be represented as a contiguous slice only
+ // if there is a single dimension in the reassociation group that has a
+ // size not equal to 1.
+ if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
+ int nonUnitSizeCount = 0;
+ for (int64_t expandedShapeIdx : reassocIndices) {
+ if (expandedShape[expandedShapeIdx] != 1) {
+ nonUnitSizeCount++;
+ expandedSizes.push_back(collapsedSize);
+ expandedOffsets.push_back(collapsedOffset);
+ continue;
}
- groupExpandedSizes.push_back(
- rewriter.getIndexAttr(currentCollapsedsize));
- groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
+ expandedSizes.push_back(b.getIndexAttr(1));
+ expandedOffsets.push_back(b.getIndexAttr(0));
+ }
- currentCollapsedOffset /= expandedShapeSize;
+ if (nonUnitSizeCount != 1) {
+ return failure();
}
+ continue;
+ }
- // Now handle the leading dimensions where the slice size is equal to 1
- // (k-1...0).
- // The size for these dimensions must be 1 because of how we constructed
- // the slice size of the expanded shape. We spread the original collapsed
- // size over the expanded shape sizes until we reached dimension k where
- // the remaining size was smaller than the expanded shape size, and spread
- // the remaining size on it. So, now we are left with only 1s.
- for (idx++; idx < reassocGroupSize; ++idx) {
- int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
- int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
- groupExpandedSizes.push_back(rewriter.getIndexAttr(1));
- groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
- currentCollapsedOffset /= expandedShapeSize;
+ // CASE #2 = size and offset are static.
+ // Verify that the slice can be represented as a contiguous slice of the
+ // src of the collapse_shape.
+ // Checking this is done on order of most internal dimensions first,
+ // so traversal is done in reverse order of the reassociation group.
+ // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
+ // ...,An] then we first find the size and offset for n...k+1 then for k
+ // and then for k-1...0.
+
+ // currentCollapsedsize and currentCollapsedOffset are initialized with
+ // the original collapsed size and offset and divided by the expanded
+ // shape size in each dimension as we go along the reassociation group.
+ // In essence we are spreading the original collapsed size and offset over
+ // the various expanded slice dimensions.
+ // The variables are used both to check the validity of the slice and to
+ // compute the expanded sizes and offsets.
+ int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
+ int64_t currentCollapsedOffset =
+ getConstantIntValue(collapsedOffset).value();
+ SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
+ ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
+ reassocIndices.rend());
+ int64_t idx = 0;
+ int64_t reassocGroupSize = reassocIndices.size();
+
+ // First handle the trailing dimensions where the slice size should be
+ // equal to the tensor shape and the offset should be 0 (n...k+1).
+ for (; idx < reassocGroupSize; ++idx) {
+ int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
+
+ if (currentCollapsedsize < expandedShapeSize)
+ break;
+
+ // We need to make sure that the slice size can be set to the shape size
+ // and the offset to 0.
+ if ((currentCollapsedsize % expandedShapeSize) != 0 ||
+ (currentCollapsedOffset % expandedShapeSize) != 0) {
+ return failure();
}
- expandedSizes.append(groupExpandedSizes.rbegin(),
- groupExpandedSizes.rend());
- expandedOffsets.append(groupExpandedOffsets.rbegin(),
- groupExpandedOffsets.rend());
+ groupExpandedSizes.push_back(b.getIndexAttr(expandedShapeSize));
+ groupExpandedOffsets.push_back(b.getIndexAttr(0));
+
+ currentCollapsedsize /= expandedShapeSize;
+ currentCollapsedOffset /= expandedShapeSize;
}
- Value newSliceOp = tensor::ExtractSliceOp::create(
- rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(),
- expandedOffsets, expandedSizes, expandedStrides);
- rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
- sliceOp, sliceOp.getResultType(), newSliceOp,
- collapseShapeOp.getReassociationIndices());
+ // Now handle the first dim where slicing occurs on (k).
+ if (idx < reassocGroupSize) {
+ int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
+ int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
+ // We need to make sure that the slice size in this dim + offset will
+ // not exceed the shape size.
+ if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
+ return failure();
+ }
+ groupExpandedSizes.push_back(b.getIndexAttr(currentCollapsedsize));
+ groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim));
+ currentCollapsedOffset /= expandedShapeSize;
+ }
- return success();
+ // Now handle the leading dimensions where the slice size is equal to 1
+ // (k-1...0).
+ // The size for these dimensions must be 1 because of how we constructed
+ // the slice size of the expanded shape. We spread the original collapsed
+ // size over the expanded shape sizes until we reached dimension k where
+ // the remaining size was smaller than the expanded shape size, and spread
+ // the remaining size on it. So, now we are left with only 1s.
+ for (idx++; idx < reassocGroupSize; ++idx) {
+ int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
+ int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
+ groupExpandedSizes.push_back(b.getIndexAttr(1));
+ groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim));
+ currentCollapsedOffset /= expandedShapeSize;
+ }
+ expandedSizes.append(groupExpandedSizes.rbegin(),
+ groupExpandedSizes.rend());
+ expandedOffsets.append(groupExpandedOffsets.rbegin(),
+ groupExpandedOffsets.rend());
}
-};
-
-} // namespace
+ return success();
+}
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
RewritePatternSet &patterns) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index e3cba388..8d63646 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -122,8 +122,9 @@ struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
const APFloat lowestVal =
APFloat::getLargest(padConstVal.getSemantics(), true);
return padConstVal == lowestVal;
- } else if (auto padConstIntAttr =
- mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
+ }
+ if (auto padConstIntAttr =
+ mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
const APInt padConstVal = *padConstIntAttr.begin();
const unsigned int bitWidth = padConstVal.getBitWidth();
const APInt lowestVal =
@@ -555,7 +556,8 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
// Check we have a valid NaN propagation combination.
const auto opNanMode = op.getNanMode();
const auto clampNanMode = clampOp.getNanMode();
- if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
+ if (opNanMode == NanPropagationMode::IGNORE &&
+ clampNanMode == NanPropagationMode::PROPAGATE)
return failure();
auto maxValAttr = op.getMaxValAttr();
@@ -636,10 +638,16 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
}
}
+ auto newMode = (opNanMode != clampNanMode)
+ ? tosa::NanPropagationMode::IGNORE
+ : opNanMode;
+
+ auto newModeAttr =
+ NanPropagationModeAttr::get(rewriter.getContext(), newMode);
+
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
- rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
- : opNanMode));
+ newModeAttr);
return success();
}
};
@@ -1120,13 +1128,14 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
}
if (rhsTy == resultTy) {
- if (isSplatZero(resultETy, lhsAttr))
+ if (isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape())
+ // constant values can only be resized if resulting type is static
return lhsAttr.resizeSplat(resultTy);
if (isSplatOne(resultETy, lhsAttr, shift))
return rhs;
}
if (lhsTy == resultTy) {
- if (isSplatZero(resultETy, rhsAttr))
+ if (isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape())
return rhsAttr.resizeSplat(resultTy);
if (isSplatOne(resultETy, rhsAttr, shift))
return lhs;
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 3cafb19..bd7aee5 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -270,6 +270,244 @@ void mlir::tosa::printVariableOpTypeOrInitialValue(
}
}
+namespace {
+
+// parse attributes with special handling for tosa enum attributes
+template <typename EnumType>
+ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser,
+ NamedAttrList &outAttrs) {
+ llvm::StringRef name;
+ if (parser.parseOptionalKeyword(&name) || parser.parseEqual())
+ return failure();
+
+ // special handling: rounding_mode accepts a *bare* RoundingMode enum
+ // keyword.
+ llvm::StringRef kw;
+ if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
+ if (name == "rounding_mode" &&
+ succeeded(parser.parseOptionalKeyword(&kw))) {
+ auto sym = symbolizeRoundingMode(kw);
+ if (!sym)
+ return parser.emitError(parser.getCurrentLocation())
+ << "invalid rounding_mode value: " << kw;
+ auto attr = RoundingModeAttr::get(parser.getContext(), sym.value());
+ outAttrs.push_back(NamedAttribute(name, attr));
+ return success();
+ }
+ }
+ // special handling: mode accepts a *bare* ResizeMode enum keyword.
+ if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
+ if (name == "mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
+ auto sym = symbolizeResizeMode(kw);
+ if (!sym)
+ return parser.emitError(parser.getCurrentLocation())
+ << "invalid resize mode value: " << kw;
+ auto attr = ResizeModeAttr::get(parser.getContext(), sym.value());
+ outAttrs.push_back(NamedAttribute(name, attr));
+ return success();
+ }
+ }
+ // special handling: nan_mode accepts a *bare* NanPropagationMode enum
+ // keyword.
+ if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
+ if (name == "nan_mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
+ auto sym = symbolizeNanPropagationMode(kw);
+ if (!sym)
+ return parser.emitError(parser.getCurrentLocation())
+ << "invalid nan_mode value: " << kw;
+ auto attr = NanPropagationModeAttr::get(parser.getContext(), sym.value());
+ outAttrs.push_back(NamedAttribute(name, attr));
+ return success();
+ }
+ }
+
+ // Default path: parse any normal attribute literal, including fully qualified
+ // enum keyword
+ Attribute attr;
+ return parser.parseAttribute(attr, name, outAttrs);
+}
+
+template <typename EnumType>
+ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
+ // parse operands
+ SmallVector<OpAsmParser::UnresolvedOperand, 5> operands;
+ if (parser.parseCommaSeparatedList(
+ [&]() { return parser.parseOperand(operands.emplace_back()); }))
+ return failure();
+
+ // Parse { attr-dict } with special handling for enum bare token
+ NamedAttrList attrs;
+ if (succeeded(parser.parseOptionalLBrace()) &&
+ failed(parser.parseOptionalRBrace())) {
+ do {
+ if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
+ return failure();
+ } while (succeeded(parser.parseOptionalComma()));
+ if (parser.parseRBrace())
+ return failure();
+ }
+
+ FunctionType fnTy;
+ if (parser.parseColonType(fnTy))
+ return failure();
+
+ // Resolve operands and types
+ if (failed(parser.resolveOperands(operands, fnTy.getInputs(),
+ parser.getCurrentLocation(),
+ result.operands)))
+ return failure();
+
+ result.addTypes(fnTy.getResult(0));
+ result.addAttributes(attrs);
+
+ return success();
+}
+
+void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) {
+ parser << namedAttr.getName().strref() << " = ";
+ auto attr = namedAttr.getValue();
+ if (auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
+ parser << roundingModeAttr.getValue();
+ } else if (auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
+ parser << resizeModeAttr.getValue();
+ } else if (auto nanPropagationModeAttr =
+ dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
+ parser << nanPropagationModeAttr.getValue();
+ } else {
+ parser.printAttribute(attr);
+ }
+}
+
+// print with special handling for default valued NanPropagationMode attribute
+void printWithNanPropagationHandling(OpAsmPrinter &parser, Operation *op) {
+ parser << " ";
+ parser.printOperands(op->getOperands());
+
+ NamedAttrList toPrint(op->getAttrs());
+ // remove default NanPropagate attribute
+ const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
+ for (auto attr : op->getAttrs()) {
+ if (auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
+ if (nanAttr.getValue() == kDefaultNanValue) {
+ // elide from toPrint
+ toPrint.erase(attr.getName());
+ break;
+ }
+ }
+ }
+
+ if (!toPrint.empty()) {
+ parser << " {";
+ llvm::interleaveComma(toPrint, parser, [&](const NamedAttribute namedAttr) {
+ printNamedAttr(parser, namedAttr);
+ });
+ parser << "}";
+ }
+
+ parser << " : ";
+ parser.printFunctionalType(op);
+}
+
+// print with special handling for enums: RoundingMode, ResizeMode
+void printWithEnumHandling(OpAsmPrinter &parser, Operation *op) {
+ parser << " ";
+ parser.printOperands(op->getOperands());
+
+ if (!op->getAttrs().empty()) {
+ parser << " {";
+ llvm::interleaveComma(op->getAttrs(), parser,
+ [&](const NamedAttribute namedAttr) {
+ printNamedAttr(parser, namedAttr);
+ });
+ parser << "}";
+ }
+
+ parser << " : ";
+ parser.printFunctionalType(op);
+}
+
+} // namespace
+
+ParseResult RescaleOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
+}
+
+void RescaleOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
+ParseResult ApplyScaleOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
+}
+
+void ApplyScaleOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
+ParseResult ResizeOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::ResizeMode>(parser, result);
+}
+
+void ResizeOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
+ParseResult ArgMaxOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void ArgMaxOp::print(OpAsmPrinter &parser) {
+ printWithNanPropagationHandling(parser, *this);
+}
+
+ParseResult MaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void MaxPool2dOp::print(OpAsmPrinter &parser) {
+ printWithNanPropagationHandling(parser, *this);
+}
+
+ParseResult ClampOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void ClampOp::print(OpAsmPrinter &parser) {
+ printWithNanPropagationHandling(parser, *this);
+}
+
+ParseResult MaximumOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void MaximumOp::print(OpAsmPrinter &parser) {
+ printWithNanPropagationHandling(parser, *this);
+}
+
+ParseResult MinimumOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void MinimumOp::print(OpAsmPrinter &parser) {
+ printWithNanPropagationHandling(parser, *this);
+}
+
+ParseResult ReduceMaxOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void ReduceMaxOp::print(OpAsmPrinter &parser) {
+ printWithNanPropagationHandling(parser, *this);
+}
+
+ParseResult ReduceMinOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void ReduceMinOp::print(OpAsmPrinter &parser) {
+ printWithNanPropagationHandling(parser, *this);
+}
+
//===----------------------------------------------------------------------===//
// Tosa utilities.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index 5590927..8143b27 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -658,10 +658,10 @@ void TosaReduceTransposes::runOnOperation() {
// (like the TransposeOp we insert for ReshapeOp),
// but in this case, that is specialized enough and overlaps
// with another direct-use TransposeOp case we need to cover anyway.
- transposeInfo.push_back({transposeOp, dependentOps});
+ transposeInfo.emplace_back(transposeOp, dependentOps);
// This is for the final replacement across all transposes.
- totalTransposeOrder.push({transposeOp, perms});
+ totalTransposeOrder.emplace(transposeOp, perms);
});
// We want to do a full fan-in analysis on a perms-level,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index c7b9534..790bbf7 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -508,14 +508,15 @@ private:
bool attributeCheckRescale(Operation *op) {
if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
- if (rescale.getRoundingMode() == "DOUBLE_ROUND" &&
+ if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
!targetEnv.allows(Extension::doubleround)) {
op->emitOpError()
<< "failed attribute check: rounding_mode = DOUBLE_ROUND "
<< "requires extension [doubleround]";
return false;
- } else if (rescale.getRoundingMode() == "INEXACT_ROUND" &&
- !targetEnv.allows(Extension::inexactround)) {
+ }
+ if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
+ !targetEnv.allows(Extension::inexactround)) {
op->emitOpError()
<< "failed attribute check: rounding_mode = INEXACT_ROUND "
<< "requires extension [inexactround]";
@@ -1122,7 +1123,7 @@ bool checkErrorIfRescale(Operation *op) {
}
// ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
- if (!scale32 && roundingMode == "DOUBLE_ROUND") {
+ if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND) {
op->emitOpError() << "DOUBLE_ROUND is only allowed with scale32=true.";
return false;
}
@@ -1307,7 +1308,8 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
if (isa<FloatType>(type)) {
return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
Float8E5M2Type>(type);
- } else if (auto intTy = dyn_cast<IntegerType>(type)) {
+ }
+ if (auto intTy = dyn_cast<IntegerType>(type)) {
if (intTy.isSignless()) {
switch (intTy.getWidth()) {
case 1:
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 9266a63..48df1a0 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -37,16 +37,13 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/InterleavedRange.h"
#include <optional>
#define DEBUG_TYPE "transform-dialect"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
-
#define DEBUG_TYPE_MATCHER "transform-matcher"
-#define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
-#define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
using namespace mlir;
@@ -182,8 +179,7 @@ transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
DiagnosedSilenceableFailure result =
state.applyTransform(cast<TransformOpInterface>(transform));
if (result.isSilenceableFailure()) {
- LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
- << "\n");
+ LDBG() << "alternative failed: " << result.getMessage();
failed = true;
break;
}
@@ -1155,12 +1151,10 @@ transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
std::optional<DiagnosedSilenceableFailure> maybeFailure;
for (Operation *root : state.getPayloadOps(getRoot())) {
WalkResult walkResult = root->walk([&](Operation *op) {
- DEBUG_MATCHER({
- DBGS_MATCHER() << "matching ";
- op->print(llvm::dbgs(),
- OpPrintingFlags().assumeVerified().skipRegions());
- llvm::dbgs() << " @" << op << "\n";
- });
+ LDBG(1, DEBUG_TYPE_MATCHER)
+ << "matching "
+ << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions())
+ << " @" << op;
// Try matching.
SmallVector<SmallVector<MappedValue>> mappings;
@@ -1172,8 +1166,8 @@ transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
if (diag.isDefiniteFailure())
return WalkResult::interrupt();
if (diag.isSilenceableFailure()) {
- DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
- << " failed: " << diag.getMessage());
+ LDBG(1, DEBUG_TYPE_MATCHER) << "matcher " << matcher.getName()
+ << " failed: " << diag.getMessage();
return WalkResult::advance();
}
@@ -1304,12 +1298,10 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
if (!getRestrictRoot() && op == root)
return WalkResult::advance();
- DEBUG_MATCHER({
- DBGS_MATCHER() << "matching ";
- op->print(llvm::dbgs(),
- OpPrintingFlags().assumeVerified().skipRegions());
- llvm::dbgs() << " @" << op << "\n";
- });
+ LDBG(1, DEBUG_TYPE_MATCHER)
+ << "matching "
+ << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions())
+ << " @" << op;
firstMatchArgument.clear();
firstMatchArgument.push_back(op);
@@ -1322,8 +1314,8 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
if (diag.isDefiniteFailure())
return WalkResult::interrupt();
if (diag.isSilenceableFailure()) {
- DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
- << " failed: " << diag.getMessage());
+ LDBG(1, DEBUG_TYPE_MATCHER) << "matcher " << matcher.getName()
+ << " failed: " << diag.getMessage();
continue;
}
@@ -2173,10 +2165,10 @@ DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
::std::optional<::mlir::Operation *> maybeCurrent,
transform::TransformResults &results, transform::TransformState &state) {
if (!maybeCurrent.has_value()) {
- DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; });
+ LDBG(1, DEBUG_TYPE_MATCHER) << "MatchOperationEmptyOp success";
return DiagnosedSilenceableFailure::success();
}
- DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; });
+ LDBG(1, DEBUG_TYPE_MATCHER) << "MatchOperationEmptyOp failure";
return emitSilenceableError() << "operation is not empty";
}
diff --git a/mlir/lib/Dialect/Transform/IR/Utils.cpp b/mlir/lib/Dialect/Transform/IR/Utils.cpp
index d666390..773eb13 100644
--- a/mlir/lib/Dialect/Transform/IR/Utils.cpp
+++ b/mlir/lib/Dialect/Transform/IR/Utils.cpp
@@ -11,6 +11,7 @@
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
using namespace mlir;
@@ -90,7 +91,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
//
// Rename private symbols in both ops in order to resolve conflicts that can
// be resolved that way.
- LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
+ LDBG() << "renaming private symbols to resolve conflicts:";
// TODO: Do we *actually* need to test in both directions?
for (auto &&[symbolTable, otherSymbolTable] : llvm::zip(
SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable},
@@ -102,7 +103,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
if (!symbolOp)
continue;
StringAttr name = symbolOp.getNameAttr();
- LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n");
+ LDBG() << " found @" << name.getValue();
// Check if there is a colliding op in the other module.
auto collidingOp =
@@ -110,7 +111,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
if (!collidingOp)
continue;
- LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue());
+ LDBG() << " collision found for @" << name.getValue();
// Collisions are fine if both opt are functions and can be merged.
if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
@@ -119,13 +120,12 @@ transform::detail::mergeSymbolsInto(Operation *target,
funcOp && collidingFuncOp) {
if (canMergeInto(funcOp, collidingFuncOp) ||
canMergeInto(collidingFuncOp, funcOp)) {
- LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
- "will be merged\n");
+ LDBG() << " but both ops are functions and will be merged";
continue;
}
// If they can't be merged, proceed like any other collision.
- LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
+ LDBG() << " and both ops are function definitions";
}
// Collision can be resolved by renaming if one of the ops is private.
@@ -133,7 +133,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
[&](SymbolOpInterface op, SymbolOpInterface otherOp,
SymbolTable &symbolTable,
SymbolTable &otherSymbolTable) -> InFlightDiagnostic {
- LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
+ LDBG() << ", renaming";
FailureOr<StringAttr> maybeNewName =
symbolTable.renameToUnique(op, {&otherSymbolTable});
if (failed(maybeNewName)) {
@@ -142,8 +142,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
<< "attempted renaming due to collision with this op";
return diag;
}
- LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue()
- << "\n");
+ LDBG() << " renamed to @" << maybeNewName->getValue();
return InFlightDiagnostic();
};
@@ -161,7 +160,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
return diag;
continue;
}
- LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
+ LDBG() << ", emitting error";
InFlightDiagnostic diag = symbolOp.emitError()
<< "doubly defined symbol @" << name.getValue();
diag.attachNote(collidingOp->getLoc()) << "previously defined here";
@@ -179,7 +178,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
// Step 2:
//
// Move all ops from `other` into target and merge public symbols.
- LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
+ LDBG() << "moving all symbols into target";
{
SmallVector<SymbolOpInterface> opsToMove;
for (Operation &op : other->getRegion(0).front()) {
@@ -193,13 +192,13 @@ transform::detail::mergeSymbolsInto(Operation *target,
targetSymbolTable.lookup(op.getNameAttr()));
// Move op even if we get a collision.
- LLVM_DEBUG(DBGS() << " moving @" << op.getName());
+ LDBG() << " moving @" << op.getName();
op->moveBefore(&target->getRegion(0).front(),
target->getRegion(0).front().end());
// If there is no collision, we are done.
if (!collidingOp) {
- LLVM_DEBUG(llvm::dbgs() << " without collision\n");
+ LDBG() << " without collision";
continue;
}
@@ -217,9 +216,9 @@ transform::detail::mergeSymbolsInto(Operation *target,
}
assert(canMergeInto(funcOp, collidingFuncOp));
- LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
- << collidingFuncOp.getLoc() << ":\n"
- << collidingFuncOp << "\n");
+ LDBG() << " with collision, trying to keep op at "
+ << collidingFuncOp.getLoc() << ":\n"
+ << collidingFuncOp;
// Update symbol table. This works with or without the previous `swap`.
targetSymbolTable.remove(funcOp);
@@ -239,6 +238,6 @@ transform::detail::mergeSymbolsInto(Operation *target,
return target->emitError()
<< "failed to verify target op after merging symbols";
- LLVM_DEBUG(DBGS() << "done merging ops\n");
+ LDBG() << "done merging ops";
return InFlightDiagnostic();
}
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index 14a4fdf..4f4620a 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -312,7 +312,7 @@ LogicalResult transform::TransformState::setParams(Value value,
}
template <typename Mapping, typename Key, typename Mapped>
-void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
+static void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
auto it = mapping.find(key);
if (it == mapping.end())
return;
@@ -771,7 +771,7 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
}
template <typename T>
-DiagnosedSilenceableFailure
+static DiagnosedSilenceableFailure
checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
transform::TransformOpInterface transform,
unsigned operandNumber) {
diff --git a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp
index 41955c8..3ced1a6 100644
--- a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp
@@ -100,12 +100,7 @@ LogicalResult PatternApplicatorExtension::findAllMatches(
PatternApplicator applicator(it->second);
// We want to discourage direct use of PatternRewriter in APIs but In this
// very specific case, an IRRewriter is not enough.
- struct TrivialPatternRewriter : public PatternRewriter {
- public:
- explicit TrivialPatternRewriter(MLIRContext *context)
- : PatternRewriter(context) {}
- };
- TrivialPatternRewriter rewriter(root->getContext());
+ PatternRewriter rewriter(root->getContext());
applicator.applyDefaultCostModel();
root->walk([&](Operation *op) {
if (succeeded(applicator.matchAndRewrite(op, rewriter)))
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
index 35ace1b..9ab484f 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -121,6 +121,80 @@ ModuleOp transform::detail::getPreloadedTransformModule(MLIRContext *context) {
->getLibraryModule();
}
+static transform::TransformOpInterface
+findTransformEntryPointNonRecursive(Operation *op, StringRef entryPoint) {
+ for (Region &region : op->getRegions()) {
+ for (Block &block : region.getBlocks()) {
+ for (auto namedSequenceOp : block.getOps<transform::NamedSequenceOp>()) {
+ if (namedSequenceOp.getSymName() == entryPoint) {
+ return cast<transform::TransformOpInterface>(
+ namedSequenceOp.getOperation());
+ }
+ }
+ }
+ }
+ return nullptr;
+}
+
+static transform::TransformOpInterface
+findTransformEntryPointRecursive(Operation *op, StringRef entryPoint) {
+ transform::TransformOpInterface transform = nullptr;
+ op->walk<WalkOrder::PreOrder>(
+ [&](transform::NamedSequenceOp namedSequenceOp) {
+ if (namedSequenceOp.getSymName() == entryPoint) {
+ transform = cast<transform::TransformOpInterface>(
+ namedSequenceOp.getOperation());
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ return transform;
+}
+
+// Will look for the transform's entry point favouring NamedSequenceOps
+// ops that exist within the operation without the need for nesting.
+// If no operation exists in the blocks owned by op, then it will recursively
+// walk the op in preorder and find the first NamedSequenceOp that matches
+// the entry point's name.
+//
+// This allows for the following two use cases:
+// 1. op is a module annotated with the transform.with_named_sequence attribute
+// that has an entry point in its block. E.g.,
+//
+// ```mlir
+// module {transform.with_named_sequence} {
+// transform.named_sequence @__transform_main(%arg0 : !transform.any_op) ->
+// () {
+// transform.yield
+// }
+// }
+// ```
+//
+// 2. op is a program which contains a nested module annotated with the
+// transform.with_named_sequence attribute. E.g.,
+//
+// ```mlir
+// module {
+// func.func @foo () {
+// }
+//
+// module {transform.with_named_sequence} {
+// transform.named_sequence @__transform_main(%arg0 : !transform.any_op)
+// -> () {
+// transform.yield
+// }
+// }
+// }
+// ```
+static transform::TransformOpInterface
+findTransformEntryPointInOp(Operation *op, StringRef entryPoint) {
+ transform::TransformOpInterface transform =
+ findTransformEntryPointNonRecursive(op, entryPoint);
+ if (!transform)
+ transform = findTransformEntryPointRecursive(op, entryPoint);
+ return transform;
+}
+
transform::TransformOpInterface
transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
StringRef entryPoint) {
@@ -128,16 +202,8 @@ transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
if (module)
l.push_back(module);
for (Operation *op : l) {
- transform::TransformOpInterface transform = nullptr;
- op->walk<WalkOrder::PreOrder>(
- [&](transform::NamedSequenceOp namedSequenceOp) {
- if (namedSequenceOp.getSymName() == entryPoint) {
- transform = cast<transform::TransformOpInterface>(
- namedSequenceOp.getOperation());
- return WalkResult::interrupt();
- }
- return WalkResult::advance();
- });
+ TransformOpInterface transform =
+ findTransformEntryPointInOp(op, entryPoint);
if (transform)
return transform;
}
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index bc85cf4..7b2734d 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -407,7 +407,7 @@ mlir::convertReassociationIndicesToExprs(
}
template <typename AffineExprTy>
-unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
+static unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
unsigned pos = 0;
for (const auto &exprs : exprArrays) {
for (auto expr : exprs) {
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index e6ef028..34385d7 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -276,7 +276,7 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
if (!ubConstant)
return std::nullopt;
std::optional<int64_t> stepConstant = getConstantIntValue(step);
- if (!stepConstant)
+ if (!stepConstant || *stepConstant == 0)
return std::nullopt;
return llvm::divideCeilSigned(*ubConstant - *lbConstant, *stepConstant);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a450056..9b2a455 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2402,6 +2402,16 @@ LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
return foldToElementsFromElements(*this, results);
}
+LogicalResult
+ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
+ ToElementsOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ auto vecType = cast<VectorType>(adaptor.getSource().getType());
+ Type elType = vecType.getElementType();
+ inferredReturnTypes.append(vecType.getNumElements(), elType);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// FromElementsOp
//===----------------------------------------------------------------------===//
@@ -2456,8 +2466,12 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
if (llvm::any_of(elements, [](Attribute attr) { return !attr; }))
return {};
+ // DenseElementsAttr only supports int/index/float/complex types.
auto destVecType = fromElementsOp.getDest().getType();
auto destEltType = destVecType.getElementType();
+ if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
+ return {};
+
// Constant attributes might have a different type than the return type.
// Convert them before creating the dense elements attribute.
auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
@@ -2768,8 +2782,8 @@ BroadcastableToResult mlir::vector::isBroadcastableTo(
Type srcType, VectorType dstVectorType,
std::pair<VectorDim, VectorDim> *mismatchingDims) {
// Broadcast scalar to vector of the same element type.
- if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
- getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
+ if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
+ srcType == getElementTypeOrSelf(dstVectorType))
return BroadcastableToResult::Success;
// From now on, only vectors broadcast.
VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
@@ -2841,9 +2855,47 @@ LogicalResult BroadcastOp::verify() {
llvm_unreachable("unexpected vector.broadcast op error");
}
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type and shape_cast only adds or removes ones in the
+// leading dimensions.
+static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
+ auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
+ if (!srcShapeCast)
+ return failure();
+
+ VectorType srcType = srcShapeCast.getSourceVectorType();
+ VectorType destType = broadcastOp.getResultVectorType();
+ // Check type compatibility.
+ if (vector::isBroadcastableTo(srcType, destType) !=
+ BroadcastableToResult::Success)
+ return failure();
+
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ ArrayRef<int64_t> shapecastShape =
+ srcShapeCast.getResultVectorType().getShape();
+ // Trailing dimensions should be the same if shape_cast only alters the
+ // leading dimensions.
+ unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
+ if (!llvm::equal(srcShape.take_back(numTrailingDims),
+ shapecastShape.take_back(numTrailingDims)))
+ return failure();
+
+ assert(all_of(srcShape.drop_back(numTrailingDims),
+ [](int64_t E) { return E == 1; }) &&
+ all_of(shapecastShape.drop_back(numTrailingDims),
+ [](int64_t E) { return E == 1; }) &&
+ "ill-formed shape_cast");
+
+ broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
+ return success();
+}
+
OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
if (getSourceType() == getResultVectorType())
return getSource();
+ if (succeeded(foldBroadcastOfShapeCast(*this)))
+ return getResult();
+
if (!adaptor.getSource())
return {};
auto vectorType = getResultVectorType();
@@ -3238,6 +3290,18 @@ LogicalResult InsertOp::verify() {
return success();
}
+// Calculate the linearized position of the continuous chunk of elements to
+// insert, based on the shape of the value to insert and the positions to insert
+// at.
+static int64_t calculateInsertPosition(VectorType destTy,
+ ArrayRef<int64_t> positions) {
+ llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
+ assert(positions.size() <= completePositions.size() &&
+ "positions size must be less than or equal to destTy rank");
+ copy(positions, completePositions.begin());
+ return linearize(completePositions, computeStrides(destTy.getShape()));
+}
+
namespace {
// If insertOp is only inserting unit dimensions it can be transformed to a
@@ -3275,6 +3339,132 @@ public:
return success();
}
};
+
+/// Pattern to optimize a chain of insertions.
+///
+/// This pattern identifies chains of vector.insert operations that:
+/// 1. Only insert values at static positions.
+/// 2. Completely initialize all elements in the resulting vector.
+/// 3. All intermediate insert operations have only one use.
+///
+/// When these conditions are met, the entire chain can be replaced with a
+/// single vector.from_elements operation.
+///
+/// To keep this pattern simple, and avoid spending too much time on matching
+/// fragmented insert chains, this pattern only considers the last insert op in
+/// the chain.
+///
+/// Example transformation:
+/// %poison = ub.poison : vector<2xi32>
+/// %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
+/// %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
+/// ->
+/// %result = vector.from_elements %c1, %c2 : vector<2xi32>
+class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(InsertOp op,
+ PatternRewriter &rewriter) const override {
+
+ VectorType destTy = op.getDestVectorType();
+ if (destTy.isScalable())
+ return failure();
+ // Ensure this is the trailing vector.insert op in a chain of inserts.
+ for (Operation *user : op.getResult().getUsers())
+ if (auto insertOp = dyn_cast<InsertOp>(user))
+ if (insertOp.getDest() == op.getResult())
+ return failure();
+
+ InsertOp currentOp = op;
+ SmallVector<InsertOp> chainInsertOps;
+ while (currentOp) {
+ // Check cond 1: Dynamic position is not supported.
+ if (currentOp.hasDynamicPosition())
+ return failure();
+
+ chainInsertOps.push_back(currentOp);
+ currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
+ // Check cond 3: Intermediate inserts have only one use to avoid an
+ // explosion of vectors.
+ if (currentOp && !currentOp->hasOneUse())
+ return failure();
+ }
+
+ int64_t vectorSize = destTy.getNumElements();
+ int64_t initializedCount = 0;
+ SmallVector<bool> initializedDestIdxs(vectorSize, false);
+ SmallVector<int64_t> pendingInsertPos;
+ SmallVector<int64_t> pendingInsertSize;
+ SmallVector<Value> pendingInsertValues;
+
+ for (auto insertOp : chainInsertOps) {
+ // This pattern can do nothing with poison index.
+ if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
+ return failure();
+
+ // Calculate the linearized position for inserting elements.
+ int64_t insertBeginPosition =
+ calculateInsertPosition(destTy, insertOp.getStaticPosition());
+
+ // The valueToStore operand may be a vector or a scalar. Need to handle
+ // both cases.
+ int64_t insertSize = 1;
+ if (auto srcVectorType =
+ llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
+ insertSize = srcVectorType.getNumElements();
+
+ assert(insertBeginPosition + insertSize <= vectorSize &&
+ "insert would overflow the vector");
+
+ for (auto index : llvm::seq<int64_t>(insertBeginPosition,
+ insertBeginPosition + insertSize)) {
+ if (initializedDestIdxs[index])
+ continue;
+ initializedDestIdxs[index] = true;
+ ++initializedCount;
+ }
+
+ // Defer the creation of ops before we can make sure the pattern can
+ // succeed.
+ pendingInsertPos.push_back(insertBeginPosition);
+ pendingInsertSize.push_back(insertSize);
+ pendingInsertValues.push_back(insertOp.getValueToStore());
+
+ if (initializedCount == vectorSize)
+ break;
+ }
+
+ // Check cond 2: all positions must be initialized.
+ if (initializedCount != vectorSize)
+ return failure();
+
+ SmallVector<Value> elements(vectorSize);
+ for (auto [insertBeginPosition, insertSize, valueToStore] :
+ llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
+ pendingInsertValues))) {
+ auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
+
+ if (!srcVectorType) {
+ elements[insertBeginPosition] = valueToStore;
+ continue;
+ }
+
+ SmallVector<Type> elementToInsertTypes(insertSize,
+ srcVectorType.getElementType());
+ // Get all elements from the vector in row-major order.
+ auto elementsToInsert = rewriter.create<vector::ToElementsOp>(
+ op.getLoc(), elementToInsertTypes, valueToStore);
+ for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
+ elements[insertBeginPosition + linearIdx] =
+ elementsToInsert.getResult(linearIdx);
+ }
+ }
+
+ rewriter.replaceOpWithNewOp<vector::FromElementsOp>(op, destTy, elements);
+ return success();
+ }
+};
+
} // namespace
static Attribute
@@ -3301,13 +3491,9 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
!insertOp->hasOneUse())
return {};
- // Calculate the linearized position of the continuous chunk of elements to
- // insert.
- llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
- copy(insertOp.getStaticPosition(), completePositions.begin());
+ // Calculate the linearized position for inserting elements.
int64_t insertBeginPosition =
- linearize(completePositions, computeStrides(destTy.getShape()));
-
+ calculateInsertPosition(destTy, insertOp.getStaticPosition());
SmallVector<Attribute> insertedValues;
Type destEltType = destTy.getElementType();
@@ -3343,7 +3529,8 @@ static Value foldInsertUseChain(InsertOp insertOp) {
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
+ results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
+ InsertChainFullyInitialized>(context);
}
OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
@@ -5599,7 +5786,7 @@ LogicalResult GatherOp::verify() {
if (resVType.getElementType() != baseType.getElementType())
return emitOpError("base and result element type should match");
- if (llvm::size(getIndices()) != baseType.getRank())
+ if (llvm::size(getOffsets()) != baseType.getRank())
return emitOpError("requires ") << baseType.getRank() << " indices";
if (resVType.getShape() != indVType.getShape())
return emitOpError("expected result dim to match indices dim");
@@ -5671,11 +5858,11 @@ public:
if (!isa<MemRefType>(op.getBase().getType()))
return rewriter.notifyMatchFailure(op, "base must be of memref type");
- if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
+ if (failed(isZeroBasedContiguousSeq(op.getIndices())))
return failure();
rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
- op.getIndices(), op.getMask(),
+ op.getOffsets(), op.getMask(),
op.getPassThru());
return success();
}
@@ -5699,7 +5886,7 @@ LogicalResult ScatterOp::verify() {
if (valueVType.getElementType() != memType.getElementType())
return emitOpError("base and valueToStore element type should match");
- if (llvm::size(getIndices()) != memType.getRank())
+ if (llvm::size(getOffsets()) != memType.getRank())
return emitOpError("requires ") << memType.getRank() << " indices";
if (valueVType.getShape() != indVType.getShape())
return emitOpError("expected valueToStore dim to match indices dim");
@@ -5734,11 +5921,11 @@ public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ScatterOp op,
PatternRewriter &rewriter) const override {
- if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
+ if (failed(isZeroBasedContiguousSeq(op.getIndices())))
return failure();
rewriter.replaceOpWithNewOp<MaskedStoreOp>(
- op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
+ op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
return success();
}
};
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 2d5cc07..fe066dc 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -139,6 +139,11 @@ void transform::ApplyLowerGatherPatternsOp::populatePatterns(
vector::populateVectorGatherLoweringPatterns(patterns);
}
+void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateVectorFromElementsLoweringPatterns(patterns);
+}
+
void transform::ApplyLowerScanPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorScanLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 6619619..546099c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -162,7 +162,7 @@ struct GatherOpInterface
return failure();
replaceOpWithNewBufferizedOp<vector::GatherOp>(
rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
- gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(),
+ gatherOp.getOffsets(), gatherOp.getIndices(), gatherOp.getMask(),
gatherOp.getPassThru());
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 9e287fc..acbf2b7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorBitCast.cpp
LowerVectorBroadcast.cpp
LowerVectorContract.cpp
+ LowerVectorFromElements.cpp
LowerVectorGather.cpp
LowerVectorInterleave.cpp
LowerVectorMask.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp
new file mode 100644
index 0000000..c22fd54
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp
@@ -0,0 +1,65 @@
+//===- LowerVectorFromElements.cpp - Lower 'vector.from_elements' op -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.from_elements' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+
+#define DEBUG_TYPE "lower-vector-from-elements"
+
+using namespace mlir;
+
+namespace {
+
+/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
+/// outermost dimension. For example:
+/// ```
+/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
+///
+/// ==>
+///
+/// %0 = ub.poison : vector<2x3xf32>
+/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
+/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
+/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
+/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
+/// ```
+///
+/// When applied exhaustively, this will produce a sequence of 1-d from_elements
+/// ops.
+struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::FromElementsOp op,
+ PatternRewriter &rewriter) const override {
+ ValueRange allElements = op.getElements();
+
+ auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
+ VectorType subTy, int64_t index) {
+ size_t subTyNumElements = subTy.getNumElements();
+ assert((index + 1) * subTyNumElements <= allElements.size() &&
+ "out of bounds");
+ ValueRange subElements =
+ allElements.slice(index * subTyNumElements, subTyNumElements);
+ return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
+ };
+
+ return unrollVectorOp(op, rewriter, unrollFromElementsFn);
+ }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorFromElementsLoweringPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<UnrollFromElements>(patterns.getContext(), benefit);
+}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index e062f55..9830189 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -54,27 +54,13 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
LogicalResult matchAndRewrite(vector::GatherOp op,
PatternRewriter &rewriter) const override {
- VectorType resultTy = op.getType();
- if (resultTy.getRank() < 2)
- return rewriter.notifyMatchFailure(op, "already 1-D");
-
- // Unrolling doesn't take vscale into account. Pattern is disabled for
- // vectors with leading scalable dim(s).
- if (resultTy.getScalableDims().front())
- return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
-
- Location loc = op.getLoc();
- Value indexVec = op.getIndexVec();
+ Value indexVec = op.getIndices();
Value maskVec = op.getMask();
Value passThruVec = op.getPassThru();
- Value result = arith::ConstantOp::create(rewriter, loc, resultTy,
- rewriter.getZeroAttr(resultTy));
-
- VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
-
- for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
- int64_t thisIdx[1] = {i};
+ auto unrollGatherFn = [&](PatternRewriter &rewriter, Location loc,
+ VectorType subTy, int64_t index) {
+ int64_t thisIdx[1] = {index};
Value indexSubVec =
vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
@@ -82,15 +68,12 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx);
Value passThruSubVec =
vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx);
- Value subGather = vector::GatherOp::create(
- rewriter, loc, subTy, op.getBase(), op.getIndices(), indexSubVec,
- maskSubVec, passThruSubVec);
- result =
- vector::InsertOp::create(rewriter, loc, subGather, result, thisIdx);
- }
+ return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(),
+ op.getOffsets(), indexSubVec, maskSubVec,
+ passThruSubVec);
+ };
- rewriter.replaceOp(op, result);
- return success();
+ return unrollVectorOp(op, rewriter, unrollGatherFn);
}
};
@@ -158,18 +141,18 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
// 2. Generate new gather indices that will model the
// strided access.
IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
- VectorType vType = op.getIndexVec().getType();
+ VectorType vType = op.getIndices().getType();
Value mulCst = arith::ConstantOp::create(
rewriter, op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
Value newIdxs =
- arith::MulIOp::create(rewriter, op.getLoc(), op.getIndexVec(), mulCst);
+ arith::MulIOp::create(rewriter, op.getLoc(), op.getIndices(), mulCst);
// 3. Create an updated gather op with the collapsed input memref and the
// updated indices.
Value newGather = vector::GatherOp::create(
rewriter, op.getLoc(), op.getResult().getType(), collapsed,
- op.getIndices(), newIdxs, op.getMask(), op.getPassThru());
+ op.getOffsets(), newIdxs, op.getMask(), op.getPassThru());
rewriter.replaceOp(op, newGather);
return success();
@@ -212,8 +195,8 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
- op.getIndexVec());
- auto baseOffsets = llvm::to_vector(op.getIndices());
+ op.getIndices());
+ auto baseOffsets = llvm::to_vector(op.getOffsets());
Value lastBaseOffset = baseOffsets.back();
Value result = op.getPassThru();
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 45ef7f0..5617b06 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -269,7 +269,7 @@ public:
// Replace the `vector.mask` operation.
rewriter.replaceOpWithNewOp<GatherOp>(
maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
- gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(),
+ gatherOp.getOffsets(), gatherOp.getIndices(), maskingOp.getMask(),
passthru);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index bb0f339..c84eb2c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -528,8 +528,7 @@ struct WarpOpTransferWrite : public WarpDistributionPattern {
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
Operation *lastNode = yield->getPrevNode();
auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
if (!writeOp)
@@ -706,6 +705,52 @@ struct WarpOpConstant : public WarpDistributionPattern {
}
};
+/// Sink out step op feeding into a warp op yield.
+/// Vector step op is treated similar to arith.constant, apart from
+/// the result that represents a sequence [0, vec_size).
+/// Due to the to vec_size == warp_size limitation,
+/// we can simply wrap the lane id into a vector (i.e., broadcast).
+/// Supporting vec_size != warp_size may involve preserving the step
+/// result and using additional arith ops (the exact details are TBD).
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xindex>) {
+/// ...
+/// %cst = vector.step : vector<32xindex>
+/// gpu.yield %cst : vector<1xindex>
+/// }
+/// ```
+/// To
+/// ```
+/// gpu.warp_execute_on_lane_0(%arg0) {
+/// ...
+/// }
+/// %lane_id_vec = vector.broadcast %arg0 : index to vector<1xindex>
+struct WarpOpStep final : public WarpDistributionPattern {
+ using Base::Base;
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *yieldOperand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
+ if (!yieldOperand)
+ return failure();
+ const unsigned operandIdx = yieldOperand->getOperandNumber();
+ auto stepOp = yieldOperand->get().getDefiningOp<vector::StepOp>();
+ VectorType resTy = stepOp.getResult().getType();
+ if (resTy.getNumElements() != static_cast<int64_t>(warpOp.getWarpSize()))
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ llvm::formatv("Expected result size ({0}) to be of warp size ({1})",
+ resTy.getNumElements(), warpOp.getWarpSize()));
+ VectorType newVecTy =
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
+ rewriter.setInsertionPointAfter(warpOp);
+ Value laneIdVec = vector::BroadcastOp::create(rewriter, warpOp.getLoc(),
+ newVecTy, warpOp.getLaneid());
+ rewriter.replaceAllUsesWith(warpOp.getResult(operandIdx), laneIdVec);
+ return success();
+ }
+};
+
/// Sink out transfer_read op feeding into a warp op yield.
/// ```
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
@@ -846,8 +891,7 @@ struct WarpOpDeadResult : public WarpDistributionPattern {
newYieldValues.reserve(warpOp->getNumResults());
DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
DenseMap<OpResult, int64_t> dedupResultPositionMap;
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
// Some values may be yielded multiple times and correspond to multiple
// results. Deduplicating occurs by taking each result with its matching
@@ -901,8 +945,7 @@ struct WarpOpForwardOperand : public WarpDistributionPattern {
using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
Value valForwarded;
unsigned resultIndex;
for (OpOperand &operand : yield->getOpOperands()) {
@@ -1708,8 +1751,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto warpOpYield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp warpOpYield = warpOp.getTerminator();
// Only pick up `ForOp` if it is the last op in the region.
Operation *lastNode = warpOpYield->getPrevNode();
auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
@@ -1826,7 +1868,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
rewriter.setInsertionPointAfter(newWarpOp);
auto newForOp = scf::ForOp::create(
rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), newForOpOperands);
+ forOp.getStep(), newForOpOperands, /*bodyBuilder=*/nullptr,
+ forOp.getUnsignedCmp());
// Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
// newly created `ForOp`. This `WarpOp` will contain all ops that were
// contained within the original `ForOp` body.
@@ -2019,7 +2062,7 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
- WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>(
+ WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
patterns.getContext(), benefit);
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
benefit);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 491b448..7dde631 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -762,6 +762,42 @@ struct LinearizeVectorStore final
}
};
+/// This pattern linearizes `vector.from_elements` operations by converting
+/// the result type to a 1-D vector while preserving all element values.
+/// The transformation creates a linearized `vector.from_elements` followed by
+/// a `vector.shape_cast` to restore the original multidimensional shape.
+///
+/// Example:
+///
+/// %0 = vector.from_elements %a, %b, %c, %d : vector<2x2xf32>
+///
+/// is converted to:
+///
+/// %0 = vector.from_elements %a, %b, %c, %d : vector<4xf32>
+/// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32>
+///
+struct LinearizeVectorFromElements final
+ : public OpConversionPattern<vector::FromElementsOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorFromElements(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+ LogicalResult
+ matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType dstTy =
+ getTypeConverter()->convertType<VectorType>(fromElementsOp.getType());
+ assert(dstTy && "vector type destination expected.");
+
+ OperandRange elements = fromElementsOp.getElements();
+ assert(elements.size() == static_cast<size_t>(dstTy.getNumElements()) &&
+ "expected same number of elements");
+ rewriter.replaceOpWithNewOp<vector::FromElementsOp>(fromElementsOp, dstTy,
+ elements);
+ return success();
+ }
+};
+
} // namespace
/// This method defines the set of operations that are linearizable, and hence
@@ -854,7 +890,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
- LinearizeVectorStore>(typeConverter, patterns.getContext());
+ LinearizeVectorStore, LinearizeVectorFromElements>(
+ typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index c707f38..369857f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -98,8 +98,9 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
// If the user has already been processed skip.
if (!processed.insert(user).second)
continue;
- if (isa<ViewLikeOpInterface>(user)) {
- users.append(user->getUsers().begin(), user->getUsers().end());
+ if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
+ Value viewDest = viewLike.getViewDest();
+ users.append(viewDest.getUsers().begin(), viewDest.getUsers().end());
continue;
}
if (isMemoryEffectFree(user))
@@ -182,8 +183,9 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
// If the user has already been processed skip.
if (!processed.insert(user).second)
continue;
- if (isa<ViewLikeOpInterface>(user)) {
- users.append(user->getUsers().begin(), user->getUsers().end());
+ if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
+ Value viewDest = viewLike.getViewDest();
+ users.append(viewDest.getUsers().begin(), viewDest.getUsers().end());
continue;
}
if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 2269a40..dbb5eb3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -600,7 +600,7 @@ struct BubbleDownVectorBitCastForExtract
// Get the first element of the mixed position as integer.
auto mixedPos = extractOp.getMixedPosition();
- if (mixedPos.size() > 0 && !isa<Attribute>(mixedPos[0]))
+ if (!mixedPos.empty() && !isa<Attribute>(mixedPos[0]))
return failure();
uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt();
@@ -2274,7 +2274,7 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
LogicalResult matchAndRewrite(MulOpType mulOp,
PatternRewriter &rewriter) const override {
- auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
+ auto resType = llvm::dyn_cast<VectorType>(mulOp.getResult().getType());
if (!resType)
return failure();
if (resType.getRank() != 2)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 501abec..e8ecb0c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -640,7 +640,7 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
// decomposed shape from each of the index, mask, and pass-through
// vectors.
Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
- loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
+ loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
Value passThruSubVec =
@@ -648,7 +648,7 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
strides);
auto slicedGather = vector::GatherOp::create(
- rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
+ rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
indexSubVec, maskSubVec, passThruSubVec);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 10ed2bc..841e138 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -279,14 +279,16 @@ vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
// Attempt to unroll until targetRank or the first scalable dimension (which
// cannot be unrolled).
auto shapeToUnroll = vType.getShape().drop_back(targetRank);
- auto scalableDimsToUnroll = vType.getScalableDims().drop_back(targetRank);
- auto it = llvm::find(scalableDimsToUnroll, true);
- auto firstScalableDim = it - scalableDimsToUnroll.begin();
+ auto inputScalableVecDimsToUnroll =
+ vType.getScalableDims().drop_back(targetRank);
+ auto it = llvm::find(inputScalableVecDimsToUnroll, true);
+ auto firstScalableDim = it - inputScalableVecDimsToUnroll.begin();
if (firstScalableDim == 0)
return {};
// All scalable dimensions should be removed now.
- scalableDimsToUnroll = scalableDimsToUnroll.slice(0, firstScalableDim);
- assert(!llvm::is_contained(scalableDimsToUnroll, true) &&
+ inputScalableVecDimsToUnroll =
+ inputScalableVecDimsToUnroll.slice(0, firstScalableDim);
+ assert(!llvm::is_contained(inputScalableVecDimsToUnroll, true) &&
"unexpected leading scalable dimension");
// Create an unroll iterator for leading dimensions.
shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim);
@@ -319,15 +321,15 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
ArrayRef<int64_t> inputVectorSizes,
Value padValue,
bool useInBoundsInsteadOfMasking,
- ArrayRef<bool> scalableDims) {
+ ArrayRef<bool> inputScalableVecDims) {
assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
"invalid input vector sizes");
auto sourceShapedType = cast<ShapedType>(source.getType());
auto sourceShape = sourceShapedType.getShape();
assert(sourceShape.size() == inputVectorSizes.size() &&
"expected same ranks.");
- auto vectorType =
- VectorType::get(inputVectorSizes, padValue.getType(), scalableDims);
+ auto vectorType = VectorType::get(inputVectorSizes, padValue.getType(),
+ inputScalableVecDims);
assert(padValue.getType() == sourceShapedType.getElementType() &&
"expected same pad element type to match source element type");
int64_t readRank = inputVectorSizes.size();
@@ -356,8 +358,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
? memref::getMixedSizes(builder, loc, source)
: tensor::getMixedSizes(builder, loc, source);
- auto maskType =
- VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims);
+ auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(),
+ inputScalableVecDims);
Value mask =
vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims);
return mlir::vector::maskOperation(builder, transferReadOp, mask)
@@ -385,9 +387,34 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
staticSize <= inputSize;
})) {
LDBG() << "Input vector sizes must be greater than or equal to iteration "
- "space "
- "static sizes";
+ "space static sizes";
return failure();
}
return success();
}
+
+LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
+ vector::UnrollVectorOpFn unrollFn) {
+ assert(op->getNumResults() == 1 && "expected single result");
+ assert(isa<VectorType>(op->getResult(0).getType()) && "expected vector type");
+ VectorType resultTy = cast<VectorType>(op->getResult(0).getType());
+ if (resultTy.getRank() < 2)
+ return rewriter.notifyMatchFailure(op, "already 1-D");
+
+ // Unrolling doesn't take vscale into account. Pattern is disabled for
+ // vectors with leading scalable dim(s).
+ if (resultTy.getScalableDims().front())
+ return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
+
+ Location loc = op->getLoc();
+ Value result = ub::PoisonOp::create(rewriter, loc, resultTy);
+ VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
+
+ for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
+ Value subVector = unrollFn(rewriter, loc, subTy, i);
+ result = vector::InsertOp::create(rewriter, loc, subVector, result, i);
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+}
diff --git a/mlir/lib/Dialect/WasmSSA/CMakeLists.txt b/mlir/lib/Dialect/WasmSSA/CMakeLists.txt
new file mode 100644
index 0000000..f33061b2
--- /dev/null
+++ b/mlir/lib/Dialect/WasmSSA/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/lib/Dialect/WasmSSA/IR/CMakeLists.txt b/mlir/lib/Dialect/WasmSSA/IR/CMakeLists.txt
new file mode 100644
index 0000000..9fc2d7b
--- /dev/null
+++ b/mlir/lib/Dialect/WasmSSA/IR/CMakeLists.txt
@@ -0,0 +1,24 @@
+add_mlir_dialect_library(MLIRWasmSSADialect
+ WasmSSAOps.cpp
+ WasmSSADialect.cpp
+ WasmSSAInterfaces.cpp
+ WasmSSATypes.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/WasmSSA
+
+ DEPENDS
+ MLIRWasmSSAOpsIncGen
+ MLIRWasmSSAInterfacesIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRCastInterfaces
+ MLIRDataLayoutInterfaces
+ MLIRDialect
+ MLIRInferTypeOpInterface
+ MLIRIR
+ MLIRSupport
+
+ PRIVATE
+ MLIRFunctionInterfaces
+ )
diff --git a/mlir/lib/Dialect/WasmSSA/IR/WasmSSADialect.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSADialect.cpp
new file mode 100644
index 0000000..98c3555
--- /dev/null
+++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSADialect.cpp
@@ -0,0 +1,38 @@
+//===- WebAssemblyDialect.cpp - MLIR WebAssembly dialect implementation ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+using namespace mlir::wasmssa;
+
+#include "mlir/Dialect/WasmSSA/IR/WasmSSAOpsDialect.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// TableGen'd types definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/WasmSSA/IR/WasmSSAOpsTypes.cpp.inc"
+
+void wasmssa::WasmSSADialect::initialize() {
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/WasmSSA/IR/WasmSSAOps.cpp.inc"
+ >();
+ addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/WasmSSA/IR/WasmSSAOpsTypes.cpp.inc"
+ >();
+}
diff --git a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp
new file mode 100644
index 0000000..61cdf6f
--- /dev/null
+++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp
@@ -0,0 +1,69 @@
+//===- WasmSSAInterfaces.cpp - WasmSSA Interfaces -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines op interfaces for the WasmSSA dialect in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/LogicalResult.h"
+
+namespace mlir::wasmssa {
+#include "mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.cpp.inc"
+
+namespace detail {
+LogicalResult verifyLabelBranchingOpInterface(Operation *op) {
+ auto branchInterface = dyn_cast<LabelBranchingOpInterface>(op);
+ llvm::FailureOr<LabelLevelOpInterface> res =
+ LabelBranchingOpInterface::getTargetOpFromBlock(
+ op->getBlock(), branchInterface.getExitLevel());
+ return res;
+}
+
+LogicalResult verifyConstantExpressionInterface(Operation *op) {
+ Region &initializerRegion = op->getRegion(0);
+ WalkResult resultState =
+ initializerRegion.walk([&](Operation *currentOp) -> WalkResult {
+ if (isa<ReturnOp>(currentOp) ||
+ currentOp->hasTrait<ConstantExprOpTrait>())
+ return WalkResult::advance();
+ op->emitError("expected a constant initializer for this operator, got ")
+ << currentOp;
+ return WalkResult::interrupt();
+ });
+ return success(!resultState.wasInterrupted());
+}
+
+LogicalResult verifyLabelLevelInterface(Operation *op) {
+ Block *target = cast<LabelLevelOpInterface>(op).getLabelTarget();
+ Region *targetRegion = target->getParent();
+ if (targetRegion != op->getParentRegion() &&
+ targetRegion->getParentOp() != op)
+ return op->emitError("target should be a block defined in same level than "
+ "operation or in its region.");
+ return success();
+}
+} // namespace detail
+
+llvm::FailureOr<LabelLevelOpInterface>
+LabelBranchingOpInterface::getTargetOpFromBlock(::mlir::Block *block,
+ uint32_t breakLevel) {
+ LabelLevelOpInterface res{};
+ for (size_t curLevel{0}; curLevel <= breakLevel; curLevel++) {
+ res = dyn_cast_or_null<LabelLevelOpInterface>(block->getParentOp());
+ if (!res)
+ return failure();
+ block = res->getBlock();
+ }
+ return res;
+}
+} // namespace mlir::wasmssa
diff --git a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
new file mode 100644
index 0000000..89b62a2
--- /dev/null
+++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
@@ -0,0 +1,494 @@
+//===- WasmSSAOps.cpp - WasmSSA dialect operations ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h"
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Region.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/FunctionImplementation.h"
+#include "llvm/Support/Casting.h"
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
+using namespace mlir;
+namespace {
+ParseResult parseElseRegion(OpAsmParser &opParser, Region &elseRegion) {
+ std::string keyword;
+ std::ignore = opParser.parseOptionalKeywordOrString(&keyword);
+ if (keyword == "else")
+ return opParser.parseRegion(elseRegion);
+ return ParseResult::success();
+}
+
+void printElseRegion(OpAsmPrinter &opPrinter, Operation *op,
+ Region &elseRegion) {
+ if (elseRegion.empty())
+ return;
+ opPrinter.printKeywordOrString("else ");
+ opPrinter.printRegion(elseRegion);
+}
+
+ParseResult parseWasmVisibility(OpAsmParser &opParser, StringAttr &visibility) {
+ std::string keyword;
+ auto initLocation = opParser.getCurrentLocation();
+ std::ignore = opParser.parseOptionalKeywordOrString(&keyword);
+ if (keyword == "nested" or keyword == "") {
+ visibility = StringAttr::get(opParser.getContext(), "nested");
+ return ParseResult::success();
+ }
+
+ if (keyword == "public" || keyword == "private") {
+ visibility = StringAttr::get(opParser.getContext(), keyword);
+ return ParseResult::success();
+ }
+ opParser.emitError(initLocation, "expecting symbol visibility");
+ return ParseResult::failure();
+}
+
+void printWasmVisibility(OpAsmPrinter &opPrinter, Operation *op,
+ Attribute visibility) {
+ opPrinter.printKeywordOrString(cast<StringAttr>(visibility).strref());
+}
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/WasmSSA/IR/WasmSSAOps.cpp.inc"
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/Types.h"
+#include "llvm/Support/LogicalResult.h"
+
+using namespace wasmssa;
+
+namespace {
+inline LogicalResult
+inferTeeGetResType(ValueRange operands,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (operands.empty())
+ return failure();
+ auto opType = dyn_cast<LocalRefType>(operands.front().getType());
+ if (!opType)
+ return failure();
+ inferredReturnTypes.push_back(opType.getElementType());
+ return success();
+}
+
+ParseResult parseImportOp(OpAsmParser &parser, OperationState &result) {
+ std::string importName;
+ auto *ctx = parser.getContext();
+ ParseResult res = parser.parseString(&importName);
+ result.addAttribute("importName", StringAttr::get(ctx, importName));
+
+ std::string fromStr;
+ res = parser.parseKeywordOrString(&fromStr);
+ if (failed(res) || fromStr != "from")
+ return failure();
+
+ std::string moduleName;
+ res = parser.parseString(&moduleName);
+ if (failed(res))
+ return failure();
+ result.addAttribute("moduleName", StringAttr::get(ctx, moduleName));
+
+ std::string asStr;
+ res = parser.parseKeywordOrString(&asStr);
+ if (failed(res) || asStr != "as")
+ return failure();
+
+ StringAttr symbolName;
+ res = parser.parseSymbolName(symbolName, SymbolTable::getSymbolAttrName(),
+ result.attributes);
+ return res;
+}
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// BlockOp
+//===----------------------------------------------------------------------===//
+
+Block *BlockOp::getLabelTarget() { return getTarget(); }
+
+//===----------------------------------------------------------------------===//
+// BlockReturnOp
+//===----------------------------------------------------------------------===//
+
+std::size_t BlockReturnOp::getExitLevel() { return 0; }
+
+Block *BlockReturnOp::getTarget() {
+ return cast<LabelBranchingOpInterface>(getOperation())
+ .getTargetOp()
+ .getOperation()
+ ->getSuccessor(0);
+}
+
+//===----------------------------------------------------------------------===//
+// ExtendLowBitsSOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ExtendLowBitsSOp::verify() {
+ auto bitsToTake = getBitsToTake().getValue().getLimitedValue();
+ if (bitsToTake != 32 && bitsToTake != 16 && bitsToTake != 8)
+ return emitError("extend op can only take 8, 16 or 32 bits. Got ")
+ << bitsToTake;
+
+ if (bitsToTake >= getInput().getType().getIntOrFloatBitWidth())
+ return emitError("trying to extend the ")
+ << bitsToTake << " low bits from a " << getInput().getType()
+ << " value is illegal";
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// FuncOp
+//===----------------------------------------------------------------------===//
+
+Block *FuncOp::addEntryBlock() {
+ if (!getBody().empty()) {
+ emitError("adding entry block to a FuncOp which already has one");
+ return &getBody().front();
+ }
+ Block &block = getBody().emplaceBlock();
+ for (auto argType : getFunctionType().getInputs())
+ block.addArgument(LocalRefType::get(argType), getLoc());
+ return &block;
+}
+
+void FuncOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ StringRef symbol, FunctionType funcType) {
+ FuncOp::build(odsBuilder, odsState, symbol, funcType, {}, {}, "nested");
+}
+
+ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
+ auto buildFuncType = [&parser](Builder &builder, ArrayRef<Type> argTypes,
+ ArrayRef<Type> results,
+ function_interface_impl::VariadicFlag,
+ std::string &) {
+ SmallVector<Type> argTypesWithoutLocal{};
+ argTypesWithoutLocal.reserve(argTypes.size());
+ llvm::for_each(argTypes, [&parser, &argTypesWithoutLocal](Type argType) {
+ auto refType = dyn_cast<LocalRefType>(argType);
+ auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
+ if (!refType) {
+ mlir::emitError(loc, "invalid type for wasm.func argument. Expecting "
+ "!wasm<local T>, got ")
+ << argType;
+ return;
+ }
+ argTypesWithoutLocal.push_back(refType.getElementType());
+ });
+
+ return builder.getFunctionType(argTypesWithoutLocal, results);
+ };
+
+ return function_interface_impl::parseFunctionOp(
+ parser, result, /*allowVariadic=*/false,
+ getFunctionTypeAttrName(result.name), buildFuncType,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+}
+
+LogicalResult FuncOp::verifyBody() {
+ if (getBody().empty())
+ return success();
+ Block &entry = getBody().front();
+ if (entry.getNumArguments() != getFunctionType().getNumInputs())
+ return emitError("entry block should have same number of arguments as "
+ "function type. Function type has ")
+ << getFunctionType().getNumInputs() << ", entry block has "
+ << entry.getNumArguments();
+
+ for (auto [argNo, funcSignatureType, blockType] : llvm::enumerate(
+ getFunctionType().getInputs(), entry.getArgumentTypes())) {
+ auto blockLocalRefType = dyn_cast<LocalRefType>(blockType);
+ if (!blockLocalRefType)
+ return emitError("entry block argument type should be LocalRefType, got ")
+ << blockType << " for block argument " << argNo;
+ if (blockLocalRefType.getElementType() != funcSignatureType)
+ return emitError("func argument type #")
+ << argNo << "(" << funcSignatureType
+ << ") doesn't match entry block referenced type ("
+ << blockLocalRefType.getElementType() << ")";
+ }
+ return success();
+}
+
+void FuncOp::print(OpAsmPrinter &p) {
+ function_interface_impl::printFunctionOp(
+ p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+ getArgAttrsAttrName(), getResAttrsAttrName());
+}
+
+//===----------------------------------------------------------------------===//
+// FuncImportOp
+//===----------------------------------------------------------------------===//
+
+void FuncImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ StringRef symbol, StringRef moduleName,
+ StringRef importName, FunctionType type) {
+ FuncImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
+ type, {}, {}, odsBuilder.getStringAttr("nested"));
+}
+
+//===----------------------------------------------------------------------===//
+// GlobalOp
+//===----------------------------------------------------------------------===//
+
+void GlobalOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ StringRef symbol, Type type, bool isMutable) {
+ GlobalOp::build(odsBuilder, odsState, symbol, type, isMutable,
+ odsBuilder.getStringAttr("nested"));
+}
+
+// Custom formats
+ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
+ StringAttr symbolName;
+ Type globalType;
+ auto *ctx = parser.getContext();
+ ParseResult res = parser.parseSymbolName(
+ symbolName, SymbolTable::getSymbolAttrName(), result.attributes);
+
+ res = parser.parseType(globalType);
+ result.addAttribute(getTypeAttrName(result.name), TypeAttr::get(globalType));
+ std::string mutableString;
+ res = parser.parseOptionalKeywordOrString(&mutableString);
+ if (res.succeeded() && mutableString == "mutable")
+ result.addAttribute("isMutable", UnitAttr::get(ctx));
+ std::string visibilityString;
+ res = parser.parseOptionalKeywordOrString(&visibilityString);
+ if (res.succeeded())
+ result.addAttribute("sym_visibility",
+ StringAttr::get(ctx, visibilityString));
+ res = parser.parseColon();
+ Region *globalInitRegion = result.addRegion();
+ res = parser.parseRegion(*globalInitRegion);
+ return res;
+}
+
+void GlobalOp::print(OpAsmPrinter &printer) {
+ printer << " @" << getSymName().str() << " " << getType();
+ if (getIsMutable())
+ printer << " mutable";
+ if (auto vis = getSymVisibility())
+ printer << " " << *vis;
+ printer << " :";
+ Region &body = getRegion();
+ if (!body.empty()) {
+ printer << ' ';
+ printer.printRegion(body, /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/true);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// GlobalGetOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+GlobalGetOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ // If the parent requires a constant context, verify that global.get is a
+ // constant as defined per the wasm standard.
+ if (!this->getOperation()
+ ->getParentWithTrait<ConstantExpressionInitializerOpTrait>())
+ return success();
+ Operation *symTabOp = SymbolTable::getNearestSymbolTable(*this);
+ StringRef referencedSymbol = getGlobal();
+ Operation *definitionOp = symbolTable.lookupSymbolIn(
+ symTabOp, StringAttr::get(this->getContext(), referencedSymbol));
+ if (!definitionOp)
+ return emitError() << "symbol @" << referencedSymbol << " is undefined";
+ auto definitionImport = dyn_cast<GlobalImportOp>(definitionOp);
+ if (!definitionImport || definitionImport.getIsMutable()) {
+ return emitError("global.get op is considered constant if it's referring "
+ "to a import.global symbol marked non-mutable");
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// GlobalImportOp
+//===----------------------------------------------------------------------===//
+
+void GlobalImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ StringRef symbol, StringRef moduleName,
+ StringRef importName, Type type, bool isMutable) {
+ GlobalImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
+ type, isMutable, odsBuilder.getStringAttr("nested"));
+}
+
+ParseResult GlobalImportOp::parse(OpAsmParser &parser, OperationState &result) {
+ auto *ctx = parser.getContext();
+ ParseResult res = parseImportOp(parser, result);
+ if (res.failed())
+ return failure();
+ std::string mutableOrSymVisString;
+ res = parser.parseOptionalKeywordOrString(&mutableOrSymVisString);
+ if (res.succeeded() && mutableOrSymVisString == "mutable") {
+ result.addAttribute("isMutable", UnitAttr::get(ctx));
+ res = parser.parseOptionalKeywordOrString(&mutableOrSymVisString);
+ }
+
+ if (res.succeeded())
+ result.addAttribute("sym_visibility",
+ StringAttr::get(ctx, mutableOrSymVisString));
+ res = parser.parseColon();
+
+ Type importedType;
+ res = parser.parseType(importedType);
+ if (res.succeeded())
+ result.addAttribute(getTypeAttrName(result.name),
+ TypeAttr::get(importedType));
+ return res;
+}
+
+void GlobalImportOp::print(OpAsmPrinter &printer) {
+ printer << " \"" << getImportName() << "\" from \"" << getModuleName()
+ << "\" as @" << getSymName();
+ if (getIsMutable())
+ printer << " mutable";
+ if (auto vis = getSymVisibility())
+ printer << " " << *vis;
+ printer << " : " << getType();
+}
+
+//===----------------------------------------------------------------------===//
+// IfOp
+//===----------------------------------------------------------------------===//
+
+Block *IfOp::getLabelTarget() { return getTarget(); }
+
+//===----------------------------------------------------------------------===//
+// LocalOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult LocalOp::inferReturnTypes(
+ MLIRContext *context, ::std::optional<Location> location,
+ ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
+ RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
+ LocalOp::GenericAdaptor<ValueRange> adaptor{operands, attributes, properties,
+ regions};
+ auto type = adaptor.getTypeAttr();
+ if (!type)
+ return failure();
+ auto resType = LocalRefType::get(type.getContext(), type.getValue());
+ inferredReturnTypes.push_back(resType);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// LocalGetOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult LocalGetOp::inferReturnTypes(
+ MLIRContext *context, ::std::optional<Location> location,
+ ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
+ RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
+ return inferTeeGetResType(operands, inferredReturnTypes);
+}
+
+//===----------------------------------------------------------------------===//
+// LocalSetOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult LocalSetOp::verify() {
+ if (getLocalVar().getType().getElementType() != getValue().getType())
+ return emitError("input type and result type of local.set do not match");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// LocalTeeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult LocalTeeOp::inferReturnTypes(
+ MLIRContext *context, ::std::optional<Location> location,
+ ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
+ RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
+ return inferTeeGetResType(operands, inferredReturnTypes);
+}
+
+LogicalResult LocalTeeOp::verify() {
+ if (getLocalVar().getType().getElementType() != getValue().getType() ||
+ getValue().getType() != getResult().getType())
+ return emitError("input type and output type of local.tee do not match");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// LoopOp
+//===----------------------------------------------------------------------===//
+
+Block *LoopOp::getLabelTarget() { return &getBody().front(); }
+
+//===----------------------------------------------------------------------===//
+// MemOp
+//===----------------------------------------------------------------------===//
+
+void MemOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ StringRef symbol, LimitType limit) {
+ MemOp::build(odsBuilder, odsState, symbol, limit,
+ odsBuilder.getStringAttr("nested"));
+}
+
+//===----------------------------------------------------------------------===//
+// MemImportOp
+//===----------------------------------------------------------------------===//
+
+void MemImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ StringRef symbol, StringRef moduleName,
+ StringRef importName, LimitType limits) {
+ MemImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
+ limits, odsBuilder.getStringAttr("nested"));
+}
+
+//===----------------------------------------------------------------------===//
+// ReinterpretOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ReinterpretOp::verify() {
+ auto inT = getInput().getType();
+ auto resT = getResult().getType();
+ if (inT == resT)
+ return emitError("reinterpret input and output type should be distinct");
+ if (inT.getIntOrFloatBitWidth() != resT.getIntOrFloatBitWidth())
+ return emitError() << "input type (" << inT << ") and output type (" << resT
+ << ") have incompatible bit widths";
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ReturnOp
+//===----------------------------------------------------------------------===//
+
+void ReturnOp::build(OpBuilder &odsBuilder, OperationState &odsState) {}
+
+//===----------------------------------------------------------------------===//
+// TableOp
+//===----------------------------------------------------------------------===//
+
+void TableOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ StringRef symbol, TableType type) {
+ TableOp::build(odsBuilder, odsState, symbol, type,
+ odsBuilder.getStringAttr("nested"));
+}
+
+//===----------------------------------------------------------------------===//
+// TableImportOp
+//===----------------------------------------------------------------------===//
+
+void TableImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ StringRef symbol, StringRef moduleName,
+ StringRef importName, TableType type) {
+ TableImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
+ type, odsBuilder.getStringAttr("nested"));
+}
diff --git a/mlir/lib/Dialect/WasmSSA/IR/WasmSSATypes.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSATypes.cpp
new file mode 100644
index 0000000..bee8c81
--- /dev/null
+++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSATypes.cpp
@@ -0,0 +1,18 @@
+//===- WasmSSAOps.cpp - WasmSSA dialect operations ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/Types.h"
+#include "llvm/Support/LogicalResult.h"
+
+#include <optional>
+
+namespace mlir::wasmssa {
+#include "mlir/Dialect/WasmSSA/IR/WasmSSATypeConstraints.cpp.inc"
+} // namespace mlir::wasmssa
diff --git a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
index 242a97c..7869a28 100644
--- a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
@@ -7,13 +7,18 @@ add_mlir_dialect_library(MLIRXeGPUDialect
DEPENDS
MLIRXeGPUIncGen
+ MLIRXeGPUAttrInterfaceIncGen
MLIRXeGPUAttrsIncGen
MLIRXeGPUEnumsIncGen
LINK_LIBS PUBLIC
MLIRArithDialect
+ MLIRIndexDialect
+ MLIRAffineUtils
MLIRArithUtils
MLIRDialectUtils
+ MLIRGPUDialect
+ MLIRXeVMDialect
MLIRIR
MLIRViewLikeInterface
MLIRVectorDialect
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 3c0ca114..7f3be7f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -6,12 +6,16 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
using std::optional;
@@ -33,10 +37,61 @@ void XeGPUDialect::initialize() {
>();
}
+/// Generates instructions to compute offsets for a subgroup identified by
+/// its multidimensional indices (sgId), using the specified subgroup layout
+/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
+/// dimensions (sizePerWg).
+static SmallVector<SmallVector<Value>>
+genOffsetsComputingInsts(OpBuilder &builder, Location loc,
+ SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout,
+ ArrayRef<int64_t> sizePerSg,
+ ArrayRef<int64_t> sizePerWg) {
+
+ SmallVector<SmallVector<Value>> offsets;
+
+ // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i]
+ SmallVector<Value> localOffsets = llvm::map_to_vector(
+ llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value {
+ return builder.createOrFold<index::MulOp>(
+ loc, std::get<0>(t),
+ builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
+ });
+
+ // distUnit[i] is the minimum value between sizePerWg[i] and
+ // sgLayout[i] * sizePerSg[i]
+ SmallVector<int64_t> distUnit = llvm::map_to_vector(
+ llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)),
+ [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
+
+ for (SmallVector<int64_t> unitOffs :
+ StaticTileOffsetRange(sizePerWg, distUnit)) {
+ SmallVector<Value> base =
+ llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
+ return arith::ConstantIndexOp::create(builder, loc, d);
+ });
+
+ SmallVector<Value> adds = llvm::map_to_vector(
+ llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value {
+ return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t),
+ std::get<1>(t));
+ });
+
+ SmallVector<Value> mods = llvm::map_to_vector(
+ llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value {
+ return builder.createOrFold<index::RemUOp>(
+ loc, std::get<0>(t),
+ arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
+ });
+
+ offsets.push_back(mods);
+ }
+ return offsets;
+}
+
// Checks if the given shape can be evenly distributed based on the layout
// and data factors provided by the LayoutAttr.
bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
- xegpu::LayoutAttr attr) {
+ xegpu::DistributeLayoutAttr attr) {
assert(attr && "Layout attribute is missing.");
// Checks whether the given shape can be evenly distributed using the
@@ -49,52 +104,51 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
// smaller than `layout[i] * data[i]`, allowing multiple compute units to
// share the data.
auto tryDistribute = [&](llvm::ArrayRef<int64_t> shape,
- DenseI32ArrayAttr layout, DenseI32ArrayAttr data,
+ SmallVector<int64_t> layout,
+ SmallVector<int64_t> data,
bool rr = true) -> optional<SmallVector<int64_t>> {
llvm::SmallVector<int64_t> newShape(shape);
- if (layout) {
- auto vec = llvm::to_vector_of<int64_t>(layout.asArrayRef());
- if (vec.size() != shape.size())
+ if (layout.size()) {
+ if (layout.size() != shape.size())
return std::nullopt;
- auto ratio = computeShapeRatio(shape, vec);
+ auto ratio = computeShapeRatio(shape, layout);
if (!ratio.has_value())
return std::nullopt;
newShape = ratio.value();
}
- if (data) {
- auto vec = llvm::to_vector_of<int64_t>(data.asArrayRef());
- if (vec.size() != shape.size())
+ if (data.size()) {
+ if (data.size() != shape.size())
return std::nullopt;
- auto ratio = computeShapeRatio(newShape, vec);
+ auto ratio = computeShapeRatio(newShape, data);
if (!ratio.has_value() && rr)
- ratio = computeShapeRatio(vec, newShape);
+ ratio = computeShapeRatio(data, newShape);
if (!ratio.has_value())
return std::nullopt;
// if data is not null, we always return it for next phase.
- newShape = vec;
+ newShape = data;
}
return newShape;
};
// check the sgLayout and sgData
auto maybeSgShape =
- tryDistribute(shape, attr.getSgLayout(), attr.getSgData());
+ tryDistribute(shape, attr.getSgLayoutAsInt(), attr.getSgDataAsInt());
if (!maybeSgShape)
return false;
auto sgShape = maybeSgShape.value();
// check InstData, it neither have layout nor need round-robin
auto maybeInstShape =
- tryDistribute(sgShape, nullptr, attr.getInstData(), false);
+ tryDistribute(sgShape, {}, attr.getInstDataAsInt(), false);
if (!maybeInstShape)
return false;
auto instShape = maybeInstShape.value();
// check LaneLayout and LaneData
- auto maybeLaneShape =
- tryDistribute(instShape, attr.getLaneLayout(), attr.getLaneData(), false);
+ auto maybeLaneShape = tryDistribute(instShape, attr.getLaneLayoutAsInt(),
+ attr.getLaneDataAsInt(), false);
return maybeLaneShape.has_value();
}
@@ -211,6 +265,150 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
return success();
}
+FailureOr<SmallVector<Value>>
+LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
+ Value linearId) {
+ // delinearizeSubgroupId is only available for
+ // workgroup-level layout attribute
+ if (!isForWorkgroup())
+ return failure();
+
+ // TODO: handle order attribute
+ auto hasDefaultOrder = [&]() {
+ DenseI32ArrayAttr order = getOrder();
+ return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>(
+ llvm::reverse(order.asArrayRef())));
+ };
+ if (!hasDefaultOrder())
+ return mlir::emitError(loc, "order attribute is currently not supported.");
+
+ auto dims = llvm::map_to_vector(getSgLayoutAsInt(), [&](int64_t d) -> Value {
+ return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
+ });
+
+ return affine::delinearizeIndex(builder, loc, linearId, dims);
+}
+
+/// Implements DistributeLayoutAttr::getOffsets to generate
+/// instructions for computing multi-dimensional offsets when distributed by
+/// LayoutAttr.
+FailureOr<SmallVector<SmallVector<Value>>>
+LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
+ ArrayRef<int64_t> shape) {
+ if (!isForWorkgroup())
+ return failure();
+
+ SmallVector<int64_t> sgLayout = getSgLayoutAsInt();
+ SmallVector<int64_t> sgShape = getSgDataAsInt();
+ if (sgShape.empty()) {
+ if (auto derivedShape = computeShapeRatio(shape, sgLayout))
+ sgShape = derivedShape.value();
+ else
+ return failure();
+ }
+
+ // delinearize Ids
+ auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+ if (failed(maybeIds))
+ return failure();
+ SmallVector<Value> sgIds = *maybeIds;
+
+ return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
+ shape);
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_SliceAttr
+//===----------------------------------------------------------------------===//
+LogicalResult
+SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
+ xegpu::DistributeLayoutAttr parent, DenseI64ArrayAttr dims) {
+ if (!parent || !dims)
+ return emitError() << "expected parent layout and dims attribute";
+
+ int64_t rank = parent.getRank();
+
+ // check every element in dims is unique and smaller than rank
+ llvm::SmallDenseSet<int64_t> seen;
+ for (int64_t dim : dims.asArrayRef()) {
+ if (dim < 0 || dim >= rank)
+ return emitError() << "invalid dim (" << dim << ") in slice attribute.";
+ if (!seen.insert(dim).second)
+ return emitError() << "repeated dim (" << dim << ") in slice attribute.";
+ }
+ return success();
+}
+
+SliceAttr SliceAttr::flatten() const {
+ xegpu::DistributeLayoutAttr parent = getParent();
+ SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
+
+ while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
+ parent = sliceAttr.getParent();
+ slicedDims.push_back(sliceAttr.getDims());
+ }
+
+ auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(parent);
+ SmallVector<int64_t> indices =
+ llvm::to_vector(llvm::seq<int64_t>(0, layoutAttr.getRank()));
+
+ // get remaining dims (flattend) by applying slice ops with all slicedDims
+ SmallVector<int64_t> remainingDims(indices);
+ for (auto dim : llvm::reverse(slicedDims))
+ remainingDims = XeGPUDialect::slice(llvm::ArrayRef<int64_t>(remainingDims),
+ dim.asArrayRef());
+
+ // get flattend sliced dims by applying slice ops with the remaining dims
+ SmallVector<int64_t> flattendDims = XeGPUDialect::slice(
+ llvm::ArrayRef<int64_t>(indices), llvm::ArrayRef<int64_t>(remainingDims));
+
+ return xegpu::SliceAttr::get(
+ getContext(), layoutAttr,
+ DenseI64ArrayAttr::get(getContext(), flattendDims));
+}
+
+FailureOr<SmallVector<Value>>
+SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
+ Value linearId) {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ return parent.delinearizeSubgroupId(builder, loc, linearId);
+}
+
+/// Implements DistributeLayoutAttr::getOffsets to generate
+/// instructions for computing multi-dimensional offsets when distributed by
+/// SliceAttr.
+FailureOr<SmallVector<SmallVector<Value>>>
+SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
+ ArrayRef<int64_t> shape) {
+ assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
+ if (!isForWorkgroup())
+ return failure();
+
+ SmallVector<int64_t> sgLayout = getSgLayoutAsInt();
+ SmallVector<int64_t> sgShape = getSgDataAsInt();
+ if (sgShape.empty()) {
+ if (auto derivedShape = computeShapeRatio(shape, sgLayout))
+ sgShape = derivedShape.value();
+ else
+ return failure();
+ }
+
+ // delinearize Ids
+ auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+ if (failed(maybeIds))
+ return failure();
+
+ // The effective sgIds for offsets computing correspond
+ // to the dims that are not sliced.
+ ArrayRef<int64_t> dims = flatten().getDims().asArrayRef();
+ SmallVector<Value> sgIds =
+ XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
+
+ return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
+ shape);
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_RangeAttr
//===----------------------------------------------------------------------===//
@@ -230,7 +428,7 @@ RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
// XeGPU_TensorDescType
//===----------------------------------------------------------------------===//
-mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
+mlir::Type TensorDescType::parse(AsmParser &parser) {
llvm::SmallVector<int64_t> shape;
mlir::Type elementType;
mlir::FailureOr<mlir::Attribute> encoding;
@@ -280,7 +478,7 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
layout.value_or(mlir::Attribute()));
}
-void TensorDescType::print(::mlir::AsmPrinter &printer) const {
+void TensorDescType::print(AsmPrinter &printer) const {
printer << "<";
auto shape = getShape();
@@ -325,10 +523,10 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
return Base::get(context, shape, elementType, attr, layout);
}
-LogicalResult TensorDescType::verify(
- llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
- llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
- mlir::Attribute encoding, mlir::Attribute layout) {
+LogicalResult
+TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
+ llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
+ mlir::Attribute encoding, mlir::Attribute layout) {
size_t rank = shape.size();
if (rank == 0)
@@ -394,6 +592,119 @@ LogicalResult TensorDescType::verify(
return success();
}
+//===----------------------------------------------------------------------===//
+// XeGPU_MemDescType
+//===----------------------------------------------------------------------===//
+mlir::Type MemDescType::parse(AsmParser &parser) {
+ llvm::SmallVector<int64_t> shape;
+ mlir::Type elementType;
+ mlir::FailureOr<MemLayoutAttr> layout;
+
+ // Parse literal '<'
+ if (parser.parseLess())
+ return {};
+
+ auto shapeLoc = parser.getCurrentLocation();
+ if (mlir::failed(parser.parseDimensionList(shape, false, true))) {
+ parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
+ return {};
+ }
+
+ auto elemTypeLoc = parser.getCurrentLocation();
+ if (mlir::failed(parser.parseType(elementType))) {
+ parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
+ return {};
+ }
+
+ // parse optional attributes
+ if (mlir::succeeded(parser.parseOptionalComma())) {
+ MemLayoutAttr attr;
+ ParseResult res = parser.parseAttribute(attr);
+ if (mlir::failed(res))
+ return {};
+ layout = attr;
+ }
+
+ // Parse literal '>'
+ if (parser.parseGreater())
+ return {};
+
+ MLIRContext *ctxt = parser.getContext();
+ return MemDescType::getChecked(
+ [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
+ elementType, layout.value_or(MemLayoutAttr()));
+}
+
+void MemDescType::print(AsmPrinter &printer) const {
+ printer << "<";
+
+ printer.printDimensionList(getShape());
+ printer << 'x';
+ printer << getElementType();
+
+ if (auto layout = getMemLayout())
+ printer << ", " << layout;
+
+ printer << ">";
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_MemDescType
+//===----------------------------------------------------------------------===//
+
+Attribute MemLayoutAttr::parse(AsmParser &parser, Type type) {
+
+ auto context = parser.getContext();
+ llvm::SMLoc loc = parser.getCurrentLocation();
+
+ llvm::SmallDenseSet<StringRef> seenKeys;
+ SmallVector<NamedAttribute> attributes;
+
+ auto parseElt = [&]() -> ParseResult {
+ StringRef nameId;
+ if (failed(parser.parseKeyword(&nameId)))
+ return parser.emitError(loc, "expected valid attribute name");
+
+ if (!seenKeys.insert(nameId).second)
+ return parser.emitError(loc, "duplicate key '")
+ << nameId << " in mem layout attribute";
+
+ if (failed(parser.parseEqual()))
+ return failure();
+
+ Attribute attr;
+ if (failed(parser.parseAttribute(attr)))
+ return failure();
+ attributes.emplace_back(nameId, attr);
+ return success();
+ };
+
+ // Parse literal '<'
+ if (parser.parseLess())
+ return {};
+
+ if (failed(parser.parseCommaSeparatedList(parseElt)))
+ return {};
+
+ // Parse literal '>'
+ if (parser.parseGreater())
+ return {};
+
+ return parser.getChecked<MemLayoutAttr>(
+ loc, context, DictionaryAttr::get(context, attributes));
+}
+
+void MemLayoutAttr::print(AsmPrinter &printer) const {
+ printer << "<";
+ ArrayRef<NamedAttribute> attrs = getAttrs().getValue();
+ for (size_t i = 0; i < attrs.size(); i++) {
+ printer << attrs[i].getName().str() << " = " << attrs[i].getValue();
+ if (i < attrs.size() - 1)
+ printer << ", ";
+ }
+ printer << ">";
+}
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 33450f3..aca6654 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
@@ -21,6 +23,17 @@
namespace mlir {
namespace xegpu {
+bool isSharedMemory(const MemRefType &memrefTy) {
+ Attribute attr = memrefTy.getMemorySpace();
+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
+ return intAttr.getInt() == 3;
+ if (auto memrefSpace = llvm::dyn_cast<MemorySpaceAttr>(attr))
+ return memrefSpace.getValue() == MemorySpace::SLM;
+ if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
+ return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
+ return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
+}
+
template <typename T>
static std::string makeString(T array, bool breakline = false) {
std::string buf;
@@ -45,13 +58,6 @@ static SmallVector<int64_t> getShapeOf(Type type) {
return shape;
}
-static int64_t getRankOf(Value val) {
- auto type = val.getType();
- if (auto ty = llvm::dyn_cast<ShapedType>(type))
- return ty.getRank();
- return 0;
-}
-
static bool isReadHintOrNone(const CachePolicyAttr &attr) {
if (!attr)
return true;
@@ -76,13 +82,18 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
if (!tdescTy.isScattered())
return emitError() << "Expects a scattered TensorDesc.";
- if (!valueTy)
- return emitError() << "Expecting a vector type result.";
+ auto chunkSize = tdescTy.getChunkSizeAsInt();
+ if (!valueTy) {
+ if (chunkSize > 1)
+ return emitError() << "Expecting chunk size == 1 for scalar result";
+ if (dyn_cast<VectorType>(maskTy))
+ return emitError() << "Expecting a vector type result.";
+ return success();
+ }
auto maskShape = getShapeOf(maskTy);
auto valueShape = getShapeOf(valueTy);
auto tdescShape = getShapeOf(tdescTy);
- auto chunkSize = tdescTy.getChunkSizeAsInt();
if (valueTy.getElementType() != tdescTy.getElementType())
return emitError()
@@ -111,25 +122,49 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
}
static LogicalResult
-isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy,
- int64_t chunkSize,
+isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
+ VectorType valueTy, int64_t chunkSize,
function_ref<InFlightDiagnostic()> emitError) {
- if (!valueTy)
- return emitError() << "Expecting a vector type result.";
+ auto maskVecTy = dyn_cast<VectorType>(maskTy);
+ auto offsetsVecTy = dyn_cast<VectorType>(offsetsTy);
+ if (!valueTy) {
+ if (chunkSize > 1)
+ return emitError() << "Expecting chunk size == 1 for scalar result";
+ if (maskVecTy || offsetsVecTy)
+ return emitError() << "Expecting scalar mask and offsets.";
+ else if (maskVecTy && offsetsVecTy)
+ return emitError() << "Expecting a vector type result.";
+ return success();
+ }
+ auto valueSize = valueTy.getNumElements();
+ // SIMT mode with scalar mask and offsets.
+ if (!maskVecTy && !offsetsVecTy) {
+ if (valueSize != chunkSize)
+ return emitError() << "value elements must match chunk size "
+ << chunkSize;
+ return success();
+ }
auto maskShape = getShapeOf(maskTy);
auto valueShape = getShapeOf(valueTy);
- // a valid shape for SIMT case
- if (valueTy.getRank() == 1) {
- if (valueTy.getNumElements() != chunkSize)
- return emitError() << "value elements must match chunk size " << chunkSize
- << " for SIMT code.";
- return success();
+ if (!maskVecTy)
+ return emitError() << "Expecting a vector type mask.";
+ int64_t maskSize = maskVecTy.getNumElements();
+
+ if (chunkSize > 1) {
+ if ((valueTy.getRank() == 1) && (valueSize != chunkSize))
+ return emitError() << "value elements must match chunk size "
+ << chunkSize;
+ } else {
+ if (valueSize != maskSize)
+ return emitError()
+ << "Mask should match value except the chunk size dim.";
}
-
llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
+ if (maskSize == 1)
+ return success();
if (chunkSize > 1)
expectedMaskShape.pop_back();
if (expectedMaskShape != maskShape)
@@ -156,41 +191,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
}
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
- Type tdesc, TypedValue<MemRefType> source,
+ Type tdesc, Value source,
llvm::ArrayRef<OpFoldResult> shape,
llvm::ArrayRef<OpFoldResult> strides) {
- assert(shape.size() && strides.size() && shape.size() == strides.size() &&
- "Shape and strides must be present and of equal size for ui64 "
- "initialization.");
+ Type srcTy = source.getType();
+ assert((isa<IntegerType, MemRefType>(srcTy)) &&
+ "Source has to be either int or memref.");
- llvm::SmallVector<int64_t> staticShape;
- llvm::SmallVector<int64_t> staticStrides;
llvm::SmallVector<Value> dynamicShape;
llvm::SmallVector<Value> dynamicStrides;
- dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
- dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-
- auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
- auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
-
- build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
- dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
- staticStridesAttr);
-}
-
-void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
- Type tdesc, TypedValue<IntegerType> source,
- llvm::ArrayRef<OpFoldResult> shape,
- llvm::ArrayRef<OpFoldResult> strides) {
- assert(shape.size() && strides.size() && shape.size() == strides.size() &&
- "Shape and strides must be present and of equal size for ui64 "
- "initialization.");
-
llvm::SmallVector<int64_t> staticShape;
llvm::SmallVector<int64_t> staticStrides;
- llvm::SmallVector<Value> dynamicShape;
- llvm::SmallVector<Value> dynamicStrides;
dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
@@ -198,6 +210,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
+ if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
+ auto memrefShape = memrefTy.getShape();
+ auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
+
+ // if shape and strides are from Memref, we don't need attributes for them
+ // to keep the IR print clean.
+ if (staticShape == memrefShape && staticStrides == memrefStrides) {
+ staticShapeAttr = DenseI64ArrayAttr();
+ staticStridesAttr = DenseI64ArrayAttr();
+ }
+ }
+
build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
staticStridesAttr);
@@ -265,8 +289,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
}
LogicalResult CreateNdDescOp::verify() {
- auto rank = (int64_t)getMixedOffsets().size();
- bool invalidRank = false;
+ size_t rank = getMixedSizes().size();
+ bool invalidRank = rank != getMixedStrides().size();
bool invalidElemTy = false;
// Memory space of created TensorDesc should match with the source.
@@ -280,31 +304,28 @@ LogicalResult CreateNdDescOp::verify() {
<< " Source: " << srcMemorySpace
<< ", TensorDesc: " << tdescMemorySpace;
+ if (size_t offsetRank = getMixedOffsets().size())
+ invalidRank |= (offsetRank != rank);
+
// check source type matches the rank if it is a memref.
// It also should have the same ElementType as TensorDesc.
- auto memrefTy = dyn_cast<MemRefType>(getSourceType());
- if (memrefTy) {
- invalidRank |= (memrefTy.getRank() != rank);
+ if (auto memrefTy = dyn_cast<MemRefType>(getSourceType()))
invalidElemTy |= memrefTy.getElementType() != getElementType();
- }
if (llvm::isa<IntegerType>(getSourceType())) {
// strides and shape must present for integer source.
if (getMixedStrides().empty() || getMixedSizes().empty())
- return emitOpError("Expecting strides and shape to be present for "
+ return emitOpError("expecting strides and shape to be present for "
"integer source.");
}
- // mismatches among shape, strides, and offsets are
- // already handeled by OffsetSizeAndStrideOpInterface.
- // So they are not check here.
if (invalidRank)
return emitOpError(
"Expecting the rank of shape, strides, offsets, and source (if source "
"is a memref) should match with each other.");
// check result TensorDesc rank
- if (getType().getRank() > rank)
+ if (getType().getRank() > (int64_t)rank)
return emitOpError(
"Expecting the TensorDesc rank is not greater than the "
"ranks of shape, strides, offsets or the memref source.");
@@ -360,13 +381,10 @@ ParseResult parseOptionalDynamicIndexList(
void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op,
OperandRange values,
DenseI64ArrayAttr integers) {
-
- if (!integers)
+ if (!integers || integers.empty())
return;
-
- return printDynamicIndexList(printer, op, values, integers,
- /*scalableFlags=*/{}, {},
- AsmParser::Delimiter::Square);
+ printDynamicIndexList(printer, op, values, integers,
+ /*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square);
}
//===----------------------------------------------------------------------===//
// XeGPU_PrefetchNdOp
@@ -381,6 +399,21 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
l1_hint, l2_hint, l3_hint);
}
+void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
+ Value tensorDesc, ArrayRef<OpFoldResult> offsets,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ SmallVector<Value> dynamicOffsets;
+ SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+
+ build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
+ l2_hint, l3_hint);
+}
+
LogicalResult PrefetchNdOp::verify() {
auto tdescTy = getTensorDescType();
if (tdescTy.isScattered())
@@ -423,6 +456,22 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
l3_hint);
}
+void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
+ Value tensorDesc, ArrayRef<OpFoldResult> offsets,
+ UnitAttr packed, DenseI64ArrayAttr transpose,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ SmallVector<Value> dynamicOffsets;
+ SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+
+ build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
+ packed, transpose, l1_hint, l2_hint, l3_hint);
+}
+
LogicalResult LoadNdOp::verify() {
auto tdescTy = getTensorDescType();
auto valueTy = getType();
@@ -529,6 +578,21 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
}
+void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
+ Value tensorDesc, ArrayRef<OpFoldResult> offsets,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ SmallVector<Value> dynamicOffsets;
+ SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+
+ build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
+ l1_hint, l2_hint, l3_hint);
+}
+
LogicalResult StoreNdOp::verify() {
auto dstTy = getTensorDescType(); // Tile
auto valTy = getValueType(); // Vector
@@ -635,10 +699,6 @@ void CreateDescOp::build(OpBuilder &builder, OperationState &state,
LogicalResult CreateDescOp::verify() {
auto tdescTy = getTensorDescType();
- if (getRankOf(getSource()) > 1)
- return emitOpError(
- "Expecting the source is a 1D memref or pointer (uint64_t).");
-
if (!tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
@@ -673,12 +733,14 @@ LogicalResult CreateDescOp::verify() {
LogicalResult PrefetchOp::verify() {
auto tdescTy = getTensorDescType();
- if (tdescTy && !tdescTy.isScattered())
- return emitOpError("Expects a scattered TensorDesc.\n");
+ if (!tdescTy && !getOffsets())
+ return emitOpError("Expects offsets.");
- if (!tdescTy && getRankOf(getSource()) > 1)
- return emitOpError(
- "Expecting the source is a 1D memref or pointer (uint64_t).");
+ if (tdescTy && getOffsets())
+ return emitOpError("offsets not allowed.");
+
+ if (tdescTy && !tdescTy.isScattered())
+ return emitOpError("Expects a scattered TensorDesc.");
if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -689,6 +751,13 @@ LogicalResult PrefetchOp::verify() {
if (!isReadHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
+ auto srcTy = getSourceType();
+ if (srcTy.isInteger() && !getOffsetAlignByteAttr())
+ return emitOpError("offset_align_byte is required with integer source.");
+
+ if (getOffsetAlignByteAttr() && !srcTy.isInteger())
+ return emitOpError("offset_align_byte only allowed with integer source.");
+
return success();
}
@@ -696,7 +765,8 @@ void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
- build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint);
+ build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint,
+ IntegerAttr{});
}
//===----------------------------------------------------------------------===//
@@ -707,13 +777,15 @@ LogicalResult LoadGatherOp::verify() {
auto maskTy = getMaskType();
auto valueTy = getValueType();
+ if (!tdescTy && !getOffsets())
+ return emitOpError("Expects offsets.");
+
+ if (tdescTy && getOffsets())
+ return emitOpError("offsets not allowed.");
+
if (tdescTy && !tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.");
- if (!tdescTy && getRankOf(getSource()) > 1)
- return emitOpError(
- "Expecting the source is a 1D memref or pointer (uint64_t).");
-
if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -730,10 +802,11 @@ LogicalResult LoadGatherOp::verify() {
uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
auto memTy = dyn_cast<MemRefType>(srcTy);
- if (memTy && (valueTy.getElementType() != memTy.getElementType()))
+ if (memTy && (getElementType() != memTy.getElementType()))
return emitError() << "Value should have the same element type as MemRef.";
- return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
+ auto offsetsTy = getOffsets().getType();
+ return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
[&]() { return emitOpError(); });
}
@@ -746,6 +819,22 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
l1_hint, l2_hint, l3_hint);
}
+void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
+ Type valueType, Value source,
+ ArrayRef<OpFoldResult> offsets, Value mask,
+ IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ auto loc = source.getLoc();
+ int64_t size = static_cast<int64_t>(offsets.size());
+ auto type = VectorType::get(size, builder.getIndexType());
+ auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+ auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+ build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
+ l2_hint, l3_hint);
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_StoreScatterOp
//===----------------------------------------------------------------------===//
@@ -754,12 +843,14 @@ LogicalResult StoreScatterOp::verify() {
auto maskTy = getMaskType();
auto valueTy = getValueType();
- if (tdescTy && !tdescTy.isScattered())
- return emitOpError("Expects a scattered TensorDesc.\n");
+ if (!tdescTy && !getOffsets())
+ return emitOpError("Expects offsets.");
- if (!tdescTy && getRankOf(getDest()) > 1)
- return emitOpError(
- "Expecting the dest is a 1D memref or pointer (uint64_t).");
+ if (tdescTy && getOffsets())
+ return emitOpError("offsets not allowed.");
+
+ if (tdescTy && !tdescTy.isScattered())
+ return emitOpError("Expects a scattered TensorDesc.");
if (!isWriteHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -778,10 +869,11 @@ LogicalResult StoreScatterOp::verify() {
uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
auto memTy = dyn_cast<MemRefType>(destTy);
- if (memTy && (valueTy.getElementType() != memTy.getElementType()))
+ if (memTy && (getElementType() != memTy.getElementType()))
return emitError() << "Value should have the same element type as MemRef.";
- return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
+ auto offsetsTy = getOffsets().getType();
+ return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
[&]() { return emitOpError(); });
}
@@ -794,6 +886,24 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
l2_hint, l3_hint);
}
+void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
+ Value value, Value dest,
+ ArrayRef<OpFoldResult> offsets, Value mask,
+ IntegerAttr chunk_size,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ auto loc = dest.getLoc();
+ int64_t size = static_cast<int64_t>(offsets.size());
+ auto type = VectorType::get(size, builder.getIndexType());
+ auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+ auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+ // Call the correct builder overload that does not expect result types.
+ build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
+ l3_hint);
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_UpdateOffsetOp
//===----------------------------------------------------------------------===//
@@ -888,8 +998,8 @@ LogicalResult ConvertLayoutOp::verify() {
// both input and target layouts should be WgLayout or SgLayout at the same
// time.
- if ((!srcLayout.isWgLayout() || !resLayout.isWgLayout()) &&
- (!srcLayout.isSgLayout() || !resLayout.isSgLayout()))
+ if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) &&
+ (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup()))
return emitOpError("expected input layout and target layout be WgLayout or "
"SgLayout at the same time.");
@@ -928,9 +1038,107 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<FoldConvertLayoutOp>(context);
}
+//===----------------------------------------------------------------------===//
+// XeGPU_LoadMatrixOp
+//===----------------------------------------------------------------------===//
+void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
+ TypedValue<MemDescType> memDesc,
+ llvm::ArrayRef<OpFoldResult> offsets,
+ DistributeLayoutAttr layout) {
+ llvm::SmallVector<Value> dynamicOffsets;
+ llvm::SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+ build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
+ layout);
+}
+
+LogicalResult LoadMatrixOp::verify() {
+ VectorType resTy = getRes().getType();
+ MemDescType mdescTy = getMemDesc().getType();
+
+ if (mdescTy.getRank() != 2)
+ return emitOpError("mem_desc must be 2D.");
+
+ ArrayRef<int64_t> valueShape = resTy.getShape();
+ ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+ if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape),
+ [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+ return emitOpError("result shape must not exceed mem_desc shape.");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_StoreMatrixOp
+//===----------------------------------------------------------------------===//
+void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
+ TypedValue<MemDescType> memDesc,
+ llvm::ArrayRef<OpFoldResult> offsets,
+ DistributeLayoutAttr layout) {
+ llvm::SmallVector<Value> dynamicOffsets;
+ llvm::SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+ build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
+ layout);
+}
+
+LogicalResult StoreMatrixOp::verify() {
+ VectorType dataTy = getData().getType();
+ MemDescType mdescTy = getMemDesc().getType();
+
+ if (mdescTy.getRank() != 2)
+ return emitOpError("mem_desc must be 2D.");
+
+ ArrayRef<int64_t> dataShape = dataTy.getShape();
+ ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+ if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
+ [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+ return emitOpError("data shape must not exceed mem_desc shape.");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_MemDescSubviewOp
+//===----------------------------------------------------------------------===//
+
+void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state,
+ Type resTy, Value src,
+ llvm::ArrayRef<OpFoldResult> offsets) {
+ llvm::SmallVector<Value> dynamicOffsets;
+ llvm::SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+ build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr);
+}
+
+LogicalResult MemDescSubviewOp::verify() {
+ MemDescType srcTy = getSrc().getType();
+ MemDescType resTy = getRes().getType();
+ ArrayRef<int64_t> srcShape = srcTy.getShape();
+ ArrayRef<int64_t> resShape = resTy.getShape();
+
+ if (srcTy.getRank() < resTy.getRank())
+ return emitOpError("result rank must not exceed source rank.");
+
+ if (llvm::any_of(
+ llvm::zip_equal(resShape, srcShape.take_back(resShape.size())),
+ [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+ return emitOpError("result shape must not exceed source shape.");
+
+ if (srcTy.getStrides() != resTy.getStrides())
+ return emitOpError("result must inherit the source strides.");
+
+ return success();
+}
+
} // namespace xegpu
} // namespace mlir
+namespace mlir {
+#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
+} // namespace mlir
#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
#define GET_OP_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index d82c541..9ee002e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -84,9 +84,10 @@ struct ConvertLayoutOpPattern
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
PatternRewriter &rewriter) const override {
- xegpu::LayoutAttr input_layout = op.getInputLayoutAttr();
- xegpu::LayoutAttr target_layout = op.getTargetLayoutAttr();
- if (!input_layout.getInstData() || !target_layout.getInstData())
+ xegpu::DistributeLayoutAttr input_layout = op.getInputLayoutAttr();
+ xegpu::DistributeLayoutAttr target_layout = op.getTargetLayoutAttr();
+ if (input_layout.getInstDataAsInt().empty() ||
+ target_layout.getInstDataAsInt().empty())
return rewriter.notifyMatchFailure(op, "Not a target ConvertLayoutOp.");
input_layout = input_layout.dropInstData();
@@ -140,10 +141,11 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
else
value = (Value)operandOrResult;
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operandOrResult);
- if (layout && layout.isSgLayout()) {
- if (auto inst_data = layout.getInstData())
- return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(operandOrResult);
+ if (layout && layout.isForSubgroup()) {
+ if (!layout.getInstDataAsInt().empty())
+ return layout.getInstDataAsInt();
if (auto type = dyn_cast<ShapedType>(value.getType()))
return llvm::to_vector(type.getShape());
@@ -204,13 +206,15 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
// skip the op if any of its operands or results has workgroup level layouts
bool hasWgLayoutOperands =
llvm::any_of(op->getOpOperands(), [](OpOperand &opr) {
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(opr);
- return layout && layout.isWgLayout();
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(opr);
+ return layout && layout.isForWorkgroup();
});
bool hasWgLayoutResults =
llvm::any_of(op->getOpResults(), [](OpResult result) {
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result);
- return layout && layout.isWgLayout();
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(result);
+ return layout && layout.isForWorkgroup();
});
if (hasWgLayoutOperands || hasWgLayoutResults) {
LDBG() << "skip unrolling for op with workgroup level layout: " << *op;
@@ -220,8 +224,8 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
auto isUnrollable = [](Value value, ArrayRef<int64_t> tileShape) {
Type valTy = value.getType();
if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) {
- xegpu::LayoutAttr layout = tdescTy.getLayoutAttr();
- return layout && layout.getInstData();
+ xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
+ return layout && !layout.getInstDataAsInt().empty();
}
auto shapedType = dyn_cast<ShapedType>(valTy);
return shapedType && !llvm::equal(tileShape, shapedType.getShape());
@@ -247,7 +251,8 @@ void XeGPUBlockingPass::runOnOperation() {
// Preserve the LayoutAttr for each operand to the owner's DictionaryAttr.
// This ensures that the LayoutAttr remains accessible even if the defining
// operation is replaced.
- xegpu::setLayoutAttrs(op, [](Value v) { return xegpu::getLayoutAttr(v); });
+ xegpu::setDistributeLayoutAttrs(
+ op, [](Value v) { return xegpu::getDistributeLayoutAttr(v); });
auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
xegpu::LayoutAttr layout) {
@@ -272,7 +277,7 @@ void XeGPUBlockingPass::runOnOperation() {
auto layout =
llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding());
- if (layout && layout.isWgLayout())
+ if (layout && layout.isForWorkgroup())
return failure();
int count;
@@ -289,7 +294,7 @@ void XeGPUBlockingPass::runOnOperation() {
ArrayRef<int64_t> shape = type.getShape();
xegpu::LayoutAttr layout = type.getLayoutAttr();
- if (layout && layout.isWgLayout())
+ if (layout && layout.isForWorkgroup())
return failure();
int count;
@@ -377,7 +382,7 @@ void XeGPUBlockingPass::runOnOperation() {
if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
op->removeAttr(name);
if (!isa<LoopLikeOpInterface>(op))
- xegpu::setLayoutAttr(result, layout.dropInstData());
+ xegpu::setDistributeLayoutAttr(result, layout.dropInstData());
}
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index bef8804..5cb47b2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -718,7 +718,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
}
// If the result is a vector type, add a temporary layout attribute to the
// op.
- xegpu::setLayoutAttr(result, layout);
+ xegpu::setDistributeLayoutAttr(result, layout);
}
return success();
}
@@ -800,7 +800,7 @@ updateControlFlowOps(mlir::OpBuilder &builder,
// If the type is a vector type and this region argument is an OpResult,
// set the layout attribute on the OpResult.
if (auto result = dyn_cast<OpResult>(successorInput))
- xegpu::setLayoutAttr(result, successorOperandLayout);
+ xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
}
}
return success();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 8957ea5..dddb5ea 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -277,22 +277,13 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
descOp, "the tensor descriptor lacks layout attribute");
SmallVector<size_t> newRetIndices;
- SmallVector<Value> newYieldValues;
- SmallVector<Type> newYieldTypes;
-
- for (Value operand : descOp->getOperands()) {
- newYieldValues.push_back(operand);
- newYieldTypes.push_back(operand.getType());
- }
rewriter.setInsertionPoint(warpOp);
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, /* new yieled values = */ newYieldValues,
- /* new yielded types = */ newYieldTypes, newRetIndices);
+ rewriter, warpOp, /* new yieled values = */ descOp->getOperands(),
+ /* new yielded types = */ descOp.getOperandTypes(), newRetIndices);
- SmallVector<Value> newDescOperands;
- for (size_t i : newRetIndices) {
- newDescOperands.push_back(newWarpOp.getResult(i));
- }
+ SmallVector<Value> newDescOperands = llvm::map_to_vector(
+ newRetIndices, [&](size_t i) { return newWarpOp.getResult(i); });
rewriter.setInsertionPointAfter(newWarpOp);
xegpu::TensorDescType distributedTensorDescTy =
descOp.getType().dropLayouts(); // Distributed tensor descriptor type
@@ -345,8 +336,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
Operation *lastNode = yield->getPrevNode();
auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
if (!storeOp)
@@ -458,8 +448,7 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
// Make sure the same load op is the last operation in the warp op body.
// This ensure that load op is not sinked earlier violating any barrier
// synchronizations.
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
return yield->getPrevNode() == op;
});
@@ -696,39 +685,30 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
warpOp, "warp result is not a xegpu::UpdateNdOffset op");
auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>();
unsigned operandIdx = operand->getOperandNumber();
- // new update op does not have layout attribute.
- xegpu::TensorDescType newTensorDescTy =
- updateOp.getTensorDescType().dropLayouts();
- SmallVector<Value, 3> newYieldValues;
- SmallVector<Type, 3> newYieldTypes;
- for (Value operand : updateOp->getOperands()) {
- newYieldValues.push_back(operand);
- if (isa<xegpu::TensorDescType>(operand.getType())) {
- newYieldTypes.push_back(newTensorDescTy);
- } else {
- newYieldTypes.push_back(operand.getType());
- }
- }
SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
+ rewriter, warpOp, updateOp->getOperands(), updateOp.getOperandTypes(),
+ newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
- SmallVector<Value> newUpdateOperands;
- for (size_t i : newRetIndices) {
- // For the tensor descriptor operand, the layout attribute is dropped
- // after distribution. Types needs to be resolved in this case.
- if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) {
- newUpdateOperands.push_back(resolveDistributedTy(
- newWarpOp.getResult(i), newTensorDescTy, rewriter));
- } else {
- newUpdateOperands.push_back(newWarpOp.getResult(i));
- }
- }
+ // new update op does not have layout attribute.
+ xegpu::TensorDescType distributedTensorDescTy =
+ updateOp.getTensorDescType().dropLayouts();
+ SmallVector<Value> newUpdateOperands =
+ llvm::map_to_vector(newRetIndices, [&](size_t i) {
+ // For the tensor descriptor operand, the layout attribute is
+ // dropped after distribution. Types needs to be resolved in this
+ // case.
+ if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) {
+ return resolveDistributedTy(newWarpOp.getResult(i),
+ distributedTensorDescTy, rewriter);
+ }
+ return newWarpOp.getResult(i);
+ });
// Create a new update op outside the warp op.
auto newUpdateOp = xegpu::UpdateNdOffsetOp::create(
- rewriter, newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands,
- updateOp->getAttrs());
+ rewriter, newWarpOp.getLoc(), distributedTensorDescTy,
+ newUpdateOperands, updateOp->getAttrs());
xegpu::removeLayoutAttrs(newUpdateOp);
Value distributedVal = newWarpOp.getResult(operandIdx);
// Resolve the distributed type with the original type.
@@ -770,8 +750,7 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
Operation *lastNode = yield->getPrevNode();
auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
if (!prefetchOp)
@@ -812,8 +791,7 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
Operation *lastNode = yield->getPrevNode();
// The last node must be a gpu::BarrierOp.
auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
@@ -859,14 +837,15 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
if (!isa<VectorType>(operand.get().getType()))
continue;
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand);
+ auto layout =
+ xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(operand);
if (!layout) {
op->emitError("Could not find layout attribute for operand ")
<< operand.getOperandNumber() << " of operation " << op->getName();
signalPassFailure();
return;
}
- xegpu::setLayoutAttr(operand, layout);
+ xegpu::setDistributeLayoutAttr(operand, layout);
}
});
// Step 2: Move all operations of a GPU function inside
@@ -900,7 +879,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
if (vecRank == 0)
return AffineMap::get(val.getContext());
// Get the layout of the vector type.
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(val);
+ // TODO: support more layout types
+ auto layout = xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(val);
// If no layout is specified, assume the inner most dimension is distributed
// for now.
if (!layout)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 850f70c..9f627c7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -34,38 +34,29 @@ using namespace mlir;
namespace {
-// Check if there is sg id range attached to the scf.if op.
-static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange,
- int64_t &endOfRange) {
- Operation *parent = op->getParentOp();
- // Find the outermost scf::IfOp with xegpu.sg_id_range.
+// Retrieve the RangeAttr if it is specified.
+static xegpu::RangeAttr getRangeSpecAttr(Operation *op) {
+ Operation *parent = op->getParentOfType<scf::IfOp>();
while (parent) {
- if (auto ifOp = dyn_cast<scf::IfOp>(parent)) {
- if (auto attr = llvm::dyn_cast_or_null<xegpu::RangeAttr>(
- ifOp->getAttr("sg_id_range"))) {
- startOfRange = attr.getStart().getInt();
- endOfRange = attr.getEnd().getInt();
- break;
- }
- }
- parent = parent->getParentOp();
+ if (auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
+ parent->getAttr("sg_id_range")))
+ return attr;
+ parent = parent->getParentOfType<scf::IfOp>();
}
- // Return false if startOfRange is 0
- return (startOfRange > 0 && endOfRange > startOfRange);
+ return {};
}
static std::pair<SmallVector<int64_t>, int>
-getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
+getSgShapeAndCount(ArrayRef<int64_t> shape,
+ xegpu::DistributeLayoutAttr layout) {
int count = 1;
SmallVector<int64_t> sgShape(shape);
-
- if (layout && layout.isWgLayout()) {
- DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout();
- auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
- if (DenseI32ArrayAttr sgDataAttr = layout.getSgData())
- sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
- else
- sgShape = computeShapeRatio(shape, sgLayout).value_or(sgShape);
+ if (layout && layout.isForWorkgroup()) {
+ SmallVector<int64_t> sgLayout = layout.getSgLayoutAsInt();
+ if (!layout.getSgDataAsInt().empty())
+ sgShape = layout.getSgDataAsInt();
+ else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout))
+ sgShape = *maybeDerivedSgData;
SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
// Clamp distUnit to the original shape to handle cases where data is
// shared among subgroups, which may cause distUnit to exceed the original
@@ -77,6 +68,67 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
return std::make_pair(sgShape, count);
}
+/// Utility helper for deriving a list of offsets for each sub-TensorDescs
+/// or sub-MemDescs to be accessed by current subgroup (sgId) based on the
+/// associated distribute layout attribute, the shape, subgroup id and the
+/// original offsets of the op
+template <
+ typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
+ xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
+static LogicalResult
+genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
+ SmallVector<SmallVector<OpFoldResult>> &offsetsList) {
+ Location loc = op.getLoc();
+ SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
+ // not applicable to ops without offsets operands.
+ if (origOffsets.empty())
+ return failure();
+
+ // not applicable to ops without workgroup layout attributes
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ Value sgId = rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr);
+
+ // verify and adjust the sgId if the range specifier is present
+ xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
+ if (sgIdRange) {
+ int64_t startOfRange = sgIdRange.getStart().getInt();
+ int64_t endOfRange = sgIdRange.getEnd().getInt();
+ // verify the RangeAttr against the layout attribute
+ if (layout.getNumSubgroups() != endOfRange - startOfRange)
+ return rewriter.notifyMatchFailure(
+ op, "sg_layout size must match the sg_id_range");
+ // adjust the sgId if necessary
+ if (startOfRange > 0) {
+ Value startOfRangeVal =
+ rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
+ sgId = rewriter.create<index::SubOp>(loc, sgId, startOfRangeVal);
+ }
+ }
+
+ // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
+ // descriptors to be accessed, based on the layout information.
+ ArrayRef<int64_t> wgShape = op.getDataShape();
+ auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ if (failed(maybeDescOffsets))
+ return failure();
+
+ // Compute the final global offsets for each accessed sub-tensor
+ // or sub-memory descriptor.
+ for (const auto &sgOffsets : *maybeDescOffsets) {
+ SmallVector<OpFoldResult> newOffsets = xegpu::addWithRightAligned(
+ rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets);
+ offsetsList.push_back(std::move(newOffsets));
+ }
+
+ // callback(offsetsList);
+ return success();
+}
+
/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
/// from a workgroup descriptor. It replaces the offsets and sizes with
/// appropriate values for the subgroup.
@@ -125,125 +177,74 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
- // Calculate offset for each subgroup
- static SmallVector<OpFoldResult>
- calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
- const SmallVector<OpFoldResult> &originalOffsets,
- const SmallVector<Value> &localOffset,
- const SmallVector<int64_t> &distUnitBaseAddr,
- const SmallVector<int64_t> &distUnitShape) {
- assert(localOffset.size() == distUnitBaseAddr.size() &&
- "localOffset and distUnitBaseAddr must have the same rank");
-
- SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
- originalOffsets.end());
- size_t rank = localOffset.size();
- for (size_t i = 0; i < rank; ++i) {
- size_t dimIdx = originalOffsets.size() - rank + i;
- Value constOffset =
- arith::ConstantIndexOp::create(rewriter, loc, distUnitBaseAddr[i]);
- Value offset =
- rewriter.createOrFold<index::AddOp>(loc, localOffset[i], constOffset);
- Value modValue =
- arith::ConstantIndexOp::create(rewriter, loc, distUnitShape[i]);
- Value offsetMod =
- rewriter.createOrFold<index::RemUOp>(loc, offset, modValue);
- Value origOffset = getValueOrCreateConstantIndexOp(
- rewriter, loc, originalOffsets[dimIdx]);
- Value globalOffset =
- rewriter.createOrFold<index::AddOp>(loc, origOffset, offsetMod);
- globalOffsets[dimIdx] = globalOffset;
- }
-
- return globalOffsets;
- }
-
LogicalResult
matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
+ SmallVector<SmallVector<OpFoldResult>> offsetsList;
+ if (failed(genOffsetsList(rewriter, op, offsetsList)))
+ return failure();
+
MLIRContext *ctx = op.getContext();
xegpu::TensorDescType tdescTy = op.getType();
- auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
- if (!layout)
- return failure();
- Type elemTy = tdescTy.getElementType();
ArrayRef<int64_t> wgShape = tdescTy.getShape();
- // sgLayout must be present for workgroup-level distribution.
- SmallVector<int64_t> sgLayout;
- if (auto sgLayoutAttr = layout.getSgLayout())
- sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
- else
- return rewriter.notifyMatchFailure(
- op, "sgLayout attribute is required in layout");
-
+ Type elemTy = tdescTy.getElementType();
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ auto newTdescTy =
+ xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
+ layout.dropSgLayoutAndData());
- // TODO : Handle order attribute
- // Get the subgroup ID
- auto linearSgId =
- gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-
- // Create constants for layout dimensions
- SmallVector<Value> sgLayoutDim(sgLayout.size());
- SmallVector<Value> sgDataDim(sgShape.size());
+ SmallVector<Value> newOps;
+ for (auto offsets : offsetsList) {
+ auto newOp = xegpu::CreateNdDescOp::create(
+ rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
+ op.getMixedSizes(), op.getMixedStrides());
- for (size_t i = 0; i < sgLayout.size(); i++) {
- sgLayoutDim[i] =
- arith::ConstantIndexOp::create(rewriter, loc, sgLayout[i]);
- sgDataDim[i] = arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
+ newOps.push_back(newOp);
}
+ rewriter.replaceOpWithMultiple(op, {newOps});
- int64_t startOfRange = -1, endOfRange = -1;
- bool sgIdRangeSpecified =
- isSgIdRangeSpecified(op, startOfRange, endOfRange);
-
- Value adjustedSgId = linearSgId;
- if (sgIdRangeSpecified) {
- int64_t sgCount = endOfRange - startOfRange;
- if (computeProduct(sgLayout) != sgCount)
- return rewriter.notifyMatchFailure(
- op, "sg_layout size must match the sg_id_range");
- // Subtract startOfRange from the original subgroup id to get the adjusted
- // sg id
- Value startOfRangeVal =
- arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
- adjustedSgId =
- rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
- }
+ return success();
+ }
+};
+
+// This pattern transforms the CreateNdDescOp without offsets to create a
+// subgroup descriptor from a workgroup descriptor
+struct WgToSgCreateNdOpNoOffset
+ : public OpConversionPattern<xegpu::CreateNdDescOp> {
+ using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
- auto deLinearizeSgId =
- affine::delinearizeIndex(rewriter, loc, adjustedSgId, sgLayoutDim);
- if (failed(deLinearizeSgId))
+ LogicalResult
+ matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ // Check no offsets are specified.
+ if (!op.getMixedOffsets().empty())
return failure();
- SmallVector<Value> sgIds = *deLinearizeSgId;
-
- // Calculate distribution unit shape and local offsets for subgroup
- SmallVector<int64_t> distUnitShape(sgLayout.size());
- SmallVector<Value> localOffset(sgLayout.size());
- for (size_t i = 0; i < sgLayout.size(); i++) {
- distUnitShape[i] = std::min(sgLayout[i] * sgShape[i], wgShape[i]);
- localOffset[i] =
- rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]);
- }
- SmallVector<OpFoldResult> originalOffsets = op.getMixedOffsets();
+ Location loc = op.getLoc();
+ MLIRContext *ctx = op.getContext();
+ xegpu::TensorDescType tdescTy = op.getType();
+ auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+ Type elemTy = tdescTy.getElementType();
+ ArrayRef<int64_t> wgShape = tdescTy.getShape();
+
+ SmallVector<int64_t> sgShape;
+ int count;
+ std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
xegpu::TensorDescType newTdescTy =
xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
layout.dropSgLayoutAndData());
- SmallVector<Value> newCreateNdOps;
- for (SmallVector<int64_t> distUnitBaseAddr :
- StaticTileOffsetRange(wgShape, distUnitShape)) {
- SmallVector<OpFoldResult> globalOffsets =
- calculateGlobalOffsets(rewriter, loc, originalOffsets, localOffset,
- distUnitBaseAddr, distUnitShape);
-
- auto newCreateNdOp = xegpu::CreateNdDescOp::create(
- rewriter, loc, newTdescTy, op.getSource(), globalOffsets,
- op.getMixedSizes(), op.getMixedStrides());
- newCreateNdOps.push_back(newCreateNdOp);
- }
+
+ SmallVector<Value> newCreateNdOps(count);
+ std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
+ return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
+ op.getSource(), op.getMixedSizes(),
+ op.getMixedStrides());
+ });
rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
return success();
@@ -256,12 +257,10 @@ struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
LogicalResult
matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- SmallVector<Value> newLoadOps;
-
- int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
- if ((offsetSize != 0) || op.getConstOffsetsAttr())
+ if (!op.getMixedOffsets().empty())
return failure();
+ SmallVector<Value> newLoadOps;
for (auto src : adaptor.getTensorDesc()) {
xegpu::TensorDescType tdescTy =
dyn_cast<xegpu::TensorDescType>(src.getType());
@@ -284,9 +283,7 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
LogicalResult
matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
-
- int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
- if ((offsetSize != 0) || op.getConstOffsetsAttr())
+ if (!op.getMixedOffsets().empty())
return failure();
for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
@@ -298,6 +295,84 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
}
};
+// This pattern transforms the LoadNdOp with explicit offsets to load
+// subgroup data.
+struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
+ using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ SmallVector<SmallVector<OpFoldResult>> offsetsList;
+ if (failed(genOffsetsList(rewriter, op, offsetsList)))
+ return failure();
+
+ SmallVector<Value> newOps;
+ for (auto [tdesc, offsets] :
+ llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
+ auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
+ VectorType newResTy =
+ VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
+ auto newOp = xegpu::LoadNdOp::create(
+ rewriter, op.getLoc(), newResTy, tdesc, offsets,
+ /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
+ newOps.push_back(newOp);
+ }
+ rewriter.replaceOpWithMultiple(op, {newOps});
+
+ return success();
+ }
+};
+
+// This pattern transforms the StoreNdOp with explicit offsets to store
+// subgroup data.
+struct WgToSgStoreNdOpWithOffset
+ : public OpConversionPattern<xegpu::StoreNdOp> {
+ using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<SmallVector<OpFoldResult>> offsetsList;
+ if (failed(genOffsetsList(rewriter, op, offsetsList)))
+ return failure();
+
+ for (auto [v, tdesc, offsets] :
+ llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
+ rewriter.create<xegpu::StoreNdOp>(op.getLoc(), v, tdesc, offsets,
+ op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr());
+ }
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+};
+
+// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
+// subgroup data.
+struct WgToSgPrefetchNdOpWithOffset
+ : public OpConversionPattern<xegpu::PrefetchNdOp> {
+ using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<SmallVector<OpFoldResult>> offsetsList;
+ if (failed(genOffsetsList(rewriter, op, offsetsList)))
+ return failure();
+
+ for (auto [tdesc, offsets] :
+ llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
+ rewriter.create<xegpu::PrefetchNdOp>(
+ op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr());
+ }
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+};
+
/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
/// offsets of the new subgroup src tensor descriptors.
@@ -331,7 +406,7 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
if (resultTy.getRank() != 2)
return failure();
- auto originalLayout = xegpu::getLayoutAttr(op.getResult());
+ auto originalLayout = xegpu::getDistributeLayoutAttr(op.getResult());
if (!originalLayout)
return failure();
@@ -354,8 +429,8 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
resultTy.getElementType());
tmpC = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
- xegpu::setLayoutAttr(cast<OpResult>(tmpC),
- originalLayout.dropSgLayoutAndData());
+ xegpu::setDistributeLayoutAttr(cast<OpResult>(tmpC),
+ originalLayout.dropSgLayoutAndData());
newDpasOps.push_back(tmpC);
}
@@ -395,8 +470,9 @@ struct WgToSgVectorBroadcastOp
VectorType resultType = op.getResult().getType();
ArrayRef<int64_t> wgShape = resultType.getShape();
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
- if (!layout || !layout.getSgLayout())
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
return failure();
// TODO: Currently only supports cases where the source and result ranks
@@ -411,10 +487,8 @@ struct WgToSgVectorBroadcastOp
VectorType::get(sgShape, resultType.getElementType());
// Check if the output layout is distributable
- SmallVector<int64_t> sgLayout;
- if (auto sgLayoutAttr = layout.getSgLayout())
- sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
- else
+ SmallVector<int64_t> sgLayout = layout.getSgLayoutAsInt();
+ if (sgLayout.empty())
return failure();
if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
@@ -433,8 +507,8 @@ struct WgToSgVectorBroadcastOp
for (auto operand : adaptor.getOperands().front()) {
auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
newResultType, operand);
- xegpu::setLayoutAttr(newBroadcast->getResult(0),
- layout.dropSgLayoutAndData());
+ xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
+ layout.dropSgLayoutAndData());
newBroadcastOps.push_back(newBroadcast.getResult());
}
@@ -460,8 +534,9 @@ struct WgToSgElementwiseOp : public ConversionPattern {
ArrayRef<int64_t> wgShape = resultType.getShape();
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
- if (!layout || !layout.getSgLayout())
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op->getResult(0));
+ if (!layout || !layout.isForWorkgroup())
return failure();
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
@@ -526,8 +601,8 @@ struct WgToSgElementwiseOp : public ConversionPattern {
// is lowered to:
// #a = #xegpu.layout<inst_data = [16, 16]>
// #b = #xegpu.layout<inst_data = [8, 16]>
-// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, matrix_desc<32x64xf32>
-// %d = load_matrix %slm <{layout_result_0 = #a}> : matrix_desc<32x64xf32> -> vector<16x32xf32>
+// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32>
+// %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32>
// xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32>
// clang-format on
struct WgToSgConvertLayoutOp
@@ -536,10 +611,12 @@ struct WgToSgConvertLayoutOp
LogicalResult
matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- xegpu::LayoutAttr input = op.getInputLayout();
- xegpu::LayoutAttr target = op.getTargetLayout();
+ // TODO: currently, we only support LayoutAttr
+ auto input = dyn_cast<xegpu::LayoutAttr>(op.getInputLayout());
+ auto target = dyn_cast<xegpu::LayoutAttr>(op.getTargetLayout());
- if (!input || !target || !input.isWgLayout() || !target.isWgLayout())
+ if (!input || !target || !input.isForWorkgroup() ||
+ !target.isForWorkgroup())
return rewriter.notifyMatchFailure(
op, "Input and target layouts must have subgroup layout");
@@ -649,16 +726,213 @@ struct UnrealizedConversionCastOpPattern
}
};
+// This pattern distributes arith.constant op into subgroup-level constants
+struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
+ using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
+ auto vecType = dyn_cast<VectorType>(op.getType());
+ if (!vecAttr || !vecAttr.isSplat() || !vecType)
+ return failure();
+
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ ArrayRef<int64_t> wgShape = vecType.getShape();
+ SmallVector<int64_t> sgShape;
+ int count;
+ std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+
+ // Current limitation: constant of vector with single value.
+ // TODO: support more complex cases, e.g., vector with multiple values.
+ Attribute singleVal = vecAttr.getSplatValue<Attribute>();
+
+ auto newType = VectorType::get(sgShape, vecType.getElementType());
+ auto sgAttr = DenseElementsAttr::get(newType, singleVal);
+ auto cstOp =
+ arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
+ if (auto newLayout = layout.dropSgLayoutAndData())
+ xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
+ SmallVector<Value> newConsts(count, cstOp);
+
+ rewriter.replaceOpWithMultiple(op, {newConsts});
+ return success();
+ }
+};
+
+// This pattern transforms the LoadGatherOp with explicit offsets to load
+// subgroup data
+struct WgToSgLoadGatherOpWithOffset
+ : public OpConversionPattern<xegpu::LoadGatherOp> {
+ using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ if (!op.getOffsets())
+ return failure();
+
+ Location loc = op.getLoc();
+ VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
+ if (!resultType)
+ return failure();
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+
+ // The offsets need to be distributed
+ auto offsetsVecType =
+ dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
+ auto maskVecType =
+ dyn_cast<VectorType>(adaptor.getMask().front().getType());
+ if (!offsetsVecType || !maskVecType ||
+ offsetsVecType.getShape() != maskVecType.getShape()) {
+ return rewriter.notifyMatchFailure(op,
+ "offsets have not been distributed");
+ }
+
+ SmallVector<Value> newLoadOps;
+ auto chunkSizeAttr =
+ rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
+ VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
+ for (auto [offsets, mask] :
+ llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
+ auto newLoadOp = rewriter.create<xegpu::LoadGatherOp>(
+ loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
+ op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+ xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0),
+ layout.dropSgLayoutAndData());
+ newLoadOps.push_back(newLoadOp);
+ }
+ rewriter.replaceOpWithMultiple(op, {newLoadOps});
+ return success();
+ }
+};
+
+// This pattern transforms the StoreScatterOp with explicit offsets to store
+// subgroup data
+struct WgToSgStoreScatterOpWithOffset
+ : public OpConversionPattern<xegpu::StoreScatterOp> {
+ using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ if (!op.getOffsets())
+ return failure();
+
+ Location loc = op.getLoc();
+ VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
+ if (!valueType)
+ return failure();
+
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getValue());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ // The offsets need to be distributed
+ auto offsetsVecType =
+ dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
+ auto maskVecType =
+ dyn_cast<VectorType>(adaptor.getMask().front().getType());
+ if (!offsetsVecType || !maskVecType ||
+ offsetsVecType.getShape() != maskVecType.getShape()) {
+ return rewriter.notifyMatchFailure(op,
+ "offsets have not been distributed");
+ }
+
+ auto chunkSizeOpt = op.getChunkSize();
+ int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
+ auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
+ for (auto [val, offs, mask] : llvm::zip(
+ adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
+ rewriter.create<xegpu::StoreScatterOp>(
+ loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
+ // Update the layout attribute to drop sg_layout and sg_data.
+ if (auto newLayout = layout.dropSgLayoutAndData())
+ op->setAttr("layout", newLayout);
+ }
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
+ using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ SmallVector<SmallVector<OpFoldResult>> offsetsList;
+ if (failed(genOffsetsList(rewriter, op, offsetsList)))
+ return failure();
+
+ ArrayRef<int64_t> wgShape = op.getDataShape();
+ VectorType valueTy = op.getRes().getType();
+ Type elemTy = valueTy.getElementType();
+
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ VectorType newResTy = VectorType::get(sgShape, elemTy);
+ SmallVector<Value> newOps;
+ for (auto offsets : offsetsList) {
+ auto newOp = rewriter.create<xegpu::LoadMatrixOp>(
+ op.getLoc(), newResTy, op.getMemDesc(), offsets,
+ layout.dropSgLayoutAndData());
+ newOps.push_back(newOp);
+ }
+ rewriter.replaceOpWithMultiple(op, {newOps});
+
+ return success();
+ }
+};
+
+struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
+ using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ SmallVector<SmallVector<OpFoldResult>> offsetsList;
+ if (failed(genOffsetsList(rewriter, op, offsetsList)))
+ return failure();
+
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
+ rewriter.create<xegpu::StoreMatrixOp>(op.getLoc(), v, op.getMemDesc(),
+ offsets,
+ layout.dropSgLayoutAndData());
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
} // namespace
namespace mlir {
namespace xegpu {
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
- patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
- WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
- UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
- WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
- patterns.getContext());
+ patterns
+ .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
+ WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
+ WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
+ WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
+ WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
+ WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
+ WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
+ WgToSgStoreMatrixOp>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -748,8 +1022,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return xegpu::TensorDescType();
};
- auto isLegal = [&](xegpu::LayoutAttr layout) -> bool {
- return !layout || !layout.isWgLayout();
+ auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool {
+ return !layout || !layout.isForWorkgroup();
};
target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
@@ -761,13 +1035,46 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
});
target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
- auto layout = xegpu::getLayoutAttr(op.getResult());
+ auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
return isLegal(layout);
});
+ target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
+ [=](xegpu::LoadMatrixOp op) -> bool {
+ return isLegal(op.getLayoutAttr());
+ });
+
+ target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
+ [=](xegpu::StoreMatrixOp op) -> bool {
+ return isLegal(op.getLayoutAttr());
+ });
+
+ target.addDynamicallyLegalOp<arith::ConstantOp>(
+ [=](arith::ConstantOp op) -> bool {
+ auto vecType = dyn_cast<VectorType>(op.getType());
+ if (!vecType)
+ return true;
+ return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
+ });
+
+ target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
+ [=](xegpu::LoadGatherOp op) -> bool {
+ auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
+ return isLegal(layout);
+ });
+
+ target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
+ [=](xegpu::StoreScatterOp op) -> bool {
+ // Check if the layout attribute is present on the result.
+ auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout");
+ if (!layout)
+ return true;
+ return isLegal(layout);
+ });
+
target.addDynamicallyLegalOp<vector::BroadcastOp>(
[=](vector::BroadcastOp op) -> bool {
- return isLegal(xegpu::getLayoutAttr(op.getResult()));
+ return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
});
target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
@@ -795,7 +1102,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
}
}
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op->getResult(0));
return isLegal(layout);
});
diff --git a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
index 98e84a4..d9bf4a1 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
@@ -7,5 +7,7 @@ add_mlir_dialect_library(MLIRXeGPUUtils
LINK_LIBS PUBLIC
MLIRIR
MLIRSCFTransforms
+ MLIRGPUDialect
+ MLIRXeVMDialect
MLIRXeGPUDialect
)
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 2cf21fb..cac1ffe 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -11,6 +11,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
@@ -38,7 +41,7 @@ mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) {
auto layout = llvm::dyn_cast_if_present<LayoutAttr>(tdescTy.getLayout());
// It only works for subgroup level layout, which only has lane_layout
// and lane_data, and is to distribute a SIMD code into SIMT code.
- if (!layout || !layout.isSgLayout())
+ if (!layout || !layout.isForSubgroup())
return failure();
SmallVector<int64_t> laneData(layout.getLaneData().asArrayRef());
@@ -111,7 +114,7 @@ std::string xegpu::getLayoutName(const OpResult result) {
return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
}
-xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {
+xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
if (!value)
return nullptr;
@@ -129,11 +132,11 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {
// for LoadNdOp, the layout is stored in the tensor descriptor
if (auto loadNd = dyn_cast<xegpu::LoadNdOp>(defOp))
- return getLayoutAttr(loadNd.getTensorDesc());
+ return getDistributeLayoutAttr(loadNd.getTensorDesc());
std::string layoutName = getLayoutName(result);
if (defOp->hasAttr(layoutName))
- return defOp->getAttrOfType<xegpu::LayoutAttr>(layoutName);
+ return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
}
if (auto arg = dyn_cast<BlockArgument>(value)) {
@@ -141,49 +144,51 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {
if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
OpOperand *tiedInit = loop.getTiedLoopInit(arg);
if (tiedInit)
- return getLayoutAttr(tiedInit->get());
+ return getDistributeLayoutAttr(tiedInit->get());
}
}
return nullptr;
}
-xegpu::LayoutAttr xegpu::getLayoutAttr(const OpOperand &opr) {
+xegpu::DistributeLayoutAttr
+xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
Operation *op = opr.getOwner();
std::string layoutName = xegpu::getLayoutName(opr);
if (op->hasAttr(layoutName))
- return op->getAttrOfType<xegpu::LayoutAttr>(layoutName);
- return getLayoutAttr(opr.get());
+ return op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+ return getDistributeLayoutAttr(opr.get());
}
template <typename T, typename>
-void xegpu::setLayoutAttr(const T &operandOrResult, const LayoutAttr layout) {
+void xegpu::setDistributeLayoutAttr(const T &operandOrResult,
+ const DistributeLayoutAttr layout) {
Operation *owner = operandOrResult.getOwner();
std::string name = xegpu::getLayoutName(operandOrResult);
- if (layout && !owner->hasAttrOfType<LayoutAttr>(name))
+ if (layout && !owner->hasAttrOfType<DistributeLayoutAttr>(name))
owner->setAttr(name, layout);
}
// Explicit instantiation for OpResult
-template void
-xegpu::setLayoutAttr<mlir::OpResult>(const mlir::OpResult &result,
- const mlir::xegpu::LayoutAttr layout);
+template void xegpu::setDistributeLayoutAttr<mlir::OpResult>(
+ const mlir::OpResult &result,
+ const mlir::xegpu::DistributeLayoutAttr layout);
// Explicit instantiation for OpOperand
-template void
-xegpu::setLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand,
- const mlir::xegpu::LayoutAttr layout);
+template void xegpu::setDistributeLayoutAttr<mlir::OpOperand>(
+ const mlir::OpOperand &operand,
+ const mlir::xegpu::DistributeLayoutAttr layout);
-void xegpu::setLayoutAttrs(Operation *op,
- function_ref<LayoutAttr(Value)> getLayoutImpl) {
+void xegpu::setDistributeLayoutAttrs(
+ Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) {
op->walk([&](Operation *nestOp) {
for (OpOperand &opr : nestOp->getOpOperands()) {
auto layout = getLayoutImpl(opr.get());
- setLayoutAttr(opr, layout);
+ setDistributeLayoutAttr(opr, layout);
}
for (OpResult result : nestOp->getOpResults()) {
auto layout = getLayoutImpl(result);
- setLayoutAttr(result, layout);
+ setDistributeLayoutAttr(result, layout);
}
});
}
@@ -192,7 +197,7 @@ template <typename T, typename>
void xegpu::removeLayoutAttr(const T &operandOrResult) {
Operation *owner = operandOrResult.getOwner();
std::string name = xegpu::getLayoutName(operandOrResult);
- if (owner->hasAttrOfType<LayoutAttr>(name))
+ if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
owner->removeAttr(name);
}
@@ -303,7 +308,8 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
if (!inputTy || !resultTy)
return WalkResult::skip();
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(input);
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(input);
if (!layout)
return WalkResult::skip();
@@ -341,7 +347,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
}
{ // perform the conversion from RankedTensorType to VectorType based on the
- // LayoutAttr
+ // DistributeLayoutAttr
// Handle the UnrealizedConversionCastOp introduced by the first step.
// For vector->RankedTensorType, it will simply forward the inputs.
@@ -404,3 +410,49 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
(void)mlir::applyPartialConversion(op, target, std::move(patterns));
}
}
+
+std::optional<std::string> xegpu::getChipStr(Operation *op) {
+ auto gpuModuleOp = op->getParentOfType<gpu::GPUModuleOp>();
+
+ if (!gpuModuleOp)
+ return std::nullopt;
+
+ auto targetAttrs = gpuModuleOp.getTargets();
+ if (targetAttrs) {
+ for (auto &attr : *targetAttrs) {
+ auto xevmAttr = llvm::dyn_cast<xevm::XeVMTargetAttr>(attr);
+ if (xevmAttr)
+ return xevmAttr.getChip().str();
+ }
+ }
+
+ return std::nullopt;
+}
+
+/// Generates element-wise addition ops of two arrays with automatic alignment.
+/// When the input arrays have different sizes, the shorter array is
+/// right-aligned with the longer array, and the unmatched leading elements from
+/// the longer array are preserved unchanged. This is commonly used for offset
+/// computation where higher-dimensional offsets need to be added to
+/// lower-dimensional adjustments.
+///
+/// Example:
+/// lhs = [l1, l2, l3], rhs = [r1, r2]
+/// Result: [11, l2+r1, l3+r2]
+SmallVector<OpFoldResult>
+xegpu::addWithRightAligned(OpBuilder &builder, Location loc,
+ ArrayRef<OpFoldResult> lhs,
+ ArrayRef<OpFoldResult> rhs) {
+ // ensure a is longer than b
+ ArrayRef<OpFoldResult> a = lhs.size() >= rhs.size() ? lhs : rhs;
+ ArrayRef<OpFoldResult> b = lhs.size() >= rhs.size() ? rhs : lhs;
+ SmallVector<OpFoldResult> results(a.take_front(a.size() - b.size()));
+ a = a.slice(a.size() - b.size());
+ for (auto [l, r] : llvm::zip(a, b)) {
+ auto lval = getValueOrCreateConstantIndexOp(builder, loc, l);
+ auto rval = getValueOrCreateConstantIndexOp(builder, loc, r);
+ results.push_back(builder.createOrFold<index::AddOp>(loc, lval, rval));
+ }
+ return results;
+ return {};
+}