aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp210
-rw-r--r--mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp34
-rw-r--r--mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp5
-rw-r--r--mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp3
-rw-r--r--mlir/lib/Dialect/Affine/IR/AffineOps.cpp2
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp2
-rw-r--r--mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp6
-rw-r--r--mlir/lib/Dialect/Affine/Utils/Utils.cpp3
-rw-r--r--mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td4
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp2
-rw-r--r--mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp6
-rw-r--r--mlir/lib/Dialect/Async/IR/Async.cpp2
-rw-r--r--mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp12
-rw-r--r--mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp22
-rw-r--r--mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp8
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp10
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp10
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp2
-rw-r--r--mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp34
-rw-r--r--mlir/lib/Dialect/DLTI/DLTI.cpp5
-rw-r--r--mlir/lib/Dialect/EmitC/IR/EmitC.cpp33
-rw-r--r--mlir/lib/Dialect/Func/Utils/Utils.cpp25
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp4
-rw-r--r--mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp1
-rw-r--r--mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp12
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp4
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp8
-rw-r--r--mlir/lib/Dialect/IRDL/IRDLLoading.cpp4
-rw-r--r--mlir/lib/Dialect/Index/IR/IndexDialect.cpp14
-rw-r--r--mlir/lib/Dialect/LLVMIR/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp11
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp19
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp30
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp4
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp2040
-rw-r--r--mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp45
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp23
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp11
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp12
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp301
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp8
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp68
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp30
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp36
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp247
-rw-r--r--mlir/lib/Dialect/Linalg/Utils/Utils.cpp740
-rw-r--r--mlir/lib/Dialect/MemRef/IR/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp1
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp19
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp7
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp6
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp32
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp33
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp89
-rw-r--r--mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp7
-rw-r--r--mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp6
-rw-r--r--mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp7
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp636
-rw-r--r--mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp781
-rw-r--r--mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp431
-rw-r--r--mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp237
-rw-r--r--mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp117
-rw-r--r--mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt7
-rw-r--r--mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp111
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp32
-rw-r--r--mlir/lib/Dialect/SCF/IR/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp186
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp1
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp216
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp83
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp77
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp5
-rw-r--r--mlir/lib/Dialect/Shard/IR/ShardOps.cpp82
-rw-r--r--mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp8
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp8
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp4
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp1
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp72
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h2
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp3
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorOps.cpp76
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp56
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp7
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp85
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp20
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp29
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt4
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp111
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp9
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp18
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp310
-rw-r--r--mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp18
-rw-r--r--mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp13
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformOps.cpp4
-rw-r--r--mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp5
-rw-r--r--mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp4
-rw-r--r--mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp8
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp26
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp50
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp14
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp285
-rw-r--r--mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp43
-rw-r--r--mlir/lib/Dialect/X86Vector/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt17
-rw-r--r--mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp64
-rw-r--r--mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt3
-rw-r--r--mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp143
-rw-r--r--mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp301
-rw-r--r--mlir/lib/Dialect/XeGPU/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp380
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp137
-rw-r--r--mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt17
-rw-r--r--mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp695
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp490
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp611
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp589
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp27
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp267
-rw-r--r--mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp97
127 files changed, 11075 insertions, 1399 deletions
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index df955fc..b7a665b 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -55,6 +55,10 @@ void AMDGPUDialect::initialize() {
#define GET_OP_LIST
#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
>();
+ addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc"
+ >();
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
@@ -339,19 +343,45 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
}
//===----------------------------------------------------------------------===//
-// ScaledExtPacked816Op
+// ScaledExtPackedMatrixOp
//===----------------------------------------------------------------------===//
-LogicalResult ScaledExtPacked816Op::verify() {
+LogicalResult ScaledExtPackedMatrixOp::verify() {
int blockSize = getBlockSize();
- assert((blockSize == 16 || blockSize == 32) && "invalid block size");
+ assert(llvm::is_contained({16, 32}, blockSize) && "invalid block size");
+
int firstScaleByte = getFirstScaleByte();
- if (blockSize == 16 && !llvm::is_contained({0, 1}, firstScaleByte)) {
- return emitOpError(
- "blockSize of 16 can only have firstScaleByte be 0 or 1.");
- }
- if (blockSize == 32 && !llvm::is_contained({0, 2}, firstScaleByte)) {
- return emitOpError(
- "blockSize of 32 can only have firstScaleByte be 0 or 2.");
+ int firstScaleLane = getFirstScaleLane();
+ auto sourceType = cast<VectorType>(getSource().getType());
+ Type elementType = sourceType.getElementType();
+ auto floatType = cast<FloatType>(elementType);
+ unsigned bitWidth = floatType.getWidth();
+
+ assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
+
+ const bool is_fp8 = bitWidth == 8;
+ const bool is_block_16 = blockSize == 16;
+
+ if (!is_fp8) {
+ if (is_block_16) {
+ if (!llvm::is_contained({0, 1}, firstScaleByte)) {
+ return emitOpError("blockSize of 16 can only have firstScaleByte be 0 "
+ "or 1 for f4 and f6.");
+ }
+ } else {
+ if (!llvm::is_contained({0, 2}, firstScaleByte)) {
+ return emitOpError("blockSize of 32 can only have firstScaleByte be 0 "
+ "or 2 for f4 and f6.");
+ }
+ }
+ } else {
+ if (is_block_16) {
+ bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
+ ((firstScaleLane == 16) && (firstScaleByte == 2));
+ if (!is_valid) {
+ return emitOpError("blockSize of 16 can only have (firstScaleLane, "
+ "firstScaleByte) be (0, 0) or (16, 2) for f8.");
+ }
+ }
}
return success();
@@ -567,6 +597,53 @@ LogicalResult PermlaneSwapOp::verify() {
}
//===----------------------------------------------------------------------===//
+// MemoryCounterWaitOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Fuse adjacent memory counter wait ops, taking the minimum value of the
+/// counters.
+struct FuseMemoryCounterWaitOp final : OpRewritePattern<MemoryCounterWaitOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(MemoryCounterWaitOp op,
+ PatternRewriter &rewriter) const override {
+ auto next = dyn_cast<MemoryCounterWaitOp>(op->getNextNode());
+ if (!next)
+ return failure();
+
+ auto setters = {&MemoryCounterWaitOp::setLoad,
+ &MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs,
+ &MemoryCounterWaitOp::setExp,
+ &MemoryCounterWaitOp::setTensor};
+ auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp(),
+ op.getTensor()};
+ auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(),
+ next.getExp(), next.getTensor()};
+ rewriter.modifyOpInPlace(op, [&] {
+ for (auto [setter, lhs, rhs] :
+ llvm::zip_equal(setters, lhsVals, rhsVals)) {
+ if (lhs && rhs) {
+ (op.*setter)(std::min(*lhs, *rhs));
+ } else if (lhs) {
+ (op.*setter)(*lhs);
+ } else if (rhs) {
+ (op.*setter)(*rhs);
+ }
+ }
+ });
+ rewriter.eraseOp(next);
+ return success();
+ }
+};
+} // namespace
+
+void MemoryCounterWaitOp::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<FuseMemoryCounterWaitOp>(context);
+}
+
+//===----------------------------------------------------------------------===//
// GatherToLDSOp
//===----------------------------------------------------------------------===//
@@ -662,19 +739,123 @@ LogicalResult TransposeLoadOp::verify() {
};
auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
- if (validNumElems == kValidLoadSizeMap.end()) {
+ if (validNumElems == kValidLoadSizeMap.end())
return emitOpError("Unsupported element type size for transpose load: ")
<< elementTypeSize << " bits";
- }
- if (numElements != validNumElems->second) {
+
+ if (numElements != validNumElems->second)
return emitOpError(
"Transferring type size mismatch: expected num of elements: ")
<< validNumElems->second;
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// MakeDmaBaseOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult MakeDmaBaseOp::verify() {
+
+ auto ldsType = cast<MemRefType>(getLds().getType());
+ auto globalType = cast<MemRefType>(getGlobal().getType());
+ if (!hasWorkgroupMemorySpace(ldsType.getMemorySpace()))
+ return emitOpError(
+ "lds memref must have workgroup address space attribute.");
+ if (!hasGlobalMemorySpace(globalType.getMemorySpace()))
+ return emitOpError(
+ "global memref must have global address space attribute.");
+
+ Type elementType = ldsType.getElementType();
+ unsigned width = elementType.getIntOrFloatBitWidth();
+
+ if (!llvm::is_contained<unsigned>({8, 16, 32, 64}, width))
+ return emitOpError(
+ "element type must be 1, 2, 4, or 8 bytes long but type was ")
+ << width << " bits long.";
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// MakeDmaDescriptorOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult MakeDmaDescriptorOp::verify() {
+ ArrayRef<int64_t> globalStaticStrides = getGlobalStaticStrides();
+
+ if (globalStaticStrides.empty())
+ return emitOpError("strides must not be empty.");
+ if (globalStaticStrides.back() != 1)
+ return emitOpError("strides for the innermost dimension must be 1.");
+
+ ArrayRef<int64_t> globalStaticSizes = getGlobalStaticSizes();
+ size_t rank = globalStaticSizes.size();
+ if (rank > 5)
+ return emitOpError("tensor and tile must be at most of rank 5.");
+ if (rank != globalStaticStrides.size())
+ return emitOpError("strides and sizes must have same rank.");
+
+ ArrayRef<int64_t> sharedStaticSizes = getSharedStaticSizes();
+ if (rank != sharedStaticSizes.size())
+ return emitOpError("tensor must have same rank as tile.");
+
+ unsigned elementTypeWidth = getElementTypeWidth();
+ if (!llvm::is_contained<unsigned>({8, 16, 32, 64}, elementTypeWidth))
+ return emitOpError(
+ "element type width must be 1, 2, 4 or 8 bytes, but was ")
+ << elementTypeWidth << " bits long";
+
+ if (Value atomicBarrierAddress = getAtomicBarrierAddress()) {
+ auto atomicBarrierAddressType =
+ cast<MemRefType>(atomicBarrierAddress.getType());
+ bool barrierInLDS =
+ hasWorkgroupMemorySpace(atomicBarrierAddressType.getMemorySpace());
+ if (!barrierInLDS)
+ return emitOpError("atomic barrier address must be in LDS.");
}
+ if (getEarlyTimeout() && !getWorkgroupMask())
+ return emitOpError(
+ "early timeout does not apply when workgroup_mask is not set.");
return success();
}
+OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) {
+ SmallVector<OpFoldResult> mixedGlobalSizes(getMixedGlobalSizes());
+ SmallVector<OpFoldResult> mixedGlobalStrides(getMixedGlobalStrides());
+ SmallVector<OpFoldResult> mixedSharedSizes(getMixedSharedSizes());
+
+ if (failed(foldDynamicIndexList(mixedGlobalSizes, /*onlyNonNegative=*/true,
+ /*onlyNonZero=*/true)) &&
+ failed(foldDynamicIndexList(mixedGlobalStrides, /*onlyNonNegative=*/true,
+ /*onlyNonZero=*/true)) &&
+ failed(foldDynamicIndexList(mixedSharedSizes, /*onlyNonNegative=*/true,
+ /*onlyNonZero=*/true)))
+ return nullptr;
+
+ SmallVector<Value> dynamicGlobalSizes, dynamicGlobalStrides,
+ dynamicSharedSizes;
+ SmallVector<int64_t> staticGlobalSizes, staticGlobalStrides,
+ staticSharedSizes;
+
+ dispatchIndexOpFoldResults(mixedGlobalSizes, dynamicGlobalSizes,
+ staticGlobalSizes);
+ setGlobalStaticSizes(staticGlobalSizes);
+ getGlobalDynamicSizesMutable().assign(dynamicGlobalSizes);
+
+ dispatchIndexOpFoldResults(mixedGlobalStrides, dynamicGlobalStrides,
+ staticGlobalStrides);
+ setGlobalStaticStrides(staticGlobalStrides);
+ getGlobalDynamicStridesMutable().assign(dynamicGlobalStrides);
+
+ dispatchIndexOpFoldResults(mixedSharedSizes, dynamicSharedSizes,
+ staticSharedSizes);
+ setSharedStaticSizes(staticSharedSizes);
+ getSharedDynamicSizesMutable().assign(dynamicSharedSizes);
+ return getResult();
+}
+
//===----------------------------------------------------------------------===//
// ScaledMFMAOp
//===----------------------------------------------------------------------===//
@@ -813,5 +994,8 @@ void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc"
+
#define GET_OP_CLASSES
#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
index f15c63c..89ef51f 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
@@ -33,19 +33,18 @@ using namespace mlir::amdgpu;
/// This pattern supports lowering of: `vector.maskedload` to `vector.load`
/// and `arith.select` if the memref is in buffer address space.
-static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
- vector::MaskedLoadOp maskedOp) {
- auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
+static LogicalResult hasBufferAddressSpace(Type type) {
+ auto memRefType = dyn_cast<MemRefType>(type);
if (!memRefType)
- return rewriter.notifyMatchFailure(maskedOp, "not a memref source");
+ return failure();
Attribute addrSpace = memRefType.getMemorySpace();
if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
- return rewriter.notifyMatchFailure(maskedOp, "no address space");
+ return failure();
if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
amdgpu::AddressSpace::FatRawBuffer)
- return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space");
+ return failure();
return success();
}
@@ -83,10 +82,11 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
PatternRewriter &rewriter) const override {
if (maskedOp->hasAttr(kMaskedloadNeedsMask))
- return failure();
+ return rewriter.notifyMatchFailure(maskedOp, "already rewritten");
- if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) {
- return failure();
+ if (failed(hasBufferAddressSpace(maskedOp.getBase().getType()))) {
+ return rewriter.notifyMatchFailure(
+ maskedOp, "isn't a load from a fat buffer resource");
}
// Check if this is either a full inbounds load or an empty, oob load. If
@@ -176,9 +176,14 @@ struct FullMaskedLoadToConditionalLoad
LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp,
PatternRewriter &rewriter) const override {
+ if (succeeded(hasBufferAddressSpace(loadOp.getBase().getType())))
+ return rewriter.notifyMatchFailure(
+ loadOp, "buffer loads are handled by a more specialized pattern");
+
FailureOr<Value> maybeCond = matchFullMask(rewriter, loadOp.getMask());
if (failed(maybeCond)) {
- return failure();
+ return rewriter.notifyMatchFailure(loadOp,
+ "isn't loading a broadcasted scalar");
}
Value cond = maybeCond.value();
@@ -203,6 +208,15 @@ struct FullMaskedStoreToConditionalStore
LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp,
PatternRewriter &rewriter) const override {
+ // A condition-free implementation of fully masked stores requires
+ // 1) an accessor for the num_records field on buffer resources/fat pointers
+ // 2) knowledge that said field will always be set accurately - that is,
+ // that writes to x < num_records of offset wouldn't trap, which is
+ // something a pattern user would need to assert or we'd need to prove.
+ //
+ // Therefore, conditional stores to buffers still go down this path at
+ // present.
+
FailureOr<Value> maybeCond = matchFullMask(rewriter, storeOp.getMask());
if (failed(maybeCond)) {
return failure();
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
index 4d2d873..3d1a734 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
@@ -66,9 +66,10 @@ static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
.Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; })
.Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; })
.Case([](arith::MaxUIOp) { return arith::AtomicRMWKind::maxu; })
+ .Case([](arith::XOrIOp) { return arith::AtomicRMWKind::xori; })
+ .Case([](arith::MaxNumFOp) { return arith::AtomicRMWKind::maxnumf; })
+ .Case([](arith::MinNumFOp) { return arith::AtomicRMWKind::minnumf; })
.Default([](Operation *) -> std::optional<arith::AtomicRMWKind> {
- // TODO: AtomicRMW supports other kinds of reductions this is
- // currently not detecting, add those when the need arises.
return std::nullopt;
});
if (!maybeKind)
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
index b405ec2..edfae7e 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
@@ -342,8 +342,7 @@ void FlatAffineValueConstraints::getIneqAsAffineValueMap(
if (inequality[pos] > 0)
// Lower bound.
- std::transform(bound.begin(), bound.end(), bound.begin(),
- std::negate<int64_t>());
+ llvm::transform(bound, bound.begin(), std::negate<int64_t>());
else
// Upper bound (which is exclusive).
bound.back() += 1;
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 0c35921..c6addfb 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -5421,7 +5421,7 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
return rewriter.notifyMatchFailure(op,
"no unit basis entries to replace");
- if (newIndices.size() == 0) {
+ if (newIndices.empty()) {
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
return success();
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
index c942c02..b04e2d6 100644
--- a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
@@ -82,7 +82,7 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
ArrayRef<int64_t> oldShape = oldMemRefType.getShape();
SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank());
newShape[0] = 2;
- std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1);
+ llvm::copy(oldShape, newShape.begin() + 1);
return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({});
};
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 4743941..8f1249e 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -1711,6 +1711,12 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef<AffineForOp> loops) {
outermost.getBody()->getOperations().splice(
Block::iterator(secondOutermostLoop.getOperation()),
innermost.getBody()->getOperations());
+ for (auto [iter, init] :
+ llvm::zip_equal(secondOutermostLoop.getRegionIterArgs(),
+ secondOutermostLoop.getInits())) {
+ iter.replaceAllUsesWith(init);
+ iter.dropAllUses();
+ }
secondOutermostLoop.erase();
return success();
}
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 845be20..deba160 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1327,9 +1327,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
assert(cast<MemRefType>(oldMemRef.getType()).getElementType() ==
cast<MemRefType>(newMemRef.getType()).getElementType());
- std::unique_ptr<DominanceInfo> domInfo;
- std::unique_ptr<PostDominanceInfo> postDomInfo;
-
// Walk all uses of old memref; collect ops to perform replacement. We use a
// DenseSet since an operation could potentially have multiple uses of a
// memref (although rare), and the replacement later is going to erase ops.
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index de3efc9f..e256915 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -389,8 +389,8 @@ def TruncIExtUIToExtUI :
// trunci(shrsi(x, c)) -> trunci(shrui(x, c))
def TruncIShrSIToTrunciShrUI :
Pat<(Arith_TruncIOp:$tr
- (Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0)), $overflow),
- (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0))), $overflow),
+ (Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0), $exact), $overflow),
+ (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0)), $exact), $overflow),
[(TruncationMatchesShiftAmount $x, $tr, $c0)]>;
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index adeb50b..c4e81e5 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -35,7 +35,7 @@ static Value createConst(Location loc, Type type, int value,
}
/// Create a float constant.
-static Value createFloatConst(Location loc, Type type, APFloat value,
+static Value createFloatConst(Location loc, Type type, const APFloat &value,
PatternRewriter &rewriter) {
auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
index 39e398b..cb7c3d7 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
@@ -150,7 +150,7 @@ public:
rhsMask = packInputs(op1.getRhsMask(), op2.getRhsMask());
}
- auto extOp = op.getLhs().getDefiningOp();
+ auto *extOp = op.getLhs().getDefiningOp();
arm_sme::CombiningKind kind = op.getKind();
if (kind == arm_sme::CombiningKind::Add) {
@@ -311,8 +311,8 @@ public:
rhsMask = packInputs(rhs0Mask, rhs1Mask);
}
- auto lhsExtOp = op.getLhs().getDefiningOp();
- auto rhsExtOp = op.getRhs().getDefiningOp();
+ auto *lhsExtOp = op.getLhs().getDefiningOp();
+ auto *rhsExtOp = op.getRhs().getDefiningOp();
arm_sme::CombiningKind kind = op.getKind();
if (kind == arm_sme::CombiningKind::Add) {
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 8e4a49d..e19b917 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -17,8 +17,6 @@ using namespace mlir::async;
#include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
-constexpr StringRef AsyncDialect::kAllowedToBlockAttrName;
-
void AsyncDialect::initialize() {
addOperations<
#define GET_OP_LIST
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index e0cf353..9b11270 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -680,16 +680,6 @@ bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
return false;
}
-// bufferization.to_buffer is not allowed to change the rank.
-static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
-#ifndef NDEBUG
- auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(tensor.getType());
- assert((!rankedTensorType || llvm::cast<MemRefType>(memrefType).getRank() ==
- rankedTensorType.getRank()) &&
- "to_buffer would be invalid: mismatching ranks");
-#endif
-}
-
FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options,
const BufferizationState &state) {
@@ -708,7 +698,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
FailureOr<BufferLikeType> bufferType = getBufferType(value, options, state);
if (failed(bufferType))
return failure();
- ensureToBufferOpIsValid(value, *bufferType);
+
return bufferization::ToBufferOp::create(rewriter, value.getLoc(),
*bufferType, value)
.getResult();
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index 6c08cdf..bd177ba 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -21,25 +21,6 @@ using namespace mlir::bufferization;
#include "mlir/Dialect/Bufferization/IR/BufferizationOpsDialect.cpp.inc"
-/// Attribute name used to mark function arguments who's buffers can be written
-/// to during One-Shot Module Bufferize.
-constexpr const ::llvm::StringLiteral BufferizationDialect::kWritableAttrName;
-
-/// Attribute name used to mark the bufferization layout for region arguments
-/// during One-Shot Module Bufferize.
-constexpr const ::llvm::StringLiteral
- BufferizationDialect::kBufferLayoutAttrName;
-
-/// An attribute that can be attached to ops with an allocation and/or
-/// deallocation side effect. It indicates that the op is under a "manual
-/// deallocation" scheme. In the case of an allocation op, the returned
-/// value is *not* an automatically managed allocation and assigned an
-/// ownership of "false". Furthermore, only deallocation ops that are
-/// guaranteed to deallocate a buffer under "manual deallocation" are
-/// allowed to have this attribute. (Deallocation ops without this
-/// attribute are rejected by the ownership-based buffer deallocation pass.)
-constexpr const ::llvm::StringLiteral BufferizationDialect::kManualDeallocation;
-
//===----------------------------------------------------------------------===//
// Bufferization Dialect Interfaces
//===----------------------------------------------------------------------===//
@@ -73,9 +54,6 @@ struct BuiltinTensorExternalModel
mlir::LogicalResult verifyCompatibleBufferType(
mlir::Type tensor, BufferLikeType bufferType,
llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
- assert(isa<TensorType>(tensor) && "expected tensor type");
- assert(isa<BaseMemRefType>(bufferType) && "expected memref type");
-
auto tensorType = cast<ShapedType>(tensor);
auto memrefType = cast<ShapedType>(bufferType);
diff --git a/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp b/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp
index 51feec7..f8eb45c 100644
--- a/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp
+++ b/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp
@@ -17,6 +17,10 @@
// Pipeline implementation.
//===----------------------------------------------------------------------===//
+void mlir::bufferization::buildBufferDeallocationPipeline(OpPassManager &pm) {
+ buildBufferDeallocationPipeline(pm, BufferDeallocationPipelineOptions());
+}
+
void mlir::bufferization::buildBufferDeallocationPipeline(
OpPassManager &pm, const BufferDeallocationPipelineOptions &options) {
memref::ExpandReallocPassOptions expandAllocPassOptions{
@@ -44,5 +48,7 @@ void mlir::bufferization::registerBufferizationPipelines() {
"The default pipeline for automatically inserting deallocation "
"operations after one-shot bufferization. Deallocation operations "
"(except `memref.realloc`) may not be present already.",
- buildBufferDeallocationPipeline);
+ [](OpPassManager &pm, const BufferDeallocationPipelineOptions &options) {
+ buildBufferDeallocationPipeline(pm, options);
+ });
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index b9ee0a4..d0742ec 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -217,7 +217,9 @@ updateCalls(ModuleOp module, const AllocDynamicSizesMap &map,
}
if (!options.filterFn(&callee))
return;
- if (callee.isExternal() || callee.isPublic())
+ if (callee.isPublic() && !options.modifyPublicFunctions)
+ return;
+ if (callee.isExternal())
return;
SmallVector<Value, 6> replaceWithNewCallResults;
@@ -295,7 +297,9 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
// function.
AllocDynamicSizesMap map;
for (auto func : module.getOps<func::FuncOp>()) {
- if (func.isExternal() || func.isPublic())
+ if (func.isPublic() && !options.modifyPublicFunctions)
+ continue;
+ if (func.isExternal())
continue;
if (!options.filterFn(&func))
continue;
@@ -326,6 +330,8 @@ struct BufferResultsToOutParamsPass
options.hoistStaticAllocs = true;
if (hoistDynamicAllocs)
options.hoistDynamicAllocs = true;
+ if (modifyPublicFunctions)
+ options.modifyPublicFunctions = true;
if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
options)))
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 1784964..677c0ba 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
+#include "mlir/Transforms/RegionUtils.h"
namespace mlir {
namespace bufferization {
@@ -105,8 +106,13 @@ Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter,
// this replacement.
Operation *insertionPoint =
findValidInsertionPoint(emptyTensorOp, user, neededValues);
- if (!insertionPoint)
- return {};
+ if (!insertionPoint) {
+ // If no already suitable insertion point was found, attempt to move all
+ // needed values before the user.
+ if (failed(moveValueDefinitions(rewriter, neededValues, user)))
+ return {};
+ insertionPoint = user;
+ }
rewriter.setInsertionPoint(insertionPoint);
Value replacement =
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 9ccbfd3..5dfe3e6 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -497,7 +497,7 @@ static bool matchesInsertDestination(const AnalysisState &state,
// terminates. All of them must be equivalent subsets.
SetVector<Value> backwardSlice =
state.findValueInReverseUseDefChain(opOperand, matchingSubset);
- return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset));
+ return llvm::all_of(backwardSlice, matchingSubset);
}
/// Return "true" if the given "read" and potentially conflicting "write" are
diff --git a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt
index 58551bb..05a787f 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt
@@ -12,4 +12,5 @@ add_mlir_dialect_library(MLIRControlFlowDialect
MLIRControlFlowInterfaces
MLIRIR
MLIRSideEffectInterfaces
+ MLIRUBDialect
)
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index f1da1a1..d2078d8 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
@@ -445,6 +446,37 @@ struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
return success(replaced);
}
};
+
+/// If the destination block of a conditional branch contains only
+/// ub.unreachable, unconditionally branch to the other destination.
+struct DropUnreachableCondBranch : public OpRewritePattern<CondBranchOp> {
+ using OpRewritePattern<CondBranchOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(CondBranchOp condbr,
+ PatternRewriter &rewriter) const override {
+ // If the "true" destination is unreachable, branch to the "false"
+ // destination.
+ Block *trueDest = condbr.getTrueDest();
+ Block *falseDest = condbr.getFalseDest();
+ if (llvm::hasSingleElement(*trueDest) &&
+ isa<ub::UnreachableOp>(trueDest->getTerminator())) {
+ rewriter.replaceOpWithNewOp<BranchOp>(condbr, falseDest,
+ condbr.getFalseOperands());
+ return success();
+ }
+
+ // If the "false" destination is unreachable, branch to the "true"
+ // destination.
+ if (llvm::hasSingleElement(*falseDest) &&
+ isa<ub::UnreachableOp>(falseDest->getTerminator())) {
+ rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest,
+ condbr.getTrueOperands());
+ return success();
+ }
+
+ return failure();
+ }
+};
} // namespace
void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -452,7 +484,7 @@ void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
SimplifyCondBranchIdenticalSuccessors,
SimplifyCondBranchFromCondBranchOnSameCondition,
- CondBranchTruthPropagation>(context);
+ CondBranchTruthPropagation, DropUnreachableCondBranch>(context);
}
SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index 173d58b..da572f1 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -606,11 +606,6 @@ FailureOr<Attribute> dlti::query(Operation *op, ArrayRef<StringRef> keys,
return dlti::query(op, entryKeys, emitError);
}
-constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutAttrName;
-constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessKey;
-constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessBig;
-constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessLittle;
-
namespace {
class TargetDataLayoutInterface : public DataLayoutDialectInterface {
public:
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 0992ce14..b0566dd 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -226,6 +226,21 @@ FailureOr<SmallVector<ReplacementItem>> parseFormatString(
}
//===----------------------------------------------------------------------===//
+// AddressOfOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult AddressOfOp::verify() {
+ emitc::LValueType referenceType = getReference().getType();
+ emitc::PointerType resultType = getResult().getType();
+
+ if (referenceType.getValueType() != resultType.getPointee())
+ return emitOpError("requires result to be a pointer to the type "
+ "referenced by operand");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//
@@ -380,6 +395,20 @@ LogicalResult emitc::ConstantOp::verify() {
OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
//===----------------------------------------------------------------------===//
+// DereferenceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult DereferenceOp::verify() {
+ emitc::PointerType pointerType = getPointer().getType();
+
+ if (pointerType.getPointee() != getResult().getType().getValueType())
+ return emitOpError("requires result to be an lvalue of the type "
+ "pointed to by operand");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// ExpressionOp
//===----------------------------------------------------------------------===//
@@ -584,6 +613,10 @@ void ForOp::print(OpAsmPrinter &p) {
LogicalResult ForOp::verifyRegions() {
// Check that the body defines as single block argument for the induction
// variable.
+ if (getBody()->getNumArguments() != 1)
+ return emitOpError("expected body to have a single block argument for the "
+ "induction variable");
+
if (getInductionVar().getType() != getLowerBound().getType())
return emitOpError(
"expected induction variable to be same type as bounds and step");
diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp
index b4cb093..d6dfd02 100644
--- a/mlir/lib/Dialect/Func/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp
@@ -254,3 +254,28 @@ func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp,
return std::make_pair(*newFuncOpOrFailure, newCallOp);
}
+
+FailureOr<func::FuncOp>
+func::lookupFnDecl(SymbolOpInterface symTable, StringRef name,
+ FunctionType funcT, SymbolTableCollection *symbolTables) {
+ FuncOp func;
+ if (symbolTables) {
+ func = symbolTables->lookupSymbolIn<FuncOp>(
+ symTable, StringAttr::get(symTable->getContext(), name));
+ } else {
+ func = llvm::dyn_cast_or_null<FuncOp>(
+ SymbolTable::lookupSymbolIn(symTable, name));
+ }
+
+ if (!func)
+ return func;
+
+ mlir::FunctionType foundFuncT = func.getFunctionType();
+ // Assert the signature of the found function is same as expected
+ if (funcT != foundFuncT) {
+ return func.emitError("matched function '")
+ << name << "' but with different type: " << foundFuncT
+ << " (expected " << funcT << ")";
+ }
+ return func;
+}
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 6c6d8d2..61a630a 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -208,7 +208,7 @@ Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
bool MMAMatrixType::isValidElementType(Type elementType) {
- return elementType.isF16() || elementType.isF32() ||
+ return elementType.isF16() || elementType.isF32() || elementType.isF64() ||
elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
elementType.isInteger(32);
}
@@ -225,7 +225,7 @@ MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
if (!MMAMatrixType::isValidElementType(elementType))
return emitError()
- << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
+ << "MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64";
return success();
}
diff --git a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
index ec68acf..85b7b1ce 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
@@ -21,6 +21,7 @@ add_mlir_dialect_library(MLIRGPUPipelines
MLIRNVVMToLLVM
MLIRReconcileUnrealizedCasts
MLIRSCFToControlFlow
+ MLIRVectorToLLVMPass
MLIRVectorToSCF
MLIRXeGPUTransforms
MLIRXeGPUToXeVM
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
index 2c3e466..5462cdd 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
@@ -72,6 +72,7 @@ void buildGpuPassPipeline(OpPassManager &pm,
ConvertGpuOpsToNVVMOpsOptions opt;
opt.useBarePtrCallConv = options.kernelUseBarePtrCallConv;
opt.indexBitwidth = options.indexBitWidth;
+ opt.allowPatternRollback = options.allowPatternRollback;
pm.addNestedPass<gpu::GPUModuleOp>(createConvertGpuOpsToNVVMOps(opt));
pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
index 1a1485b..38313dc 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
@@ -63,13 +63,20 @@ void buildGPUPassPipeline(OpPassManager &pm,
if (options.xegpuOpLevel == "workgroup") {
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUWgToSgDistribute());
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+ xegpu::XeGPUPropagateLayoutOptions layoutOptions;
+ layoutOptions.layoutKind = "inst";
+ pm.addNestedPass<gpu::GPUModuleOp>(
+ xegpu::createXeGPUPropagateLayout(layoutOptions));
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUBlocking());
pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
}
if (options.xegpuOpLevel == "subgroup" ||
options.xegpuOpLevel == "workgroup") {
- pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUPropagateLayout());
+ xegpu::XeGPUPropagateLayoutOptions layoutOptions;
+ layoutOptions.layoutKind = "lane";
+ pm.addNestedPass<gpu::GPUModuleOp>(
+ xegpu::createXeGPUPropagateLayout(layoutOptions));
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUSubgroupDistribute());
pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
@@ -104,8 +111,11 @@ void buildPostGPUCommonPassPipeline(
pm.addPass(createGpuToLLVMConversionPass(gpuToLLVMOptions));
}
pm.addPass(createLowerAffinePass());
+ pm.addPass(createConvertVectorToLLVMPass());
pm.addPass(createConvertToLLVMPass());
pm.addPass(createReconcileUnrealizedCastsPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
// gpu-module-to-binary
{
GpuModuleToBinaryPassOptions gpuToModuleBinOptions;
diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
index cd13840..70d2e11 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
@@ -143,8 +143,8 @@ private:
};
/// Erases `executeOp` and returns a clone with additional `results`.
-async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp,
- ValueRange results) {
+static async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp,
+ ValueRange results) {
// Add values to async.yield op.
Operation *yieldOp = executeOp.getBody()->getTerminator();
yieldOp->insertOperands(yieldOp->getNumOperands(), results);
diff --git a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
index 3c44733..95d5cad 100644
--- a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
@@ -39,10 +39,10 @@ void GpuModuleToBinaryPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
auto targetFormat =
llvm::StringSwitch<std::optional<CompilationTarget>>(compilationTarget)
- .Cases("offloading", "llvm", CompilationTarget::Offload)
- .Cases("assembly", "isa", CompilationTarget::Assembly)
- .Cases("binary", "bin", CompilationTarget::Binary)
- .Cases("fatbinary", "fatbin", CompilationTarget::Fatbin)
+ .Cases({"offloading", "llvm"}, CompilationTarget::Offload)
+ .Cases({"assembly", "isa"}, CompilationTarget::Assembly)
+ .Cases({"binary", "bin"}, CompilationTarget::Binary)
+ .Cases({"fatbinary", "fatbin"}, CompilationTarget::Fatbin)
.Default(std::nullopt);
if (!targetFormat)
getOperation()->emitError() << "Invalid format specified.";
diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
index 212ccc9..8d10aac 100644
--- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
+++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
@@ -169,7 +169,7 @@ LogicalResult getSegmentSizes(Operation *op, StringRef elemName,
LogicalResult getOperandSegmentSizes(Operation *op,
ArrayRef<Variadicity> variadicities,
SmallVectorImpl<int> &segmentSizes) {
- return getSegmentSizes(op, "operand", "operand_segment_sizes",
+ return getSegmentSizes(op, "operand", "operandSegmentSizes",
op->getNumOperands(), variadicities, segmentSizes);
}
@@ -180,7 +180,7 @@ LogicalResult getOperandSegmentSizes(Operation *op,
LogicalResult getResultSegmentSizes(Operation *op,
ArrayRef<Variadicity> variadicities,
SmallVectorImpl<int> &segmentSizes) {
- return getSegmentSizes(op, "result", "result_segment_sizes",
+ return getSegmentSizes(op, "result", "resultSegmentSizes",
op->getNumResults(), variadicities, segmentSizes);
}
diff --git a/mlir/lib/Dialect/Index/IR/IndexDialect.cpp b/mlir/lib/Dialect/Index/IR/IndexDialect.cpp
index 183d0e3..887e8e1 100644
--- a/mlir/lib/Dialect/Index/IR/IndexDialect.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexDialect.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Transforms/InliningUtils.h"
using namespace mlir;
using namespace mlir::index;
@@ -15,10 +16,23 @@ using namespace mlir::index;
//===----------------------------------------------------------------------===//
// IndexDialect
//===----------------------------------------------------------------------===//
+namespace {
+/// This class defines the interface for handling inlining for index
+/// dialect operations.
+struct IndexInlinerInterface : public DialectInlinerInterface {
+ using DialectInlinerInterface::DialectInlinerInterface;
+
+ /// All index dialect ops can be inlined.
+ bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
+ return true;
+ }
+};
+} // namespace
void IndexDialect::initialize() {
registerAttributes();
registerOperations();
+ addInterfaces<IndexInlinerInterface>();
declarePromisedInterface<ConvertToLLVMPatternInterface, IndexDialect>();
}
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index cc66fac..a73f0c1 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -31,6 +31,7 @@ add_mlir_dialect_library(MLIRLLVMDialect
MLIRControlFlowInterfaces
MLIRDataLayoutInterfaces
MLIRFunctionInterfaces
+ MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRIR
MLIRMemorySlotInterfaces
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index feaffa3..160b6ae 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -30,6 +30,7 @@ static constexpr llvm::StringRef kPrintF16 = "printF16";
static constexpr llvm::StringRef kPrintBF16 = "printBF16";
static constexpr llvm::StringRef kPrintF32 = "printF32";
static constexpr llvm::StringRef kPrintF64 = "printF64";
+static constexpr llvm::StringRef kPrintApFloat = "printApFloat";
static constexpr llvm::StringRef kPrintString = "printString";
static constexpr llvm::StringRef kPrintOpen = "printOpen";
static constexpr llvm::StringRef kPrintClose = "printClose";
@@ -160,6 +161,16 @@ mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
+ return lookupOrCreateReservedFn(
+ b, moduleOp, kPrintApFloat,
+ {IntegerType::get(moduleOp->getContext(), 32),
+ IntegerType::get(moduleOp->getContext(), 64)},
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
+}
+
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
return LLVM::LLVMPointerType::get(context);
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index b8331e0..9f87e50 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -219,11 +219,16 @@ bool TBAANodeAttr::classof(Attribute attr) {
MemoryEffectsAttr MemoryEffectsAttr::get(MLIRContext *context,
ArrayRef<ModRefInfo> memInfoArgs) {
if (memInfoArgs.empty())
- return MemoryEffectsAttr::get(context, ModRefInfo::ModRef,
- ModRefInfo::ModRef, ModRefInfo::ModRef);
- if (memInfoArgs.size() == 3)
+ return MemoryEffectsAttr::get(context, /*other=*/ModRefInfo::ModRef,
+ /*argMem=*/ModRefInfo::ModRef,
+ /*inaccessibleMem=*/ModRefInfo::ModRef,
+ /*errnoMem=*/ModRefInfo::ModRef,
+ /*targetMem0=*/ModRefInfo::ModRef,
+ /*targetMem1=*/ModRefInfo::ModRef);
+ if (memInfoArgs.size() == 6)
return MemoryEffectsAttr::get(context, memInfoArgs[0], memInfoArgs[1],
- memInfoArgs[2]);
+ memInfoArgs[2], memInfoArgs[3],
+ memInfoArgs[4], memInfoArgs[5]);
return {};
}
@@ -234,6 +239,12 @@ bool MemoryEffectsAttr::isReadWrite() {
return false;
if (this->getOther() != ModRefInfo::ModRef)
return false;
+ if (this->getErrnoMem() != ModRefInfo::ModRef)
+ return false;
+ if (this->getTargetMem0() != ModRefInfo::ModRef)
+ return false;
+ if (this->getTargetMem1() != ModRefInfo::ModRef)
+ return false;
return true;
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 2731069..5b81948 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -640,8 +640,6 @@ SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
// Code for LLVM::GEPOp.
//===----------------------------------------------------------------------===//
-constexpr int32_t GEPOp::kDynamicIndex;
-
GEPIndicesAdaptor<ValueRange> GEPOp::getIndices() {
return GEPIndicesAdaptor<ValueRange>(getRawConstantIndicesAttr(),
getDynamicIndices());
@@ -4226,6 +4224,34 @@ LogicalResult InlineAsmOp::verify() {
}
//===----------------------------------------------------------------------===//
+// UDivOp
+//===----------------------------------------------------------------------===//
+Speculation::Speculatability UDivOp::getSpeculatability() {
+ // X / 0 => UB
+ Value divisor = getRhs();
+ if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
+ return Speculation::Speculatable;
+
+ return Speculation::NotSpeculatable;
+}
+
+//===----------------------------------------------------------------------===//
+// SDivOp
+//===----------------------------------------------------------------------===//
+Speculation::Speculatability SDivOp::getSpeculatability() {
+ // This function conservatively assumes that all signed division by -1 are
+ // not speculatable.
+ // X / 0 => UB
+ // INT_MIN / -1 => UB
+ Value divisor = getRhs();
+ if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
+ matchPattern(divisor, m_IntRangeWithoutNegOneS()))
+ return Speculation::Speculatable;
+
+ return Speculation::NotSpeculatable;
+}
+
+//===----------------------------------------------------------------------===//
// LLVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index ce93d18..5dc4fa2 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -667,6 +667,7 @@ LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries,
static constexpr llvm::StringRef kSpirvPrefix = "spirv.";
static constexpr llvm::StringRef kArmSVCount = "aarch64.svcount";
+static constexpr llvm::StringRef kAMDGCNNamedBarrier = "amdgcn.named.barrier";
bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const {
// See llvm/lib/IR/Type.cpp for reference.
@@ -676,6 +677,9 @@ bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const {
properties |=
(LLVMTargetExtType::HasZeroInit | LLVM::LLVMTargetExtType::CanBeGlobal);
+ if (getExtTypeName() == kAMDGCNNamedBarrier)
+ properties |= LLVMTargetExtType::CanBeGlobal;
+
return (properties & prop) == prop;
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index a5ffb9e..5ce56e6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -31,6 +31,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/NVVMIntrinsicUtils.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/NVPTXAddrSpace.h"
@@ -48,6 +49,47 @@ using namespace NVVM;
static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
//===----------------------------------------------------------------------===//
+// Helper/Utility methods
+//===----------------------------------------------------------------------===//
+
+static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
+ auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
+ return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
+}
+
+static bool isPtrInGenericSpace(mlir::Value ptr) {
+ return isPtrInAddrSpace(ptr, NVVMMemorySpace::Generic);
+}
+
+static bool isPtrInSharedCTASpace(mlir::Value ptr) {
+ return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
+}
+
+static bool isPtrInSharedClusterSpace(mlir::Value ptr) {
+ return isPtrInAddrSpace(ptr, NVVMMemorySpace::SharedCluster);
+}
+
+static llvm::Value *castPtrToAddrSpace(llvm::IRBuilderBase &builder,
+ llvm::Value *ptr,
+ NVVMMemorySpace targetAS) {
+ unsigned AS = static_cast<unsigned>(targetAS);
+ return builder.CreateAddrSpaceCast(
+ ptr, llvm::PointerType::get(builder.getContext(), AS));
+}
+
+// Helper method to convert CtaGroupKind in NVVM Dialect to CtaGroupKind in LLVM
+static llvm::nvvm::CTAGroupKind
+getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup) {
+ switch (ctaGroup) {
+ case NVVM::CTAGroupKind::CTA_1:
+ return llvm::nvvm::CTAGroupKind::CG_1;
+ case NVVM::CTAGroupKind::CTA_2:
+ return llvm::nvvm::CTAGroupKind::CG_2;
+ }
+ llvm_unreachable("unsupported cta_group value");
+}
+
+//===----------------------------------------------------------------------===//
// Verifier methods
//===----------------------------------------------------------------------===//
@@ -199,6 +241,83 @@ LogicalResult CpAsyncBulkTensorReduceOp::verify() {
return success();
}
+LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
+ bool isSharedCTA = isPtrInSharedCTASpace(getDstMem());
+ if (isSharedCTA && getMulticastMask())
+ return emitError("Multicast is not supported with shared::cta mode.");
+
+ return success();
+}
+
+static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr,
+ NVVM::MemScopeKind scope,
+ Value retVal = nullptr) {
+ if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER)
+ return op->emitError("mbarrier scope must be either CTA or Cluster");
+
+ bool isSharedCluster = isPtrInSharedClusterSpace(addr);
+ bool hasRetValue = static_cast<bool>(retVal);
+ if (isSharedCluster && hasRetValue)
+ return op->emitError(
+ "mbarrier in shared_cluster space cannot return any value");
+
+ return success();
+}
+
+LogicalResult MBarrierArriveOp::verify() {
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
+ getRes());
+}
+
+LogicalResult MBarrierArriveDropOp::verify() {
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
+ getRes());
+}
+
+LogicalResult MBarrierArriveExpectTxOp::verify() {
+ // The inline-ptx version of this Op does not support all features.
+ // With predicate, this Op lowers to inline-ptx. So, verify and
+ // error-out if there are unsupported features.
+ if (getPredicate()) {
+ if (getScope() != NVVM::MemScopeKind::CTA)
+ return emitError("mbarrier scope must be CTA when using predicate");
+
+ if (isPtrInSharedClusterSpace(getAddr()))
+ return emitError("mbarrier in shared_cluster space is not supported when "
+ "using predicate");
+
+ if (getRes())
+ return emitError("return-value is not supported when using predicate");
+
+ if (getRelaxed() == true)
+ return emitError("mbarrier with relaxed semantics is not supported when "
+ "using predicate");
+ }
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
+ getRes());
+}
+
+LogicalResult MBarrierArriveDropExpectTxOp::verify() {
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
+ getRes());
+}
+
+LogicalResult MBarrierExpectTxOp::verify() {
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
+}
+
+LogicalResult MBarrierCompleteTxOp::verify() {
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
+}
+
+LogicalResult MBarrierTestWaitOp::verify() {
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
+}
+
+LogicalResult MBarrierTryWaitOp::verify() {
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
+}
+
LogicalResult ConvertFloatToTF32Op::verify() {
using RndMode = NVVM::FPRoundingMode;
switch (getRnd()) {
@@ -365,6 +484,108 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() {
return success();
}
+LogicalResult PermuteOp::verify() {
+ using Mode = NVVM::PermuteMode;
+ bool hasHi = static_cast<bool>(getHi());
+
+ switch (getMode()) {
+ case Mode::DEFAULT:
+ case Mode::F4E:
+ case Mode::B4E:
+ if (!hasHi)
+ return emitError("mode '")
+ << stringifyPermuteMode(getMode()) << "' requires 'hi' operand.";
+ break;
+ case Mode::RC8:
+ case Mode::ECL:
+ case Mode::ECR:
+ case Mode::RC16:
+ if (hasHi)
+ return emitError("mode '") << stringifyPermuteMode(getMode())
+ << "' does not accept 'hi' operand.";
+ break;
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Stochastic Rounding Conversion Ops
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType,
+ FPRoundingMode rnd,
+ bool hasRandomBits,
+ Operation *op) {
+ static constexpr FPRoundingMode validRndModes[] = {
+ FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS};
+
+ if (!llvm::is_contained(validRndModes, rnd)) {
+ return op->emitOpError(
+ "Only RN, RZ, and RS rounding modes are supported for "
+ "conversions from f32x2 to ")
+ << dstType << ".";
+ }
+
+ if (rnd == FPRoundingMode::RS) {
+ if (!hasRandomBits) {
+ return op->emitOpError("random_bits is required for RS rounding mode.");
+ }
+ } else {
+ if (hasRandomBits) {
+ return op->emitOpError(
+ "random_bits not supported for RN and RZ rounding modes.");
+ }
+ }
+
+ return success();
+}
+
+LogicalResult ConvertF32x2ToF16x2Op::verify() {
+ return verifyConvertF32x2ToFP16x2Op("f16x2", getRnd(),
+ getRandomBits() ? true : false, *this);
+}
+
+LogicalResult ConvertF32x2ToBF16x2Op::verify() {
+ return verifyConvertF32x2ToFP16x2Op("bf16x2", getRnd(),
+ getRandomBits() ? true : false, *this);
+}
+
+LogicalResult ConvertF32x4ToF8x4Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx)
+ << " types are supported for conversions from f32x4 to f8x4.";
+
+ return success();
+}
+
+LogicalResult ConvertF32x4ToF6x4Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float6E2M3FNType::get(ctx) << " and "
+ << mlir::Float6E3M2FNType::get(ctx)
+ << " types are supported for conversions from f32x4 to f6x4.";
+
+ return success();
+}
+
+LogicalResult ConvertF32x4ToF4x4Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+ return emitOpError("Only ") << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from "
+ "f32x4 to f4x4.";
+
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -866,16 +1087,517 @@ LogicalResult MmaOp::verify() {
return success();
}
-LogicalResult ShflOp::verify() {
- if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
+MMATypes MmaSpOp::accumPtxType() {
+ std::optional<mlir::NVVM::MMATypes> val = MmaOp::inferOperandMMAType(
+ getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
+ assert(val.has_value() && "accumulator PTX type should always be inferrable");
+ return val.value();
+}
+
+MMATypes MmaSpOp::resultPtxType() {
+ std::optional<mlir::NVVM::MMATypes> val =
+ MmaOp::inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
+ assert(val.has_value() && "result PTX type should always be inferrable");
+ return val.value();
+}
+
+mlir::NVVM::IDArgPair
+MmaSpOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MmaSpOp>(op);
+
+ // Get operands
+ llvm::SmallVector<llvm::Value *> args;
+ for (mlir::Value v : thisOp.getOperands())
+ args.push_back(mt.lookupValue(v));
+
+ // Get intrinsic ID using the existing getIntrinsicID method
+ auto intId = MmaSpOp::getIntrinsicID(
+ thisOp.getShape().getM(), thisOp.getShape().getN(),
+ thisOp.getShape().getK(), thisOp.getIntOverflowBehavior(),
+ thisOp.getOrderedMetadata(), thisOp.getKind(),
+ *thisOp.getMultiplicandAPtxType(), *thisOp.getMultiplicandBPtxType(),
+ thisOp.accumPtxType(), thisOp.resultPtxType());
+
+ return {intId, args};
+}
+
+void MmaSpOp::print(OpAsmPrinter &p) {
+ SmallVector<Type, 4> regTypes;
+ struct OperandFragment {
+ StringRef operandName;
+ StringRef ptxTypeAttr;
+ SmallVector<Value, 4> regs;
+ explicit OperandFragment(StringRef name, StringRef ptxTypeName)
+ : operandName(name), ptxTypeAttr(ptxTypeName) {}
+ };
+
+ std::array<OperandFragment, 5> frags{
+ OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
+ OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
+ OperandFragment("C", ""), OperandFragment("sparseMetadata", ""),
+ OperandFragment("selector", "")};
+ SmallVector<StringRef, 4> ignoreAttrNames{
+ mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()};
+
+ // Handle variadic operands A, B, C
+ for (unsigned fragIdx = 0; fragIdx < 3; fragIdx++) {
+ auto &frag = frags[fragIdx];
+ auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
+ for (auto operandIdx = varOperandSpec.first;
+ operandIdx < varOperandSpec.first + varOperandSpec.second;
+ operandIdx++) {
+ frag.regs.push_back(this->getOperand(operandIdx));
+ if (operandIdx == varOperandSpec.first) {
+ regTypes.push_back(this->getOperand(operandIdx).getType());
+ }
+ }
+ std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
+ regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
+ if (inferredType)
+ ignoreAttrNames.push_back(frag.ptxTypeAttr);
+ }
+
+ // Handle sparse metadata and selector (single operands)
+ frags[3].regs.push_back(getSparseMetadata());
+ frags[4].regs.push_back(getSparsitySelector());
+
+ auto printMmaSpOperand = [&](const OperandFragment &frag) -> void {
+ p << " " << frag.operandName;
+ p << "[";
+ p.printOperands(frag.regs);
+ p << "]";
+ };
+
+ for (const auto &frag : frags)
+ printMmaSpOperand(frag);
+
+ p.printOptionalAttrDict((*this)->getAttrs(), ignoreAttrNames);
+ p << " : ";
+ p << "(";
+ for (int i = 0; i < 3; ++i) {
+ p << regTypes[i];
+ if (i < 2)
+ p << ", ";
+ }
+ p << ") -> " << getResult().getType();
+}
+
+void MmaSpOp::build(
+ OpBuilder &builder, OperationState &result, Type resultType,
+ ValueRange operandA, ValueRange operandB, ValueRange operandC,
+ Value sparseMetadata, Value sparsitySelector, ArrayRef<int64_t> shape,
+ std::optional<MMAIntOverflow> intOverflow,
+ std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
+
+ assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
+ MLIRContext *ctx = builder.getContext();
+ result.addAttribute(
+ "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
+
+ result.addOperands(operandA);
+ result.addOperands(operandB);
+ result.addOperands(operandC);
+ result.addOperands(sparseMetadata);
+ result.addOperands(sparsitySelector);
+
+ if (multiplicandPtxTypes) {
+ result.addAttribute("multiplicandAPtxType",
+ MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
+ result.addAttribute("multiplicandBPtxType",
+ MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
+ } else {
+ if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
+ result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
+ if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
+ result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
+ }
+
+ if (intOverflow.has_value())
+ result.addAttribute("intOverflowBehavior",
+ MMAIntOverflowAttr::get(ctx, *intOverflow));
+
+ result.addTypes(resultType);
+ result.addAttribute(
+ MmaSpOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
+ static_cast<int32_t>(operandB.size()),
+ static_cast<int32_t>(operandC.size()), 1,
+ 1})); // sparseMetadata and sparsitySelector
+}
+
+ParseResult MmaSpOp::parse(OpAsmParser &parser, OperationState &result) {
+ struct OperandFragment {
+ std::optional<MMATypes> elemtype;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
+ SmallVector<Type> regTypes;
+ };
+
+ Builder &builder = parser.getBuilder();
+ std::array<OperandFragment, 6> frags; // A, B, C, sparseMetadata, selector
+
+ NamedAttrList namedAttributes;
+
+ // A helper to parse the operand segments.
+ auto parseMmaSpOperand = [&](StringRef operandName,
+ OperandFragment &frag) -> LogicalResult {
+ if (parser.parseKeyword(operandName).failed())
+ return failure();
+ if (parser
+ .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
+ .failed())
+ return failure();
return success();
- auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
- auto elementType = (type && type.getBody().size() == 2)
- ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
- : nullptr;
- if (!elementType || elementType.getWidth() != 1)
- return emitError("expected return type to be a two-element struct with "
- "i1 as the second element");
+ };
+
+ // Parse the operand segments.
+ if (parseMmaSpOperand("A", frags[0]).failed())
+ return failure();
+ if (parseMmaSpOperand("B", frags[1]).failed())
+ return failure();
+ if (parseMmaSpOperand("C", frags[2]).failed())
+ return failure();
+ if (parseMmaSpOperand("sparseMetadata", frags[3]).failed())
+ return failure();
+ if (parseMmaSpOperand("selector", frags[4]).failed())
+ return failure();
+
+ if (parser.parseOptionalAttrDict(namedAttributes).failed())
+ return failure();
+
+ // Parse the type specification and resolve operands.
+ SmallVector<Type, 3> operandTypes;
+ if (failed(parser.parseColon()))
+ return failure();
+ if (failed(parser.parseLParen()))
+ return failure();
+ if (failed(parser.parseTypeList(operandTypes)))
+ return failure();
+ if (failed(parser.parseRParen()))
+ return failure();
+ if (operandTypes.size() != 3)
+ return parser.emitError(
+ parser.getNameLoc(),
+ "expected one type for each operand segment but got " +
+ Twine(operandTypes.size()) + " types");
+ for (const auto &iter : llvm::enumerate(operandTypes)) {
+ auto &frag = frags[iter.index()];
+ frag.regTypes.resize(frag.regs.size(), iter.value());
+ if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
+ parser.getNameLoc(), result.operands)))
+ return failure();
+ frag.elemtype =
+ MmaOp::inferOperandMMAType(frag.regTypes[0],
+ /*isAccumulator*/ iter.index() >= 2);
+ }
+
+ Type resultType;
+ if (parser.parseArrow() || parser.parseType(resultType))
+ return failure();
+ frags[5].elemtype =
+ MmaOp::inferOperandMMAType(resultType, /*isAccumulator*/ true);
+
+ // Resolve sparse metadata and selector (assume i32 type)
+ Type i32Type = builder.getIntegerType(32);
+ if (parser
+ .resolveOperands(frags[3].regs, i32Type, parser.getCurrentLocation(),
+ result.operands)
+ .failed())
+ return failure();
+ if (parser
+ .resolveOperands(frags[4].regs, i32Type, parser.getCurrentLocation(),
+ result.operands)
+ .failed())
+ return failure();
+
+ std::array<StringRef, 2> names{"multiplicandAPtxType",
+ "multiplicandBPtxType"};
+ for (unsigned idx = 0; idx < names.size(); idx++) {
+ const auto &frag = frags[idx];
+ std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
+ if (!frag.elemtype.has_value() && !attr.has_value()) {
+ return parser.emitError(
+ parser.getNameLoc(),
+ "attribute " + names[idx] +
+ " is not provided explicitly and cannot be inferred");
+ }
+ if (!attr.has_value())
+ result.addAttribute(
+ names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
+ }
+
+ result.addTypes(resultType);
+ if (!namedAttributes.empty())
+ result.addAttributes(namedAttributes);
+ result.addAttribute(MmaSpOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({
+ static_cast<int32_t>(frags[0].regs.size()),
+ static_cast<int32_t>(frags[1].regs.size()),
+ static_cast<int32_t>(frags[2].regs.size()),
+ 1, // sparseMetadata
+ 1 // sparsitySelector
+ }));
+ return success();
+}
+
+LogicalResult MmaSpOp::verify() {
+ MLIRContext *context = getContext();
+ auto f16Ty = Float16Type::get(context);
+ auto i32Ty = IntegerType::get(context, 32);
+ auto f16x2Ty = VectorType::get(2, f16Ty);
+ auto f32Ty = Float32Type::get(context);
+ auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
+ context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
+
+ auto s32x4StructTy =
+ LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
+ auto f32x8StructTy =
+ LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
+ auto f16x2x2StructTy =
+ LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
+ auto f32x4StructTy =
+ LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
+ auto s32x2StructTy =
+ LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
+
+ std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
+ getShapeAttr().getK()};
+
+ // These variables define the set of allowed data types for matrices A, B, C,
+ // and result.
+ using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
+ using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
+ AllowedShapes allowedShapes;
+ AllowedTypes expectedA;
+ AllowedTypes expectedB;
+ AllowedTypes expectedC;
+ SmallVector<Type> expectedResult;
+
+ // When M = 16, we just need to calculate the number of 8xk tiles, where
+ // k is a factor that depends on the data type.
+ if (mmaShape[0] == 16) {
+ int64_t kFactor;
+ Type multiplicandFragType;
+ switch (*getMultiplicandAPtxType()) {
+ case MMATypes::tf32:
+ kFactor = 4;
+ multiplicandFragType = i32Ty;
+ expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
+ context, {f32Ty, f32Ty, f32Ty, f32Ty}));
+ // Sparse MMA supports m16n8k8 and m16n8k16 for tf32
+ allowedShapes.push_back({16, 8, 8});
+ allowedShapes.push_back({16, 8, 16});
+ break;
+ case MMATypes::bf16:
+ kFactor = 8;
+ multiplicandFragType = i32Ty;
+ expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
+ context, {f32Ty, f32Ty, f32Ty, f32Ty}));
+ // Sparse MMA supports m16n8k16 and m16n8k32 for bf16
+ allowedShapes.push_back({16, 8, 16});
+ allowedShapes.push_back({16, 8, 32});
+ break;
+ case MMATypes::f16:
+ kFactor = 8;
+ multiplicandFragType = f16x2Ty;
+ expectedResult.push_back(f16x2x2StructTy);
+ expectedResult.push_back(f32x4StructTy);
+ // Sparse MMA supports m16n8k16 and m16n8k32 for f16
+ allowedShapes.push_back({16, 8, 16});
+ allowedShapes.push_back({16, 8, 32});
+ break;
+ case MMATypes::s4:
+ case MMATypes::u4:
+ kFactor = 32;
+ // Sparse MMA supports m16n8k64 and m16n8k128 for s4/u4
+ allowedShapes.push_back({16, 8, 64});
+ allowedShapes.push_back({16, 8, 128});
+ break;
+ case MMATypes::s8:
+ case MMATypes::u8:
+ kFactor = 16;
+ // Sparse MMA supports m16n8k32 and m16n8k64 for s8/u8
+ allowedShapes.push_back({16, 8, 32});
+ allowedShapes.push_back({16, 8, 64});
+ break;
+ case MMATypes::e4m3:
+ case MMATypes::e5m2:
+ case MMATypes::e3m2:
+ case MMATypes::e2m3:
+ case MMATypes::e2m1:
+ kFactor = 32;
+ multiplicandFragType = i32Ty;
+ expectedResult.push_back(f16x2x2StructTy);
+ expectedResult.push_back(f32x4StructTy);
+ // Sparse MMA supports m16n8k64 for FP8 types
+ allowedShapes.push_back({16, 8, 64});
+ break;
+ default:
+ return emitError("invalid shape or multiplicand type: " +
+ stringifyEnum(getMultiplicandAPtxType().value()));
+ }
+
+ if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
+ expectedResult.push_back(s32x4StructTy);
+ expectedC.emplace_back(4, i32Ty);
+ multiplicandFragType = i32Ty;
+ } else if (*getMultiplicandAPtxType() >= MMATypes::e4m3 &&
+ *getMultiplicandAPtxType() <= MMATypes::e2m1) {
+ // FP8 types
+ expectedC.emplace_back(2, f16x2Ty);
+ expectedC.emplace_back(4, f32Ty);
+ } else {
+ expectedC.emplace_back(2, f16x2Ty);
+ expectedC.emplace_back(4, f32Ty);
+ }
+
+ // For sparse MMA, A operand is compressed (2:4 sparsity means half the
+ // elements)
+ int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor) / 2;
+ int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
+ expectedA.emplace_back(unitA, multiplicandFragType);
+ expectedB.emplace_back(unitB, multiplicandFragType);
+
+ if (resultPtxType() != accumPtxType())
+ return emitOpError("ctype does not match dtype");
+ }
+
+ // In the M=8 case, there is only 1 possible case per data type.
+ if (mmaShape[0] == 8) {
+ if (*getMultiplicandAPtxType() == MMATypes::f16) {
+ expectedA.emplace_back(2, f16x2Ty);
+ expectedB.emplace_back(2, f16x2Ty);
+ expectedResult.push_back(f16x2x4StructTy);
+ expectedResult.push_back(f32x8StructTy);
+ expectedC.emplace_back(4, f16x2Ty);
+ expectedC.emplace_back(8, f32Ty);
+ allowedShapes.push_back({8, 8, 4});
+ }
+ if (*getMultiplicandAPtxType() == MMATypes::f64) {
+ Type f64Ty = Float64Type::get(context);
+ expectedA.emplace_back(1, f64Ty);
+ expectedB.emplace_back(1, f64Ty);
+ expectedC.emplace_back(2, f64Ty);
+ expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
+ context, SmallVector<Type>(2, f64Ty)));
+ allowedShapes.push_back({8, 8, 4});
+ }
+ if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
+ expectedA.push_back({i32Ty});
+ expectedB.push_back({i32Ty});
+ expectedC.push_back({i32Ty, i32Ty});
+ expectedResult.push_back(s32x2StructTy);
+ if (isInt4PtxType(getMultiplicandAPtxType().value()))
+ allowedShapes.push_back({8, 8, 32});
+ if (isInt8PtxType(getMultiplicandAPtxType().value()))
+ allowedShapes.push_back({8, 8, 16});
+ }
+ }
+
+ std::string errorMessage;
+ llvm::raw_string_ostream errorStream(errorMessage);
+
+ // Check that we matched an existing shape/dtype combination.
+ if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
+ !llvm::is_contained(allowedShapes, mmaShape)) {
+ errorStream << "unimplemented variant for MMA shape <";
+ llvm::interleaveComma(mmaShape, errorStream);
+ errorStream << ">";
+ return emitOpError(errorMessage);
+ }
+
+ // Verify the operand types for segments of A, B, and C operands.
+ std::array<StringRef, 3> operandNames{"A", "B", "C"};
+ for (const auto &iter : llvm::enumerate(
+ SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
+ auto spec = this->getODSOperandIndexAndLength(iter.index());
+ SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
+ operand_type_begin() + spec.first +
+ spec.second);
+ bool match = llvm::is_contained(iter.value(), operandTySeg);
+
+ if (!match) {
+ errorStream << "Could not match types for the "
+ << operandNames[iter.index()]
+ << " operands; expected one of ";
+ for (const auto &x : iter.value()) {
+ errorStream << x.size() << "x" << x[0] << " ";
+ }
+ errorStream << "but got ";
+ llvm::interleaveComma(operandTySeg, errorStream);
+ return emitOpError(errorMessage);
+ }
+ }
+
+ // Check the result type
+ if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
+ return expectedResultType == getResult().getType();
+ })) {
+ errorStream
+ << "Could not match allowed types for the result; expected one of ";
+ llvm::interleaveComma(expectedResult, errorStream);
+ errorStream << " but got " << getResult().getType();
+ return emitOpError(errorMessage);
+ }
+
+ // Ensure int4/int8 MMA variants specify the accum overflow behavior
+ // attribute.
+ if (isInt4PtxType(*getMultiplicandAPtxType()) ||
+ isInt8PtxType(*getMultiplicandAPtxType())) {
+ if (!getIntOverflowBehavior())
+ return emitOpError("op requires " +
+ getIntOverflowBehaviorAttrName().strref() +
+ " attribute");
+ }
+
+ // Validate sparse metadata type (should be i32)
+ if (!getSparseMetadata().getType().isInteger(32)) {
+ return emitOpError() << "sparse metadata must be i32 type";
+ }
+
+ // Validate sparsity selector type (should be i32)
+ if (!getSparsitySelector().getType().isInteger(32)) {
+ return emitOpError() << "sparsity selector must be i32 type";
+ }
+
+ return success();
+}
+
+LogicalResult ShflOp::verify() {
+ auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
+
+ auto verifyTypeError = [&](Twine desc, Type expectedType,
+ Type actualType) -> LogicalResult {
+ return emitOpError("expected " + desc + " to be of type ")
+ << expectedType << " but got " << actualType << " instead";
+ };
+
+ if (returnStructType) {
+ if (!getReturnValueAndIsValid())
+ return emitOpError("\"return_value_and_is_valid\" attribute must be "
+ "specified when the return type is a struct type");
+
+ if (returnStructType.getBody().size() != 2)
+ return emitOpError("expected return type to be a two-element struct");
+
+ llvm::ArrayRef<Type> returnStruct = returnStructType.getBody();
+ auto resultType = returnStruct[0];
+ if (resultType != getVal().getType())
+ return verifyTypeError("first element in the returned struct",
+ getVal().getType(), resultType);
+
+ auto predicateType = returnStruct[1];
+ if (!predicateType.isInteger(1))
+ return verifyTypeError("second element in the returned struct",
+ mlir::IntegerType::get(getContext(), 1),
+ predicateType);
+ } else {
+ if (getReturnValueAndIsValid())
+ return emitOpError("expected return type to be a two-element struct");
+
+ if (getType() != getVal().getType())
+ return verifyTypeError("return type", getVal().getType(), getType());
+ }
return success();
}
@@ -1376,6 +2098,13 @@ bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
return true; // Has manual mapping
}
+LogicalResult NVVM::FenceSyncRestrictOp::verify() {
+ if (getOrder() != NVVM::MemOrderKind::ACQUIRE &&
+ getOrder() != NVVM::MemOrderKind::RELEASE)
+ return emitOpError("only acquire and release semantics are supported");
+ return success();
+}
+
LogicalResult NVVM::FenceProxyOp::verify() {
if (getKind() == NVVM::ProxyKind::TENSORMAP)
return emitOpError() << "tensormap proxy is not a supported proxy kind";
@@ -1398,7 +2127,6 @@ LogicalResult NVVM::FenceProxyAcquireOp::verify() {
if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
return emitOpError("uni-directional proxies only support tensormap "
"for to_proxy attribute");
-
return success();
}
@@ -1410,7 +2138,19 @@ LogicalResult NVVM::FenceProxyReleaseOp::verify() {
if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
return emitOpError("uni-directional proxies only support tensormap "
"for to_proxy attribute");
+ return success();
+}
+
+LogicalResult NVVM::FenceProxySyncRestrictOp::verify() {
+ if (getOrder() != NVVM::MemOrderKind::ACQUIRE &&
+ getOrder() != NVVM::MemOrderKind::RELEASE)
+ return emitOpError("only acquire and release semantics are supported");
+ if (getFromProxy() != NVVM::ProxyKind::GENERIC)
+ return emitOpError("only generic is support for from_proxy attribute");
+
+ if (getToProxy() != NVVM::ProxyKind::async)
+ return emitOpError("only async is supported for to_proxy attribute");
return success();
}
@@ -1426,6 +2166,15 @@ LogicalResult NVVM::BarrierOp::verify() {
if (getNumberOfThreads() && !getBarrierId())
return emitOpError(
"barrier id is missing, it should be set between 0 to 15");
+
+ if (getBarrierId() && (getReductionOp() || getReductionPredicate()))
+ return emitOpError("reduction are only available when id is 0");
+
+ if ((getReductionOp() && !getReductionPredicate()) ||
+ (!getReductionOp() && getReductionPredicate()))
+ return emitOpError("reduction predicate and reduction operation must be "
+ "specified together");
+
return success();
}
@@ -1577,6 +2326,43 @@ LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
return success();
}
+LogicalResult NVVM::ReduxOp::verify() {
+ mlir::Type reduxType = getType();
+
+ if (!reduxType.isF32()) {
+ if (getAbs())
+ return emitOpError("abs attribute is supported only for f32 type");
+ if (getNan())
+ return emitOpError("nan attribute is supported only for f32 type");
+ }
+
+ NVVM::ReduxKind kind = getKind();
+ switch (kind) {
+ case NVVM::ReduxKind::ADD:
+ case NVVM::ReduxKind::AND:
+ case NVVM::ReduxKind::OR:
+ case NVVM::ReduxKind::XOR:
+ case NVVM::ReduxKind::MAX:
+ case NVVM::ReduxKind::MIN:
+ case NVVM::ReduxKind::UMAX:
+ case NVVM::ReduxKind::UMIN:
+ if (!reduxType.isInteger(32))
+ return emitOpError("'")
+ << stringifyEnum(kind) << "' redux kind unsupported with "
+ << reduxType << " type. Only supported type is 'i32'.";
+ break;
+ case NVVM::ReduxKind::FMIN:
+ case NVVM::ReduxKind::FMAX:
+ if (!reduxType.isF32())
+ return emitOpError("'")
+ << stringifyEnum(kind) << "' redux kind unsupported with "
+ << reduxType << " type. Only supported type is 'f32'.";
+ break;
+ }
+
+ return success();
+}
+
/// Packs the given `field` into the `result`.
/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
static llvm::Value *
@@ -1626,26 +2412,76 @@ void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
//===----------------------------------------------------------------------===//
std::string NVVM::MBarrierInitOp::getPtx() {
- unsigned addressSpace =
- llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
- return (addressSpace == NVVMMemorySpace::Shared)
- ? std::string("mbarrier.init.shared.b64 [%0], %1;")
- : std::string("mbarrier.init.b64 [%0], %1;");
+ bool isShared = isPtrInSharedCTASpace(getAddr());
+ return isShared ? std::string("mbarrier.init.shared.b64 [%0], %1;")
+ : std::string("mbarrier.init.b64 [%0], %1;");
+}
+
+std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
+ bool isShared = isPtrInSharedCTASpace(getAddr());
+ return isShared
+ ? std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
+ : std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
+}
+
+std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
+ bool isShared = isPtrInSharedCTASpace(getAddr());
+ llvm::StringRef space = isShared ? ".shared" : "";
+
+ return llvm::formatv("{\n\t"
+ ".reg .pred P1; \n\t"
+ "LAB_WAIT: \n\t"
+ "mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t"
+ "@P1 bra.uni DONE; \n\t"
+ "bra.uni LAB_WAIT; \n\t"
+ "DONE: \n\t"
+ "}",
+ space);
}
//===----------------------------------------------------------------------===//
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//
+mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::BarrierOp>(op);
+ llvm::Value *barrierId = thisOp.getBarrierId()
+ ? mt.lookupValue(thisOp.getBarrierId())
+ : builder.getInt32(0);
+ llvm::Intrinsic::ID id;
+ llvm::SmallVector<llvm::Value *> args;
+ if (thisOp.getNumberOfThreads()) {
+ id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count;
+ args.push_back(barrierId);
+ args.push_back(mt.lookupValue(thisOp.getNumberOfThreads()));
+ } else if (thisOp.getReductionOp()) {
+ switch (*thisOp.getReductionOp()) {
+ case NVVM::BarrierReduction::AND:
+ id = llvm::Intrinsic::nvvm_barrier0_and;
+ break;
+ case NVVM::BarrierReduction::OR:
+ id = llvm::Intrinsic::nvvm_barrier0_or;
+ break;
+ case NVVM::BarrierReduction::POPC:
+ id = llvm::Intrinsic::nvvm_barrier0_popc;
+ break;
+ }
+ args.push_back(mt.lookupValue(thisOp.getReductionPredicate()));
+ } else {
+ id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
+ args.push_back(barrierId);
+ }
+
+ return {id, std::move(args)};
+}
+
mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MBarrierInitOp>(op);
- unsigned addressSpace =
- llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType())
- .getAddressSpace();
- llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared)
- ? llvm::Intrinsic::nvvm_mbarrier_init_shared
- : llvm::Intrinsic::nvvm_mbarrier_init;
+ bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
+ llvm::Intrinsic::ID id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared
+ : llvm::Intrinsic::nvvm_mbarrier_init;
// Fill the Intrinsic Args
llvm::SmallVector<llvm::Value *> args;
@@ -1658,16 +2494,353 @@ mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
- unsigned addressSpace =
- llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType())
- .getAddressSpace();
- llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared)
+ bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
+ llvm::Intrinsic::ID id = isShared
? llvm::Intrinsic::nvvm_mbarrier_inval_shared
: llvm::Intrinsic::nvvm_mbarrier_inval;
return {id, {mt.lookupValue(thisOp.getAddr())}};
}
+mlir::NVVM::IDArgPair MBarrierExpectTxOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierExpectTxOp>(op);
+
+ bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ // bit-0: Space
+ // bit-1: Scope
+ size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
+
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cluster};
+
+ // Fill the Intrinsic Args
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getAddr()));
+ args.push_back(mt.lookupValue(thisOp.getTxcount()));
+
+ return {IDs[index], std::move(args)};
+}
+
+mlir::NVVM::IDArgPair MBarrierCompleteTxOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierCompleteTxOp>(op);
+
+ bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ // bit-0: Space
+ // bit-1: Scope
+ size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
+
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cluster};
+
+ // Fill the Intrinsic Args
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getAddr()));
+ args.push_back(mt.lookupValue(thisOp.getTxcount()));
+
+ return {IDs[index], std::move(args)};
+}
+
+mlir::NVVM::IDArgPair MBarrierArriveOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierArriveOp>(op);
+
+ bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ // bit-0: Space
+ // bit-1: Scope
+ size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
+
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cluster};
+ static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::
+ nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cluster};
+ auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
+
+ // Tidy-up the Intrinsic Args
+ bool needCast = isPtrInGenericSpace(thisOp.getAddr());
+ llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
+ if (needCast)
+ mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
+
+ // When count is not explicitly specified, the default is 1.
+ llvm::LLVMContext &ctx = mt.getLLVMContext();
+ bool hasCount = static_cast<bool>(thisOp.getCount());
+ llvm::Value *count =
+ hasCount ? mt.lookupValue(thisOp.getCount())
+ : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
+
+ return {id, {mbar, count}};
+}
+
+mlir::NVVM::IDArgPair MBarrierArriveDropOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierArriveDropOp>(op);
+
+ bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ // bit-0: Space
+ // bit-1: Scope
+ size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
+
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cluster};
+ static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::
+ nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cluster,
+ llvm::Intrinsic::
+ nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::
+ nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cluster};
+ auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
+
+ // Tidy-up the Intrinsic Args
+ bool needCast = isPtrInGenericSpace(thisOp.getAddr());
+ llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
+ if (needCast)
+ mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
+
+ // When count is not explicitly specified, the default is 1.
+ llvm::LLVMContext &ctx = mt.getLLVMContext();
+ bool hasCount = static_cast<bool>(thisOp.getCount());
+ llvm::Value *count =
+ hasCount ? mt.lookupValue(thisOp.getCount())
+ : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
+
+ return {id, {mbar, count}};
+}
+
+bool MBarrierArriveExpectTxOp::getAsmValues(
+ RewriterBase &rewriter,
+ llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
+ &asmValues) {
+ // Add all the operands but not the attrs to the asmValues list.
+ // The attrs here are used to generate the right variants for
+ // intrinsics-lowering. So, we ignore them while generating inline-PTX.
+ for (auto val : getOperands())
+ asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
+
+ return false;
+}
+
+mlir::NVVM::IDArgPair MBarrierArriveExpectTxOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierArriveExpectTxOp>(op);
+
+ bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ // bit-0: Space
+ // bit-1: Scope
+ size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
+
+ // clang-format off
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cluster};
+ static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cluster};
+ // clang-format on
+ auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
+
+ // Tidy-up the Intrinsic Args
+ llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount());
+ llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
+ bool needCast = isPtrInGenericSpace(thisOp.getAddr());
+ if (needCast)
+ mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
+
+ return {id, {mbar, txcount}};
+}
+
+mlir::NVVM::IDArgPair MBarrierArriveDropExpectTxOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierArriveDropExpectTxOp>(op);
+
+ bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ // bit-0: Space
+ // bit-1: Scope
+ size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
+
+ // clang-format off
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cluster};
+ static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cluster};
+ // clang-format on
+ auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
+
+ // Tidy-up the Intrinsic Args
+ llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount());
+ llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
+ bool needCast = isPtrInGenericSpace(thisOp.getAddr());
+ if (needCast)
+ mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
+
+ return {id, {mbar, txcount}};
+}
+
+mlir::NVVM::IDArgPair MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op);
+ bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
+ llvm::Intrinsic::ID id =
+ isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared
+ : llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete;
+ // Fill the Intrinsic Args
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getAddr()));
+ args.push_back(mt.lookupValue(thisOp.getCount()));
+
+ return {id, std::move(args)};
+}
+
+mlir::NVVM::IDArgPair MBarrierArriveDropNocompleteOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierArriveDropNocompleteOp>(op);
+ bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
+ llvm::Intrinsic::ID id =
+ isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete_shared
+ : llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete;
+ // Fill the Intrinsic Args
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getAddr()));
+ args.push_back(mt.lookupValue(thisOp.getCount()));
+
+ return {id, std::move(args)};
+}
+
+mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
+ bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ // bit-0: isPhaseParity
+ // bit-1: Scope
+ size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0);
+
+ // clang-format off
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta};
+ static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta};
+ // clang-format on
+ auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
+
+ // Tidy-up the Intrinsic Args
+ llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
+ llvm::Value *input = mt.lookupValue(thisOp.getStateOrPhase());
+ bool needCast = isPtrInGenericSpace(thisOp.getAddr());
+ if (needCast)
+ mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
+
+ return {id, {mbar, input}};
+}
+
+mlir::NVVM::IDArgPair MBarrierTryWaitOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierTryWaitOp>(op);
+ bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ bool hasTicks = static_cast<bool>(thisOp.getTicks());
+ // bit-0: isPhaseParity
+ // bit-1: Scope
+ // bit-2: hasTicks
+ size_t index = ((hasTicks ? 1 : 0) << 2) | ((isClusterScope ? 1 : 0) << 1) |
+ (isPhaseParity ? 1 : 0);
+
+ // clang-format off
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cluster_space_cta};
+ static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cluster_space_cta};
+ // clang-format on
+ auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
+
+ // Tidy-up the mbarrier pointer
+ llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
+ bool needCast = isPtrInGenericSpace(thisOp.getAddr());
+ if (needCast)
+ mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
+
+ // Fill the Intrinsic Args
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mbar);
+ args.push_back(mt.lookupValue(thisOp.getStateOrPhase()));
+ if (hasTicks)
+ args.push_back(mt.lookupValue(thisOp.getTicks()));
+
+ return {id, std::move(args)};
+}
+
+mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op);
+ bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
+
+ llvm::Intrinsic::ID id;
+ if (thisOp.getNoinc()) {
+ id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared
+ : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc;
+ } else {
+ id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared
+ : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive;
+ }
+
+ return {id, {mt.lookupValue(thisOp.getAddr())}};
+}
+
#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
@@ -1737,11 +2910,15 @@ mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
args.push_back(mt.lookupValue(thisOp.getSrcMem()));
args.push_back(mt.lookupValue(thisOp.getSize()));
- // Multicast mask, if available.
+ // Multicast mask for shared::cluster only, if available.
mlir::Value multicastMask = thisOp.getMulticastMask();
const bool hasMulticastMask = static_cast<bool>(multicastMask);
- llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
- args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused);
+ const bool isSharedCTA = isPtrInSharedCTASpace(thisOp.getDstMem());
+ if (!isSharedCTA) {
+ llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
+ args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask)
+ : i16Unused);
+ }
// Cache hint, if available.
mlir::Value cacheHint = thisOp.getL2CacheHint();
@@ -1750,11 +2927,14 @@ mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
// Flag arguments for multicast and cachehint.
- args.push_back(builder.getInt1(hasMulticastMask));
+ if (!isSharedCTA)
+ args.push_back(builder.getInt1(hasMulticastMask));
args.push_back(builder.getInt1(hasCacheHint));
llvm::Intrinsic::ID id =
- llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
+ isSharedCTA
+ ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta
+ : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
return {id, std::move(args)};
}
@@ -2469,6 +3649,155 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
}()
+NVVM::IDArgPair
+ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ static constexpr llvm::Intrinsic::ID rndRNIds[] = {
+ llvm::Intrinsic::nvvm_ff2f16x2_rn,
+ llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
+ llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
+ llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRZIds[] = {
+ llvm::Intrinsic::nvvm_ff2f16x2_rz,
+ llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
+ llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
+ llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRSIds[] = {
+ llvm::Intrinsic::nvvm_ff2f16x2_rs,
+ llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
+ llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
+ llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
+ };
+
+ unsigned hasRelu = op.getRelu() ? 1 : 0;
+ unsigned hasSatFinite =
+ (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
+ // idx: bit-0 - relu
+ // bit-1 - satfinite
+ unsigned idx = (hasSatFinite << 1) | hasRelu;
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getSrcHi()));
+ args.push_back(mt.lookupValue(op.getSrcLo()));
+ if (op.getRandomBits())
+ args.push_back(mt.lookupValue(op.getRandomBits()));
+
+ switch (op.getRnd()) {
+ case FPRoundingMode::RN:
+ return {rndRNIds[idx], std::move(args)};
+ case FPRoundingMode::RZ:
+ return {rndRZIds[idx], std::move(args)};
+ case FPRoundingMode::RS:
+ return {rndRSIds[idx], std::move(args)};
+ default:
+ llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op");
+ }
+}
+
+NVVM::IDArgPair
+ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ static constexpr llvm::Intrinsic::ID rndRNIds[] = {
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRZIds[] = {
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRSIds[] = {
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
+ };
+
+ unsigned hasRelu = op.getRelu() ? 1 : 0;
+ unsigned hasSatFinite =
+ (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
+ // idx: bit-0 - relu
+ // bit-1 - satfinite
+ unsigned idx = (hasSatFinite << 1) | hasRelu;
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getSrcHi()));
+ args.push_back(mt.lookupValue(op.getSrcLo()));
+ if (op.getRandomBits())
+ args.push_back(mt.lookupValue(op.getRandomBits()));
+
+ switch (op.getRnd()) {
+ case FPRoundingMode::RN:
+ return {rndRNIds[idx], std::move(args)};
+ case FPRoundingMode::RZ:
+ return {rndRZIds[idx], std::move(args)};
+ case FPRoundingMode::RS:
+ return {rndRSIds[idx], std::move(args)};
+ default:
+ llvm_unreachable("Invalid rounding mode for ConvertF32x2ToBF16x2Op");
+ }
+}
+
+llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
+ mlir::Type dstTy = getDstTy();
+ bool hasRelu = getRelu();
+
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite;
+ })
+ .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite;
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid F8 type in ConvertF32x4ToF8x4Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+}
+
+llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() {
+ mlir::Type dstTy = getDstTy();
+ bool hasRelu = getRelu();
+
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite;
+ })
+ .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite;
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid F6 type in ConvertF32x4ToF6x4Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+}
+
+llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() {
+ mlir::Type dstTy = getDstTy();
+ bool hasRelu = getRelu();
+
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float4E2M1FNType>([&](mlir::Float4E2M1FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite
+ : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite;
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid F4 type in ConvertF32x4ToF4x4Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+}
+
llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
auto curOp = cast<NVVM::Tcgen05CpOp>(op);
bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
@@ -2508,6 +3837,9 @@ LogicalResult Tcgen05LdOp::verify() {
if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
result = emitError("shape 16x32bx2 requires offset argument");
+ if (getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
+ result = emitError("offset argument is only supported for shape 16x32bx2");
+
auto resTy = getRes().getType();
unsigned resLen = isa<VectorType>(resTy)
? llvm::cast<VectorType>(resTy).getNumElements()
@@ -2751,6 +4083,630 @@ NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
return {intrinsicID, args};
}
+mlir::NVVM::IDArgPair
+PermuteOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::PermuteOp>(op);
+ NVVM::PermuteMode mode = thisOp.getMode();
+
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_prmt, llvm::Intrinsic::nvvm_prmt_f4e,
+ llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8,
+ llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr,
+ llvm::Intrinsic::nvvm_prmt_rc16};
+
+ unsigned modeIndex = static_cast<unsigned>(mode);
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getLo()));
+
+ // Only first 3 modes (Default, f4e, b4e) need the hi operand.
+ if (modeIndex < 3)
+ args.push_back(mt.lookupValue(thisOp.getHi()));
+
+ args.push_back(mt.lookupValue(thisOp.getSelector()));
+
+ return {IDs[modeIndex], args};
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair
+Tcgen05MMAOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+
+ auto thisOp = cast<NVVM::Tcgen05MMAOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixD()));
+
+ llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
+ const bool isATensor = isa<llvm::PointerType>(A->getType());
+ args.push_back(A);
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixB()));
+ args.push_back(mt.lookupValue(thisOp.getIdesc()));
+ args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+
+ using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
+ using CtaGroupArray = std::array<EnableAShiftArray, 2>;
+ using IsATensorArray = std::array<CtaGroupArray, 2>;
+ using HasScaleInputDArray = std::array<IsATensorArray, 2>;
+ using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
+
+ // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift]
+ static constexpr HasDisableOutputLaneArray tcgen05MMAIDs = {
+ { // without diable output lane
+ {{// without scale input D
+ {{
+ // shared
+ {{// cg1
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic}}},
+ {{// tensor
+ {
+ // cg1
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
+ },
+ {
+ // cg2
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
+ }}},
+ }},
+ // with scale input D
+ {{ // shared
+ {{// cg1
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic}}},
+ {{// tensor
+ {
+ // cg1
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
+ },
+ {
+ // cg2
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
+ }}}}}}},
+ // with disable output lane
+ {{ // without scale input D
+ {{ // shared
+ {{// cg1
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
+ notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2,
+ notIntrinsic}}},
+ {{// cg1
+ {
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
+ },
+ // cg2
+ {
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_disable_output_lane_cg2,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift,
+ }}}}},
+ // with scale input D
+ {{ // shared
+ {{// cg1
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
+ notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2,
+ notIntrinsic}}},
+ // tensor
+ {{// cg1
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift},
+ // cg2
+ {
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift,
+ }}}}}}}}};
+
+ llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
+ bool hasScaleInputD = ScaleInputD != nullptr;
+
+ llvm::Value *DisableOutputLane =
+ mt.lookupValue(thisOp.getDisableOutputLane());
+ bool hasDisableOutputLane = DisableOutputLane != nullptr;
+
+ const unsigned ctaGroup =
+ static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
+
+ llvm::Intrinsic::ID ID =
+ tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
+ [ctaGroup - 1][thisOp.getAShift()];
+
+ assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMAOp.");
+
+ if (hasScaleInputD)
+ args.push_back(ScaleInputD);
+
+ if (hasDisableOutputLane)
+ args.push_back(DisableOutputLane);
+
+ args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
+
+ if (!hasDisableOutputLane)
+ args.push_back(builder.getInt32(ctaGroup));
+
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
+
+ return {ID, args};
+}
+
+static LogicalResult
+verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane,
+ NVVM::CTAGroupKind ctaGroup, bool hasAShift,
+ NVVM::Tcgen05MMACollectorOp collectorOp, Location loc) {
+
+ if (disableOutputLane) {
+ mlir::VectorType disableOutputLaneType =
+ cast<mlir::VectorType>(disableOutputLane.getType());
+ if ((ctaGroup == NVVM::CTAGroupKind::CTA_1 &&
+ disableOutputLaneType.getNumElements() != 4) ||
+ (ctaGroup == NVVM::CTAGroupKind::CTA_2 &&
+ disableOutputLaneType.getNumElements() != 8))
+ return emitError(loc) << "Disable Output Lane of length "
+ << disableOutputLaneType.getNumElements()
+ << " is incompatible with CtaGroupAttr";
+ }
+
+ if (hasAShift && !isATensor)
+ return emitError(
+ loc, "A-shift can be applied only when matrix A is in tensor memory");
+
+ if (hasAShift == true && (collectorOp == Tcgen05MMACollectorOp::FILL ||
+ collectorOp == Tcgen05MMACollectorOp::USE))
+ return emitError(
+ loc, "Cannot use collector buffer operation fill or use with ashift");
+
+ return success();
+}
+
+LogicalResult Tcgen05MMAOp::verify() {
+ return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
+ getDisableOutputLane(), getCtaGroup(), getAShift(),
+ getCollectorOp(), getLoc());
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma.sp functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair Tcgen05MMASparseOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+
+ auto thisOp = cast<NVVM::Tcgen05MMASparseOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixD()));
+
+ llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
+ bool isATensor = isa<llvm::PointerType>(A->getType());
+ args.push_back(A);
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixB()));
+ args.push_back(mt.lookupValue(thisOp.getIdesc()));
+ args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+ args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
+
+ using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
+ using CtaGroupArray = std::array<EnableAShiftArray, 2>;
+ using IsATensorArray = std::array<CtaGroupArray, 2>;
+ using HasScaleInputDArray = std::array<IsATensorArray, 2>;
+ using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
+
+ // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift]
+ static constexpr HasDisableOutputLaneArray tcgen05MMASparseIDs = {
+ { // without diable output lane
+ {{// without scale input D
+ {{
+ // shared
+ {{// cg1
+ {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic}}},
+ {{// tensor
+ {
+ // cg1
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
+ },
+ {
+ // cg2
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
+ }}},
+ }},
+ // with scale input D
+ {{ // shared
+ {{// cg1
+ {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
+ notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
+ notIntrinsic}}},
+ {{// tensor
+ {
+ // cg1
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
+ },
+ {
+ // cg2
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
+ }}}}}}},
+ // with disable output lane
+ {{ // without scale input D
+ {{ // shared
+ {{// cg1
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1,
+ notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2,
+ notIntrinsic}}},
+ {{// cg1
+ {
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift,
+ },
+ // cg2
+ {
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift,
+ }}}}},
+ // with scale input D
+ {{ // shared
+ {{// cg1
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1,
+ notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2,
+ notIntrinsic}}},
+ // tensor
+ {{// cg1
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift},
+ // cg2
+ {
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift,
+ }}}}}}}}};
+
+ llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
+ bool hasScaleInputD = ScaleInputD != nullptr;
+
+ llvm::Value *DisableOutputLane =
+ mt.lookupValue(thisOp.getDisableOutputLane());
+ bool hasDisableOutputLane = DisableOutputLane != nullptr;
+
+ unsigned ctaGroup =
+ static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
+
+ llvm::Intrinsic::ID ID =
+ tcgen05MMASparseIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
+ [ctaGroup - 1][thisOp.getAShift()];
+
+ assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMASparseOp.");
+
+ if (hasScaleInputD)
+ args.push_back(ScaleInputD);
+
+ if (hasDisableOutputLane)
+ args.push_back(DisableOutputLane);
+
+ args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
+
+ if (!hasDisableOutputLane)
+ args.push_back(builder.getInt32(ctaGroup));
+
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
+
+ return {ID, args};
+}
+
+LogicalResult Tcgen05MMASparseOp::verify() {
+ return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
+ getDisableOutputLane(), getCtaGroup(), getAShift(),
+ getCollectorOp(), getLoc());
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma.block_scale functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+
+ auto thisOp = cast<NVVM::Tcgen05MMABlockScaleOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixD()));
+
+ llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
+ bool isATensor = isa<llvm::PointerType>(A->getType());
+ args.push_back(A);
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixB()));
+ args.push_back(mt.lookupValue(thisOp.getIdesc()));
+ args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+ args.push_back(mt.lookupValue(thisOp.getScaleA()));
+ args.push_back(mt.lookupValue(thisOp.getScaleB()));
+ args.push_back(builder.getInt32(
+ static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()))));
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
+
+ auto kind = thisOp.getKind();
+ auto blockScale = thisOp.getBlockScale();
+ llvm::Intrinsic::ID ID = [&]() {
+ if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF8F6F4) {
+ if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
+ return isATensor ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale;
+ } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
+ return isATensor
+ ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale_block32
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32;
+ }
+ } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4) {
+ if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
+ return isATensor
+ ? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale
+ : llvm::Intrinsic::nvvm_tcgen05_mma_shared_mxf4_block_scale;
+ } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
+ return isATensor ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_mxf4_block_scale_block32
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_mxf4_block_scale_block32;
+ }
+ } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4NVF4) {
+ if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
+ return isATensor
+ ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block32
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block32;
+
+ } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
+ return isATensor
+ ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block16
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block16;
+ }
+ }
+ llvm_unreachable("Invalid tcgen05.mma.block_scale attributes");
+ }();
+
+ return {ID, args};
+}
+
+static LogicalResult
+verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp,
+ NVVM::Tcgen05MMABlockScaleKind kind,
+ NVVM::Tcgen05MMABlockScale blockScale,
+ Location loc) {
+
+ if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT &&
+ kind == Tcgen05MMABlockScaleKind::MXF4NVF4)
+ return emitError(loc, "mxf4nvf4 requires block scale attribute");
+
+ if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 &&
+ kind != Tcgen05MMABlockScaleKind::MXF4NVF4)
+ return emitError(loc,
+ llvm::formatv("{} kind does not support block16 attribute",
+ stringifyEnum(kind)));
+
+ return success();
+}
+
+LogicalResult Tcgen05MMABlockScaleOp::verify() {
+ return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(),
+ getBlockScale(), getLoc());
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma.sp.block_scale functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+
+ auto thisOp = cast<NVVM::Tcgen05MMASparseBlockScaleOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixD()));
+
+ llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
+ bool isATensor = isa<llvm::PointerType>(A->getType());
+ args.push_back(A);
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixB()));
+ args.push_back(mt.lookupValue(thisOp.getIdesc()));
+ args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+ args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
+ args.push_back(mt.lookupValue(thisOp.getScaleA()));
+ args.push_back(mt.lookupValue(thisOp.getScaleB()));
+ args.push_back(builder.getInt32(
+ static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()))));
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
+
+ auto kind = thisOp.getKind();
+ auto blockScale = thisOp.getBlockScale();
+ llvm::Intrinsic::ID ID = [&]() {
+ if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF8F6F4) {
+ if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
+ return isATensor ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale;
+ } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
+ return isATensor
+ ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale_block32
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32;
+ }
+ } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4) {
+ if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
+ return isATensor ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_mxf4_block_scale;
+ } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
+ return isATensor
+ ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale_block32
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32;
+ }
+ } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4NVF4) {
+ if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
+ return isATensor
+ ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block32
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block32;
+
+ } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
+ return isATensor
+ ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block16
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block16;
+ }
+ }
+ llvm_unreachable("Invalid tcgen05.mma.sp.block_scale attributes");
+ }();
+
+ return {ID, args};
+}
+
+LogicalResult Tcgen05MMASparseBlockScaleOp::verify() {
+ return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(),
+ getBlockScale(), getLoc());
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma.ws functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair Tcgen05MMAWsOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+
+ auto thisOp = cast<NVVM::Tcgen05MMAWsOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixD()));
+
+ llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
+ bool isATensor = isa<llvm::PointerType>(A->getType());
+ args.push_back(A);
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixB()));
+ args.push_back(mt.lookupValue(thisOp.getIdesc()));
+ args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+
+ mlir::Value ZeroColMask = thisOp.getZeroColMask();
+ llvm::Intrinsic::ID ID = notIntrinsic;
+ if (ZeroColMask) {
+ args.push_back(mt.lookupValue(ZeroColMask));
+ ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor_zero_col_mask
+ : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared_zero_col_mask;
+ } else
+ ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor
+ : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared;
+
+ args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer())));
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
+
+ return {ID, args};
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma.ws.sp functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair Tcgen05MMAWsSparseOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+
+ auto thisOp = cast<NVVM::Tcgen05MMAWsSparseOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixD()));
+
+ llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
+ bool isATensor = isa<llvm::PointerType>(A->getType());
+ args.push_back(A);
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixB()));
+ args.push_back(mt.lookupValue(thisOp.getIdesc()));
+ args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+ args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
+
+ mlir::Value ZeroColMask = thisOp.getZeroColMask();
+ llvm::Intrinsic::ID ID = notIntrinsic;
+ if (ZeroColMask) {
+ args.push_back(mt.lookupValue(ZeroColMask));
+ ID = isATensor
+ ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor_zero_col_mask
+ : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared_zero_col_mask;
+ } else
+ ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor
+ : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared;
+
+ args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer())));
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
+
+ return {ID, args};
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
@@ -2954,16 +4910,20 @@ LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
"Minimum NVVM target SM version is sm_20");
}
- gpuModuleOp->walk([&](Operation *op) {
- if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
- const NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion();
- if (!requirement.isCompatibleWith(targetSMVersion)) {
- op->emitOpError() << "is not supported on " << getChip();
- return WalkResult::interrupt();
- }
- }
- return WalkResult::advance();
- });
+ if (gpuModuleOp
+ ->walk([&](Operation *op) {
+ if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
+ const NVVMCheckSMVersion requirement =
+ reqOp.getRequiredMinSMVersion();
+ if (!requirement.isCompatibleWith(targetSMVersion)) {
+ op->emitOpError() << "is not supported on " << getChip();
+ return WalkResult::interrupt();
+ }
+ }
+ return WalkResult::advance();
+ })
+ .wasInterrupted())
+ return failure();
return success();
}
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp
index 67573c4..12dd225 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp
@@ -109,8 +109,12 @@ static Location getNestedLoc(Operation *op, LLVM::DIScopeAttr scopeAttr,
return FusedLoc::get(context, {loc}, lexicalBlockFileAttr);
}
+/// Adds DILexicalBlockFileAttr for operations with CallSiteLoc and operations
+/// from different files than their containing function.
static void setLexicalBlockFileAttr(Operation *op) {
- if (auto callSiteLoc = dyn_cast<CallSiteLoc>(op->getLoc())) {
+ Location opLoc = op->getLoc();
+
+ if (auto callSiteLoc = dyn_cast<CallSiteLoc>(opLoc)) {
auto callerLoc = callSiteLoc.getCaller();
auto calleeLoc = callSiteLoc.getCallee();
LLVM::DIScopeAttr scopeAttr;
@@ -122,6 +126,45 @@ static void setLexicalBlockFileAttr(Operation *op) {
op->setLoc(
CallSiteLoc::get(getNestedLoc(op, scopeAttr, calleeLoc), callerLoc));
}
+
+ return;
+ }
+
+ auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
+ if (!funcOp)
+ return;
+
+ FileLineColLoc opFileLoc = extractFileLoc(opLoc);
+ if (!opFileLoc)
+ return;
+
+ FileLineColLoc funcFileLoc = extractFileLoc(funcOp.getLoc());
+ if (!funcFileLoc)
+ return;
+
+ StringRef opFile = opFileLoc.getFilename().getValue();
+ StringRef funcFile = funcFileLoc.getFilename().getValue();
+
+ // Handle cross-file operations: add DILexicalBlockFileAttr when the
+ // operation's source file differs from its containing function.
+ if (opFile != funcFile) {
+ auto funcOpLoc = llvm::dyn_cast_if_present<FusedLoc>(funcOp.getLoc());
+ if (!funcOpLoc)
+ return;
+ auto scopeAttr = dyn_cast<LLVM::DISubprogramAttr>(funcOpLoc.getMetadata());
+ if (!scopeAttr)
+ return;
+
+ auto *context = op->getContext();
+ LLVM::DIFileAttr opFileAttr =
+ LLVM::DIFileAttr::get(context, llvm::sys::path::filename(opFile),
+ llvm::sys::path::parent_path(opFile));
+
+ LLVM::DILexicalBlockFileAttr lexicalBlockFileAttr =
+ LLVM::DILexicalBlockFileAttr::get(context, scopeAttr, opFileAttr, 0);
+
+ Location newLoc = FusedLoc::get(context, {opLoc}, lexicalBlockFileAttr);
+ op->setLoc(newLoc);
}
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index dcc1ef9..b4b1347 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -1057,12 +1057,15 @@ LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
// FillOpInterface implementation
//===----------------------------------------------------------------------===//
+namespace {
enum class MatchFillResult {
Success = 0,
NotLinalgOp,
WrongNumOperands,
- NotScalarInput
+ NotScalarInput,
+ TypeMismatch
};
+} // namespace
static MatchFillResult isFillInterfaceImpl(Operation *op) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
@@ -1075,17 +1078,33 @@ static MatchFillResult isFillInterfaceImpl(Operation *op) {
if (!linalgOp.isScalar(value))
return MatchFillResult::NotScalarInput;
+ // Check that the scalar input type matches the output element type.
+ OpOperand *output = linalgOp.getDpsInitOperand(0);
+ Type scalarType = value->get().getType();
+ Type outputElementType = getElementTypeOrSelf(output->get().getType());
+ if (scalarType != outputElementType)
+ return MatchFillResult::TypeMismatch;
+
return MatchFillResult::Success;
}
LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
- auto res = isFillInterfaceImpl(op);
+ MatchFillResult res = isFillInterfaceImpl(op);
if (res == MatchFillResult::NotLinalgOp)
return op->emitError("expected a LinalgOp");
if (res == MatchFillResult::WrongNumOperands)
return op->emitError("expected op with 1 input and 1 output");
if (res == MatchFillResult::NotScalarInput)
return op->emitError("expected op with scalar input");
+ if (res == MatchFillResult::TypeMismatch) {
+ auto linalgOp = cast<linalg::LinalgOp>(op);
+ Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType();
+ Type outputElementType =
+ getElementTypeOrSelf(linalgOp.getDpsInitOperand(0)->get().getType());
+ return op->emitOpError("expected fill value type (")
+ << scalarType << ") to match output element type ("
+ << outputElementType << ")";
+ }
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3dc45ed..33ec79b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1338,8 +1338,6 @@ Speculation::Speculatability GenericOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
-LogicalResult GenericOp::verify() { return success(); }
-
namespace {
/// Remove linalg operations that are just copying the values from inputs to
@@ -2091,7 +2089,7 @@ LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
return failure();
// Single dimension transpose.
- if (getPermutation().size() == 0) {
+ if (getPermutation().empty()) {
result.push_back(getInput());
return success();
}
@@ -4885,13 +4883,6 @@ void ElementwiseOp::print(OpAsmPrinter &p) {
elidedAttrs);
}
-LogicalResult ElementwiseOp::verify() {
- // All necessary checks are done either by
- // - EnumAttr (e.g. unknown operation kind)
- // - verifyStructuredOpInterface (incorrect map, sizes).
- return success();
-}
-
/// Implements the block region builder for the ElementwiseOp. This is called by
/// 'fillStructuredOpRegion'.
void ElementwiseOp::regionBuilder(
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 3a43382..b8c1bad 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -176,7 +176,8 @@ static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(
if (auto attr = dyn_cast<Attribute>(paramOrHandle)) {
reified.push_back(cast<IntegerAttr>(attr).getInt());
continue;
- } else if (isa<ParamType>(cast<Value>(paramOrHandle).getType())) {
+ }
+ if (isa<ParamType>(cast<Value>(paramOrHandle).getType())) {
ArrayRef<Attribute> params = state.getParams(cast<Value>(paramOrHandle));
if (params.size() != 1)
return transformOp.emitSilenceableError() << "expected a single param";
@@ -997,8 +998,11 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
// Iterate over the outputs of the producer and over the loop bbArgs and
// check if any bbArg points to the same value as the producer output. In
// such case, make the producer output point to the bbArg directly.
- for (OpOperand &initOperandPtr :
- cast<DestinationStyleOpInterface>(clone).getDpsInitsMutable()) {
+ auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(clone);
+ if (!dpsInterface)
+ return;
+
+ for (OpOperand &initOperandPtr : dpsInterface.getDpsInitsMutable()) {
Value producerOperand =
clone->getOperand(initOperandPtr.getOperandNumber());
for (BlockArgument containerIterArg :
@@ -1060,7 +1064,7 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
resultNumber, offsets, sizes);
// Cleanup clone.
- if (dyn_cast<LoopLikeOpInterface>(containingOp))
+ if (isa<LoopLikeOpInterface>(containingOp))
rewriter.eraseOp(tileableProducer);
return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 22690da..9e6c1e6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -747,8 +747,7 @@ struct RankReducedExtractSliceOp
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
auto rankReducedType = cast<RankedTensorType>(
tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
- reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
- strides));
+ reassociation->size(), sliceOp.getSourceType(), sizes));
Location loc = sliceOp.getLoc();
Value newSlice = tensor::ExtractSliceOp::create(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 05fc7cb..421ab5e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1038,6 +1038,62 @@ private:
ControlFusionFn controlFoldingReshapes;
};
+/// Carries information about a padded dimension.
+struct PadDimInfo {
+ // The resulting shape after padding each dimension.
+ SmallVector<int64_t> paddedShape;
+
+ // Low and high padding amounts for each dimension.
+ SmallVector<OpFoldResult> lowPad;
+ SmallVector<OpFoldResult> highPad;
+};
+
+/// Computes the expanded padding information for the given pad operation based
+/// on the provided expanded shape and reassociation indices. Returns a list of
+/// PadDimInfo containing the low and high padding amounts and the padded
+/// size for each dimension, or failure if the expansion is not possible.
+static FailureOr<PadDimInfo>
+computeExpandedPadding(tensor::PadOp padOp, ArrayRef<int64_t> expandedShape,
+ ArrayRef<ReassociationIndices> reassociations,
+ PatternRewriter &rewriter) {
+ // If the padding value depends on the index values of the pad operation,
+ // then it may not be valid to expand the dimensions, since it will change
+ // the index values on which the padding value depends. This is not currently
+ // supported by the pad expansion patterns, but it could be implemented
+ // similarly to the expansion of linalg.generic ops with linalg.index ops in
+ // the body, as is done in `updateExpandedGenericOpRegion`.
+ if (!padOp.getConstantPaddingValue())
+ return failure();
+
+ // Expanded dimensions cannot have padding because the resulting padding may
+ // not be representable by a tensor.pad op. There are some special cases where
+ // it is possible (like expanding unit dims), but supporting these cases is
+ // NYI, so disallow it for now.
+ ArrayRef<int64_t> low = padOp.getStaticLow();
+ ArrayRef<int64_t> high = padOp.getStaticHigh();
+ for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
+ if (reInd.size() != 1 && (l != 0 || h != 0))
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
+ SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
+ ArrayRef<int64_t> paddedShape = padOp.getResultType().getShape();
+ PadDimInfo padDimInfo;
+ padDimInfo.paddedShape.assign(expandedShape);
+ padDimInfo.lowPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
+ padDimInfo.highPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ if (reInd.size() == 1) {
+ padDimInfo.paddedShape[reInd[0]] = paddedShape[idx];
+ padDimInfo.lowPad[reInd[0]] = mixedLowPad[idx];
+ padDimInfo.highPad[reInd[0]] = mixedHighPad[idx];
+ }
+ }
+
+ return padDimInfo;
+}
+
class FoldPadWithProducerReshapeOpByExpansion
: public OpRewritePattern<tensor::PadOp> {
public:
@@ -1053,46 +1109,96 @@ public:
padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
if (!reshapeOp)
return failure();
- if (!reshapeOp->hasOneUse())
- return failure();
if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
return rewriter.notifyMatchFailure(padOp,
"fusion blocked by control function");
}
- ArrayRef<int64_t> low = padOp.getStaticLow();
- ArrayRef<int64_t> high = padOp.getStaticHigh();
+ RankedTensorType expandedType = reshapeOp.getSrcType();
SmallVector<ReassociationIndices> reassociations =
reshapeOp.getReassociationIndices();
+ FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
+ padOp, expandedType.getShape(), reassociations, rewriter);
+ if (failed(maybeExpandedPadding))
+ return failure();
+ PadDimInfo &expandedPadding = maybeExpandedPadding.value();
- for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
- if (reInd.size() != 1 && (l != 0 || h != 0))
- return failure();
+ Location loc = padOp->getLoc();
+ RankedTensorType expandedPaddedType =
+ padOp.getResultType().clone(expandedPadding.paddedShape);
+
+ auto newPadOp = tensor::PadOp::create(
+ rewriter, loc, expandedPaddedType, reshapeOp.getSrc(),
+ expandedPadding.lowPad, expandedPadding.highPad,
+ padOp.getConstantPaddingValue(), padOp.getNofold());
+
+ rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+ padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
+
+ return success();
+ }
+
+private:
+ ControlFusionFn controlFoldingReshapes;
+};
+
+class FoldReshapeWithProducerPadOpByExpansion
+ : public OpRewritePattern<tensor::ExpandShapeOp> {
+public:
+ FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+ PatternRewriter &rewriter) const override {
+ tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+ if (!padOp)
+ return failure();
+
+ if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+ return rewriter.notifyMatchFailure(expandOp,
+ "fusion blocked by control function");
}
- SmallVector<OpFoldResult> newLow, newHigh;
- RankedTensorType expandedType = reshapeOp.getSrcType();
- RankedTensorType paddedType = padOp.getResultType();
- SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
+ RankedTensorType expandedType = expandOp.getResultType();
+ SmallVector<ReassociationIndices> reassociations =
+ expandOp.getReassociationIndices();
+ FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
+ padOp, expandedType.getShape(), reassociations, rewriter);
+ if (failed(maybeExpandedPadding))
+ return failure();
+ PadDimInfo &expandedPadding = maybeExpandedPadding.value();
+
+ Location loc = expandOp->getLoc();
+ SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape();
+ SmallVector<int64_t> newExpandedShape(expandedType.getShape());
+ rewriter.setInsertionPointAfterValue(padOp.getSource());
+ SmallVector<OpFoldResult> padSrcSizes =
+ tensor::getMixedSizes(rewriter, loc, padOp.getSource());
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ // We know that any reassociation with multiple dims is not padded because
+ // of the requirements of computeExpandedPadding.
if (reInd.size() == 1) {
- expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
- }
- for (size_t i = 0; i < reInd.size(); ++i) {
- newLow.push_back(padOp.getMixedLowPad()[idx]);
- newHigh.push_back(padOp.getMixedHighPad()[idx]);
+ newExpandedShape[reInd[0]] = padOp.getSourceType().getDimSize(idx);
+ newExpandedSizes[reInd[0]] = padSrcSizes[idx];
}
}
-
- Location loc = padOp->getLoc();
- RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
+ RankedTensorType newExpandedType = expandedType.clone(newExpandedShape);
+ auto newExpandOp = tensor::ExpandShapeOp::create(
+ rewriter, loc, newExpandedType, padOp.getSource(), reassociations,
+ newExpandedSizes);
+ RankedTensorType expandedPaddedType =
+ padOp.getResultType().clone(expandedPadding.paddedShape);
+ rewriter.setInsertionPoint(expandOp);
auto newPadOp = tensor::PadOp::create(
- rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
+ rewriter, loc, expandedPaddedType, newExpandOp.getResult(),
+ expandedPadding.lowPad, expandedPadding.highPad,
padOp.getConstantPaddingValue(), padOp.getNofold());
- rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
- padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
+ rewriter.replaceOp(expandOp, newPadOp.getResult());
return success();
}
@@ -1921,6 +2027,62 @@ private:
ControlFusionFn controlFoldingReshapes;
};
+/// Computes the collapsed padding information for the given pad operation based
+/// on the provided collapsed shape and reassociation indices. Returns a
+/// PadDimInfo containing the low and high padding amounts and the collapsed
+/// shape for each dimension, or failure if the collapse is not possible.
+static FailureOr<PadDimInfo>
+computeCollapsedPadding(tensor::PadOp padOp,
+ ArrayRef<ReassociationIndices> reassociations,
+ PatternRewriter &rewriter) {
+ // If the padding value depends on the index values of the pad operation,
+ // then it may not be valid to collapse the dimensions, since it will change
+ // the index values on which the padding value depends. This is not currently
+ // supported by the pad collapsing patterns, but it could be implemented
+ // similarly to the collapsing of linalg.generic ops with linalg.index ops in
+ // the body, as is done in `generateCollapsedIndexingRegion`.
+ if (!padOp.getConstantPaddingValue())
+ return failure();
+
+ // Collapsed dimensions cannot have padding because this can produce strided
+ // padding that isn't representable by a tensor.pad op. There are some special
+ // cases where it is possible (like collapsing unit dims), but supporting
+ // these cases is NYI, so disallow it for now.
+ ArrayRef<int64_t> low = padOp.getStaticLow();
+ ArrayRef<int64_t> high = padOp.getStaticHigh();
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ for (int64_t dim : reInd) {
+ if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1)
+ return failure();
+ }
+ }
+
+ // Initialize padding values for collapsed tensors with zeros
+ ArrayRef<int64_t> expandedPaddedShape = padOp.getType().getShape();
+ PadDimInfo padDimInfo;
+ padDimInfo.lowPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
+ padDimInfo.highPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
+
+ // Update padding for dimensions that are not being collapsed, and compute
+ // the collapsed padded shape.
+ SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
+ SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ if (reInd.size() == 1) {
+ padDimInfo.lowPad[idx] = mixedLowPad[reInd[0]];
+ padDimInfo.highPad[idx] = mixedHighPad[reInd[0]];
+ }
+ SaturatedInteger collapsedSize = SaturatedInteger::wrap(1);
+ for (int64_t dim : reInd) {
+ collapsedSize =
+ collapsedSize * SaturatedInteger::wrap(expandedPaddedShape[dim]);
+ }
+ padDimInfo.paddedShape.push_back(collapsedSize.asInteger());
+ }
+
+ return padDimInfo;
+}
+
class FoldPadWithProducerReshapeOpByCollapsing
: public OpRewritePattern<tensor::PadOp> {
public:
@@ -1936,57 +2098,40 @@ public:
padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
if (!reshapeOp)
return failure();
- if (!reshapeOp->hasOneUse())
- return failure();
if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
return rewriter.notifyMatchFailure(padOp,
"fusion blocked by control function");
}
- ArrayRef<int64_t> low = padOp.getStaticLow();
- ArrayRef<int64_t> high = padOp.getStaticHigh();
SmallVector<ReassociationIndices> reassociations =
reshapeOp.getReassociationIndices();
+ FailureOr<PadDimInfo> maybeCollapsedPadding =
+ computeCollapsedPadding(padOp, reassociations, rewriter);
+ if (failed(maybeCollapsedPadding))
+ return failure();
+ PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();
- for (auto reInd : reassociations) {
- if (reInd.size() == 1)
- continue;
- if (llvm::any_of(reInd, [&](int64_t ind) {
- return low[ind] != 0 || high[ind] != 0;
- })) {
- return failure();
- }
- }
-
- SmallVector<OpFoldResult> newLow, newHigh;
- RankedTensorType collapsedType = reshapeOp.getSrcType();
- RankedTensorType paddedType = padOp.getResultType();
- SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
- SmallVector<OpFoldResult> expandedPaddedSizes(
- getMixedValues(reshapeOp.getStaticOutputShape(),
- reshapeOp.getOutputShape(), rewriter));
+ SmallVector<OpFoldResult> expandedPaddedSizes =
+ reshapeOp.getMixedOutputShape();
AffineExpr d0, d1, d2;
bindDims(rewriter.getContext(), d0, d1, d2);
auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
Location loc = reshapeOp->getLoc();
- for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
- OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
- OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
+ for (auto [reInd, l, h] :
+ llvm::zip_equal(reassociations, collapsedPadding.lowPad,
+ collapsedPadding.highPad)) {
if (reInd.size() == 1) {
- collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
- OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
+ expandedPaddedSizes[reInd[0]] = affine::makeComposedFoldedAffineApply(
rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
- expandedPaddedSizes[reInd[0]] = paddedSize;
}
- newLow.push_back(l);
- newHigh.push_back(h);
}
RankedTensorType collapsedPaddedType =
- paddedType.clone(collapsedPaddedShape);
+ padOp.getType().clone(collapsedPadding.paddedShape);
auto newPadOp = tensor::PadOp::create(
- rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
+ rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(),
+ collapsedPadding.lowPad, collapsedPadding.highPad,
padOp.getConstantPaddingValue(), padOp.getNofold());
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
@@ -2000,6 +2145,52 @@ private:
ControlFusionFn controlFoldingReshapes;
};
+class FoldReshapeWithProducerPadOpByCollapsing
+ : public OpRewritePattern<tensor::CollapseShapeOp> {
+public:
+ FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
+ PatternRewriter &rewriter) const override {
+ tensor::PadOp padOp = reshapeOp.getSrc().getDefiningOp<tensor::PadOp>();
+ if (!padOp)
+ return failure();
+
+ if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
+ return rewriter.notifyMatchFailure(padOp,
+ "fusion blocked by control function");
+ }
+
+ SmallVector<ReassociationIndices> reassociations =
+ reshapeOp.getReassociationIndices();
+ RankedTensorType collapsedPaddedType = reshapeOp.getResultType();
+ FailureOr<PadDimInfo> maybeCollapsedPadding =
+ computeCollapsedPadding(padOp, reassociations, rewriter);
+ if (failed(maybeCollapsedPadding))
+ return failure();
+ PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();
+
+ Location loc = reshapeOp->getLoc();
+ auto newCollapseOp = tensor::CollapseShapeOp::create(
+ rewriter, loc, padOp.getSource(), reassociations);
+
+ auto newPadOp = tensor::PadOp::create(
+ rewriter, loc, collapsedPaddedType, newCollapseOp.getResult(),
+ collapsedPadding.lowPad, collapsedPadding.highPad,
+ padOp.getConstantPaddingValue(), padOp.getNofold());
+
+ rewriter.replaceOp(reshapeOp, newPadOp.getResult());
+ return success();
+ }
+
+private:
+ ControlFusionFn controlFoldingReshapes;
+};
+
/// Pattern to collapse dimensions.
template <typename LinalgType>
class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
@@ -2239,6 +2430,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
+ patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
+ controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
}
@@ -2250,6 +2443,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
patterns.getContext(), controlFoldingReshapes);
+ patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
+ patterns.getContext(), controlFoldingReshapes);
patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
controlFoldingReshapes);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
index 9974ccd..cbd6357 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
@@ -200,10 +200,10 @@ static void populateOpPayload(
SmallVector<OpOperand *> newInputOperands = newOp.getDpsInputOperands();
updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
- SmallVector<OpOperand *> origOutputOperands = llvm::to_vector(llvm::map_range(
- genericOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
- SmallVector<OpOperand *> newOutputOperands = llvm::to_vector(llvm::map_range(
- newOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
+ SmallVector<OpOperand *> origOutputOperands =
+ llvm::to_vector(llvm::make_pointer_range(genericOp.getDpsInitsMutable()));
+ SmallVector<OpOperand *> newOutputOperands =
+ llvm::to_vector(llvm::make_pointer_range(newOp.getDpsInitsMutable()));
updateReplacements(origOutputOperands, newOutputOperands,
origOutsToNewOutsPos);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 9436f1c..161d978 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -913,8 +913,7 @@ static Value replaceByPackingResult(RewriterBase &rewriter,
llvm_unreachable("loop independence prerequisite not met");
// offsets = [maybe_leading_ivs = originalLoopIvs, 0 .. 0].
- std::copy(loopIterationCounts.begin(), loopIterationCounts.end(),
- offsets.begin());
+ llvm::copy(loopIterationCounts, offsets.begin());
hoistedPackedTensor =
scf::getForInductionVarOwner(packingResult.clonedLoopIvs.front())
->getResult(0);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 40fc0d6..c2485a0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -237,6 +237,69 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}
+/// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy`
+/// with `dilations` and `strides`.
+template <typename ConvOpTy>
+static FailureOr<LinalgOp>
+specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
+ ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
+ SmallVector<Value> inputs = genericOp.getDpsInputs();
+ ValueRange outputs = genericOp.getDpsInits();
+ SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
+ ? TypeRange(ValueRange(outputs))
+ : TypeRange{};
+ LinalgOp namedOp;
+ // Ops with no dilations and no strides.
+ if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
+ std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
+ std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
+ namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes,
+ inputs, outputs);
+ } else {
+ Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
+ Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
+ namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
+ genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
+ }
+ return namedOp;
+}
+
+/// Converts linalg.generic to named linalg.*conv/pooling* where possible.
+static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ SmallVector<int64_t> dilations, strides;
+#define CONV_OP_SPECIALIZER(ConvOpTy) \
+ if (isaConvolutionOpOfType<ConvOpTy>(genericOp, &dilations, &strides)) \
+ return specializeToConvOp<ConvOpTy>(rewriter, genericOp, dilations, \
+ strides); \
+ // -----------------------------
+ // Convolution ops.
+ // -----------------------------
+ CONV_OP_SPECIALIZER(linalg::Conv1DOp);
+ CONV_OP_SPECIALIZER(linalg::Conv1DNwcWcfOp);
+ CONV_OP_SPECIALIZER(linalg::Conv1DNcwFcwOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DOp);
+ CONV_OP_SPECIALIZER(linalg::Conv3DOp);
+ // -----------------------------
+ // Depthwise Convolution ops.
+ // -----------------------------
+ CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNcwCwOp);
+ CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp);
+ CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcmOp);
+ CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp);
+ CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp);
+ // -----------------------------
+ // Pooling ops.
+ // -----------------------------
+ CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxOp);
+ CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinOp);
+ CONV_OP_SPECIALIZER(linalg::PoolingNhwcSumOp);
+ CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxUnsignedOp);
+ CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinUnsignedOp);
+#undef CONV_OP_SPECIALIZER
+ return failure();
+}
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -316,6 +379,11 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
if (isaContractionOpInterface(genericOp)) {
return specializeLinalgContractions(rewriter, genericOp);
}
+
+ // Convolution - e.g. *conv/pooling*
+ if (isaConvolutionOpInterface(genericOp)) {
+ return specializeLinalgConvolutions(rewriter, genericOp);
+ }
return failure();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 705d6f2..8e14ef4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -452,8 +452,7 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
SmallVector<OpFoldResult> allShapeSizes =
op.createFlatListOfOperandDims(b, op.getLoc());
AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap();
- if (!shapeSizesToLoopsMap)
- return failure();
+ assert(shapeSizesToLoopsMap && "invalid linalgOp with null ShapesToLoopsMap");
auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges(
b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 57b610b..50a84ac 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -167,7 +167,7 @@ struct LinalgOpTilingInterface
llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) {
auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr);
if (!dimExpr)
- continue;
+ return failure();
unsigned position = dimExpr.getPosition();
auto it = mappedOffsets.find(position);
if (it != mappedOffsets.end()) {
@@ -216,8 +216,6 @@ struct LinalgOpTilingInterface
SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
auto linalgOp = cast<LinalgOp>(op);
- std::optional<SmallVector<OpFoldResult>> iterationSpaceOffsets,
- iterationSpaceSizes;
SmallVector<AffineMap> indexingMaps =
llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) {
OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
@@ -359,6 +357,32 @@ struct LinalgOpTilingInterface
/// Inline the op payload and store the result.
return inlinePayload(builder, linalgOp, ivs, indexedValues);
}
+
+ bool isOpFusableWithConsumerSlice(Operation *op, unsigned resultNumber,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) const {
+ // The verifier gives all the necessary requirements for consumer fusion.
+ return true;
+ }
+
+ bool isOpFusableWithProducerSlices(
+ Operation *op, ArrayRef<unsigned> operandNumbers,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
+
+ auto linalgOp = cast<LinalgOp>(op);
+ SmallVector<AffineMap> indexingMaps =
+ llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) {
+ OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
+ return linalgOp.getMatchingIndexingMap(&opOperand);
+ });
+ // Check that offsets/sizes are consistent across all operands.
+ OpBuilder b(op);
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+ return succeeded(getMappedOffsetAndSize(linalgOp, b, indexingMaps,
+ allOffsets, allSizes, mappedOffsets,
+ mappedSizes));
+ }
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index bd25e94..67e2b9f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -232,10 +232,9 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
// 2. Compute the permutation vector to shuffle packed shape into the shape
// before any outer or inner permutations have been applied.
- PackingMetadata packingMetadata = computePackingMetadata(
- packedTensorType.getRank(), packOp.getInnerDimsPos());
+ PackingMetadata packingMetadata;
SmallVector<int64_t> packedToStripMinedShapePerm =
- getPackInverseDestPerm(packOp);
+ getPackInverseDestPerm(packOp, packingMetadata);
// 3. Compute the stripMinedShape: this is the packed shape before any outer
// or inner permutations have been applied.
@@ -1168,12 +1167,9 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
"this is not supported ATM!");
}
- Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
- Attribute oneIdxAttr = rewriter.getIndexAttr(1);
Location loc = packOp.getLoc();
int64_t srcRank = packOp.getSourceRank();
- int64_t destRank = packOp.getDestRank();
// 1. Get the input that is going to be packed. If the input requires padding,
// add a padding operation and return that as the input.
@@ -1263,14 +1259,8 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
writeSizes.push_back(tileSizeOfr);
}
- // TODO: Add a constructor for tensor.insert_slice that doesn't require
- // strides nor offsets.
- SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
- SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
-
auto insert = tensor::InsertSliceOp::create(
- rewriter, loc, transposedOp.getResult()[0], packOp.getDest(),
- writeOffsets, writeSizes, writeStrides);
+ rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeSizes);
// 4. Replace tensor.packOp with tensor.insert_slice created above
rewriter.replaceOp(packOp, insert.getResult());
@@ -1280,7 +1270,6 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const {
- int64_t srcRank = unpackOp.getSourceRank();
int64_t destRank = unpackOp.getDestRank();
ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
@@ -1297,7 +1286,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
Value source = unpackOp.getSource();
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
unpackOp.getDimAndTileMapping();
- Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
// The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
@@ -1308,9 +1296,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
// outer-tiled-dims being all 1), this will be
// [ outer-untiled-dims, tile-sizes ]
SmallVector<OpFoldResult> extractSliceSizes;
- // The offset and strides attributes for ExtractSliceOp.
- SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr);
- SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr);
// Shape for EmptyOp that's used as the init value for TransposeOp below.
// This should be:
@@ -1365,8 +1350,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
Type elemType = unpackOp.getSourceType().getElementType();
auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
Value innerTile = tensor::ExtractSliceOp::create(
- rewriter, loc, readType, unpackOp.getSource(), extractSliceOffsets,
- extractSliceSizes, extractSliceStrides);
+ rewriter, loc, readType, unpackOp.getSource(), extractSliceSizes);
// 2. Transpose the tile to match the outer corresponding tile order.
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
@@ -1382,9 +1366,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
// 3. Handle in-complete tiles if needed. It truncates trailing data from the
// transposed tile.
- int numLoops = shapeForEmptyOp.size();
- SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
- SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
SmallVector<OpFoldResult> tileSizes;
ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
for (auto i : llvm::seq<unsigned>(0, destRank)) {
@@ -1394,13 +1375,11 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
}
auto partialTile =
- tensor::ExtractSliceOp::create(rewriter, loc, transposedOp.getResult()[0],
- tileOffsets, tileSizes, tileStrides);
+ tensor::ExtractSliceOp::create(rewriter, loc, RankedTensorType(),
+ transposedOp.getResult()[0], tileSizes);
// 4. Insert the result to the destination tensor.
SmallVector<OpFoldResult> writeSizes;
- SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
- SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
for (int i = 0, idx = 0; i < destRank; ++i) {
if (dimAndTileMapping.count(i) || destShape[i] != 1)
writeSizes.push_back(tileSizes[idx++]);
@@ -1408,8 +1387,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
writeSizes.push_back(oneIdxAttr);
}
auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile,
- unpackOp.getDest(), writeOffsets,
- writeSizes, writeStrides);
+ unpackOp.getDest(), writeSizes);
rewriter.replaceOp(unpackOp, insert.getResult());
return success();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index cb6199f..bb3bccd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -746,12 +746,12 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
auto vectorType = state.getCanonicalVecType(
getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
+ SmallVector<Value> indices(linalgOp.getRank(outputOperand),
+ arith::ConstantIndexOp::create(rewriter, loc, 0));
+
Operation *write;
if (vectorType.getRank() > 0) {
AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
- SmallVector<Value> indices(
- linalgOp.getRank(outputOperand),
- arith::ConstantIndexOp::create(rewriter, loc, 0));
value = broadcastIfNeeded(rewriter, value, vectorType);
assert(value.getType() == vectorType && "Incorrect type");
write = vector::TransferWriteOp::create(
@@ -762,7 +762,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
assert(value.getType() == vectorType && "Incorrect type");
write = vector::TransferWriteOp::create(rewriter, loc, value,
- outputOperand->get(), ValueRange{});
+ outputOperand->get(), indices);
}
write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
@@ -1564,13 +1564,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
return success();
}
-/// Given a linalg::PackOp, return the `dest` shape before any packing
-/// permutations.
-static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
- ArrayRef<int64_t> destShape) {
- return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
-}
-
/// Determines whether a mask for xfer_write is trivially "all true"
///
/// Given all the inputs required to generate a mask (mask sizes and shapes),
@@ -1761,99 +1754,6 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
return mlir::vector::maskOperation(builder, write, maskForWrite);
}
-/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
-/// padding value and (3) input vector sizes into:
-///
-/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
-///
-/// As in the following example:
-/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
-/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
-///
-/// This pack would be vectorized to:
-///
-/// %load = vector.mask %mask {
-/// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
-/// {in_bounds = [true, true, true]} :
-/// tensor<32x7x16xf32>, vector<32x8x16xf32>
-/// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
-/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
-/// to vector<32x4x2x1x16xf32>
-/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
-/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
-/// %write = vector.transfer_write %transpose,
-/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
-/// {in_bounds = [true, true, true, true, true]}
-/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
-///
-/// If the (3) input vector sizes are not provided, the vector sizes are
-/// determined by the result tensor shape and the `in_bounds`
-/// attribute is used instead of masking to mark out-of-bounds accesses.
-///
-/// NOTE: The input vector sizes specify the dimensions corresponding to the
-/// outer dimensions of the output tensor. The remaining dimensions are
-/// computed based on, e.g., the static inner tiles.
-/// Supporting dynamic inner tiles will require the user to specify the
-/// missing vector sizes. This is left as a TODO.
-static LogicalResult
-vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
- ArrayRef<int64_t> inputVectorSizes,
- SmallVectorImpl<Value> &newResults) {
- // TODO: Introduce a parent class that will handle the insertion point update.
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(packOp);
-
- Location loc = packOp.getLoc();
- std::optional<Value> padValue = packOp.getPaddingValue()
- ? std::optional(packOp.getPaddingValue())
- : std::nullopt;
-
- // If the input vector sizes are not provided, then the vector sizes are
- // determined by the result tensor shape. In case the vector sizes aren't
- // provided, we update the inBounds attribute instead of masking.
- bool useInBoundsInsteadOfMasking = false;
- if (inputVectorSizes.empty()) {
- ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
- inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
- useInBoundsInsteadOfMasking = true;
- }
-
- // Create masked TransferReadOp.
- SmallVector<int64_t> inputShape(inputVectorSizes);
- auto innerTiles = packOp.getStaticInnerTiles();
- auto innerDimsPos = packOp.getInnerDimsPos();
- auto outerDimsPerm = packOp.getOuterDimsPerm();
- if (!outerDimsPerm.empty())
- applyPermutationToVector(inputShape,
- invertPermutationVector(outerDimsPerm));
- for (auto [idx, size] : enumerate(innerTiles))
- inputShape[innerDimsPos[idx]] *= size;
- auto maskedRead = vector::createReadOrMaskedRead(
- rewriter, loc, packOp.getSource(), inputShape, padValue,
- useInBoundsInsteadOfMasking,
- /*inputScalableVecSizes=*/{});
-
- // Create ShapeCastOp.
- SmallVector<int64_t> destShape(inputVectorSizes);
- destShape.append(innerTiles.begin(), innerTiles.end());
- auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
- packOp.getDestType().getElementType());
- auto shapeCastOp =
- vector::ShapeCastOp::create(rewriter, loc, tiledPackType, maskedRead);
-
- // Create TransposeOp.
- auto destPermutation =
- invertPermutationVector(getPackInverseDestPerm(packOp));
- auto transposeOp = vector::TransposeOp::create(
- rewriter, loc, shapeCastOp.getResult(), destPermutation);
-
- // Create TransferWriteOp.
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, transposeOp.getResult(), packOp.getDest());
- newResults.push_back(write->getResult(0));
- return success();
-}
-
/// Given the re-associations, "collapses" the input Vector type
///
/// This is similar to CollapseShapeOp::inferCollapsedType with two notable
@@ -1901,12 +1801,120 @@ static VectorType getCollapsedVecType(VectorType type,
return VectorType::get(newShape, type.getElementType(), newScalableFlags);
}
+/// Vectorize `linalg.pack` as:
+/// * xfer_read -> shape_cast -> transpose -> xfer_write
+///
+/// The input-vector-sizes specify the _write_ vector sizes (i.e. the vector
+/// sizes for the xfer_write operation). This is sufficient to infer the other
+/// vector sizes required here.
+///
+/// If the vector sizes are not provided:
+/// * the vector sizes are determined from the destination tensor static shape.
+/// * the inBounds attribute is used instead of masking.
+///
+/// EXAMPLE (no vector sizes):
+/// ```
+/// %pack = tensor.pack %src
+/// inner_dims_pos = [2, 1]
+/// inner_tiles = [16, 2]
+/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
+/// ``
+/// is vectorizes as:
+/// ```
+/// %read = vector.transfer_read %src
+/// : tensor<32x7x16xf32>, vector<32x8x16xf32>
+/// %sc = vector.shape_cast %read
+/// : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
+/// %tr = vector.transpose %sc, [0, 1, 3, 4, 2]
+/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
+/// %write = vector.transfer_write %tr into %dest
+/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
+/// ```
+static LogicalResult
+vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+ if (!inputVectorSizes.empty()) {
+ assert(inputVectorSizes.size() == packOp.getDestRank() &&
+ "Invalid number of input vector sizes!");
+ }
+
+ // TODO: Introduce a parent class that will handle the insertion point update.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(packOp);
+
+ Location loc = packOp.getLoc();
+ std::optional<Value> padValue = packOp.getPaddingValue()
+ ? std::optional(packOp.getPaddingValue())
+ : std::nullopt;
+
+ SmallVector<int64_t> destShape =
+ SmallVector<int64_t>(packOp.getDestType().getShape());
+
+ // This is just a convenience alias to clearly communicate that the input
+ // vector sizes determine the _write_ sizes.
+ ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes;
+
+ // In the absence of input-vector-sizes, use the _static_ input tensor shape.
+ // In addition, use the inBounds attribute instead of masking.
+ bool useInBoundsInsteadOfMasking = false;
+ if (writeVectorSizes.empty()) {
+ if (ShapedType::isDynamicShape(destShape))
+ return rewriter.notifyMatchFailure(packOp,
+ "unable to infer vector sizes");
+
+ writeVectorSizes = destShape;
+ useInBoundsInsteadOfMasking = true;
+ }
+
+ // Compute pre-transpose-write-vector-type, i.e. the write vector type
+ // _before_ the transposition (i.e. before dimension permutation). This is
+ // done by inverting the permutation/transposition that's part of the Pack
+ // operation. This type is required to:
+ // 1) compute the read vector type for masked-read below, and
+ // 2) generate shape-cast Op below that expands the read vector type.
+ PackingMetadata packMetadata;
+ SmallVector<int64_t> preTransposeWriteVecSizses(writeVectorSizes);
+ auto destInvPermutation = getPackInverseDestPerm(packOp, packMetadata);
+ applyPermutationToVector(preTransposeWriteVecSizses, destInvPermutation);
+ auto preTransposeWriteVecType = VectorType::get(
+ preTransposeWriteVecSizses, packOp.getType().getElementType());
+
+ // Compute vector type for the _read_ opeartion. This is simply
+ // pre-transpose-write-vector-type with the dimensions collapsed
+ // as per the Pack operation.
+ VectorType readVecType = getCollapsedVecType(
+ preTransposeWriteVecType,
+ getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+ rewriter.getContext(), packMetadata.reassociations)));
+
+ // Create masked TransferReadOp.
+ auto maskedRead = vector::createReadOrMaskedRead(
+ rewriter, loc, packOp.getSource(), readVecType, padValue,
+ useInBoundsInsteadOfMasking);
+
+ // Create ShapeCastOp.
+ auto shapeCastOp = vector::ShapeCastOp::create(
+ rewriter, loc, preTransposeWriteVecType, maskedRead);
+
+ // Create TransposeOp.
+ auto destPermutation = invertPermutationVector(destInvPermutation);
+ auto transposeOp = vector::TransposeOp::create(
+ rewriter, loc, shapeCastOp.getResult(), destPermutation);
+
+ // Create TransferWriteOp.
+ Operation *write = createWriteOrMaskedWrite(
+ rewriter, loc, transposeOp.getResult(), packOp.getDest());
+ newResults.push_back(write->getResult(0));
+ return success();
+}
+
/// Vectorize `linalg.unpack` as:
/// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
///
-/// The input-vector-sizes specify the read vector sizes (i.e. the vector sizes
-/// for the xfer_read operation). This is sufficient to infer the other vector
-/// sizes required here.
+/// The input-vector-sizes specify the _read_ vector sizes (i.e. the vector
+/// sizes for the xfer_read operation). This is sufficient to infer the other
+/// vector sizes required here.
///
/// If the vector sizes are not provided:
/// * the vector sizes are determined from the input tensor static shape.
@@ -1960,16 +1968,20 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
// In the absence of input-vector-sizes, use the _static_ input tensor shape.
if (inputVectorSizes.empty()) {
if (ShapedType::isDynamicShape(sourceShape))
- return failure();
+ return rewriter.notifyMatchFailure(unpackOp,
+ "Unable to infer vector sizes!");
readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
useInBoundsInsteadOfMasking = true;
}
// -- Generate the read operation --
+ VectorType readVecType =
+ VectorType::get(readVectorSizes, unpackTensorType.getElementType(),
+ readScalableVectorFlags);
Value readResult = vector::createReadOrMaskedRead(
- rewriter, loc, unpackOp.getSource(), readVectorSizes, std::nullopt,
- useInBoundsInsteadOfMasking, readScalableVectorFlags);
+ rewriter, loc, unpackOp.getSource(), readVecType, std::nullopt,
+ useInBoundsInsteadOfMasking);
// -- Generate the transpose operation --
PackingMetadata packMetadata;
@@ -2015,9 +2027,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
.reifyResultShapes(rewriter, reifiedReturnShapes);
(void)status; // prevent unused variable warning on non-assert builds
assert(succeeded(status) && "failed to reify result shapes");
+ auto readType = VectorType::get(inputVectorSizes, padValue.getType());
auto maskedRead = vector::createReadOrMaskedRead(
- rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
- /*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{});
+ rewriter, loc, padOp.getSource(), readType, padValue,
+ /*useInBoundsInsteadOfMasking=*/false);
// Create Xfer write Op
Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
@@ -2212,9 +2225,9 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
Value read = mlir::vector::createReadOrMaskedRead(
- rewriter, loc, opOperand.get(), readType.getShape(),
+ rewriter, loc, opOperand.get(), readType,
/*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
- /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
+ /*useInBoundsInsteadOfMasking=*/false);
vecOperands.push_back(read);
}
@@ -2443,6 +2456,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes) {
auto padValue = packOp.getPaddingValue();
Attribute cstAttr;
+ // TODO: Relax this condiiton
if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
LDBG() << "pad value is not constant: " << packOp;
return failure();
@@ -3154,9 +3168,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
SmallVector<Value> readIndices(
vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
Value read = mlir::vector::createReadOrMaskedRead(
- rewriter, loc, source, vecType.getShape(), padValue,
- /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(),
- /*inputScalableVecSizes=*/{});
+ rewriter, loc, source, vecType, padValue,
+ /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
// Create write
auto writeIndices =
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 24d3722..01e6e1e 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -171,29 +171,24 @@ computePackUnPackPerm(int64_t rank, ArrayRef<int64_t> &innerDimsPos,
namespace mlir {
namespace linalg {
-SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp) {
+SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp,
+ PackingMetadata &metadata) {
- PackingMetadata pMetadata;
int64_t packedRank = packOp.getDestType().getRank();
ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
SmallVector<int64_t> packInvDestPerm =
- computePackUnPackPerm(packedRank, innerDimPos, outerPerm, pMetadata);
+ computePackUnPackPerm(packedRank, innerDimPos, outerPerm, metadata);
return packInvDestPerm;
}
-SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp) {
- PackingMetadata metadata;
- return getUnPackInverseSrcPerm(unpackOp, metadata);
-}
-
SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp,
PackingMetadata &metadata) {
- int64_t unpackRank = unpackOp.getSourceType().getRank();
+ int64_t packedRank = unpackOp.getSourceType().getRank();
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm();
SmallVector<int64_t> unpackInvSrcPerm =
- computePackUnPackPerm(unpackRank, innerDimPos, outerPerm, metadata);
+ computePackUnPackPerm(packedRank, innerDimPos, outerPerm, metadata);
return unpackInvSrcPerm;
}
@@ -240,6 +235,731 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
return iteratorType == utils::IteratorType::reduction;
}
+//===----------------------------------------------------------------------===//
+// Convolution matcher utilities
+//===----------------------------------------------------------------------===//
+
+/// Returns the BlockArgument that leads to `val`, if any. Traverses optional
+/// ext* ops.
+static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) {
+ BlockArgument blockArg = dyn_cast<BlockArgument>(val);
+ if ((blockArg))
+ return blockArg;
+
+ Operation *defOp = val.getDefiningOp();
+ if (!dyn_cast_if_present<arith::ExtFOp>(defOp) &&
+ !dyn_cast_if_present<arith::ExtSIOp>(defOp) &&
+ !dyn_cast_if_present<arith::ExtUIOp>(defOp)) {
+ return nullptr;
+ }
+ return dyn_cast<BlockArgument>(defOp->getOperand(0));
+}
+
+/// Utility to match block body for convolution ops.
+/// The body is thus expected to yield :-
+/// %out + (%lhs * %rhs)
+/// where: %lhs, %rhs and %out are block arguments and
+/// %lhs and %rhs can have optional upcast operation.
+static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) {
+ Operation *addOp = yieldVal.getDefiningOp();
+ if (!isa_and_present<arith::AddIOp, arith::AddFOp>(addOp))
+ return false;
+
+ Operation *mulOp = addOp->getOperand(1).getDefiningOp();
+ if (!isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp))
+ return false;
+
+ BlockArgument lhsBlockArg =
+ getBlockArgumentWithOptionalExtOps(mulOp->getOperand(0));
+ BlockArgument rhsBlockArg =
+ getBlockArgumentWithOptionalExtOps(mulOp->getOperand(1));
+ BlockArgument outBlockArg =
+ getBlockArgumentWithOptionalExtOps(addOp->getOperand(0));
+ if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||
+ lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body ||
+ outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 ||
+ rhsBlockArg.getArgNumber() != 1 || outBlockArg.getArgNumber() != 2)
+ return false;
+ return true;
+}
+
+/// Utility to match block body for linalg.pool* ops.
+template <typename... OpTypes>
+static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
+ Operation *defOp = yieldVal.getDefiningOp();
+ if (!(isa_and_present<OpTypes>(defOp) || ...))
+ return false;
+
+ BlockArgument lhsArg =
+ getBlockArgumentWithOptionalExtOps(defOp->getOperand(0));
+ BlockArgument rhsArg =
+ getBlockArgumentWithOptionalExtOps(defOp->getOperand(1));
+ if (!lhsArg || !rhsArg || lhsArg.getOwner() != body ||
+ rhsArg.getOwner() != body || lhsArg.getArgNumber() != 2 ||
+ rhsArg.getArgNumber() != 0)
+ return false;
+ return true;
+}
+
+static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal,
+ body);
+}
+
+// max_unsigned ops should not allow float data type.
+// TODO(#164800): Retire OPDSL logic.
+static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal,
+ body);
+}
+
+static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal,
+ body);
+}
+
+// min_unsigned ops should not allow float data type.
+// TODO(#164800): Retire OPDSL logic.
+static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal,
+ body);
+}
+
+static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
+}
+
+static AffineExpr getAffineMapDim(ArrayAttr indexingMaps, uint32_t mapIndex,
+ uint32_t dimIndex) {
+ auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
+ if (dimIndex < affineMap.getNumResults())
+ return affineMap.getResult(dimIndex);
+ return nullptr;
+}
+
+/// Check if `expr` is either:
+/// - a dimension expr alone (implying multiplication by 1), or
+/// - a multiplication of dimension expr by any positive constant != 1
+/// In both cases we will capture the dimension expression into `dim` and
+/// return the constant multiplier. Returns -1 in case of a match failure.
+static int64_t isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim) {
+ if ((dim = dyn_cast<AffineDimExpr>(expr)))
+ return 1;
+
+ auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
+ if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
+ return -1;
+
+ AffineExpr lhs = mulExpr.getLHS();
+ AffineExpr rhs = mulExpr.getRHS();
+
+ AffineConstantExpr cst = nullptr;
+ if (((dim = dyn_cast<AffineDimExpr>(lhs)) &&
+ (cst = dyn_cast<AffineConstantExpr>(rhs))) ||
+ ((dim = dyn_cast<AffineDimExpr>(rhs)) &&
+ (cst = dyn_cast<AffineConstantExpr>(lhs))))
+ return cst.getValue();
+ return -1;
+}
+
+/// Given an array of AffineMaps `indexingMaps` verify the following
+/// commutatively:-
+/// indexingMaps[0].getResult(iDim) ==
+/// indexingMaps[1].getResult(fDim) * <c0> +
+/// indexingMaps[n-1].getResult(oDim) * <c1>
+/// where,
+/// - c0 and c1 can be any constant,
+/// - n is the size of the indexingMaps' array,
+/// - 0, 1 and n-1 are input, filter and output map indices respectively,
+/// - iDim, fDim and oDim are the input, filter and output dimension
+/// indices in their respective indexing maps
+/// Example:
+/// #inputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6)
+/// -> (d0, d1 * 2 + d4 * 3, d2 + d5, d6)>
+/// #filterMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+/// #outputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+///
+/// Here,
+/// #inputMap[1] = #outputMap[1] * 2 + #filterMap[0] * 3
+/// Therefore,
+/// matchConvDimAddExprPattern(indexingMaps, 1, 0, 1, dilation, stride)
+/// would return true and update dilation = 3 and stride = 2
+static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
+ unsigned fDim, unsigned oDim,
+ int64_t &dilation, int64_t &stride) {
+ unsigned inputMapIdx = 0, filterMapIdx = 1,
+ outputMapIdx = indexingMaps.size() - 1;
+ AffineExpr inpExpr = getAffineMapDim(indexingMaps, inputMapIdx, iDim);
+ auto addExpr = dyn_cast_or_null<AffineBinaryOpExpr>(inpExpr);
+ if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
+ return false;
+
+ AffineExpr dim0, dim1;
+ int64_t c0 = isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0);
+ int64_t c1 = isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1);
+
+ if (c0 == -1 || c1 == -1)
+ return false;
+ // Pattern matched with dims and constants extracted.
+ AffineExpr fExpr = getAffineMapDim(indexingMaps, filterMapIdx, fDim);
+ AffineExpr oExpr = getAffineMapDim(indexingMaps, outputMapIdx, oDim);
+ if (dim0 == fExpr && dim1 == oExpr) {
+ dilation = c0;
+ stride = c1;
+ return true;
+ }
+ if (dim1 == fExpr && dim0 == oExpr) {
+ dilation = c1;
+ stride = c0;
+ return true;
+ }
+ return false;
+}
+
+/// Returns true if the given indexing maps matches with the expected indexing
+/// maps.
+static bool convLayoutMatches(ArrayRef<ArrayRef<AffineExpr>> mapListExpected,
+ ArrayAttr indexingMaps, MLIRContext *context) {
+ SmallVector<AffineMap, 4> expectedIndexingMaps =
+ AffineMap::inferFromExprList(mapListExpected, context);
+ return indexingMaps ==
+ ArrayAttr::get(
+ context, llvm::to_vector<4>(llvm::map_range(
+ expectedIndexingMaps, [&](AffineMap m) -> Attribute {
+ return AffineMapAttr::get(m);
+ })));
+}
+
+/// Enum representing pooling operation types used by ConvMatcherBuilder.
+enum class PoolingType {
+ None,
+ MaxSigned,
+ MaxUnsigned,
+ MinSigned,
+ MinUnsigned,
+ Sum
+};
+
+/// Helper class for building convolution op matchers with minimal boilerplate.
+/// Reduces repetitive code across Conv1D/2D/3D and Depthwise variants as well
+/// as Pooling ops.
+///
+/// Usage: Create an instance with the op, spatial rank, and output pointers for
+/// extracted dilations/strides. Then chain matchStride() calls for each spatial
+/// dimension, followed by matchMaps() to verify indexing maps, and finally
+/// matchBody() to verify the operation body pattern.
+///
+/// The `matched` flag starts as `true` and is set to `false` if any match step
+/// fails. This allows chaining multiple match calls; once any match fails, all
+/// subsequent calls become no-ops and the final result is `false`.
+///
+/// The `dilations` and `strides` pointers are output parameters that get
+/// populated with the extracted dilation and stride values from the operation's
+/// indexing maps during matchStride() calls. These values are initially set to
+/// 1 for each spatial dimension and updated as patterns are matched.
+class ConvMatcherBuilder {
+ LinalgOp op;
+ MLIRContext *ctx;
+ SmallVector<int64_t> *dilations, *strides;
+ ArrayAttr indexingMaps;
+ PoolingType poolingType;
+ bool matched = true;
+
+public:
+ ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector<int64_t> *d,
+ SmallVector<int64_t> *s,
+ PoolingType poolingType = PoolingType::None)
+ : op(op), ctx(op->getContext()), dilations(d), strides(s),
+ indexingMaps(op.getIndexingMaps()), poolingType(poolingType) {
+ *dilations = SmallVector<int64_t>(spatialRank, 1);
+ *strides = SmallVector<int64_t>(spatialRank, 1);
+ }
+
+ /// Get affine dimension expression for dimension `i`.
+ AffineExpr dim(unsigned i) { return getAffineDimExpr(i, ctx); }
+
+ /// Build strided expression: base * stride[idx] + kernel * dilation[idx].
+ AffineExpr strided(AffineExpr base, AffineExpr kernel, unsigned idx) {
+ return base * (*strides)[idx] + kernel * (*dilations)[idx];
+ }
+
+ /// Match stride/dilation pattern for a spatial dimension.
+ /// Returns *this for method chaining.
+ ConvMatcherBuilder &matchStride(unsigned iDim, unsigned fDim, unsigned oDim,
+ unsigned idx) {
+ if (matched) {
+ matched &= matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim,
+ (*dilations)[idx], (*strides)[idx]);
+ }
+ return *this;
+ }
+
+ /// Match expected indexing maps layout. Returns *this for method chaining.
+ ConvMatcherBuilder &matchMaps(ArrayRef<ArrayRef<AffineExpr>> maps) {
+ if (matched)
+ matched &= convLayoutMatches(maps, indexingMaps, ctx);
+ return *this;
+ }
+
+ /// Match body pattern. This should be called last.
+ bool matchBody() {
+ if (!matched)
+ return false;
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ switch (poolingType) {
+ case PoolingType::None:
+ return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body);
+ case PoolingType::MaxSigned:
+ return bodyMatcherForMaxSignedPoolOps(yieldOp.getOperand(0), body);
+ case PoolingType::MaxUnsigned:
+ return bodyMatcherForMaxUnsignedPoolOps(yieldOp.getOperand(0), body);
+ case PoolingType::MinSigned:
+ return bodyMatcherForMinSignedPoolOps(yieldOp.getOperand(0), body);
+ case PoolingType::MinUnsigned:
+ return bodyMatcherForMinUnsignedPoolOps(yieldOp.getOperand(0), body);
+ case PoolingType::Sum:
+ return bodyMatcherForSumPoolOps(yieldOp.getOperand(0), body);
+ }
+ return false;
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Matchers for specific convolution operation.
+//===----------------------------------------------------------------------===//
+
+// #inputMap = affine_map<(W, w) -> (W + w)>
+// #filterMap = affine_map<(W, w) -> (w)>
+// #outputMap = affine_map<(W, w) -> (W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv1DOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+ AffineExpr W = m.dim(0);
+ AffineExpr w = m.dim(1);
+
+ return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{m.strided(W, w, 0)},
+ /*filterMap=*/{w},
+ /*outputMap=*/{W}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(N, W, F, w, c) -> (N, W + w, c)>
+// #filterMap = affine_map<(N, W, F, w, c) -> (w, c, F)>
+// #outputMap = affine_map<(N, W, F, w, c) -> (N, W, F)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv1DNwcWcfOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+ AffineExpr N = m.dim(0);
+ AffineExpr W = m.dim(1);
+ AffineExpr F = m.dim(2);
+ AffineExpr w = m.dim(3);
+ AffineExpr c = m.dim(4);
+
+ return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), c},
+ /*filterMap=*/{w, c, F},
+ /*outputMap=*/{N, W, F}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(N, F, W, c, w) -> (N, c, W + w)>
+// #filterMap = affine_map<(N, F, W, c, w) -> (F, c, w)>
+// #outputMap = affine_map<(N, F, W, c, w) -> (N, F, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv1DNcwFcwOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+ AffineExpr N = m.dim(0);
+ AffineExpr F = m.dim(1);
+ AffineExpr W = m.dim(2);
+ AffineExpr c = m.dim(3);
+ AffineExpr w = m.dim(4);
+
+ return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)},
+ /*filterMap=*/{F, c, w},
+ /*outputMap=*/{N, F, W}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(H, W, h, w) -> (H + h, W + w)>
+// #filterMap = affine_map<(H, W, h, w) -> (h, w)>
+// #outputMap = affine_map<(H, W, h, w) -> (H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv2DOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+ AffineExpr H = m.dim(0);
+ AffineExpr W = m.dim(1);
+ AffineExpr h = m.dim(2);
+ AffineExpr w = m.dim(3);
+
+ return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
+ .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
+ .matchMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{H, W}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)>
+// #filterMap = affine_map<(D, H, W, d, h, w) -> (d, h, w)>
+// #outputMap = affine_map<(D, H, W, d, h, w) -> (D, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv3DOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
+ AffineExpr D = m.dim(0);
+ AffineExpr H = m.dim(1);
+ AffineExpr W = m.dim(2);
+ AffineExpr d = m.dim(3);
+ AffineExpr h = m.dim(4);
+ AffineExpr w = m.dim(5);
+
+ return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
+ .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
+ .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/2)
+ .matchMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1),
+ m.strided(W, w, 2)},
+ /*filterMap=*/{d, h, w},
+ /*outputMap=*/{D, H, W}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(N, W, C, w) -> (N, C, W + w)>
+// #filterMap = affine_map<(N, W, C, w) -> (C, w)>
+// #outputMap = affine_map<(N, W, C, w) -> (N, C, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv1DNcwCwOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+ AffineExpr N = m.dim(0);
+ AffineExpr W = m.dim(1);
+ AffineExpr C = m.dim(2);
+ AffineExpr w = m.dim(3);
+
+ return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
+ /*filterMap=*/{C, w},
+ /*outputMap=*/{N, C, W}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)>
+// #filterMap = affine_map<(N, W, C, w) -> (w, C)>
+// #outputMap = affine_map<(N, W, C, w) -> (N, W, C)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+ AffineExpr N = m.dim(0);
+ AffineExpr W = m.dim(1);
+ AffineExpr C = m.dim(2);
+ AffineExpr w = m.dim(3);
+
+ return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+ /*filterMap=*/{w, C},
+ /*outputMap=*/{N, W, C}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(N, W, C, CM, w) -> (N, W + w, C)>
+// #filterMap = affine_map<(N, W, C, CM, w) -> (w, C, CM)>
+// #outputMap = affine_map<(N, W, C, CM, w) -> (N, W, C, CM)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+ AffineExpr N = m.dim(0);
+ AffineExpr W = m.dim(1);
+ AffineExpr C = m.dim(2);
+ AffineExpr CM = m.dim(3);
+ AffineExpr w = m.dim(4);
+
+ return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+ /*filterMap=*/{w, C, CM},
+ /*outputMap=*/{N, W, C, CM}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)>
+// #filterMap = affine_map<(N, H, W, C, h, w) -> (C, h, w)>
+// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+ AffineExpr N = m.dim(0);
+ AffineExpr H = m.dim(1);
+ AffineExpr W = m.dim(2);
+ AffineExpr C = m.dim(3);
+ AffineExpr h = m.dim(4);
+ AffineExpr w = m.dim(5);
+
+ return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
+ .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1)
+ .matchMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
+ /*filterMap=*/{C, h, w},
+ /*outputMap=*/{N, C, H, W}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(N, D, H, W, CM, d, h, w, C)
+// -> (N, D + d, H + h, W + w, C)>
+// #filterMap = affine_map<(N, D, H, W, CM, d, h, w, C)
+// -> (d, h, w, C, CM)>
+// #outputMap = affine_map<(N, D, H, W, CM, d, h, w, C)
+// -> (N, D, H, W, C, CM)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
+ AffineExpr N = m.dim(0);
+ AffineExpr D = m.dim(1);
+ AffineExpr H = m.dim(2);
+ AffineExpr W = m.dim(3);
+ AffineExpr CM = m.dim(4);
+ AffineExpr d = m.dim(5);
+ AffineExpr h = m.dim(6);
+ AffineExpr w = m.dim(7);
+ AffineExpr C = m.dim(8);
+
+ return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
+ .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
+ m.strided(W, w, 2), C},
+ /*filterMap=*/{d, h, w, C, CM},
+ /*outputMap=*/{N, D, H, W, C, CM}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)>
+// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcMaxOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
+ PoolingType::MaxSigned);
+ AffineExpr N = m.dim(0);
+ AffineExpr H = m.dim(1);
+ AffineExpr W = m.dim(2);
+ AffineExpr C = m.dim(3);
+ AffineExpr h = m.dim(4);
+ AffineExpr w = m.dim(5);
+
+ return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)>
+// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcMinOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
+ PoolingType::MinSigned);
+ AffineExpr N = m.dim(0);
+ AffineExpr H = m.dim(1);
+ AffineExpr W = m.dim(2);
+ AffineExpr C = m.dim(3);
+ AffineExpr h = m.dim(4);
+ AffineExpr w = m.dim(5);
+
+ return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)>
+// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcSumOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
+ PoolingType::Sum);
+ AffineExpr N = m.dim(0);
+ AffineExpr H = m.dim(1);
+ AffineExpr W = m.dim(2);
+ AffineExpr C = m.dim(3);
+ AffineExpr h = m.dim(4);
+ AffineExpr w = m.dim(5);
+
+ return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)>
+// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
+ PoolingType::MaxUnsigned);
+ AffineExpr N = m.dim(0);
+ AffineExpr H = m.dim(1);
+ AffineExpr W = m.dim(2);
+ AffineExpr C = m.dim(3);
+ AffineExpr h = m.dim(4);
+ AffineExpr w = m.dim(5);
+
+ return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)>
+// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcMinUnsignedOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
+ PoolingType::MinUnsigned);
+ AffineExpr N = m.dim(0);
+ AffineExpr H = m.dim(1);
+ AffineExpr W = m.dim(2);
+ AffineExpr C = m.dim(3);
+ AffineExpr h = m.dim(4);
+ AffineExpr w = m.dim(5);
+
+ return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
+ .matchBody();
+}
+
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
Value source, Value pad, bool nofold,
ValueRange typeDynDims) {
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index 1382c7ac..d358362 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
MLIRMemorySlotInterfaces
MLIRShapedOpInterfaces
MLIRSideEffectInterfaces
+ MLIRUBDialect
MLIRValueBoundsOpInterface
MLIRViewLikeInterface
)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp
index 6ff63df..a1e3f10 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp
@@ -10,6 +10,7 @@
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index dfa2e4e..5404238 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
@@ -61,15 +62,8 @@ static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
// Interfaces for AllocaOp
//===----------------------------------------------------------------------===//
-static bool isSupportedElementType(Type type) {
- return llvm::isa<MemRefType>(type) ||
- OpBuilder(type.getContext()).getZeroAttr(type);
-}
-
SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
MemRefType type = getType();
- if (!isSupportedElementType(type.getElementType()))
- return {};
if (!type.hasStaticShape())
return {};
// Make sure the memref contains only a single element.
@@ -81,16 +75,7 @@ SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
OpBuilder &builder) {
- assert(isSupportedElementType(slot.elemType));
- // TODO: support more types.
- return TypeSwitch<Type, Value>(slot.elemType)
- .Case([&](MemRefType t) {
- return memref::AllocaOp::create(builder, getLoc(), t);
- })
- .Default([&](Type t) {
- return arith::ConstantOp::create(builder, getLoc(), t,
- builder.getZeroAttr(t));
- });
+ return ub::PoisonOp::create(builder, getLoc(), slot.elemType);
}
std::optional<PromotableAllocationOpInterface>
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1c21a2f..1035d7c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1074,13 +1074,6 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
return subview.getDynamicSize(sourceIndex);
}
- if (auto sizeInterface =
- dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
- assert(sizeInterface.isDynamicSize(unsignedIndex) &&
- "Expected dynamic subview size");
- return sizeInterface.getDynamicSize(unsignedIndex);
- }
-
// dim(memrefcast) -> dim
if (succeeded(foldMemRefCast(*this)))
return getResult();
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index bd02516..c9352e8 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -959,7 +959,11 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
PatternRewriter &rewriter) const override {
auto viewLikeOp =
extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
- if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest())
+ // ViewLikeOpInterface by itself doesn't guarantee to preserve the base
+ // pointer in general and `memref.view` is one such example, so just check
+ // for a few specific cases.
+ if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest() ||
+ !isa<memref::SubViewOp, memref::ReinterpretCastOp>(viewLikeOp))
return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source");
rewriter.modifyOpInPlace(extractOp, [&]() {
extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 214410f..3667fdb 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -347,28 +347,55 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
return failure();
- llvm::TypeSwitch<Operation *, void>(loadOp)
+
+ return llvm::TypeSwitch<Operation *, LogicalResult>(loadOp)
.Case([&](affine::AffineLoadOp op) {
rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
loadOp, expandShapeOp.getViewSource(), sourceIndices);
+ return success();
})
.Case([&](memref::LoadOp op) {
rewriter.replaceOpWithNewOp<memref::LoadOp>(
loadOp, expandShapeOp.getViewSource(), sourceIndices,
op.getNontemporal());
+ return success();
})
.Case([&](vector::LoadOp op) {
rewriter.replaceOpWithNewOp<vector::LoadOp>(
op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
op.getNontemporal());
+ return success();
})
.Case([&](vector::MaskedLoadOp op) {
rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
op.getMask(), op.getPassThru());
+ return success();
+ })
+ .Case([&](vector::TransferReadOp op) {
+ // We only support minor identity maps in the permutation attribute.
+ if (!op.getPermutationMap().isMinorIdentity())
+ return failure();
+
+ // We only support the case where the source of the expand shape has
+ // rank greater than or equal to the vector rank.
+ const int64_t sourceRank = sourceIndices.size();
+ const int64_t vectorRank = op.getVectorType().getRank();
+ if (sourceRank < vectorRank)
+ return failure();
+
+ // We need to construct a new minor identity map since we will have lost
+ // some dimensions in folding away the expand shape.
+ auto minorIdMap = AffineMap::getMinorIdentityMap(sourceRank, vectorRank,
+ op.getContext());
+
+ rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
+ op, op.getVectorType(), expandShapeOp.getViewSource(),
+ sourceIndices, minorIdMap, op.getPadding(), op.getMask(),
+ op.getInBounds());
+ return success();
})
.DefaultUnreachable("unexpected operation");
- return success();
}
template <typename OpTy>
@@ -659,6 +686,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
+ LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>,
StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 6a81a15..c498c8a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -90,17 +90,16 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
if (!dimIndex)
return failure();
- ReifiedRankedShapedTypeDims reifiedResultShapes;
- if (failed(reifyResultShapes(rewriter, dimValue.getOwner(),
- reifiedResultShapes)))
+ FailureOr<OpFoldResult> replacement = reifyDimOfResult(
+ rewriter, dimValue.getOwner(), dimValue.getResultNumber(), *dimIndex);
+ if (failed(replacement))
return failure();
- unsigned resultNumber = dimValue.getResultNumber();
- // Do not apply pattern if the IR is invalid (dim out of bounds).
- if ((size_t)(*dimIndex) >= reifiedResultShapes[resultNumber].size())
- return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds");
- Value replacement = getValueOrCreateConstantIndexOp(
- rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]);
- rewriter.replaceOp(dimOp, replacement);
+ // Check if the OpFoldResult is empty (unreifiable dimension).
+ if (!replacement.value())
+ return failure();
+ Value replacementVal = getValueOrCreateConstantIndexOp(
+ rewriter, dimOp.getLoc(), replacement.value());
+ rewriter.replaceOp(dimOp, replacementVal);
return success();
}
};
@@ -166,12 +165,14 @@ namespace {
struct ResolveRankedShapeTypeResultDimsPass final
: public memref::impl::ResolveRankedShapeTypeResultDimsPassBase<
ResolveRankedShapeTypeResultDimsPass> {
+ using Base::Base;
void runOnOperation() override;
};
struct ResolveShapedTypeResultDimsPass final
: public memref::impl::ResolveShapedTypeResultDimsPassBase<
ResolveShapedTypeResultDimsPass> {
+ using Base::Base;
void runOnOperation() override;
};
@@ -195,14 +196,22 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ auto result = applyPatternsGreedily(getOperation(), std::move(patterns));
+ if (errorOnPatternIterationLimit && failed(result)) {
+ getOperation()->emitOpError(
+ "dim operation resolution hit pattern iteration limit");
return signalPassFailure();
+ }
}
void ResolveShapedTypeResultDimsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ auto result = applyPatternsGreedily(getOperation(), std::move(patterns));
+ if (errorOnPatternIterationLimit && failed(result)) {
+ getOperation()->emitOpError(
+ "dim operation resolution hit pattern iteration limit");
return signalPassFailure();
+ }
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 14152c5..e5cc41e 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -268,61 +268,82 @@ struct SubViewOpInterface
MemRefType sourceType = subView.getSource().getType();
// For each dimension, assert that:
- // 0 <= offset < dim_size
- // 0 <= offset + (size - 1) * stride < dim_size
+ // For empty slices (size == 0) : 0 <= offset <= dim_size
+ // For non-empty slices (size > 0): 0 <= offset < dim_size
+ // 0 <= offset + (size - 1) * stride
+ // dim_size
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
Value one = arith::ConstantIndexOp::create(builder, loc, 1);
+
auto metadataOp =
ExtractStridedMetadataOp::create(builder, loc, subView.getSource());
+
for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
- // Reset insertion point to before the operation for each dimension
+ // Reset insertion point to before the operation for each dimension.
builder.setInsertionPoint(subView);
+
Value offset = getValueOrCreateConstantIndexOp(
builder, loc, subView.getMixedOffsets()[i]);
Value size = getValueOrCreateConstantIndexOp(builder, loc,
subView.getMixedSizes()[i]);
Value stride = getValueOrCreateConstantIndexOp(
builder, loc, subView.getMixedStrides()[i]);
-
- // Verify that offset is in-bounds.
Value dimSize = metadataOp.getSizes()[i];
- Value offsetInBounds =
- generateInBoundsCheck(builder, loc, offset, zero, dimSize);
- cf::AssertOp::create(builder, loc, offsetInBounds,
+
+ // Verify that offset is in-bounds (conditional on slice size).
+ Value sizeIsZero = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::eq, size, zero);
+ auto offsetCheckIf = scf::IfOp::create(
+ builder, loc, sizeIsZero,
+ [&](OpBuilder &b, Location loc) {
+ // For empty slices, offset can be at the boundary: 0 <= offset <=
+ // dimSize.
+ Value offsetGEZero = arith::CmpIOp::create(
+ b, loc, arith::CmpIPredicate::sge, offset, zero);
+ Value offsetLEDimSize = arith::CmpIOp::create(
+ b, loc, arith::CmpIPredicate::sle, offset, dimSize);
+ Value emptyOffsetValid =
+ arith::AndIOp::create(b, loc, offsetGEZero, offsetLEDimSize);
+ scf::YieldOp::create(b, loc, emptyOffsetValid);
+ },
+ [&](OpBuilder &b, Location loc) {
+ // For non-empty slices, offset must be a valid index: 0 <= offset
+ // dimSize.
+ Value offsetInBounds =
+ generateInBoundsCheck(b, loc, offset, zero, dimSize);
+ scf::YieldOp::create(b, loc, offsetInBounds);
+ });
+
+ Value offsetCondition = offsetCheckIf.getResult(0);
+ cf::AssertOp::create(builder, loc, offsetCondition,
generateErrorMessage(op, "offset " +
std::to_string(i) +
" is out-of-bounds"));
- // Only verify if size > 0
+ // Verify that the slice endpoint is in-bounds (only for non-empty
+ // slices).
Value sizeIsNonZero = arith::CmpIOp::create(
builder, loc, arith::CmpIPredicate::sgt, size, zero);
+ auto ifOp = scf::IfOp::create(
+ builder, loc, sizeIsNonZero,
+ [&](OpBuilder &b, Location loc) {
+ // Verify that slice does not run out-of-bounds.
+ Value sizeMinusOne = arith::SubIOp::create(b, loc, size, one);
+ Value sizeMinusOneTimesStride =
+ arith::MulIOp::create(b, loc, sizeMinusOne, stride);
+ Value lastPos =
+ arith::AddIOp::create(b, loc, offset, sizeMinusOneTimesStride);
+ Value lastPosInBounds =
+ generateInBoundsCheck(b, loc, lastPos, zero, dimSize);
+ scf::YieldOp::create(b, loc, lastPosInBounds);
+ },
+ [&](OpBuilder &b, Location loc) {
+ Value trueVal =
+ arith::ConstantOp::create(b, loc, b.getBoolAttr(true));
+ scf::YieldOp::create(b, loc, trueVal);
+ });
- auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(),
- sizeIsNonZero, /*withElseRegion=*/true);
-
- // Populate the "then" region (for size > 0).
- builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-
- // Verify that slice does not run out-of-bounds.
- Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
- Value sizeMinusOneTimesStride =
- arith::MulIOp::create(builder, loc, sizeMinusOne, stride);
- Value lastPos =
- arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
- Value lastPosInBounds =
- generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
-
- scf::YieldOp::create(builder, loc, lastPosInBounds);
-
- // Populate the "else" region (for size == 0).
- builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- Value trueVal =
- arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true));
- scf::YieldOp::create(builder, loc, trueVal);
-
- builder.setInsertionPointAfter(ifOp);
Value finalCondition = ifOp.getResult(0);
-
cf::AssertOp::create(
builder, loc, finalCondition,
generateErrorMessage(op,
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index 6200366..e548698 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -133,17 +133,20 @@ getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits,
}
/// Returns true if all the uses of op are not read/load.
-/// There can be SubviewOp users as long as all its users are also
+/// There can be view-like-op users as long as all its users are also
/// StoreOp/transfer_write. If return true it also fills out the uses, if it
/// returns false uses is unchanged.
static bool resultIsNotRead(Operation *op, std::vector<Operation *> &uses) {
std::vector<Operation *> opUses;
for (OpOperand &use : op->getUses()) {
Operation *useOp = use.getOwner();
+ // Use escaped the scope
+ if (useOp->mightHaveTrait<OpTrait::IsTerminator>())
+ return false;
if (isa<memref::DeallocOp>(useOp) ||
(useOp->getNumResults() == 0 && useOp->getNumRegions() == 0 &&
!mlir::hasEffect<MemoryEffects::Read>(useOp)) ||
- (isa<memref::SubViewOp>(useOp) && resultIsNotRead(useOp, opUses))) {
+ (isa<ViewLikeOpInterface>(useOp) && resultIsNotRead(useOp, opUses))) {
opUses.push_back(useOp);
continue;
}
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index 2a857ed..0d05313 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -675,7 +675,7 @@ MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc,
Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
OpBuilder &b, Location loc, OpFoldResult laneId, Value memref,
IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
- auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn));
+ auto loads = buildMemRefLoads(b, loc, laneId, memref, indexFn);
Type elementType = getElementTypeOrSelf(memref.getType());
auto vt = VectorType::get(vectorShape, elementType);
@@ -727,7 +727,7 @@ SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand(
[&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
toStore.push_back(v);
});
- return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn));
+ return buildMemRefStores(b, loc, toStore, laneId, memref, indexFn);
}
static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
@@ -792,7 +792,7 @@ FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
if (failed(maybeInfo))
return failure();
- MmaSyncInfo info = *maybeInfo;
+ const MmaSyncInfo &info = *maybeInfo;
auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
diff --git a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp
index 40e769e..1d775fb 100644
--- a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp
+++ b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp
@@ -41,5 +41,12 @@ InFlightDiagnostic OpenACCSupport::emitNYI(Location loc, const Twine &message) {
return mlir::emitError(loc, "not yet implemented: " + message);
}
+bool OpenACCSupport::isValidSymbolUse(Operation *user, SymbolRefAttr symbol,
+ Operation **definingOpPtr) {
+ if (impl)
+ return impl->isValidSymbolUse(user, symbol, definingOpPtr);
+ return acc::isValidSymbolUse(user, symbol, definingOpPtr);
+}
+
} // namespace acc
} // namespace mlir
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 35eba72..47f1222 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -15,6 +15,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
@@ -203,12 +204,91 @@ struct MemRefPointerLikeModel
return false;
}
+
+ mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
+ TypedValue<PointerLikeType> srcPtr,
+ Type valueType) const {
+ // Load from a memref - only valid for scalar memrefs (rank 0).
+ // This is because the address computation for memrefs is part of the load
+ // (and not computed separately), but the API does not have arguments for
+ // indexing.
+ auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr);
+ if (!memrefValue)
+ return {};
+
+ auto memrefTy = memrefValue.getType();
+
+ // Only load from scalar memrefs (rank 0)
+ if (memrefTy.getRank() != 0)
+ return {};
+
+ return memref::LoadOp::create(builder, loc, memrefValue);
+ }
+
+ bool genStore(Type pointer, OpBuilder &builder, Location loc,
+ Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
+ // Store to a memref - only valid for scalar memrefs (rank 0)
+ // This is because the address computation for memrefs is part of the store
+ // (and not computed separately), but the API does not have arguments for
+ // indexing.
+ auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr);
+ if (!memrefValue)
+ return false;
+
+ auto memrefTy = memrefValue.getType();
+
+ // Only store to scalar memrefs (rank 0)
+ if (memrefTy.getRank() != 0)
+ return false;
+
+ memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
+ return true;
+ }
};
struct LLVMPointerPointerLikeModel
: public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
LLVM::LLVMPointerType> {
Type getElementType(Type pointer) const { return Type(); }
+
+ mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
+ TypedValue<PointerLikeType> srcPtr,
+ Type valueType) const {
+ // For LLVM pointers, we need the valueType to determine what to load
+ if (!valueType)
+ return {};
+
+ return LLVM::LoadOp::create(builder, loc, valueType, srcPtr);
+ }
+
+ bool genStore(Type pointer, OpBuilder &builder, Location loc,
+ Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
+ LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
+ return true;
+ }
+};
+
+struct MemrefAddressOfGlobalModel
+ : public AddressOfGlobalOpInterface::ExternalModel<
+ MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
+ SymbolRefAttr getSymbol(Operation *op) const {
+ auto getGlobalOp = cast<memref::GetGlobalOp>(op);
+ return getGlobalOp.getNameAttr();
+ }
+};
+
+struct MemrefGlobalVariableModel
+ : public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
+ memref::GlobalOp> {
+ bool isConstant(Operation *op) const {
+ auto globalOp = cast<memref::GlobalOp>(op);
+ return globalOp.getConstant();
+ }
+
+ Region *getInitRegion(Operation *op) const {
+ // GlobalOp uses attributes for initialization, not regions
+ return nullptr;
+ }
};
/// Helper function for any of the times we need to modify an ArrayAttr based on
@@ -302,6 +382,11 @@ void OpenACCDialect::initialize() {
MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext());
LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
*getContext());
+
+ // Attach operation interfaces
+ memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
+ *getContext());
+ memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext());
}
//===----------------------------------------------------------------------===//
@@ -467,6 +552,28 @@ checkValidModifier(Op op, acc::DataClauseModifier validModifiers) {
return success();
}
+template <typename OpT, typename RecipeOpT>
+static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName) {
+ // Mappable types do not need a recipe because it is possible to generate one
+ // from its API. Reject reductions though because no API is available for them
+ // at this time.
+ if (mlir::acc::isMappableType(op.getVar().getType()) &&
+ !std::is_same_v<OpT, acc::ReductionOp>)
+ return success();
+
+ mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr();
+ if (!operandRecipe)
+ return op->emitOpError() << "recipe expected for " << operandName;
+
+ auto decl =
+ SymbolTable::lookupNearestSymbolFrom<RecipeOpT>(op, operandRecipe);
+ if (!decl)
+ return op->emitOpError()
+ << "expected symbol reference " << operandRecipe << " to point to a "
+ << operandName << " declaration";
+ return success();
+}
+
static ParseResult parseVar(mlir::OpAsmParser &parser,
OpAsmParser::UnresolvedOperand &var) {
// Either `var` or `varPtr` keyword is required.
@@ -573,6 +680,18 @@ static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op,
}
}
+static ParseResult parseRecipeSym(mlir::OpAsmParser &parser,
+ mlir::SymbolRefAttr &recipeAttr) {
+ if (failed(parser.parseAttribute(recipeAttr)))
+ return failure();
+ return success();
+}
+
+static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op,
+ mlir::SymbolRefAttr recipeAttr) {
+ p << recipeAttr;
+}
+
//===----------------------------------------------------------------------===//
// DataBoundsOp
//===----------------------------------------------------------------------===//
@@ -595,6 +714,9 @@ LogicalResult acc::PrivateOp::verify() {
return failure();
if (failed(checkNoModifier(*this)))
return failure();
+ if (failed(
+ checkRecipe<acc::PrivateOp, acc::PrivateRecipeOp>(*this, "private")))
+ return failure();
return success();
}
@@ -609,6 +731,9 @@ LogicalResult acc::FirstprivateOp::verify() {
return failure();
if (failed(checkNoModifier(*this)))
return failure();
+ if (failed(checkRecipe<acc::FirstprivateOp, acc::FirstprivateRecipeOp>(
+ *this, "firstprivate")))
+ return failure();
return success();
}
@@ -637,6 +762,9 @@ LogicalResult acc::ReductionOp::verify() {
return failure();
if (failed(checkNoModifier(*this)))
return failure();
+ if (failed(checkRecipe<acc::ReductionOp, acc::ReductionRecipeOp>(
+ *this, "reduction")))
+ return failure();
return success();
}
@@ -1042,6 +1170,65 @@ struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
}
};
+/// Remove empty acc.kernel_environment operations. If the operation has wait
+/// operands, create a acc.wait operation to preserve synchronization.
+struct RemoveEmptyKernelEnvironment
+ : public OpRewritePattern<acc::KernelEnvironmentOp> {
+ using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op,
+ PatternRewriter &rewriter) const override {
+ assert(op->getNumRegions() == 1 && "expected op to have one region");
+
+ Block &block = op.getRegion().front();
+ if (!block.empty())
+ return failure();
+
+ // Conservatively disable canonicalization of empty acc.kernel_environment
+ // operations if the wait operands in the kernel_environment cannot be fully
+ // represented by acc.wait operation.
+
+ // Disable canonicalization if device type is not the default
+ if (auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) {
+ for (auto attr : deviceTypeAttr) {
+ if (auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) {
+ if (dtAttr.getValue() != mlir::acc::DeviceType::None)
+ return failure();
+ }
+ }
+ }
+
+ // Disable canonicalization if any wait segment has a devnum
+ if (auto hasDevnumAttr = op.getHasWaitDevnumAttr()) {
+ for (auto attr : hasDevnumAttr) {
+ if (auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) {
+ if (boolAttr.getValue())
+ return failure();
+ }
+ }
+ }
+
+ // Disable canonicalization if there are multiple wait segments
+ if (auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) {
+ if (segmentsAttr.size() > 1)
+ return failure();
+ }
+
+ // Remove empty kernel environment.
+ // Preserve synchronization by creating acc.wait operation if needed.
+ if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr())
+ rewriter.replaceOpWithNewOp<acc::WaitOp>(op, op.getWaitOperands(),
+ /*asyncOperand=*/Value(),
+ /*waitDevnum=*/Value(),
+ /*async=*/nullptr,
+ /*ifCond=*/Value());
+ else
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Recipe Region Helpers
//===----------------------------------------------------------------------===//
@@ -1263,6 +1450,28 @@ PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
return recipe;
}
+std::optional<PrivateRecipeOp>
+PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
+ StringRef recipeName,
+ FirstprivateRecipeOp firstprivRecipe) {
+ // Create the private.recipe op with the same type as the firstprivate.recipe.
+ OpBuilder::InsertionGuard guard(builder);
+ auto varType = firstprivRecipe.getType();
+ auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
+
+ // Clone the init region
+ IRMapping mapping;
+ firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping);
+
+ // Clone destroy region if the firstprivate.recipe has one.
+ if (!firstprivRecipe.getDestroyRegion().empty()) {
+ IRMapping mapping;
+ firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(),
+ mapping);
+ }
+ return recipe;
+}
+
//===----------------------------------------------------------------------===//
// FirstprivateRecipeOp
//===----------------------------------------------------------------------===//
@@ -1373,40 +1582,6 @@ LogicalResult acc::ReductionRecipeOp::verifyRegions() {
}
//===----------------------------------------------------------------------===//
-// Custom parser and printer verifier for private clause
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseSymOperandList(
- mlir::OpAsmParser &parser,
- llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
- llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
- llvm::SmallVector<SymbolRefAttr> attributes;
- if (failed(parser.parseCommaSeparatedList([&]() {
- if (parser.parseAttribute(attributes.emplace_back()) ||
- parser.parseArrow() ||
- parser.parseOperand(operands.emplace_back()) ||
- parser.parseColonType(types.emplace_back()))
- return failure();
- return success();
- })))
- return failure();
- llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
- attributes.end());
- symbols = ArrayAttr::get(parser.getContext(), arrayAttr);
- return success();
-}
-
-static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op,
- mlir::OperandRange operands,
- mlir::TypeRange types,
- std::optional<mlir::ArrayAttr> attributes) {
- llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) {
- p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
- << std::get<1>(it).getType();
- });
-}
-
-//===----------------------------------------------------------------------===//
// ParallelOp
//===----------------------------------------------------------------------===//
@@ -1425,45 +1600,19 @@ static LogicalResult checkDataOperands(Op op,
return success();
}
-template <typename Op>
-static LogicalResult
-checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
- mlir::OperandRange operands, llvm::StringRef operandName,
- llvm::StringRef symbolName, bool checkOperandType = true) {
- if (!operands.empty()) {
- if (!attributes || attributes->size() != operands.size())
- return op->emitOpError()
- << "expected as many " << symbolName << " symbol reference as "
- << operandName << " operands";
- } else {
- if (attributes)
- return op->emitOpError()
- << "unexpected " << symbolName << " symbol reference";
- return success();
- }
-
+template <typename OpT, typename RecipeOpT>
+static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp,
+ const mlir::ValueRange &operands,
+ llvm::StringRef operandName) {
llvm::DenseSet<Value> set;
- for (auto args : llvm::zip(operands, *attributes)) {
- mlir::Value operand = std::get<0>(args);
-
+ for (mlir::Value operand : operands) {
+ if (!mlir::isa<OpT>(operand.getDefiningOp()))
+ return accConstructOp->emitOpError()
+ << "expected " << operandName << " as defining op";
if (!set.insert(operand).second)
- return op->emitOpError()
+ return accConstructOp->emitOpError()
<< operandName << " operand appears more than once";
-
- mlir::Type varType = operand.getType();
- auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
- auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
- if (!decl)
- return op->emitOpError()
- << "expected symbol reference " << symbolRef << " to point to a "
- << operandName << " declaration";
-
- if (checkOperandType && decl.getType() && decl.getType() != varType)
- return op->emitOpError() << "expected " << operandName << " (" << varType
- << ") to be the same type as " << operandName
- << " declaration (" << decl.getType() << ")";
}
-
return success();
}
@@ -1520,17 +1669,17 @@ static LogicalResult verifyDeviceTypeAndSegmentCountMatch(
}
LogicalResult acc::ParallelOp::verify() {
- if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
- *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
- "privatizations", /*checkOperandType=*/false)))
+ if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
+ mlir::acc::PrivateRecipeOp>(
+ *this, getPrivateOperands(), "private")))
return failure();
- if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
- *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
- "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
+ if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
+ mlir::acc::FirstprivateRecipeOp>(
+ *this, getFirstprivateOperands(), "firstprivate")))
return failure();
- if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
- *this, getReductionRecipes(), getReductionOperands(), "reduction",
- "reductions", false)))
+ if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
+ mlir::acc::ReductionRecipeOp>(
+ *this, getReductionOperands(), "reduction")))
return failure();
if (failed(verifyDeviceTypeAndSegmentCountMatch(
@@ -1661,7 +1810,6 @@ void ParallelOp::build(mlir::OpBuilder &odsBuilder,
mlir::ValueRange gangPrivateOperands,
mlir::ValueRange gangFirstPrivateOperands,
mlir::ValueRange dataClauseOperands) {
-
ParallelOp::build(
odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
/*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
@@ -1670,9 +1818,8 @@ void ParallelOp::build(mlir::OpBuilder &odsBuilder,
/*numGangsDeviceType=*/nullptr, numWorkers,
/*numWorkersDeviceType=*/nullptr, vectorLength,
/*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
- /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr,
- gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands,
- /*firstprivatizations=*/nullptr, dataClauseOperands,
+ /*selfAttr=*/nullptr, reductionOperands, gangPrivateOperands,
+ gangFirstPrivateOperands, dataClauseOperands,
/*defaultAttr=*/nullptr, /*combined=*/nullptr);
}
@@ -1749,46 +1896,22 @@ void acc::ParallelOp::addWaitOperands(
void acc::ParallelOp::addPrivatization(MLIRContext *context,
mlir::acc::PrivateOp op,
mlir::acc::PrivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getPrivateOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getPrivatizationRecipesAttr())
- llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
void acc::ParallelOp::addFirstPrivatization(
MLIRContext *context, mlir::acc::FirstprivateOp op,
mlir::acc::FirstprivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getFirstprivateOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getFirstprivatizationRecipesAttr())
- llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
void acc::ParallelOp::addReduction(MLIRContext *context,
mlir::acc::ReductionOp op,
mlir::acc::ReductionRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getReductionOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getReductionRecipesAttr())
- llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
static ParseResult parseNumGangs(
@@ -2356,17 +2479,17 @@ mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
}
LogicalResult acc::SerialOp::verify() {
- if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
- *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
- "privatizations", /*checkOperandType=*/false)))
+ if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
+ mlir::acc::PrivateRecipeOp>(
+ *this, getPrivateOperands(), "private")))
return failure();
- if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
- *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
- "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
+ if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
+ mlir::acc::FirstprivateRecipeOp>(
+ *this, getFirstprivateOperands(), "firstprivate")))
return failure();
- if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
- *this, getReductionRecipes(), getReductionOperands(), "reduction",
- "reductions", false)))
+ if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
+ mlir::acc::ReductionRecipeOp>(
+ *this, getReductionOperands(), "reduction")))
return failure();
if (failed(verifyDeviceTypeAndSegmentCountMatch(
@@ -2430,46 +2553,22 @@ void acc::SerialOp::addWaitOperands(
void acc::SerialOp::addPrivatization(MLIRContext *context,
mlir::acc::PrivateOp op,
mlir::acc::PrivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getPrivateOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getPrivatizationRecipesAttr())
- llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
void acc::SerialOp::addFirstPrivatization(
MLIRContext *context, mlir::acc::FirstprivateOp op,
mlir::acc::FirstprivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getFirstprivateOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getFirstprivatizationRecipesAttr())
- llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
void acc::SerialOp::addReduction(MLIRContext *context,
mlir::acc::ReductionOp op,
mlir::acc::ReductionRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getReductionOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getReductionRecipesAttr())
- llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
//===----------------------------------------------------------------------===//
@@ -2599,6 +2698,27 @@ LogicalResult acc::KernelsOp::verify() {
return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
}
+void acc::KernelsOp::addPrivatization(MLIRContext *context,
+ mlir::acc::PrivateOp op,
+ mlir::acc::PrivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
+ getPrivateOperandsMutable().append(op.getResult());
+}
+
+void acc::KernelsOp::addFirstPrivatization(
+ MLIRContext *context, mlir::acc::FirstprivateOp op,
+ mlir::acc::FirstprivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
+ getFirstprivateOperandsMutable().append(op.getResult());
+}
+
+void acc::KernelsOp::addReduction(MLIRContext *context,
+ mlir::acc::ReductionOp op,
+ mlir::acc::ReductionRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
+ getReductionOperandsMutable().append(op.getResult());
+}
+
void acc::KernelsOp::addNumWorkersOperand(
MLIRContext *context, mlir::Value newValue,
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
@@ -2691,6 +2811,15 @@ void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
//===----------------------------------------------------------------------===//
+// KernelEnvironmentOp
+//===----------------------------------------------------------------------===//
+
+void acc::KernelEnvironmentOp::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<RemoveEmptyKernelEnvironment>(context);
+}
+
+//===----------------------------------------------------------------------===//
// LoopOp
//===----------------------------------------------------------------------===//
@@ -2899,19 +3028,21 @@ bool hasDuplicateDeviceTypes(
}
/// Check for duplicates in the DeviceType array attribute.
-LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
+/// Returns std::nullopt if no duplicates, or the duplicate DeviceType if found.
+static std::optional<mlir::acc::DeviceType>
+checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
if (!deviceTypes)
- return success();
+ return std::nullopt;
for (auto attr : deviceTypes) {
auto deviceTypeAttr =
mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
if (!deviceTypeAttr)
- return failure();
+ return mlir::acc::DeviceType::None;
if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
- return failure();
+ return deviceTypeAttr.getValue();
}
- return success();
+ return std::nullopt;
}
LogicalResult acc::LoopOp::verify() {
@@ -2938,9 +3069,10 @@ LogicalResult acc::LoopOp::verify() {
getCollapseDeviceTypeAttr().getValue().size())
return emitOpError() << "collapse attribute count must match collapse"
<< " device_type count";
- if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr())))
- return emitOpError()
- << "duplicate device_type found in collapseDeviceType attribute";
+ if (auto duplicateDeviceType = checkDeviceTypes(getCollapseDeviceTypeAttr()))
+ return emitOpError() << "duplicate device_type `"
+ << acc::stringifyDeviceType(*duplicateDeviceType)
+ << "` found in collapseDeviceType attribute";
// Check gang
if (!getGangOperands().empty()) {
@@ -2953,8 +3085,12 @@ LogicalResult acc::LoopOp::verify() {
return emitOpError() << "gangOperandsArgType attribute count must match"
<< " gangOperands count";
}
- if (getGangAttr() && failed(checkDeviceTypes(getGangAttr())))
- return emitOpError() << "duplicate device_type found in gang attribute";
+ if (getGangAttr()) {
+ if (auto duplicateDeviceType = checkDeviceTypes(getGangAttr()))
+ return emitOpError() << "duplicate device_type `"
+ << acc::stringifyDeviceType(*duplicateDeviceType)
+ << "` found in gang attribute";
+ }
if (failed(verifyDeviceTypeAndSegmentCountMatch(
*this, getGangOperands(), getGangOperandsSegmentsAttr(),
@@ -2962,22 +3098,30 @@ LogicalResult acc::LoopOp::verify() {
return failure();
// Check worker
- if (failed(checkDeviceTypes(getWorkerAttr())))
- return emitOpError() << "duplicate device_type found in worker attribute";
- if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())))
- return emitOpError() << "duplicate device_type found in "
- "workerNumOperandsDeviceType attribute";
+ if (auto duplicateDeviceType = checkDeviceTypes(getWorkerAttr()))
+ return emitOpError() << "duplicate device_type `"
+ << acc::stringifyDeviceType(*duplicateDeviceType)
+ << "` found in worker attribute";
+ if (auto duplicateDeviceType =
+ checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr()))
+ return emitOpError() << "duplicate device_type `"
+ << acc::stringifyDeviceType(*duplicateDeviceType)
+ << "` found in workerNumOperandsDeviceType attribute";
if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
getWorkerNumOperandsDeviceTypeAttr(),
"worker")))
return failure();
// Check vector
- if (failed(checkDeviceTypes(getVectorAttr())))
- return emitOpError() << "duplicate device_type found in vector attribute";
- if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr())))
- return emitOpError() << "duplicate device_type found in "
- "vectorOperandsDeviceType attribute";
+ if (auto duplicateDeviceType = checkDeviceTypes(getVectorAttr()))
+ return emitOpError() << "duplicate device_type `"
+ << acc::stringifyDeviceType(*duplicateDeviceType)
+ << "` found in vector attribute";
+ if (auto duplicateDeviceType =
+ checkDeviceTypes(getVectorOperandsDeviceTypeAttr()))
+ return emitOpError() << "duplicate device_type `"
+ << acc::stringifyDeviceType(*duplicateDeviceType)
+ << "` found in vectorOperandsDeviceType attribute";
if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
getVectorOperandsDeviceTypeAttr(),
"vector")))
@@ -3042,19 +3186,19 @@ LogicalResult acc::LoopOp::verify() {
}
}
- if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
- *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
- "privatizations", false)))
+ if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
+ mlir::acc::PrivateRecipeOp>(
+ *this, getPrivateOperands(), "private")))
return failure();
- if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
- *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
- "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
+ if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
+ mlir::acc::FirstprivateRecipeOp>(
+ *this, getFirstprivateOperands(), "firstprivate")))
return failure();
- if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
- *this, getReductionRecipes(), getReductionOperands(), "reduction",
- "reductions", false)))
+ if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
+ mlir::acc::ReductionRecipeOp>(
+ *this, getReductionOperands(), "reduction")))
return failure();
if (getCombined().has_value() &&
@@ -3068,8 +3212,12 @@ LogicalResult acc::LoopOp::verify() {
if (getRegion().empty())
return emitError("expected non-empty body.");
- // When it is container-like - it is expected to hold a loop-like operation.
- if (isContainerLike()) {
+ if (getUnstructured()) {
+ if (!isContainerLike())
+ return emitError(
+ "unstructured acc.loop must not have induction variables");
+ } else if (isContainerLike()) {
+ // When it is container-like - it is expected to hold a loop-like operation.
// Obtain the maximum collapse count - we use this to check that there
// are enough loops contained.
uint64_t collapseCount = getCollapseValue().value_or(1);
@@ -3484,45 +3632,21 @@ void acc::LoopOp::addGangOperands(
void acc::LoopOp::addPrivatization(MLIRContext *context,
mlir::acc::PrivateOp op,
mlir::acc::PrivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getPrivateOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getPrivatizationRecipesAttr())
- llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
void acc::LoopOp::addFirstPrivatization(
MLIRContext *context, mlir::acc::FirstprivateOp op,
mlir::acc::FirstprivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getFirstprivateOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getFirstprivatizationRecipesAttr())
- llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op,
mlir::acc::ReductionRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getReductionOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getReductionRecipesAttr())
- llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
//===----------------------------------------------------------------------===//
@@ -3987,7 +4111,8 @@ LogicalResult acc::RoutineOp::verify() {
if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
- "be present at the same time";
+ "be present at the same time for device_type `"
+ << acc::stringifyDeviceType(dtype) << "`";
}
return success();
@@ -4284,6 +4409,100 @@ RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
return std::nullopt;
}
+void RoutineOp::addSeq(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+ setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
+ effectiveDeviceTypes));
+}
+
+void RoutineOp::addVector(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+ setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
+ effectiveDeviceTypes));
+}
+
+void RoutineOp::addWorker(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+ setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
+ effectiveDeviceTypes));
+}
+
+void RoutineOp::addGang(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+ setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
+ effectiveDeviceTypes));
+}
+
+void RoutineOp::addGang(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
+ uint64_t val) {
+ llvm::SmallVector<mlir::Attribute> dimValues;
+ llvm::SmallVector<mlir::Attribute> deviceTypes;
+
+ if (getGangDimAttr())
+ llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
+ if (getGangDimDeviceTypeAttr())
+ llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
+
+ assert(dimValues.size() == deviceTypes.size());
+
+ if (effectiveDeviceTypes.empty()) {
+ dimValues.push_back(
+ mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
+ deviceTypes.push_back(
+ acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
+ } else {
+ for (DeviceType dt : effectiveDeviceTypes) {
+ dimValues.push_back(
+ mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
+ deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
+ }
+ }
+ assert(dimValues.size() == deviceTypes.size());
+
+ setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
+ setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
+}
+
+void RoutineOp::addBindStrName(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
+ mlir::StringAttr val) {
+ unsigned before = getBindStrNameDeviceTypeAttr()
+ ? getBindStrNameDeviceTypeAttr().size()
+ : 0;
+
+ setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
+ context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes));
+ unsigned after = getBindStrNameDeviceTypeAttr().size();
+
+ llvm::SmallVector<mlir::Attribute> vals;
+ if (getBindStrNameAttr())
+ llvm::copy(getBindStrNameAttr(), std::back_inserter(vals));
+ for (unsigned i = 0; i < after - before; ++i)
+ vals.push_back(val);
+
+ setBindStrNameAttr(mlir::ArrayAttr::get(context, vals));
+}
+
+void RoutineOp::addBindIDName(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
+ mlir::SymbolRefAttr val) {
+ unsigned before =
+ getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0;
+
+ setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
+ context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes));
+ unsigned after = getBindIdNameDeviceTypeAttr().size();
+
+ llvm::SmallVector<mlir::Attribute> vals;
+ if (getBindIdNameAttr())
+ llvm::copy(getBindIdNameAttr(), std::back_inserter(vals));
+ for (unsigned i = 0; i < after - before; ++i)
+ vals.push_back(val);
+
+ setBindIdNameAttr(mlir::ArrayAttr::get(context, vals));
+}
+
//===----------------------------------------------------------------------===//
// InitOp
//===----------------------------------------------------------------------===//
@@ -4667,3 +4886,12 @@ mlir::acc::getMutableDataOperands(mlir::Operation *accOp) {
.Default([&](mlir::Operation *) { return nullptr; })};
return dataOperands;
}
+
+mlir::SymbolRefAttr mlir::acc::getRecipe(mlir::Operation *accOp) {
+ auto recipe{
+ llvm::TypeSwitch<mlir::Operation *, mlir::SymbolRefAttr>(accOp)
+ .Case<ACC_DATA_ENTRY_OPS>(
+ [&](auto entry) { return entry.getRecipeAttr(); })
+ .Default([&](mlir::Operation *) { return mlir::SymbolRefAttr{}; })};
+ return recipe;
+}
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
new file mode 100644
index 0000000..67cdf10
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
@@ -0,0 +1,781 @@
+//===- ACCImplicitData.cpp ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass implements the OpenACC specification for "Variables with
+// Implicitly Determined Data Attributes" (OpenACC 3.4 spec, section 2.6.2).
+//
+// Overview:
+// ---------
+// The pass automatically generates data clause operations for variables used
+// within OpenACC compute constructs (parallel, kernels, serial) that do not
+// already have explicit data clauses. The semantics follow these rules:
+//
+// 1. If there is a default(none) clause visible, no implicit data actions
+// apply.
+//
+// 2. An aggregate variable (arrays, derived types, etc.) will be treated as:
+// - In a present clause when default(present) is visible.
+// - In a copy clause otherwise.
+//
+// 3. A scalar variable will be treated as if it appears in:
+// - A copy clause if the compute construct is a kernels construct.
+// - A firstprivate clause otherwise (parallel, serial).
+//
+// Requirements:
+// -------------
+// To use this pass in a pipeline, the following requirements must be met:
+//
+// 1. Type Interface Implementation: Variables from the dialect being used
+// must implement one or both of the following MLIR interfaces:
+// `acc::MappableType` and/or `acc::PointerLikeType`
+//
+// These interfaces provide the necessary methods for the pass to:
+// - Determine variable type categories (scalar vs. aggregate)
+// - Generate appropriate bounds information
+// - Generate privatization recipes
+//
+// 2. Operation Interface Implementation: Operations that access partial
+// entities or create views should implement the following MLIR
+// interfaces: `acc::PartialEntityAccess` and/or
+// `mlir::ViewLikeOpInterface`
+//
+// These interfaces are used for proper data clause ordering, ensuring
+// that base entities are mapped before derived entities (e.g., a
+// struct is mapped before its fields, an array is mapped before
+// subarray views).
+//
+// 3. Analysis Registration (Optional): If custom behavior is needed for
+// variable name extraction or alias analysis, the dialect should
+// pre-register the `acc::OpenACCSupport` and `mlir::AliasAnalysis` analyses.
+//
+// If not registered, default behavior will be used.
+//
+// Implementation Details:
+// -----------------------
+// The pass performs the following operations:
+//
+// 1. Finds candidate variables which are live-in to the compute region and
+// are not already in a data clause or private clause.
+//
+// 2. Generates both data "entry" and "exit" clause operations that match
+// the intended action depending on variable type:
+// - copy -> acc.copyin (entry) + acc.copyout (exit)
+// - present -> acc.present (entry) + acc.delete (exit)
+// - firstprivate -> acc.firstprivate (entry only, no exit)
+//
+// 3. Ensures that default clause is taken into consideration by looking
+// through current construct and parent constructs to find the "visible
+// default clause".
+//
+// 4. Fixes up SSA value links so that uses in the acc region reference the
+// result of the newly created data clause operations.
+//
+// 5. When generating implicit data clause operations, it also adds variable
+// name information and marks them with the implicit flag.
+//
+// 6. Recipes are generated by calling the appropriate entrypoints in the
+// MappableType and PointerLikeType interfaces.
+//
+// 7. AliasAnalysis is used to determine if a variable is already covered by
+// an existing data clause (e.g., an interior pointer covered by its parent).
+//
+// Examples:
+// ---------
+//
+// Example 1: Scalar in parallel construct (implicit firstprivate)
+//
+// Before:
+// func.func @test() {
+// %scalar = memref.alloca() {acc.var_name = "x"} : memref<f32>
+// acc.parallel {
+// %val = memref.load %scalar[] : memref<f32>
+// acc.yield
+// }
+// }
+//
+// After:
+// func.func @test() {
+// %scalar = memref.alloca() {acc.var_name = "x"} : memref<f32>
+// %firstpriv = acc.firstprivate varPtr(%scalar : memref<f32>)
+// -> memref<f32> {implicit = true, name = "x"}
+// acc.parallel firstprivate(@recipe -> %firstpriv : memref<f32>) {
+// %val = memref.load %firstpriv[] : memref<f32>
+// acc.yield
+// }
+// }
+//
+// Example 2: Scalar in kernels construct (implicit copy)
+//
+// Before:
+// func.func @test() {
+// %scalar = memref.alloca() {acc.var_name = "n"} : memref<i32>
+// acc.kernels {
+// %val = memref.load %scalar[] : memref<i32>
+// acc.terminator
+// }
+// }
+//
+// After:
+// func.func @test() {
+// %scalar = memref.alloca() {acc.var_name = "n"} : memref<i32>
+// %copyin = acc.copyin varPtr(%scalar : memref<i32>) -> memref<i32>
+// {dataClause = #acc<data_clause acc_copy>,
+// implicit = true, name = "n"}
+// acc.kernels dataOperands(%copyin : memref<i32>) {
+// %val = memref.load %copyin[] : memref<i32>
+// acc.terminator
+// }
+// acc.copyout accPtr(%copyin : memref<i32>)
+// to varPtr(%scalar : memref<i32>)
+// {dataClause = #acc<data_clause acc_copy>,
+// implicit = true, name = "n"}
+// }
+//
+// Example 3: Array (aggregate) in parallel (implicit copy)
+//
+// Before:
+// func.func @test() {
+// %array = memref.alloca() {acc.var_name = "arr"} : memref<100xf32>
+// acc.parallel {
+// %c0 = arith.constant 0 : index
+// %val = memref.load %array[%c0] : memref<100xf32>
+// acc.yield
+// }
+// }
+//
+// After:
+// func.func @test() {
+// %array = memref.alloca() {acc.var_name = "arr"} : memref<100xf32>
+// %copyin = acc.copyin varPtr(%array : memref<100xf32>)
+// -> memref<100xf32>
+// {dataClause = #acc<data_clause acc_copy>,
+// implicit = true, name = "arr"}
+// acc.parallel dataOperands(%copyin : memref<100xf32>) {
+// %c0 = arith.constant 0 : index
+// %val = memref.load %copyin[%c0] : memref<100xf32>
+// acc.yield
+// }
+// acc.copyout accPtr(%copyin : memref<100xf32>)
+// to varPtr(%array : memref<100xf32>)
+// {dataClause = #acc<data_clause acc_copy>,
+// implicit = true, name = "arr"}
+// }
+//
+// Example 4: Array with default(present)
+//
+// Before:
+// func.func @test() {
+// %array = memref.alloca() {acc.var_name = "arr"} : memref<100xf32>
+// acc.parallel {
+// %c0 = arith.constant 0 : index
+// %val = memref.load %array[%c0] : memref<100xf32>
+// acc.yield
+// } attributes {defaultAttr = #acc<defaultvalue present>}
+// }
+//
+// After:
+// func.func @test() {
+// %array = memref.alloca() {acc.var_name = "arr"} : memref<100xf32>
+// %present = acc.present varPtr(%array : memref<100xf32>)
+// -> memref<100xf32>
+// {implicit = true, name = "arr"}
+// acc.parallel dataOperands(%present : memref<100xf32>)
+// attributes {defaultAttr = #acc<defaultvalue present>} {
+// %c0 = arith.constant 0 : index
+// %val = memref.load %present[%c0] : memref<100xf32>
+// acc.yield
+// }
+// acc.delete accPtr(%present : memref<100xf32>)
+// {dataClause = #acc<data_clause acc_present>,
+// implicit = true, name = "arr"}
+// }
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+
+#include "mlir/Analysis/AliasAnalysis.h"
+#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Dialect/OpenACC/OpenACCUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/ErrorHandling.h"
+#include <type_traits>
+
+namespace mlir {
+namespace acc {
+#define GEN_PASS_DEF_ACCIMPLICITDATA
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+} // namespace acc
+} // namespace mlir
+
+#define DEBUG_TYPE "acc-implicit-data"
+
+using namespace mlir;
+
+namespace {
+
+class ACCImplicitData : public acc::impl::ACCImplicitDataBase<ACCImplicitData> {
+public:
+ using acc::impl::ACCImplicitDataBase<ACCImplicitData>::ACCImplicitDataBase;
+
+ void runOnOperation() override;
+
+private:
+ /// Looks through the `dominatingDataClauses` to find the original data clause
+ /// op for an alias. Returns nullptr if no original data clause op is found.
+ template <typename OpT>
+ Operation *getOriginalDataClauseOpForAlias(
+ Value var, OpBuilder &builder, OpT computeConstructOp,
+ const SmallVector<Value> &dominatingDataClauses);
+
+ /// Generates the appropriate `acc.copyin`, `acc.present`,`acc.firstprivate`,
+ /// etc. data clause op for a candidate variable.
+ template <typename OpT>
+ Operation *generateDataClauseOpForCandidate(
+ Value var, ModuleOp &module, OpBuilder &builder, OpT computeConstructOp,
+ const SmallVector<Value> &dominatingDataClauses,
+ const std::optional<acc::ClauseDefaultValue> &defaultClause);
+
+ /// Generates the implicit data ops for a compute construct.
+ template <typename OpT>
+ void generateImplicitDataOps(
+ ModuleOp &module, OpT computeConstructOp,
+ std::optional<acc::ClauseDefaultValue> &defaultClause);
+
+ /// Generates a private recipe for a variable.
+ acc::PrivateRecipeOp generatePrivateRecipe(ModuleOp &module, Value var,
+ Location loc, OpBuilder &builder,
+ acc::OpenACCSupport &accSupport);
+
+ /// Generates a firstprivate recipe for a variable.
+ acc::FirstprivateRecipeOp
+ generateFirstprivateRecipe(ModuleOp &module, Value var, Location loc,
+ OpBuilder &builder,
+ acc::OpenACCSupport &accSupport);
+
+ /// Generates recipes for a list of variables.
+ void generateRecipes(ModuleOp &module, OpBuilder &builder,
+ Operation *computeConstructOp,
+ const SmallVector<Value> &newOperands);
+};
+
+/// Determines if a variable is a candidate for implicit data mapping.
+/// Returns true if the variable is a candidate, false otherwise.
+static bool isCandidateForImplicitData(Value val, Region &accRegion) {
+ // Ensure the variable is an allowed type for data clause.
+ if (!acc::isPointerLikeType(val.getType()) &&
+ !acc::isMappableType(val.getType()))
+ return false;
+
+ // If this is already coming from a data clause, we do not need to generate
+ // another.
+ if (isa_and_nonnull<ACC_DATA_ENTRY_OPS>(val.getDefiningOp()))
+ return false;
+
+ // If this is only used by private clauses, it is not a real live-in.
+ if (acc::isOnlyUsedByPrivateClauses(val, accRegion))
+ return false;
+
+ return true;
+}
+
+template <typename OpT>
+Operation *ACCImplicitData::getOriginalDataClauseOpForAlias(
+ Value var, OpBuilder &builder, OpT computeConstructOp,
+ const SmallVector<Value> &dominatingDataClauses) {
+ auto &aliasAnalysis = this->getAnalysis<AliasAnalysis>();
+ for (auto dataClause : dominatingDataClauses) {
+ if (auto *dataClauseOp = dataClause.getDefiningOp()) {
+ // Only accept clauses that guarantee that the alias is present.
+ if (isa<acc::CopyinOp, acc::CreateOp, acc::PresentOp, acc::NoCreateOp,
+ acc::DevicePtrOp>(dataClauseOp))
+ if (aliasAnalysis.alias(acc::getVar(dataClauseOp), var).isMust())
+ return dataClauseOp;
+ }
+ }
+ return nullptr;
+}
+
+// Generates bounds for variables that have unknown dimensions
+static void fillInBoundsForUnknownDimensions(Operation *dataClauseOp,
+ OpBuilder &builder) {
+
+ if (!acc::getBounds(dataClauseOp).empty())
+ // If bounds are already present, do not overwrite them.
+ return;
+
+ // For types that have unknown dimensions, attempt to generate bounds by
+ // relying on MappableType being able to extract it from the IR.
+ auto var = acc::getVar(dataClauseOp);
+ auto type = var.getType();
+ if (auto mappableTy = dyn_cast<acc::MappableType>(type)) {
+ if (mappableTy.hasUnknownDimensions()) {
+ TypeSwitch<Operation *>(dataClauseOp)
+ .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClauseOp) {
+ if (std::is_same_v<decltype(dataClauseOp), acc::DevicePtrOp>)
+ return;
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPoint(dataClauseOp);
+ auto bounds = mappableTy.generateAccBounds(var, builder);
+ if (!bounds.empty())
+ dataClauseOp.getBoundsMutable().assign(bounds);
+ });
+ }
+ }
+}
+
+acc::PrivateRecipeOp
+ACCImplicitData::generatePrivateRecipe(ModuleOp &module, Value var,
+ Location loc, OpBuilder &builder,
+ acc::OpenACCSupport &accSupport) {
+ auto type = var.getType();
+ std::string recipeName =
+ accSupport.getRecipeName(acc::RecipeKind::private_recipe, type, var);
+
+ // Check if recipe already exists
+ auto existingRecipe = module.lookupSymbol<acc::PrivateRecipeOp>(recipeName);
+ if (existingRecipe)
+ return existingRecipe;
+
+ // Set insertion point to module body in a scoped way
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(module.getBody());
+
+ auto recipe =
+ acc::PrivateRecipeOp::createAndPopulate(builder, loc, recipeName, type);
+ if (!recipe.has_value())
+ return accSupport.emitNYI(loc, "implicit private"), nullptr;
+ return recipe.value();
+}
+
+acc::FirstprivateRecipeOp
+ACCImplicitData::generateFirstprivateRecipe(ModuleOp &module, Value var,
+ Location loc, OpBuilder &builder,
+ acc::OpenACCSupport &accSupport) {
+ auto type = var.getType();
+ std::string recipeName =
+ accSupport.getRecipeName(acc::RecipeKind::firstprivate_recipe, type, var);
+
+ // Check if recipe already exists
+ auto existingRecipe =
+ module.lookupSymbol<acc::FirstprivateRecipeOp>(recipeName);
+ if (existingRecipe)
+ return existingRecipe;
+
+ // Set insertion point to module body in a scoped way
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(module.getBody());
+
+ auto recipe = acc::FirstprivateRecipeOp::createAndPopulate(builder, loc,
+ recipeName, type);
+ if (!recipe.has_value())
+ return accSupport.emitNYI(loc, "implicit firstprivate"), nullptr;
+ return recipe.value();
+}
+
+void ACCImplicitData::generateRecipes(ModuleOp &module, OpBuilder &builder,
+ Operation *computeConstructOp,
+ const SmallVector<Value> &newOperands) {
+ auto &accSupport = this->getAnalysis<acc::OpenACCSupport>();
+ for (auto var : newOperands) {
+ auto loc{var.getLoc()};
+ if (auto privateOp = dyn_cast<acc::PrivateOp>(var.getDefiningOp())) {
+ auto recipe = generatePrivateRecipe(
+ module, acc::getVar(var.getDefiningOp()), loc, builder, accSupport);
+ if (recipe)
+ privateOp.setRecipeAttr(
+ SymbolRefAttr::get(module->getContext(), recipe.getSymName()));
+ } else if (auto firstprivateOp =
+ dyn_cast<acc::FirstprivateOp>(var.getDefiningOp())) {
+ auto recipe = generateFirstprivateRecipe(
+ module, acc::getVar(var.getDefiningOp()), loc, builder, accSupport);
+ if (recipe)
+ firstprivateOp.setRecipeAttr(SymbolRefAttr::get(
+ module->getContext(), recipe.getSymName().str()));
+ } else {
+ accSupport.emitNYI(var.getLoc(), "implicit reduction");
+ }
+ }
+}
+
+// Generates the data entry data op clause so that it adheres to OpenACC
+// rules as follows (line numbers and specification from OpenACC 3.4):
+// 1388 An aggregate variable will be treated as if it appears either:
+// 1389 - In a present clause if there is a default(present) clause visible at
+// the compute construct.
+// 1391 - In a copy clause otherwise.
+// 1392 A scalar variable will be treated as if it appears either:
+// 1393 - In a copy clause if the compute construct is a kernels construct.
+// 1394 - In a firstprivate clause otherwise.
+template <typename OpT>
+Operation *ACCImplicitData::generateDataClauseOpForCandidate(
+ Value var, ModuleOp &module, OpBuilder &builder, OpT computeConstructOp,
+ const SmallVector<Value> &dominatingDataClauses,
+ const std::optional<acc::ClauseDefaultValue> &defaultClause) {
+ auto &accSupport = this->getAnalysis<acc::OpenACCSupport>();
+ acc::VariableTypeCategory typeCategory =
+ acc::VariableTypeCategory::uncategorized;
+ if (auto mappableTy = dyn_cast<acc::MappableType>(var.getType())) {
+ typeCategory = mappableTy.getTypeCategory(var);
+ } else if (auto pointerLikeTy =
+ dyn_cast<acc::PointerLikeType>(var.getType())) {
+ typeCategory = pointerLikeTy.getPointeeTypeCategory(
+ cast<TypedValue<acc::PointerLikeType>>(var),
+ pointerLikeTy.getElementType());
+ }
+
+ bool isScalar =
+ acc::bitEnumContainsAny(typeCategory, acc::VariableTypeCategory::scalar);
+ bool isAnyAggregate = acc::bitEnumContainsAny(
+ typeCategory, acc::VariableTypeCategory::aggregate);
+ Location loc = computeConstructOp->getLoc();
+
+ Operation *op = nullptr;
+ op = getOriginalDataClauseOpForAlias(var, builder, computeConstructOp,
+ dominatingDataClauses);
+ if (op) {
+ if (isa<acc::NoCreateOp>(op))
+ return acc::NoCreateOp::create(builder, loc, var,
+ /*structured=*/true, /*implicit=*/true,
+ accSupport.getVariableName(var),
+ acc::getBounds(op));
+
+ if (isa<acc::DevicePtrOp>(op))
+ return acc::DevicePtrOp::create(builder, loc, var,
+ /*structured=*/true, /*implicit=*/true,
+ accSupport.getVariableName(var),
+ acc::getBounds(op));
+
+ // The original data clause op is a PresentOp, CopyinOp, or CreateOp,
+ // hence guaranteed to be present.
+ return acc::PresentOp::create(builder, loc, var,
+ /*structured=*/true, /*implicit=*/true,
+ accSupport.getVariableName(var),
+ acc::getBounds(op));
+ } else if (isScalar) {
+ if (enableImplicitReductionCopy &&
+ acc::isOnlyUsedByReductionClauses(var,
+ computeConstructOp->getRegion(0))) {
+ auto copyinOp =
+ acc::CopyinOp::create(builder, loc, var,
+ /*structured=*/true, /*implicit=*/true,
+ accSupport.getVariableName(var));
+ copyinOp.setDataClause(acc::DataClause::acc_reduction);
+ return copyinOp.getOperation();
+ }
+ if constexpr (std::is_same_v<OpT, acc::KernelsOp> ||
+ std::is_same_v<OpT, acc::KernelEnvironmentOp>) {
+ // Scalars are implicit copyin in kernels construct.
+ // We also do the same for acc.kernel_environment because semantics
+ // of user variable mappings should be applied while ACC construct exists
+ // and at this point we should only be dealing with unmapped variables
+ // that were made live-in by the compiler.
+ // TODO: This may be revisited.
+ auto copyinOp =
+ acc::CopyinOp::create(builder, loc, var,
+ /*structured=*/true, /*implicit=*/true,
+ accSupport.getVariableName(var));
+ copyinOp.setDataClause(acc::DataClause::acc_copy);
+ return copyinOp.getOperation();
+ } else {
+ // Scalars are implicit firstprivate in parallel and serial construct.
+ return acc::FirstprivateOp::create(builder, loc, var,
+ /*structured=*/true, /*implicit=*/true,
+ accSupport.getVariableName(var));
+ }
+ } else if (isAnyAggregate) {
+ Operation *newDataOp = nullptr;
+
+ // When default(present) is true, the implicit behavior is present.
+ if (defaultClause.has_value() &&
+ defaultClause.value() == acc::ClauseDefaultValue::Present) {
+ newDataOp = acc::PresentOp::create(builder, loc, var,
+ /*structured=*/true, /*implicit=*/true,
+ accSupport.getVariableName(var));
+ newDataOp->setAttr(acc::getFromDefaultClauseAttrName(),
+ builder.getUnitAttr());
+ } else {
+ auto copyinOp =
+ acc::CopyinOp::create(builder, loc, var,
+ /*structured=*/true, /*implicit=*/true,
+ accSupport.getVariableName(var));
+ copyinOp.setDataClause(acc::DataClause::acc_copy);
+ newDataOp = copyinOp.getOperation();
+ }
+
+ return newDataOp;
+ } else {
+ // This is not a fatal error - for example when the element type is
+ // pointer type (aka we have a pointer of pointer), it is potentially a
+ // deep copy scenario which is not being handled here.
+ // Other types need to be canonicalized. Thus just log unhandled cases.
+ LLVM_DEBUG(llvm::dbgs()
+ << "Unhandled case for implicit data mapping " << var << "\n");
+ }
+ return nullptr;
+}
+
+// Ensures that result values from the acc data clause ops are used inside the
+// acc region. ie:
+// acc.kernels {
+// use %val
+// }
+// =>
+// %dev = acc.dataop %val
+// acc.kernels {
+// use %dev
+// }
+static void legalizeValuesInRegion(Region &accRegion,
+ SmallVector<Value> &newPrivateOperands,
+ SmallVector<Value> &newDataClauseOperands) {
+ for (Value dataClause :
+ llvm::concat<Value>(newDataClauseOperands, newPrivateOperands)) {
+ Value var = acc::getVar(dataClause.getDefiningOp());
+ replaceAllUsesInRegionWith(var, dataClause, accRegion);
+ }
+}
+
+// Adds the private operands to the compute construct operation.
+template <typename OpT>
+static void addNewPrivateOperands(OpT &accOp,
+ const SmallVector<Value> &privateOperands) {
+ if (privateOperands.empty())
+ return;
+
+ for (auto priv : privateOperands) {
+ if (isa<acc::PrivateOp>(priv.getDefiningOp())) {
+ accOp.getPrivateOperandsMutable().append(priv);
+ } else if (isa<acc::FirstprivateOp>(priv.getDefiningOp())) {
+ accOp.getFirstprivateOperandsMutable().append(priv);
+ } else {
+ llvm_unreachable("unhandled reduction operand");
+ }
+ }
+}
+
+static Operation *findDataExitOp(Operation *dataEntryOp) {
+ auto res = acc::getAccVar(dataEntryOp);
+ for (auto *user : res.getUsers())
+ if (isa<ACC_DATA_EXIT_OPS>(user))
+ return user;
+ return nullptr;
+}
+
+// Generates matching data exit operation as described in the acc dialect
+// for how data clauses are decomposed:
+// https://mlir.llvm.org/docs/Dialects/OpenACCDialect/#operation-categories
+// Key ones used here:
+// * acc {construct} copy -> acc.copyin (before region) + acc.copyout (after
+// region)
+// * acc {construct} present -> acc.present (before region) + acc.delete
+// (after region)
+static void
+generateDataExitOperations(OpBuilder &builder, Operation *accOp,
+ const SmallVector<Value> &newDataClauseOperands,
+ const SmallVector<Value> &sortedDataClauseOperands) {
+ builder.setInsertionPointAfter(accOp);
+ Value lastDataClause = nullptr;
+ for (auto dataEntry : llvm::reverse(sortedDataClauseOperands)) {
+ if (llvm::find(newDataClauseOperands, dataEntry) ==
+ newDataClauseOperands.end()) {
+ // If this is not a new data clause operand, we should not generate an
+ // exit operation for it.
+ lastDataClause = dataEntry;
+ continue;
+ }
+ if (lastDataClause)
+ if (auto *dataExitOp = findDataExitOp(lastDataClause.getDefiningOp()))
+ builder.setInsertionPointAfter(dataExitOp);
+ Operation *dataEntryOp = dataEntry.getDefiningOp();
+ if (isa<acc::CopyinOp>(dataEntryOp)) {
+ auto copyoutOp = acc::CopyoutOp::create(
+ builder, dataEntryOp->getLoc(), dataEntry, acc::getVar(dataEntryOp),
+ /*structured=*/true, /*implicit=*/true,
+ acc::getVarName(dataEntryOp).value(), acc::getBounds(dataEntryOp));
+ copyoutOp.setDataClause(acc::DataClause::acc_copy);
+ } else if (isa<acc::PresentOp, acc::NoCreateOp>(dataEntryOp)) {
+ auto deleteOp = acc::DeleteOp::create(
+ builder, dataEntryOp->getLoc(), dataEntry,
+ /*structured=*/true, /*implicit=*/true,
+ acc::getVarName(dataEntryOp).value(), acc::getBounds(dataEntryOp));
+ deleteOp.setDataClause(acc::getDataClause(dataEntryOp).value());
+ } else if (isa<acc::DevicePtrOp>(dataEntryOp)) {
+ // Do nothing.
+ } else {
+ llvm_unreachable("unhandled data exit");
+ }
+ lastDataClause = dataEntry;
+ }
+}
+
+/// Returns all base references of a value in order.
+/// So for example, if we have a reference to a struct field like
+/// s.f1.f2.f3, this will return <s, s.f1, s.f1.f2, s.f1.f2.f3>.
+/// Any intermediate casts/view-like operations are included in the
+/// chain as well.
+static SmallVector<Value> getBaseRefsChain(Value val) {
+ SmallVector<Value> baseRefs;
+ baseRefs.push_back(val);
+ while (true) {
+ Value prevVal = val;
+
+ val = acc::getBaseEntity(val);
+ if (val != baseRefs.front())
+ baseRefs.insert(baseRefs.begin(), val);
+
+ // If this is a view-like operation, it is effectively another
+ // view of the same entity so we should add it to the chain also.
+ if (auto viewLikeOp = val.getDefiningOp<ViewLikeOpInterface>()) {
+ val = viewLikeOp.getViewSource();
+ baseRefs.insert(baseRefs.begin(), val);
+ }
+
+ // Continue loop if we made any progress
+ if (val == prevVal)
+ break;
+ }
+
+ return baseRefs;
+}
+
+static void insertInSortedOrder(SmallVector<Value> &sortedDataClauseOperands,
+ Operation *newClause) {
+ auto *insertPos =
+ std::find_if(sortedDataClauseOperands.begin(),
+ sortedDataClauseOperands.end(), [&](Value dataClauseVal) {
+ // Get the base refs for the current clause we are looking
+ // at.
+ auto var = acc::getVar(dataClauseVal.getDefiningOp());
+ auto baseRefs = getBaseRefsChain(var);
+
+ // If the newClause is of a base ref of an existing clause,
+ // we should insert it right before the current clause.
+ // Thus return true to stop iteration when this is the
+ // case.
+ return std::find(baseRefs.begin(), baseRefs.end(),
+ acc::getVar(newClause)) != baseRefs.end();
+ });
+
+ if (insertPos != sortedDataClauseOperands.end()) {
+ newClause->moveBefore(insertPos->getDefiningOp());
+ sortedDataClauseOperands.insert(insertPos, acc::getAccVar(newClause));
+ } else {
+ sortedDataClauseOperands.push_back(acc::getAccVar(newClause));
+ }
+}
+
+template <typename OpT>
+void ACCImplicitData::generateImplicitDataOps(
+ ModuleOp &module, OpT computeConstructOp,
+ std::optional<acc::ClauseDefaultValue> &defaultClause) {
+ // Implicit data attributes are only applied if "[t]here is no default(none)
+ // clause visible at the compute construct."
+ if (defaultClause.has_value() &&
+ defaultClause.value() == acc::ClauseDefaultValue::None)
+ return;
+ assert(!defaultClause.has_value() ||
+ defaultClause.value() == acc::ClauseDefaultValue::Present);
+
+ // 1) Collect live-in values.
+ Region &accRegion = computeConstructOp->getRegion(0);
+ SetVector<Value> liveInValues;
+ getUsedValuesDefinedAbove(accRegion, liveInValues);
+
+ // 2) Run the filtering to find relevant pointers that need copied.
+ auto isCandidate{[&](Value val) -> bool {
+ return isCandidateForImplicitData(val, accRegion);
+ }};
+ auto candidateVars(
+ llvm::to_vector(llvm::make_filter_range(liveInValues, isCandidate)));
+ if (candidateVars.empty())
+ return;
+
+ // 3) Generate data clauses for the variables.
+ SmallVector<Value> newPrivateOperands;
+ SmallVector<Value> newDataClauseOperands;
+ OpBuilder builder(computeConstructOp);
+ if (!candidateVars.empty()) {
+ LLVM_DEBUG(llvm::dbgs() << "== Generating clauses for ==\n"
+ << computeConstructOp << "\n");
+ }
+ auto &domInfo = this->getAnalysis<DominanceInfo>();
+ auto &postDomInfo = this->getAnalysis<PostDominanceInfo>();
+ auto dominatingDataClauses =
+ acc::getDominatingDataClauses(computeConstructOp, domInfo, postDomInfo);
+ for (auto var : candidateVars) {
+ auto newDataClauseOp = generateDataClauseOpForCandidate(
+ var, module, builder, computeConstructOp, dominatingDataClauses,
+ defaultClause);
+ fillInBoundsForUnknownDimensions(newDataClauseOp, builder);
+ LLVM_DEBUG(llvm::dbgs() << "Generated data clause for " << var << ":\n"
+ << "\t" << *newDataClauseOp << "\n");
+ if (isa_and_nonnull<acc::PrivateOp, acc::FirstprivateOp, acc::ReductionOp>(
+ newDataClauseOp)) {
+ newPrivateOperands.push_back(acc::getAccVar(newDataClauseOp));
+ } else if (isa_and_nonnull<ACC_DATA_CLAUSE_OPS>(newDataClauseOp)) {
+ newDataClauseOperands.push_back(acc::getAccVar(newDataClauseOp));
+ dominatingDataClauses.push_back(acc::getAccVar(newDataClauseOp));
+ }
+ }
+
+ // 4) Legalize values in region (aka the uses in the region are the result
+ // of the data clause ops)
+ legalizeValuesInRegion(accRegion, newPrivateOperands, newDataClauseOperands);
+
+ // 5) Generate private recipes which are required for properly attaching
+ // private operands.
+ if constexpr (!std::is_same_v<OpT, acc::KernelsOp> &&
+ !std::is_same_v<OpT, acc::KernelEnvironmentOp>)
+ generateRecipes(module, builder, computeConstructOp, newPrivateOperands);
+
+ // 6) Figure out insertion order for the new data clause operands.
+ SmallVector<Value> sortedDataClauseOperands(
+ computeConstructOp.getDataClauseOperands());
+ for (auto newClause : newDataClauseOperands)
+ insertInSortedOrder(sortedDataClauseOperands, newClause.getDefiningOp());
+
+ // 7) Generate the data exit operations.
+ generateDataExitOperations(builder, computeConstructOp, newDataClauseOperands,
+ sortedDataClauseOperands);
+ // 8) Add all of the new operands to the compute construct op.
+ if constexpr (!std::is_same_v<OpT, acc::KernelsOp> &&
+ !std::is_same_v<OpT, acc::KernelEnvironmentOp>)
+ addNewPrivateOperands(computeConstructOp, newPrivateOperands);
+ computeConstructOp.getDataClauseOperandsMutable().assign(
+ sortedDataClauseOperands);
+}
+
+void ACCImplicitData::runOnOperation() {
+ ModuleOp module = this->getOperation();
+ module.walk([&](Operation *op) {
+ if (isa<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>(op)) {
+ assert(op->getNumRegions() == 1 && "must have 1 region");
+
+ auto defaultClause = acc::getDefaultAttr(op);
+ llvm::TypeSwitch<Operation *, void>(op)
+ .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>(
+ [&](auto op) {
+ generateImplicitDataOps(module, op, defaultClause);
+ })
+ .Default([&](Operation *) {});
+ }
+ });
+}
+
+} // namespace
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp
new file mode 100644
index 0000000..8cab223
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp
@@ -0,0 +1,431 @@
+//===- ACCImplicitDeclare.cpp ---------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass applies implicit `acc declare` actions to global variables
+// referenced in OpenACC compute regions and routine functions.
+//
+// Overview:
+// ---------
+// Global references in an acc regions (for globals not marked with `acc
+// declare` by the user) can be handled in one of two ways:
+// - Mapped through data clauses
+// - Implicitly marked as `acc declare` (this pass)
+//
+// Thus, the OpenACC specification focuses solely on implicit data mapping rules
+// whose implementation is captured in `ACCImplicitData` pass.
+//
+// However, it is both advantageous and required for certain cases to
+// use implicit `acc declare` instead:
+// - Any functions that are implicitly marked as `acc routine` through
+// `ACCImplicitRoutine` may reference globals. Since data mapping
+// is only possible for compute regions, such globals can only be
+// made available on device through `acc declare`.
+// - Compiler can generate and use globals for cases needed in IR
+// representation such as type descriptors or various names needed for
+// runtime calls and error reporting - such cases often are introduced
+// after a frontend semantic checking is done since it is related to
+// implementation detail. Thus, such compiler generated globals would
+// not have been visible for a user to mark with `acc declare`.
+// - Constant globals such as filename strings or data initialization values
+// are values that do not get mutated but are still needed for appropriate
+// runtime execution. If a kernel is launched 1000 times, it is not a
+// good idea to map such a global 1000 times. Therefore, such globals
+// benefit from being marked with `acc declare`.
+//
+// This pass automatically
+// marks global variables with the `acc.declare` attribute when they are
+// referenced in OpenACC compute constructs or routine functions and meet
+// the criteria noted above, ensuring
+// they are properly handled for device execution.
+//
+// The pass performs two main optimizations:
+//
+// 1. Hoisting: For non-constant globals referenced in compute regions, the
+// pass hoists the address-of operation out of the region when possible,
+// allowing them to be implicitly mapped through normal data clause
+// mechanisms rather than requiring declare marking.
+//
+// 2. Declaration: For globals that must be available on the device (constants,
+// globals in routines, globals in recipe operations), the pass adds the
+// `acc.declare` attribute with the copyin data clause.
+//
+// Requirements:
+// -------------
+// To use this pass in a pipeline, the following requirements must be met:
+//
+// 1. Operation Interface Implementation: Operations that compute addresses
+// of global variables must implement the `acc::AddressOfGlobalOpInterface`
+// and those that represent globals must implement the
+// `acc::GlobalOpInterface`. Additionally, any operations that indirectly
+// access globals must implement the `acc::IndirectGlobalAccessOpInterface`.
+//
+// 2. Analysis Registration (Optional): If custom behavior is needed for
+// determining if a symbol use is valid within GPU regions, the dialect
+// should pre-register the `acc::OpenACCSupport` analysis.
+//
+// Examples:
+// ---------
+//
+// Example 1: Non-constant global in compute region (hoisted)
+//
+// Before:
+// memref.global @g_scalar : memref<f32> = dense<0.0>
+// func.func @test() {
+// acc.serial {
+// %addr = memref.get_global @g_scalar : memref<f32>
+// %val = memref.load %addr[] : memref<f32>
+// acc.yield
+// }
+// }
+//
+// After:
+// memref.global @g_scalar : memref<f32> = dense<0.0>
+// func.func @test() {
+// %addr = memref.get_global @g_scalar : memref<f32>
+// acc.serial {
+// %val = memref.load %addr[] : memref<f32>
+// acc.yield
+// }
+// }
+//
+// Example 2: Constant global in compute region (declared)
+//
+// Before:
+// memref.global constant @g_const : memref<f32> = dense<1.0>
+// func.func @test() {
+// acc.serial {
+// %addr = memref.get_global @g_const : memref<f32>
+// %val = memref.load %addr[] : memref<f32>
+// acc.yield
+// }
+// }
+//
+// After:
+// memref.global constant @g_const : memref<f32> = dense<1.0>
+// {acc.declare = #acc.declare<dataClause = acc_copyin>}
+// func.func @test() {
+// acc.serial {
+// %addr = memref.get_global @g_const : memref<f32>
+// %val = memref.load %addr[] : memref<f32>
+// acc.yield
+// }
+// }
+//
+// Example 3: Global in acc routine (declared)
+//
+// Before:
+// memref.global @g_data : memref<f32> = dense<0.0>
+// acc.routine @routine_0 func(@device_func)
+// func.func @device_func() attributes {acc.routine_info = ...} {
+// %addr = memref.get_global @g_data : memref<f32>
+// %val = memref.load %addr[] : memref<f32>
+// }
+//
+// After:
+// memref.global @g_data : memref<f32> = dense<0.0>
+// {acc.declare = #acc.declare<dataClause = acc_copyin>}
+// acc.routine @routine_0 func(@device_func)
+// func.func @device_func() attributes {acc.routine_info = ...} {
+// %addr = memref.get_global @g_data : memref<f32>
+// %val = memref.load %addr[] : memref<f32>
+// }
+//
+// Example 4: Global in private recipe (declared if recipe is used)
+//
+// Before:
+// memref.global @g_init : memref<f32> = dense<0.0>
+// acc.private.recipe @priv_recipe : memref<f32> init {
+// ^bb0(%arg0: memref<f32>):
+// %alloc = memref.alloc() : memref<f32>
+// %global = memref.get_global @g_init : memref<f32>
+// %val = memref.load %global[] : memref<f32>
+// memref.store %val, %alloc[] : memref<f32>
+// acc.yield %alloc : memref<f32>
+// } destroy { ... }
+// func.func @test() {
+// %var = memref.alloc() : memref<f32>
+// %priv = acc.private varPtr(%var : memref<f32>)
+// recipe(@priv_recipe) -> memref<f32>
+// acc.parallel private(%priv : memref<f32>) { ... }
+// }
+//
+// After:
+// memref.global @g_init : memref<f32> = dense<0.0>
+// {acc.declare = #acc.declare<dataClause = acc_copyin>}
+// acc.private.recipe @priv_recipe : memref<f32> init {
+// ^bb0(%arg0: memref<f32>):
+// %alloc = memref.alloc() : memref<f32>
+// %global = memref.get_global @g_init : memref<f32>
+// %val = memref.load %global[] : memref<f32>
+// memref.store %val, %alloc[] : memref<f32>
+// acc.yield %alloc : memref<f32>
+// } destroy { ... }
+// func.func @test() {
+// %var = memref.alloc() : memref<f32>
+// %priv = acc.private varPtr(%var : memref<f32>)
+// recipe(@priv_recipe) -> memref<f32>
+// acc.parallel private(%priv : memref<f32>) { ... }
+// }
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+
+#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+namespace acc {
+#define GEN_PASS_DEF_ACCIMPLICITDECLARE
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+} // namespace acc
+} // namespace mlir
+
+#define DEBUG_TYPE "acc-implicit-declare"
+
+using namespace mlir;
+
+namespace {
+
+using GlobalOpSetT = llvm::SmallSetVector<Operation *, 16>;
+
+/// Checks whether a use of the requested `globalOp` should be considered
+/// for hoisting out of acc region due to avoid `acc declare`ing something
+/// that instead should be implicitly mapped.
+static bool isGlobalUseCandidateForHoisting(Operation *globalOp,
+ Operation *user,
+ SymbolRefAttr symbol,
+ acc::OpenACCSupport &accSupport) {
+ // This symbol is valid in GPU region. This means semantics
+ // would change if moved to host - therefore it is not a candidate.
+ if (accSupport.isValidSymbolUse(user, symbol))
+ return false;
+
+ bool isConstant = false;
+ bool isFunction = false;
+
+ if (auto globalVarOp = dyn_cast<acc::GlobalVariableOpInterface>(globalOp))
+ isConstant = globalVarOp.isConstant();
+
+ if (isa<FunctionOpInterface>(globalOp))
+ isFunction = true;
+
+ // Constants should be kept in device code to ensure they are duplicated.
+ // Function references should be kept in device code to ensure their device
+ // addresses are computed. Everything else should be hoisted since we already
+ // proved they are not valid symbols in GPU region.
+ return !isConstant && !isFunction;
+}
+
+/// Checks whether it is valid to use acc.declare marking on the global.
+bool isValidForAccDeclare(Operation *globalOp) {
+ // For functions - we use acc.routine marking instead.
+ return !isa<FunctionOpInterface>(globalOp);
+}
+
+/// Checks whether a recipe operation has meaningful use of its symbol that
+/// justifies processing its regions for global references. Returns false if:
+/// 1. The recipe has no symbol uses at all, or
+/// 2. The only symbol use is the recipe's own symbol definition
+template <typename RecipeOpT>
+static bool hasRelevantRecipeUse(RecipeOpT &recipeOp, ModuleOp &mod) {
+ std::optional<SymbolTable::UseRange> symbolUses = recipeOp.getSymbolUses(mod);
+
+ // No recipe symbol uses.
+ if (!symbolUses.has_value() || symbolUses->empty())
+ return false;
+
+ // If more than one use, assume it's used.
+ auto begin = symbolUses->begin();
+ auto end = symbolUses->end();
+ if (begin != end && std::next(begin) != end)
+ return true;
+
+ // If single use, check if the use is the recipe itself.
+ const SymbolTable::SymbolUse &use = *symbolUses->begin();
+ return use.getUser() != recipeOp.getOperation();
+}
+
+// Hoists addr_of operations for non-constant globals out of OpenACC regions.
+// This way - they are implicitly mapped instead of being considered for
+// implicit declare.
+template <typename AccConstructT>
+static void hoistNonConstantDirectUses(AccConstructT accOp,
+ acc::OpenACCSupport &accSupport) {
+ accOp.walk([&](acc::AddressOfGlobalOpInterface addrOfOp) {
+ SymbolRefAttr symRef = addrOfOp.getSymbol();
+ if (symRef) {
+ Operation *globalOp =
+ SymbolTable::lookupNearestSymbolFrom(addrOfOp, symRef);
+ if (isGlobalUseCandidateForHoisting(globalOp, addrOfOp, symRef,
+ accSupport)) {
+ addrOfOp->moveBefore(accOp);
+ LLVM_DEBUG(
+ llvm::dbgs() << "Hoisted:\n\t" << addrOfOp << "\n\tfrom:\n\t";
+ accOp->print(llvm::dbgs(),
+ OpPrintingFlags{}.skipRegions().enableDebugInfo());
+ llvm::dbgs() << "\n");
+ }
+ }
+ });
+}
+
+// Collects the globals referenced in a device region
+static void collectGlobalsFromDeviceRegion(Region &region,
+ GlobalOpSetT &globals,
+ acc::OpenACCSupport &accSupport,
+ SymbolTable &symTab) {
+ region.walk([&](Operation *op) {
+ // 1) Only consider relevant operations which use symbols
+ auto addrOfOp = dyn_cast<acc::AddressOfGlobalOpInterface>(op);
+ if (addrOfOp) {
+ SymbolRefAttr symRef = addrOfOp.getSymbol();
+ // 2) Found an operation which uses the symbol. Next determine if it
+ // is a candidate for `acc declare`. Some of the criteria considered
+ // is whether this symbol is not already a device one (either because
+ // acc declare is already used or this is a CUF global).
+ Operation *globalOp = nullptr;
+ bool isCandidate = !accSupport.isValidSymbolUse(op, symRef, &globalOp);
+ // 3) Add the candidate to the set of globals to be `acc declare`d.
+ if (isCandidate && globalOp && isValidForAccDeclare(globalOp))
+ globals.insert(globalOp);
+ } else if (auto indirectAccessOp =
+ dyn_cast<acc::IndirectGlobalAccessOpInterface>(op)) {
+ // Process operations that indirectly access globals
+ llvm::SmallVector<SymbolRefAttr> symbols;
+ indirectAccessOp.getReferencedSymbols(symbols, &symTab);
+ for (SymbolRefAttr symRef : symbols)
+ if (Operation *globalOp = symTab.lookup(symRef.getLeafReference()))
+ if (isValidForAccDeclare(globalOp))
+ globals.insert(globalOp);
+ }
+ });
+}
+
+// Adds the declare attribute to the operation `op`.
+static void addDeclareAttr(MLIRContext *context, Operation *op,
+ acc::DataClause clause) {
+ op->setAttr(acc::getDeclareAttrName(),
+ acc::DeclareAttr::get(context,
+ acc::DataClauseAttr::get(context, clause)));
+}
+
+// This pass applies implicit declare actions for globals referenced in
+// OpenACC compute and routine regions.
+class ACCImplicitDeclare
+ : public acc::impl::ACCImplicitDeclareBase<ACCImplicitDeclare> {
+public:
+ using ACCImplicitDeclareBase<ACCImplicitDeclare>::ACCImplicitDeclareBase;
+
+ void runOnOperation() override {
+ ModuleOp mod = getOperation();
+ MLIRContext *context = &getContext();
+ acc::OpenACCSupport &accSupport = getAnalysis<acc::OpenACCSupport>();
+
+ // 1) Start off by hoisting any AddressOf operations out of acc region
+ // for any cases we do not want to `acc declare`. This is because we can
+ // rely on implicit data mapping in majority of cases without uselessly
+ // polluting the device globals.
+ mod.walk([&](Operation *op) {
+ TypeSwitch<Operation *, void>(op)
+ .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>(
+ [&](auto accOp) {
+ hoistNonConstantDirectUses(accOp, accSupport);
+ });
+ });
+
+ // 2) Collect global symbols which need to be `acc declare`d. Do it for
+ // compute regions, acc routine, and existing globals with the declare
+ // attribute.
+ SymbolTable symTab(mod);
+ GlobalOpSetT globalsToAccDeclare;
+ mod.walk([&](Operation *op) {
+ TypeSwitch<Operation *, void>(op)
+ .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>(
+ [&](auto accOp) {
+ collectGlobalsFromDeviceRegion(
+ accOp.getRegion(), globalsToAccDeclare, accSupport, symTab);
+ })
+ .Case<FunctionOpInterface>([&](auto func) {
+ if ((acc::isAccRoutine(func) ||
+ acc::isSpecializedAccRoutine(func)) &&
+ !func.isExternal())
+ collectGlobalsFromDeviceRegion(func.getFunctionBody(),
+ globalsToAccDeclare, accSupport,
+ symTab);
+ })
+ .Case<acc::GlobalVariableOpInterface>([&](auto globalVarOp) {
+ if (globalVarOp->getAttr(acc::getDeclareAttrName()))
+ if (Region *initRegion = globalVarOp.getInitRegion())
+ collectGlobalsFromDeviceRegion(*initRegion, globalsToAccDeclare,
+ accSupport, symTab);
+ })
+ .Case<acc::PrivateRecipeOp>([&](auto privateRecipe) {
+ if (hasRelevantRecipeUse(privateRecipe, mod)) {
+ collectGlobalsFromDeviceRegion(privateRecipe.getInitRegion(),
+ globalsToAccDeclare, accSupport,
+ symTab);
+ collectGlobalsFromDeviceRegion(privateRecipe.getDestroyRegion(),
+ globalsToAccDeclare, accSupport,
+ symTab);
+ }
+ })
+ .Case<acc::FirstprivateRecipeOp>([&](auto firstprivateRecipe) {
+ if (hasRelevantRecipeUse(firstprivateRecipe, mod)) {
+ collectGlobalsFromDeviceRegion(firstprivateRecipe.getInitRegion(),
+ globalsToAccDeclare, accSupport,
+ symTab);
+ collectGlobalsFromDeviceRegion(
+ firstprivateRecipe.getDestroyRegion(), globalsToAccDeclare,
+ accSupport, symTab);
+ collectGlobalsFromDeviceRegion(firstprivateRecipe.getCopyRegion(),
+ globalsToAccDeclare, accSupport,
+ symTab);
+ }
+ })
+ .Case<acc::ReductionRecipeOp>([&](auto reductionRecipe) {
+ if (hasRelevantRecipeUse(reductionRecipe, mod)) {
+ collectGlobalsFromDeviceRegion(reductionRecipe.getInitRegion(),
+ globalsToAccDeclare, accSupport,
+ symTab);
+ collectGlobalsFromDeviceRegion(
+ reductionRecipe.getCombinerRegion(), globalsToAccDeclare,
+ accSupport, symTab);
+ }
+ });
+ });
+
+ // 3) Finally, generate the appropriate declare actions needed to ensure
+ // this is considered for device global.
+ for (Operation *globalOp : globalsToAccDeclare) {
+ LLVM_DEBUG(
+ llvm::dbgs() << "Global is being `acc declare copyin`d: ";
+ globalOp->print(llvm::dbgs(),
+ OpPrintingFlags{}.skipRegions().enableDebugInfo());
+ llvm::dbgs() << "\n");
+
+ // Mark it as declare copyin.
+ addDeclareAttr(context, globalOp, acc::DataClause::acc_copyin);
+
+ // TODO: May need to create the global constructor which does the mapping
+ // action. It is not yet clear if this is needed yet (since the globals
+ // might just end up in the GPU image without requiring mapping via
+ // runtime).
+ }
+ }
+};
+
+} // namespace
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp
new file mode 100644
index 0000000..12efaf4
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp
@@ -0,0 +1,237 @@
+//===- ACCImplicitRoutine.cpp - OpenACC Implicit Routine Transform -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass implements the implicit rules described in OpenACC specification
+// for `Routine Directive` (OpenACC 3.4 spec, section 2.15.1).
+//
+// "If no explicit routine directive applies to a procedure whose definition
+// appears in the program unit being compiled, then the implementation applies
+// an implicit routine directive to that procedure if any of the following
+// conditions holds:
+// - The procedure is called or its address is accessed in a compute region."
+//
+// The specification further states:
+// "When the implementation applies an implicit routine directive to a
+// procedure, it must recursively apply implicit routine directives to other
+// procedures for which the above rules specify relevant dependencies. Such
+// dependencies can form a cycle, so the implementation must take care to avoid
+// infinite recursion."
+//
+// This pass implements these requirements by:
+// 1. Walking through all OpenACC compute constructs and functions already
+// marked with `acc routine` in the module and identifying function calls
+// within these regions.
+// 2. Creating implicit `acc.routine` operations for functions that don't
+// already have routine declarations.
+// 3. Recursively walking through all existing `acc routine` and creating
+// implicit routine operations for function calls within these routines,
+// while avoiding infinite recursion through proper tracking.
+//
+// Requirements:
+// -------------
+// To use this pass in a pipeline, the following requirements must be met:
+//
+// 1. Operation Interface Implementation: Operations that define functions
+// or call functions should implement `mlir::FunctionOpInterface` and
+// `mlir::CallOpInterface` respectively.
+//
+// 2. Analysis Registration (Optional): If custom behavior is needed for
+// determining if a symbol use is valid within GPU regions, the dialect
+// should pre-register the `acc::OpenACCSupport` analysis.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+
+#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include <queue>
+
+#define DEBUG_TYPE "acc-implicit-routine"
+
+namespace mlir {
+namespace acc {
+#define GEN_PASS_DEF_ACCIMPLICITROUTINE
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+} // namespace acc
+} // namespace mlir
+
+namespace {
+
+using namespace mlir;
+
+class ACCImplicitRoutine
+ : public acc::impl::ACCImplicitRoutineBase<ACCImplicitRoutine> {
+private:
+ unsigned routineCounter = 0;
+ static constexpr llvm::StringRef accRoutinePrefix = "acc_routine_";
+
+ // Count existing routine operations and update counter
+ void initRoutineCounter(ModuleOp module) {
+ module.walk([&](acc::RoutineOp routineOp) { routineCounter++; });
+ }
+
+ // Check if routine has a default bind clause or a device-type specific bind
+ // clause. Returns true if `acc routine` has a default bind clause or
+ // a device-type specific bind clause.
+ bool isACCRoutineBindDefaultOrDeviceType(acc::RoutineOp op,
+ acc::DeviceType deviceType) {
+ // Fast check to avoid device-type specific lookups.
+ if (!op.getBindIdName() && !op.getBindStrName())
+ return false;
+ return op.getBindNameValue().has_value() ||
+ op.getBindNameValue(deviceType).has_value();
+ }
+
+ // Generate a unique name for the routine and create the routine operation
+ acc::RoutineOp createRoutineOp(OpBuilder &builder, Location loc,
+ FunctionOpInterface &callee) {
+ std::string routineName =
+ (accRoutinePrefix + std::to_string(routineCounter++)).str();
+ auto routineOp = acc::RoutineOp::create(
+ builder, loc,
+ /* sym_name=*/builder.getStringAttr(routineName),
+ /* func_name=*/
+ mlir::SymbolRefAttr::get(builder.getContext(),
+ builder.getStringAttr(callee.getName())),
+ /* bindIdName=*/nullptr,
+ /* bindStrName=*/nullptr,
+ /* bindIdNameDeviceType=*/nullptr,
+ /* bindStrNameDeviceType=*/nullptr,
+ /* worker=*/nullptr,
+ /* vector=*/nullptr,
+ /* seq=*/nullptr,
+ /* nohost=*/nullptr,
+ /* implicit=*/builder.getUnitAttr(),
+ /* gang=*/nullptr,
+ /* gangDim=*/nullptr,
+ /* gangDimDeviceType=*/nullptr);
+
+ // Assert that the callee does not already have routine info attribute
+ assert(!callee->hasAttr(acc::getRoutineInfoAttrName()) &&
+ "function is already associated with a routine");
+
+ callee->setAttr(
+ acc::getRoutineInfoAttrName(),
+ mlir::acc::RoutineInfoAttr::get(
+ builder.getContext(),
+ {mlir::SymbolRefAttr::get(builder.getContext(),
+ builder.getStringAttr(routineName))}));
+ return routineOp;
+ }
+
+ // Used to walk through a compute region looking for function calls.
+ void
+ implicitRoutineForCallsInComputeRegions(Operation *op, SymbolTable &symTab,
+ mlir::OpBuilder &builder,
+ acc::OpenACCSupport &accSupport) {
+ op->walk([&](CallOpInterface callOp) {
+ if (!callOp.getCallableForCallee())
+ return;
+
+ auto calleeSymbolRef =
+ dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee());
+ // When call is done through ssa value, the callee is not a symbol.
+ // Skip it because we don't know the call target.
+ if (!calleeSymbolRef)
+ return;
+
+ auto callee = symTab.lookup<FunctionOpInterface>(
+ calleeSymbolRef.getLeafReference().str());
+ // If the callee does not exist or is already a valid symbol for GPU
+ // regions, skip it
+
+ assert(callee && "callee function must be found in symbol table");
+ if (accSupport.isValidSymbolUse(callOp.getOperation(), calleeSymbolRef))
+ return;
+ builder.setInsertionPoint(callee);
+ createRoutineOp(builder, callee.getLoc(), callee);
+ });
+ }
+
+ // Recursively handle calls within a routine operation
+ void implicitRoutineForCallsInRoutine(acc::RoutineOp routineOp,
+ mlir::OpBuilder &builder,
+ acc::OpenACCSupport &accSupport,
+ acc::DeviceType targetDeviceType) {
+ // When bind clause is used, it means that the target is different than the
+ // function to which the `acc routine` is used with. Skip this case to
+ // avoid implicitly recursively marking calls that would not end up on
+ // device.
+ if (isACCRoutineBindDefaultOrDeviceType(routineOp, targetDeviceType))
+ return;
+
+ SymbolTable symTab(routineOp->getParentOfType<ModuleOp>());
+ std::queue<acc::RoutineOp> routineQueue;
+ routineQueue.push(routineOp);
+ while (!routineQueue.empty()) {
+ auto currentRoutine = routineQueue.front();
+ routineQueue.pop();
+ auto func = symTab.lookup<FunctionOpInterface>(
+ currentRoutine.getFuncName().getLeafReference());
+ func.walk([&](CallOpInterface callOp) {
+ if (!callOp.getCallableForCallee())
+ return;
+
+ auto calleeSymbolRef =
+ dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee());
+ // When call is done through ssa value, the callee is not a symbol.
+ // Skip it because we don't know the call target.
+ if (!calleeSymbolRef)
+ return;
+
+ auto callee = symTab.lookup<FunctionOpInterface>(
+ calleeSymbolRef.getLeafReference().str());
+ // If the callee does not exist or is already a valid symbol for GPU
+ // regions, skip it
+ assert(callee && "callee function must be found in symbol table");
+ if (accSupport.isValidSymbolUse(callOp.getOperation(), calleeSymbolRef))
+ return;
+ builder.setInsertionPoint(callee);
+ auto newRoutineOp = createRoutineOp(builder, callee.getLoc(), callee);
+ routineQueue.push(newRoutineOp);
+ });
+ }
+ }
+
+public:
+ using ACCImplicitRoutineBase<ACCImplicitRoutine>::ACCImplicitRoutineBase;
+
+ void runOnOperation() override {
+ auto module = getOperation();
+ mlir::OpBuilder builder(module.getContext());
+ SymbolTable symTab(module);
+ initRoutineCounter(module);
+
+ acc::OpenACCSupport &accSupport = getAnalysis<acc::OpenACCSupport>();
+
+ // Handle compute regions
+ module.walk([&](Operation *op) {
+ if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(op))
+ implicitRoutineForCallsInComputeRegions(op, symTab, builder,
+ accSupport);
+ });
+
+ // Use the device type option from the pass options.
+ acc::DeviceType targetDeviceType = deviceType;
+
+ // Handle existing routines
+ module.walk([&](acc::RoutineOp routineOp) {
+ implicitRoutineForCallsInRoutine(routineOp, builder, accSupport,
+ targetDeviceType);
+ });
+ }
+};
+
+} // namespace
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp
new file mode 100644
index 0000000..f41ce276
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp
@@ -0,0 +1,117 @@
+//===- ACCLegalizeSerial.cpp - Legalize ACC Serial region -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass converts acc.serial into acc.parallel with num_gangs(1)
+// num_workers(1) vector_length(1).
+//
+// This transformation simplifies processing of acc regions by unifying the
+// handling of serial and parallel constructs. Since an OpenACC serial region
+// executes sequentially (like a parallel region with a single gang, worker, and
+// vector), this conversion is semantically equivalent while enabling code reuse
+// in later compilation stages.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Region.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+namespace acc {
+#define GEN_PASS_DEF_ACCLEGALIZESERIAL
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+} // namespace acc
+} // namespace mlir
+
+#define DEBUG_TYPE "acc-legalize-serial"
+
+namespace {
+using namespace mlir;
+
+struct ACCSerialOpConversion : public OpRewritePattern<acc::SerialOp> {
+ using OpRewritePattern<acc::SerialOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(acc::SerialOp serialOp,
+ PatternRewriter &rewriter) const override {
+
+ const Location loc = serialOp.getLoc();
+
+ // Create a container holding the constant value of 1 for use as the
+ // num_gangs, num_workers, and vector_length attributes.
+ llvm::SmallVector<mlir::Value> numValues;
+ auto value = arith::ConstantIntOp::create(rewriter, loc, 1, 32);
+ numValues.push_back(value);
+
+ // Since num_gangs is specified as both attributes and values, create a
+ // segment attribute.
+ llvm::SmallVector<int32_t> numGangsSegments;
+ numGangsSegments.push_back(numValues.size());
+ auto gangSegmentsAttr = rewriter.getDenseI32ArrayAttr(numGangsSegments);
+
+ // Create a device_type attribute set to `none` which ensures that
+ // the parallel dimensions specification applies to the default clauses.
+ llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
+ auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+ rewriter.getContext(), mlir::acc::DeviceType::None);
+ crtDeviceTypes.push_back(crtDeviceTypeAttr);
+ auto devTypeAttr =
+ mlir::ArrayAttr::get(rewriter.getContext(), crtDeviceTypes);
+
+ LLVM_DEBUG(llvm::dbgs() << "acc.serial OP: " << serialOp << "\n");
+
+ // Create a new acc.parallel op with the same operands - except include the
+ // num_gangs, num_workers, and vector_length attributes.
+ acc::ParallelOp parOp = acc::ParallelOp::create(
+ rewriter, loc, serialOp.getAsyncOperands(),
+ serialOp.getAsyncOperandsDeviceTypeAttr(), serialOp.getAsyncOnlyAttr(),
+ serialOp.getWaitOperands(), serialOp.getWaitOperandsSegmentsAttr(),
+ serialOp.getWaitOperandsDeviceTypeAttr(),
+ serialOp.getHasWaitDevnumAttr(), serialOp.getWaitOnlyAttr(), numValues,
+ gangSegmentsAttr, devTypeAttr, numValues, devTypeAttr, numValues,
+ devTypeAttr, serialOp.getIfCond(), serialOp.getSelfCond(),
+ serialOp.getSelfAttrAttr(), serialOp.getReductionOperands(),
+ serialOp.getPrivateOperands(), serialOp.getFirstprivateOperands(),
+ serialOp.getDataClauseOperands(), serialOp.getDefaultAttrAttr(),
+ serialOp.getCombinedAttr());
+
+ parOp.getRegion().takeBody(serialOp.getRegion());
+
+ LLVM_DEBUG(llvm::dbgs() << "acc.parallel OP: " << parOp << "\n");
+ rewriter.replaceOp(serialOp, parOp);
+
+ return success();
+ }
+};
+
+class ACCLegalizeSerial
+ : public mlir::acc::impl::ACCLegalizeSerialBase<ACCLegalizeSerial> {
+public:
+ using ACCLegalizeSerialBase<ACCLegalizeSerial>::ACCLegalizeSerialBase;
+ void runOnOperation() override {
+ func::FuncOp funcOp = getOperation();
+ MLIRContext *context = funcOp.getContext();
+ RewritePatternSet patterns(context);
+ patterns.insert<ACCSerialOpConversion>(context);
+ (void)applyPatternsGreedily(funcOp, std::move(patterns));
+ }
+};
+
+} // namespace
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
index 7d93495..10a1796 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
@@ -1,4 +1,8 @@
add_mlir_dialect_library(MLIROpenACCTransforms
+ ACCImplicitData.cpp
+ ACCImplicitDeclare.cpp
+ ACCImplicitRoutine.cpp
+ ACCLegalizeSerial.cpp
LegalizeDataValues.cpp
ADDITIONAL_HEADER_DIRS
@@ -14,7 +18,10 @@ add_mlir_dialect_library(MLIROpenACCTransforms
MLIROpenACCTypeInterfacesIncGen
LINK_LIBS PUBLIC
+ MLIRAnalysis
+ MLIROpenACCAnalysis
MLIROpenACCDialect
+ MLIROpenACCUtils
MLIRFuncDialect
MLIRIR
MLIRPass
diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
index fbac28e..7f27b44 100644
--- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
+++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
@@ -9,8 +9,13 @@
#include "mlir/Dialect/OpenACC/OpenACCUtils.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/Intrinsics.h"
#include "llvm/Support/Casting.h"
mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region &region) {
@@ -155,3 +160,109 @@ mlir::Value mlir::acc::getBaseEntity(mlir::Value val) {
return val;
}
+
+bool mlir::acc::isValidSymbolUse(mlir::Operation *user,
+ mlir::SymbolRefAttr symbol,
+ mlir::Operation **definingOpPtr) {
+ mlir::Operation *definingOp =
+ mlir::SymbolTable::lookupNearestSymbolFrom(user, symbol);
+
+ // If there are no defining ops, we have no way to ensure validity because
+ // we cannot check for any attributes.
+ if (!definingOp)
+ return false;
+
+ if (definingOpPtr)
+ *definingOpPtr = definingOp;
+
+ // Check if the defining op is a recipe (private, reduction, firstprivate).
+ // Recipes are valid as they get materialized before being offloaded to
+ // device. They are only instructions for how to materialize.
+ if (mlir::isa<mlir::acc::PrivateRecipeOp, mlir::acc::ReductionRecipeOp,
+ mlir::acc::FirstprivateRecipeOp>(definingOp))
+ return true;
+
+ // Check if the defining op is a function
+ if (auto func =
+ mlir::dyn_cast_if_present<mlir::FunctionOpInterface>(definingOp)) {
+ // If this symbol is actually an acc routine - then it is expected for it
+ // to be offloaded - therefore it is valid.
+ if (func->hasAttr(mlir::acc::getRoutineInfoAttrName()))
+ return true;
+
+ // If this symbol is a call to an LLVM intrinsic, then it is likely valid.
+ // Check the following:
+ // 1. The function is private
+ // 2. The function has no body
+ // 3. Name starts with "llvm."
+ // 4. The function's name is a valid LLVM intrinsic name
+ if (func.getVisibility() == mlir::SymbolTable::Visibility::Private &&
+ func.getFunctionBody().empty() && func.getName().starts_with("llvm.") &&
+ llvm::Intrinsic::lookupIntrinsicID(func.getName()) !=
+ llvm::Intrinsic::not_intrinsic)
+ return true;
+ }
+
+ // A declare attribute is needed for symbol references.
+ bool hasDeclare = definingOp->hasAttr(mlir::acc::getDeclareAttrName());
+ return hasDeclare;
+}
+
+llvm::SmallVector<mlir::Value>
+mlir::acc::getDominatingDataClauses(mlir::Operation *computeConstructOp,
+ mlir::DominanceInfo &domInfo,
+ mlir::PostDominanceInfo &postDomInfo) {
+ llvm::SmallSetVector<mlir::Value, 8> dominatingDataClauses;
+
+ llvm::TypeSwitch<mlir::Operation *>(computeConstructOp)
+ .Case<mlir::acc::ParallelOp, mlir::acc::KernelsOp, mlir::acc::SerialOp>(
+ [&](auto op) {
+ for (auto dataClause : op.getDataClauseOperands()) {
+ dominatingDataClauses.insert(dataClause);
+ }
+ })
+ .Default([](mlir::Operation *) {});
+
+ // Collect the data clauses from enclosing data constructs.
+ mlir::Operation *currParentOp = computeConstructOp->getParentOp();
+ while (currParentOp) {
+ if (mlir::isa<mlir::acc::DataOp>(currParentOp)) {
+ for (auto dataClause : mlir::dyn_cast<mlir::acc::DataOp>(currParentOp)
+ .getDataClauseOperands()) {
+ dominatingDataClauses.insert(dataClause);
+ }
+ }
+ currParentOp = currParentOp->getParentOp();
+ }
+
+ // Find the enclosing function/subroutine
+ auto funcOp =
+ computeConstructOp->getParentOfType<mlir::FunctionOpInterface>();
+ if (!funcOp)
+ return dominatingDataClauses.takeVector();
+
+ // Walk the function to find `acc.declare_enter`/`acc.declare_exit` pairs that
+ // dominate and post-dominate the compute construct and add their data
+ // clauses to the list.
+ funcOp->walk([&](mlir::acc::DeclareEnterOp declareEnterOp) {
+ if (domInfo.dominates(declareEnterOp.getOperation(), computeConstructOp)) {
+ // Collect all `acc.declare_exit` ops for this token.
+ llvm::SmallVector<mlir::acc::DeclareExitOp> exits;
+ for (auto *user : declareEnterOp.getToken().getUsers())
+ if (auto declareExit = mlir::dyn_cast<mlir::acc::DeclareExitOp>(user))
+ exits.push_back(declareExit);
+
+ // Only add clauses if every `acc.declare_exit` op post-dominates the
+ // compute construct.
+ if (!exits.empty() &&
+ llvm::all_of(exits, [&](mlir::acc::DeclareExitOp exitOp) {
+ return postDomInfo.postDominates(exitOp, computeConstructOp);
+ })) {
+ for (auto dataClause : declareEnterOp.getDataClauseOperands())
+ dominatingDataClauses.insert(dataClause);
+ }
+ }
+ });
+
+ return dominatingDataClauses.takeVector();
+}
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 1b069c6..103295d 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -617,6 +617,7 @@ parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
break;
case ClauseScheduleKind::Auto:
case ClauseScheduleKind::Runtime:
+ case ClauseScheduleKind::Distribute:
chunkSize = std::nullopt;
}
@@ -1817,6 +1818,9 @@ static ParseResult parseMapClause(OpAsmParser &parser,
if (mapTypeMod == "ref_ptr_ptee")
mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;
+ if (mapTypeMod == "is_device_ptr")
+ mapTypeBits |= ClauseMapFlags::is_device_ptr;
+
return success();
};
@@ -1886,6 +1890,8 @@ static void printMapClause(OpAsmPrinter &p, Operation *op,
mapTypeStrs.push_back("ref_ptee");
if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr_ptee))
mapTypeStrs.push_back("ref_ptr_ptee");
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::is_device_ptr))
+ mapTypeStrs.push_back("is_device_ptr");
if (mapFlags == ClauseMapFlags::none)
mapTypeStrs.push_back("none");
@@ -2824,6 +2830,7 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
ArrayRef<NamedAttribute> attributes) {
build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
/*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
+ /*linear_var_types*/ nullptr,
/*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
/*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
/*private_needs_barrier=*/false,
@@ -2842,8 +2849,8 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
WsloopOp::build(
builder, state,
/*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
- clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
- clauses.ordered, clauses.privateVars,
+ clauses.linearStepVars, clauses.linearVarTypes, clauses.nowait,
+ clauses.order, clauses.orderMod, clauses.ordered, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
clauses.reductionMod, clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
@@ -2888,17 +2895,16 @@ LogicalResult WsloopOp::verifyRegions() {
void SimdOp::build(OpBuilder &builder, OperationState &state,
const SimdOperands &clauses) {
MLIRContext *ctx = builder.getContext();
- // TODO Store clauses in op: linearVars, linearStepVars
- SimdOp::build(builder, state, clauses.alignedVars,
- makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
- /*linear_vars=*/{}, /*linear_step_vars=*/{},
- clauses.nontemporalVars, clauses.order, clauses.orderMod,
- clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
- clauses.privateNeedsBarrier, clauses.reductionMod,
- clauses.reductionVars,
- makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
- makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
- clauses.simdlen);
+ SimdOp::build(
+ builder, state, clauses.alignedVars,
+ makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
+ clauses.linearVars, clauses.linearStepVars, clauses.linearVarTypes,
+ clauses.nontemporalVars, clauses.order, clauses.orderMod,
+ clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
+ clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+ makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
+ clauses.simdlen);
}
LogicalResult SimdOp::verify() {
diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index 423e1c3..b111117 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -19,5 +19,5 @@ add_mlir_dialect_library(MLIRSCFDialect
MLIRSideEffectInterfaces
MLIRTensorDialect
MLIRValueBoundsOpInterface
+ MLIRTransformUtils
)
-
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 2946b53..c4bd31f 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -26,6 +26,7 @@
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
+#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
@@ -2565,6 +2566,39 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
struct ConditionPropagation : public OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;
+ /// Kind of parent region in the ancestor cache.
+ enum class Parent { Then, Else, None };
+
+ /// Returns the kind of region ("then", "else", or "none") of the
+ /// IfOp that the given region is transitively nested in. Updates
+ /// the cache accordingly.
+ static Parent getParentType(Region *toCheck, IfOp op,
+ DenseMap<Region *, Parent> &cache,
+ Region *endRegion) {
+ SmallVector<Region *> seen;
+ while (toCheck != endRegion) {
+ auto found = cache.find(toCheck);
+ if (found != cache.end())
+ return found->second;
+ seen.push_back(toCheck);
+ if (&op.getThenRegion() == toCheck) {
+ for (Region *region : seen)
+ cache[region] = Parent::Then;
+ return Parent::Then;
+ }
+ if (&op.getElseRegion() == toCheck) {
+ for (Region *region : seen)
+ cache[region] = Parent::Else;
+ return Parent::Else;
+ }
+ toCheck = toCheck->getParentRegion();
+ }
+
+ for (Region *region : seen)
+ cache[region] = Parent::None;
+ return Parent::None;
+ }
+
LogicalResult matchAndRewrite(IfOp op,
PatternRewriter &rewriter) const override {
// Early exit if the condition is constant since replacing a constant
@@ -2580,9 +2614,12 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
Value constantTrue = nullptr;
Value constantFalse = nullptr;
+ DenseMap<Region *, Parent> cache;
for (OpOperand &use :
llvm::make_early_inc_range(op.getCondition().getUses())) {
- if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
+ switch (getParentType(use.getOwner()->getParentRegion(), op, cache,
+ op.getCondition().getParentRegion())) {
+ case Parent::Then: {
changed = true;
if (!constantTrue)
@@ -2591,8 +2628,9 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
rewriter.modifyOpInPlace(use.getOwner(),
[&]() { use.set(constantTrue); });
- } else if (op.getElseRegion().isAncestor(
- use.getOwner()->getParentRegion())) {
+ break;
+ }
+ case Parent::Else: {
changed = true;
if (!constantFalse)
@@ -2601,6 +2639,10 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
rewriter.modifyOpInPlace(use.getOwner(),
[&]() { use.set(constantFalse); });
+ break;
+ }
+ case Parent::None:
+ break;
}
}
@@ -3646,6 +3688,133 @@ LogicalResult scf::WhileOp::verify() {
}
namespace {
+/// Move a scf.if op that is directly before the scf.condition op in the while
+/// before region, and whose condition matches the condition of the
+/// scf.condition op, down into the while after region.
+///
+/// scf.while (..) : (...) -> ... {
+/// %additional_used_values = ...
+/// %cond = ...
+/// ...
+/// %res = scf.if %cond -> (...) {
+/// use(%additional_used_values)
+/// ... // then block
+/// scf.yield %then_value
+/// } else {
+/// scf.yield %else_value
+/// }
+/// scf.condition(%cond) %res, ...
+/// } do {
+/// ^bb0(%res_arg, ...):
+/// use(%res_arg)
+/// ...
+///
+/// becomes
+/// scf.while (..) : (...) -> ... {
+/// %additional_used_values = ...
+/// %cond = ...
+/// ...
+/// scf.condition(%cond) %else_value, ..., %additional_used_values
+/// } do {
+/// ^bb0(%res_arg ..., %additional_args): :
+/// use(%additional_args)
+/// ... // if then block
+/// use(%then_value)
+/// ...
+struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
+ using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(scf::WhileOp op,
+ PatternRewriter &rewriter) const override {
+ auto conditionOp = op.getConditionOp();
+
+ // Only support ifOp right before the condition at the moment. Relaxing this
+ // would require to:
+ // - check that the body does not have side-effects conflicting with
+ // operations between the if and the condition.
+ // - check that results of the if operation are only used as arguments to
+ // the condition.
+ auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
+
+ // Check that the ifOp is directly before the conditionOp and that it
+ // matches the condition of the conditionOp. Also ensure that the ifOp has
+ // no else block with content, as that would complicate the transformation.
+ // TODO: support else blocks with content.
+ if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
+ (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
+ return failure();
+
+ assert((ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
+ *ifOp->user_begin() == conditionOp)) &&
+ "ifOp has unexpected uses");
+
+ Location loc = op.getLoc();
+
+ // Replace uses of ifOp results in the conditionOp with the yielded values
+ // from the ifOp branches.
+ for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
+ auto it = llvm::find(ifOp->getResults(), arg);
+ if (it != ifOp->getResults().end()) {
+ size_t ifOpIdx = it.getIndex();
+ Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
+ Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
+
+ rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
+ rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
+ }
+ }
+
+ // Collect additional used values from before region.
+ SetVector<Value> additionalUsedValuesSet;
+ visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) {
+ if (&op.getBefore() == operand->get().getParentRegion())
+ additionalUsedValuesSet.insert(operand->get());
+ });
+
+ // Create new whileOp with additional used values as results.
+ auto additionalUsedValues = additionalUsedValuesSet.getArrayRef();
+ auto additionalValueTypes = llvm::map_to_vector(
+ additionalUsedValues, [](Value val) { return val.getType(); });
+ size_t additionalValueSize = additionalUsedValues.size();
+ SmallVector<Type> newResultTypes(op.getResultTypes());
+ newResultTypes.append(additionalValueTypes);
+
+ auto newWhileOp =
+ scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
+
+ rewriter.modifyOpInPlace(newWhileOp, [&] {
+ newWhileOp.getBefore().takeBody(op.getBefore());
+ newWhileOp.getAfter().takeBody(op.getAfter());
+ newWhileOp.getAfter().addArguments(
+ additionalValueTypes,
+ SmallVector<Location>(additionalValueSize, loc));
+ });
+
+ rewriter.modifyOpInPlace(conditionOp, [&] {
+ conditionOp.getArgsMutable().append(additionalUsedValues);
+ });
+
+ // Replace uses of additional used values inside the ifOp then region with
+ // the whileOp after region arguments.
+ rewriter.replaceUsesWithIf(
+ additionalUsedValues,
+ newWhileOp.getAfterArguments().take_back(additionalValueSize),
+ [&](OpOperand &use) {
+ return ifOp.getThenRegion().isAncestor(
+ use.getOwner()->getParentRegion());
+ });
+
+ // Inline ifOp then region into new whileOp after region.
+ rewriter.eraseOp(ifOp.thenYield());
+ rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
+ newWhileOp.getAfterBody()->begin());
+ rewriter.eraseOp(ifOp);
+ rewriter.replaceOp(op,
+ newWhileOp->getResults().drop_back(additionalValueSize));
+ return success();
+ }
+};
+
/// Replace uses of the condition within the do block with true, since otherwise
/// the block would not be evaluated.
///
@@ -4302,7 +4471,7 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
LogicalResult matchAndRewrite(WhileOp loop,
PatternRewriter &rewriter) const override {
- auto oldBefore = loop.getBeforeBody();
+ auto *oldBefore = loop.getBeforeBody();
ConditionOp oldTerm = loop.getConditionOp();
ValueRange beforeArgs = oldBefore->getArguments();
ValueRange termArgs = oldTerm.getArgs();
@@ -4323,7 +4492,7 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
beforeArgs);
}
- auto oldAfter = loop.getAfterBody();
+ auto *oldAfter = loop.getAfterBody();
SmallVector<Type> newResultTypes(beforeArgs.size());
for (auto &&[i, j] : llvm::enumerate(*mapping))
@@ -4332,8 +4501,8 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
auto newLoop = WhileOp::create(
rewriter, loop.getLoc(), newResultTypes, loop.getInits(),
/*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
- auto newBefore = newLoop.getBeforeBody();
- auto newAfter = newLoop.getAfterBody();
+ auto *newBefore = newLoop.getBeforeBody();
+ auto *newAfter = newLoop.getAfterBody();
SmallVector<Value> newResults(beforeArgs.size());
SmallVector<Value> newAfterArgs(beforeArgs.size());
@@ -4358,7 +4527,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<RemoveLoopInvariantArgsFromBeforeBlock,
RemoveLoopInvariantValueYielded, WhileConditionTruth,
WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
- WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
+ WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp
index 8f7d5e3..c469a99 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp
@@ -44,7 +44,6 @@ mlir::scf::parallelForToNestedFors(RewriterBase &rewriter,
lowerBounds.size() == steps.size() &&
"Mismatched parallel loop bounds");
- SmallVector<Value> ivs;
scf::LoopNest loopNest =
scf::buildLoopNest(rewriter, loc, lowerBounds, upperBounds, steps);
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 29b770f..009c2c3 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1092,7 +1092,7 @@ static LogicalResult addInitOperandsToLoopNest(
for (auto [outerLoop, innerLoop] :
llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
// Again assume that all the outer loops are scf.for operations.
- auto outerForLoop = cast<scf::ForOp>(outerLoop);
+ auto outerForLoop = cast<scf::ForOp>(outerLoop.getOperation());
auto outerLoopYield =
cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
SmallVector<Value> newYields =
@@ -2184,61 +2184,24 @@ cloneAsInsertSlices(RewriterBase &rewriter,
return clonedSlices;
}
-/// Implementation of fusing consumer of a single slice by computing the
-/// slice of the consumer in-place for scf loop.
-FailureOr<scf::SCFFuseConsumerOfSliceResult>
-mlir::scf::tileAndFuseConsumerOfSlices(
- RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
- MutableArrayRef<LoopLikeOpInterface> loops) {
- if (candidateSlices.empty()) {
- return rewriter.notifyMatchFailure(
- rewriter.getUnknownLoc(),
- "no candidate slices provided for consumer fusion");
- }
- // Return if `loops` is empty, return an error for now. Caller is expected
- // to handle this case.
- if (loops.empty()) {
- return rewriter.notifyMatchFailure(
- candidateSlices.front(),
- "cannot call tile and fuse consumer with an empty loop nest");
- }
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSlicesImpl(RewriterBase &rewriter, Operation *consumerOp,
+ ArrayRef<OpOperand *> consumerOpOperands,
+ ArrayRef<Operation *> candidateSlices,
+ MutableArrayRef<LoopLikeOpInterface> loops) {
+ assert(!loops.empty() && "expected loops to be not empty");
- if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
- llvm::all_of(candidateSlices,
- llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
+ // 1. Check assumption for loop with `reorderOperations` disabled.
+ if (failed(checkAssumptionForLoop(loops.front(), consumerOp, false))) {
return rewriter.notifyMatchFailure(
- candidateSlices.front(),
- "candidates slices need to be all `tensor.extract_slice`s or "
- "`tensor.parallel_insert_slice`s");
- }
-
- // 1. Get the consumer of scf.for for the result yielded by
- // tensor.insert_slice/parallel_insert_slice.
- SmallVector<OpOperand *> consumerOpOperands;
- Operation *consumerOp;
- {
- FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
- getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
- if (failed(maybeConsumerOpOperand)) {
- return rewriter.notifyMatchFailure(candidateSlices.front(),
- "could not fetch consumer to fuse");
- }
- std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
- consumerOp = consumerOpOperands.front()->getOwner();
+ loops.front(), "the first user of loop should not dominate any define "
+ "of consumer operand(s)");
}
LoopLikeOpInterface outerMostLoop = loops.front();
LoopLikeOpInterface innerMostLoop = loops.back();
- // Check assumption for loop with `reorderOperations` disabled.
- if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
- return rewriter.notifyMatchFailure(
- outerMostLoop, "the first user of loop should not dominate any define "
- "of consumer operand(s)");
- }
-
OpBuilder::InsertionGuard g(rewriter);
-
// 2. Check consumer is not using scf loop's output as init.
auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
if (!dstOp)
@@ -2428,11 +2391,166 @@ mlir::scf::tileAndFuseConsumerOfSlices(
llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
});
+ auto consumerOpOperandsVec = llvm::to_vector(consumerOpOperands);
return scf::SCFFuseConsumerOfSliceResult{
- std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands),
+ std::move(consumerOpOperandsVec), std::move(tiledAndFusedOpOperands),
std::move(tileAndFuseResult->tiledOps)};
}
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf loop.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumerOfSlices(
+ RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
+ MutableArrayRef<LoopLikeOpInterface> loops) {
+ if (candidateSlices.empty()) {
+ return rewriter.notifyMatchFailure(
+ rewriter.getUnknownLoc(),
+ "no candidate slices provided for consumer fusion");
+ }
+ // Return if `loops` is empty, return an error for now. Caller is expected
+ // to handle this case.
+ if (loops.empty()) {
+ return rewriter.notifyMatchFailure(
+ candidateSlices.front(),
+ "cannot call tile and fuse consumer with an empty loop nest");
+ }
+
+ if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
+ llvm::all_of(candidateSlices,
+ llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
+ return rewriter.notifyMatchFailure(
+ candidateSlices.front(),
+ "candidates slices need to be all `tensor.extract_slice`s or "
+ "`tensor.parallel_insert_slice`s");
+ }
+
+ // Get the consumer of scf.for for the result yielded by
+ // tensor.insert_slice/parallel_insert_slice.
+ FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperands =
+ getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
+ if (failed(maybeConsumerOpOperands)) {
+ return rewriter.notifyMatchFailure(candidateSlices.front(),
+ "could not fetch consumer to fuse");
+ }
+ Operation *consumerOp = maybeConsumerOpOperands->front()->getOwner();
+
+ return tileAndFuseConsumerOfSlicesImpl(rewriter, consumerOp,
+ maybeConsumerOpOperands.value(),
+ candidateSlices, loops);
+}
+
+/// For a given `result` of a `forallOp` return the
+/// `tensor.parallel_insert_slice` op (or combining op) that is used to
+/// construct this result.
+static std::optional<Operation *>
+getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result) {
+ if (result.getOwner() != forallOp)
+ return std::nullopt;
+ BlockArgument bbArg = forallOp.getTiedBlockArgument(result);
+ SmallVector<Operation *> combiningOps = forallOp.getCombiningOps(bbArg);
+ // If the number of combining ops is not 1, then this is unexpected. Return
+ // nullopt.
+ if (combiningOps.size() != 1)
+ return std::nullopt;
+ return combiningOps[0];
+}
+
+/// For a given result of the loop nest that is a tiled loop nest, return the
+/// insert slice-like op that is used for consumer fusion
+static std::optional<Operation *>
+getProducingInsertSliceLikeOp(OpResult result,
+ ArrayRef<LoopLikeOpInterface> loops) {
+ assert(!loops.empty() && "Expected loops to be not empty");
+ LoopLikeOpInterface outerMostLoop = loops.front();
+ if (auto forallOp = dyn_cast<scf::ForallOp>(outerMostLoop.getOperation())) {
+ assert(loops.size() == 1 &&
+ "expected only a single loop when tiling using scf.forall");
+ return getProducingParallelInsertSlice(forallOp, result);
+ }
+ // Assume that the loop nest is a nested `scf.for` that is created through
+ // tiling and retrieve the `tensor.insert_slice` operation used to construct
+ // the result.
+ while (loops.size() != 1) {
+ LoopLikeOpInterface loop = loops.front();
+ if (result.getOwner() != loop)
+ return std::nullopt;
+ auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
+ if (!forOp)
+ return std::nullopt;
+ auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+ auto innerForResult =
+ dyn_cast<OpResult>(yieldOp.getOperand(result.getResultNumber()));
+ if (!innerForResult)
+ return std::nullopt;
+ result = innerForResult;
+ loops = loops.drop_front();
+ }
+ LoopLikeOpInterface loop = loops.front();
+ if (result.getOwner() != loop)
+ return std::nullopt;
+ auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
+ if (!forOp)
+ return std::nullopt;
+ auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+ auto insertSliceOp = yieldOp.getOperand(result.getResultNumber())
+ .getDefiningOp<tensor::InsertSliceOp>();
+ if (!insertSliceOp)
+ return std::nullopt;
+ return insertSliceOp;
+}
+
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer,
+ MutableArrayRef<LoopLikeOpInterface> loops) {
+ if (!isa<TilingInterface>(consumer)) {
+ return rewriter.notifyMatchFailure(
+ consumer, "unhandled consumer that does not implement TilingInterface");
+ }
+
+ // Return if `loops` is empty, return an error for now. Caller is expected
+ // to handle this case.
+ if (loops.empty()) {
+ return rewriter.notifyMatchFailure(
+ consumer, "cannot call tile and fuse consumer with an empty loop nest");
+ }
+
+ LoopLikeOpInterface outermostLoop = loops.front();
+
+ // Collect the operands of the consumer that come from the outermost loop of
+ // the loop nest.
+ SmallVector<OpOperand *> consumerFusableOperands;
+ for (OpOperand &opOperand : consumer->getOpOperands()) {
+ if (opOperand.get().getDefiningOp() == outermostLoop) {
+ consumerFusableOperands.push_back(&opOperand);
+ }
+ }
+
+ // Nothing to fuse. Just return an empty set.
+ if (consumerFusableOperands.empty()) {
+ return mlir::scf::SCFFuseConsumerOfSliceResult{consumerFusableOperands,
+ SmallVector<OpOperand *>{},
+ SmallVector<Operation *>{}};
+ }
+
+ // Collect the relevant tensor.insert_slice/tensor.parallel_insert_slices
+ // for fusion.
+ SmallVector<Operation *> candidateSlices;
+ candidateSlices.reserve(consumerFusableOperands.size());
+ for (OpOperand *opOperand : consumerFusableOperands) {
+ std::optional<Operation *> slice =
+ getProducingInsertSliceLikeOp(cast<OpResult>(opOperand->get()), loops);
+ if (!slice) {
+ return rewriter.notifyMatchFailure(
+ consumer,
+ "couldnt find producing insert-slice like operation for operand");
+ }
+ candidateSlices.push_back(slice.value());
+ }
+ return tileAndFuseConsumerOfSlicesImpl(
+ rewriter, consumer, consumerFusableOperands, candidateSlices, loops);
+}
+
//===----------------------------------------------------------------------===//
// lowerToLoopsUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index f0b46e6..a846d7e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -220,6 +220,89 @@ MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
}
//===----------------------------------------------------------------------===//
+// spirv.Switch
+//===----------------------------------------------------------------------===//
+
+void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector,
+ Block *defaultTarget, ValueRange defaultOperands,
+ DenseIntElementsAttr literals, BlockRange targets,
+ ArrayRef<ValueRange> targetOperands) {
+ build(builder, result, selector, defaultOperands, targetOperands, literals,
+ defaultTarget, targets);
+}
+
+void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector,
+ Block *defaultTarget, ValueRange defaultOperands,
+ ArrayRef<APInt> literals, BlockRange targets,
+ ArrayRef<ValueRange> targetOperands) {
+ DenseIntElementsAttr literalsAttr;
+ if (!literals.empty()) {
+ ShapedType literalType = VectorType::get(
+ static_cast<int64_t>(literals.size()), selector.getType());
+ literalsAttr = DenseIntElementsAttr::get(literalType, literals);
+ }
+ build(builder, result, selector, defaultTarget, defaultOperands, literalsAttr,
+ targets, targetOperands);
+}
+
+void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector,
+ Block *defaultTarget, ValueRange defaultOperands,
+ ArrayRef<int32_t> literals, BlockRange targets,
+ ArrayRef<ValueRange> targetOperands) {
+ DenseIntElementsAttr literalsAttr;
+ if (!literals.empty()) {
+ ShapedType literalType = VectorType::get(
+ static_cast<int64_t>(literals.size()), selector.getType());
+ literalsAttr = DenseIntElementsAttr::get(literalType, literals);
+ }
+ build(builder, result, selector, defaultTarget, defaultOperands, literalsAttr,
+ targets, targetOperands);
+}
+
+LogicalResult SwitchOp::verify() {
+ std::optional<DenseIntElementsAttr> literals = getLiterals();
+ BlockRange targets = getTargets();
+
+ if (!literals && targets.empty())
+ return success();
+
+ Type selectorType = getSelector().getType();
+ Type literalType = literals->getType().getElementType();
+ if (literalType != selectorType)
+ return emitOpError() << "'selector' type (" << selectorType
+ << ") should match literals type (" << literalType
+ << ")";
+
+ if (literals && literals->size() != static_cast<int64_t>(targets.size()))
+ return emitOpError() << "number of literals (" << literals->size()
+ << ") should match number of targets ("
+ << targets.size() << ")";
+ return success();
+}
+
+SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
+ assert(index < getNumSuccessors() && "invalid successor index");
+ return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
+ : getTargetOperandsMutable(index - 1));
+}
+
+Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
+ std::optional<DenseIntElementsAttr> literals = getLiterals();
+
+ if (!literals)
+ return getDefaultTarget();
+
+ SuccessorRange targets = getTargets();
+ if (auto value = dyn_cast_or_null<IntegerAttr>(operands.front())) {
+ for (auto [index, literal] : llvm::enumerate(literals->getValues<APInt>()))
+ if (literal == value.getValue())
+ return targets[index];
+ return getDefaultTarget();
+ }
+ return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
// spirv.mlir.loop
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
index 2f3a28f..8575487 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
@@ -81,6 +81,83 @@ static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
}
}
+/// Adapted from the cf.switch implementation.
+/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
+/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
+static ParseResult parseSwitchOpCases(
+ OpAsmParser &parser, Type &selectorType, Block *&defaultTarget,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands,
+ SmallVectorImpl<Type> &defaultOperandTypes, DenseIntElementsAttr &literals,
+ SmallVectorImpl<Block *> &targets,
+ SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>>
+ &targetOperands,
+ SmallVectorImpl<SmallVector<Type>> &targetOperandTypes) {
+ if (parser.parseKeyword("default") || parser.parseColon() ||
+ parser.parseSuccessor(defaultTarget))
+ return failure();
+ if (succeeded(parser.parseOptionalLParen())) {
+ if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
+ /*allowResultNumber=*/false) ||
+ parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
+ return failure();
+ }
+
+ SmallVector<APInt> values;
+ unsigned bitWidth = selectorType.getIntOrFloatBitWidth();
+ while (succeeded(parser.parseOptionalComma())) {
+ int64_t value = 0;
+ if (failed(parser.parseInteger(value)))
+ return failure();
+ values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
+
+ Block *target;
+ SmallVector<OpAsmParser::UnresolvedOperand> operands;
+ SmallVector<Type> operandTypes;
+ if (failed(parser.parseColon()) || failed(parser.parseSuccessor(target)))
+ return failure();
+ if (succeeded(parser.parseOptionalLParen())) {
+ if (failed(parser.parseOperandList(operands,
+ OpAsmParser::Delimiter::None)) ||
+ failed(parser.parseColonTypeList(operandTypes)) ||
+ failed(parser.parseRParen()))
+ return failure();
+ }
+ targets.push_back(target);
+ targetOperands.emplace_back(operands);
+ targetOperandTypes.emplace_back(operandTypes);
+ }
+
+ if (!values.empty()) {
+ ShapedType literalType =
+ VectorType::get(static_cast<int64_t>(values.size()), selectorType);
+ literals = DenseIntElementsAttr::get(literalType, values);
+ }
+ return success();
+}
+
+static void
+printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type selectorType,
+ Block *defaultTarget, OperandRange defaultOperands,
+ TypeRange defaultOperandTypes, DenseIntElementsAttr literals,
+ SuccessorRange targets, OperandRangeRange targetOperands,
+ const TypeRangeRange &targetOperandTypes) {
+ p << " default: ";
+ p.printSuccessorAndUseList(defaultTarget, defaultOperands);
+
+ if (!literals)
+ return;
+
+ for (auto [index, literal] : llvm::enumerate(literals.getValues<APInt>())) {
+ p << ',';
+ p.printNewline();
+ p << " ";
+ p << literal.getLimitedValue();
+ p << ": ";
+ p.printSuccessorAndUseList(targets[index], targetOperands[index]);
+ }
+ p.printNewline();
+}
+
} // namespace mlir::spirv
// TablenGen'erated operation definitions.
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index cb9b7f6..f07307f 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -502,6 +502,11 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
<< type << " illegal: cannot handle zero-element tensors\n");
return nullptr;
}
+ if (arrayElemCount > std::numeric_limits<unsigned>::max()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << type << " illegal: cannot fit tensor into target type\n");
+ return nullptr;
+ }
Type arrayElemType = convertScalarType(targetEnv, options, scalarType);
if (!arrayElemType)
diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
index 645cbff..5941f7d 100644
--- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
+++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
@@ -476,38 +476,37 @@ void GridShapeOp::getAsmResultNames(
//===----------------------------------------------------------------------===//
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
- FlatSymbolRefAttr grid,
- ArrayRef<GridAxesAttr> split_axes,
- ArrayRef<int64_t> static_halos,
- ArrayRef<int64_t> static_offsets) {
+ FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> splitAxes,
+ ArrayRef<int64_t> staticHalos,
+ ArrayRef<int64_t> staticOffsets) {
return build(
- b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes),
- ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
- ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
+ b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), splitAxes),
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), {},
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticOffsets), {});
}
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
- llvm::StringRef grid, ArrayRef<GridAxesAttr> split_axes,
- ArrayRef<int64_t> static_halos,
- ArrayRef<int64_t> static_offsets) {
+ llvm::StringRef grid, ArrayRef<GridAxesAttr> splitAxes,
+ ArrayRef<int64_t> staticHalos,
+ ArrayRef<int64_t> staticOffsets) {
return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), grid),
- GridAxesArrayAttr::get(b.getContext(), split_axes),
- ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
- ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets),
+ GridAxesArrayAttr::get(b.getContext(), splitAxes),
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), {},
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticOffsets),
{});
}
void ShardingOp::build(
::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
- FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> split_axes,
- ::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes,
- ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) {
+ FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> splitAxes,
+ ::mlir::ArrayRef<::mlir::OpFoldResult> haloSizes,
+ ::mlir::ArrayRef<::mlir::OpFoldResult> shardedDimsOffsets) {
mlir::SmallVector<int64_t> staticHalos, staticDims;
mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
- dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
- dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
+ dispatchIndexOpFoldResults(haloSizes, dynamicHalos, staticHalos);
+ dispatchIndexOpFoldResults(shardedDimsOffsets, dynamicDims, staticDims);
return build(
- b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes),
+ b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), splitAxes),
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos,
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
}
@@ -576,7 +575,7 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return failure();
}
if (mlir::ShapedType::isDynamicShape(grid->getShape()) &&
- getStaticShardedDimsOffsets().size() > 0) {
+ !getStaticShardedDimsOffsets().empty()) {
return emitError() << "sharded dims offsets are not allowed for "
"device grids with dynamic shape.";
}
@@ -650,14 +649,14 @@ public:
if (dynamicOffs.empty() && !staticOffs.empty()) {
assert(staticOffs.size() >= 2);
auto diff = staticOffs[1] - staticOffs[0];
- bool all_same = staticOffs.size() > 2;
+ bool allSame = staticOffs.size() > 2;
for (auto i = 2u; i < staticOffs.size(); ++i) {
if (staticOffs[i] - staticOffs[i - 1] != diff) {
- all_same = false;
+ allSame = false;
break;
}
}
- if (all_same) {
+ if (allSame) {
staticOffs.clear();
modified = true;
}
@@ -749,7 +748,7 @@ bool Sharding::operator==(const Sharding &rhs) const {
bool Sharding::operator!=(const Sharding &rhs) const { return !(*this == rhs); }
-Sharding::Sharding(::mlir::FlatSymbolRefAttr grid_) : grid(grid_) {}
+Sharding::Sharding(::mlir::FlatSymbolRefAttr grid) : grid(grid) {}
Sharding::Sharding(Value rhs) {
auto shardingOp = rhs.getDefiningOp<ShardingOp>();
@@ -767,21 +766,20 @@ Sharding::Sharding(Value rhs) {
SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
}
-Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid_,
- ArrayRef<GridAxesAttr> split_axes_,
- ArrayRef<int64_t> static_halo_sizes_,
- ArrayRef<int64_t> static_sharded_dims_offsets_,
- ArrayRef<Value> dynamic_halo_sizes_,
- ArrayRef<Value> dynamic_sharded_dims_offsets_) {
- Sharding res(grid_);
- if (split_axes_.empty()) {
+Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid,
+ ArrayRef<GridAxesAttr> splitAxes,
+ ArrayRef<int64_t> staticHaloSizes,
+ ArrayRef<int64_t> staticShardedDimsOffsets,
+ ArrayRef<Value> dynamicHaloSizes,
+ ArrayRef<Value> dynamicShardedDimsOffsets) {
+ Sharding res(grid);
+ if (splitAxes.empty()) {
return res;
}
- res.split_axes.resize(split_axes_.size());
- for (auto [i, axis] : llvm::enumerate(split_axes_)) {
- res.split_axes[i] =
- GridAxesAttr::get(grid_.getContext(), axis.asArrayRef());
+ res.split_axes.resize(splitAxes.size());
+ for (auto [i, axis] : llvm::enumerate(splitAxes)) {
+ res.split_axes[i] = GridAxesAttr::get(grid.getContext(), axis.asArrayRef());
}
auto clone = [](const auto src, auto &dst) {
@@ -789,10 +787,10 @@ Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid_,
llvm::copy(src, dst.begin());
};
- clone(static_halo_sizes_, res.static_halo_sizes);
- clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
- clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
- clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
+ clone(staticHaloSizes, res.static_halo_sizes);
+ clone(staticShardedDimsOffsets, res.static_sharded_dims_offsets);
+ clone(dynamicHaloSizes, res.dynamic_halo_sizes);
+ clone(dynamicShardedDimsOffsets, res.dynamic_sharded_dims_offsets);
return res;
}
@@ -809,10 +807,10 @@ void ShardShapeOp::getAsmResultNames(
void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState,
::llvm::ArrayRef<int64_t> dims,
- ArrayRef<Value> dims_dyn, ::mlir::Value sharding,
+ ArrayRef<Value> dimsDyn, ::mlir::Value sharding,
::mlir::ValueRange device) {
SmallVector<mlir::Type> resType(dims.size(), odsBuilder.getIndexType());
- build(odsBuilder, odsState, resType, dims, dims_dyn, sharding,
+ build(odsBuilder, odsState, resType, dims, dimsDyn, sharding,
SmallVector<int64_t>(device.size(), ShapedType::kDynamic), device);
}
diff --git a/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
index 3bfbf373..f954131 100644
--- a/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
@@ -184,7 +184,7 @@ ReshardingRquirementKind getReshardingRquirementKind(
for (auto [result, sharding] :
llvm::zip_equal(op->getResults(), resultShardings)) {
- for (auto user : result.getUsers()) {
+ for (auto *user : result.getUsers()) {
ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
if (!shardOp) {
continue;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index ae7eef2..9db9814 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -1365,8 +1365,8 @@ public:
arith::SubIOp::create(rewriter, loc, capacity, newSize);
Value fillValue = constantZero(rewriter, loc, value.getType());
Value subBuffer = memref::SubViewOp::create(
- rewriter, loc, newBuffer, /*offset=*/ValueRange{newSize},
- /*size=*/ValueRange{fillSize},
+ rewriter, loc, newBuffer, /*offsets=*/ValueRange{newSize},
+ /*sizes=*/ValueRange{fillSize},
/*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
linalg::FillOp::create(rewriter, loc, fillValue, subBuffer);
}
@@ -1386,8 +1386,8 @@ public:
memref::StoreOp::create(rewriter, loc, value, buffer, size);
} else {
Value subBuffer = memref::SubViewOp::create(
- rewriter, loc, buffer, /*offset=*/ValueRange{size},
- /*size=*/ValueRange{n},
+ rewriter, loc, buffer, /*offsets=*/ValueRange{size},
+ /*sizes=*/ValueRange{n},
/*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
linalg::FillOp::create(rewriter, loc, value, subBuffer);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 7a26cd3..1fbcf5f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1050,7 +1050,7 @@ public:
/// Sparse codegen rule for position accesses.
class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
public:
- using OpAdaptor = typename ToPositionsOp::Adaptor;
+ using OpAdaptor = ToPositionsOp::Adaptor;
using OpConversionPattern<ToPositionsOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor,
@@ -1073,7 +1073,7 @@ public:
class SparseToCoordinatesConverter
: public OpConversionPattern<ToCoordinatesOp> {
public:
- using OpAdaptor = typename ToCoordinatesOp::Adaptor;
+ using OpAdaptor = ToCoordinatesOp::Adaptor;
using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor,
@@ -1099,7 +1099,7 @@ public:
class SparseToCoordinatesBufferConverter
: public OpConversionPattern<ToCoordinatesBufferOp> {
public:
- using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor;
+ using OpAdaptor = ToCoordinatesBufferOp::Adaptor;
using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor,
@@ -1121,7 +1121,7 @@ public:
/// Sparse codegen rule for value accesses.
class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
public:
- using OpAdaptor = typename ToValuesOp::Adaptor;
+ using OpAdaptor = ToValuesOp::Adaptor;
using OpConversionPattern<ToValuesOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index febec6d..23436a6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -132,8 +132,8 @@ static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem,
SmallVector<Value> scalarArgs(idxs);
Value indexVec = idxs.back();
scalarArgs.back() = constantIndex(rewriter, loc, 0);
- vector::ScatterOp::create(rewriter, loc, mem, scalarArgs, indexVec, vmask,
- rhs);
+ vector::ScatterOp::create(rewriter, loc, /*resultType=*/nullptr, mem,
+ scalarArgs, indexVec, vmask, rhs);
return;
}
vector::MaskedStoreOp::create(rewriter, loc, mem, idxs, vmask, rhs);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 869d27a..7e8d360 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -22,7 +22,6 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
-#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
index ffa8b40..9904803 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
@@ -80,6 +80,53 @@ inline static bool includesDenseOutput(SortMask mask) {
return includesAny(mask, SortMask::kIncludeDenseOutput);
}
+/// Returns a sparsity rank for loop ordering: lower values indicate
+/// dimensions that should be placed in outer loops.
+/// 0 = Dense, 1 = Compressed, 2 = Singleton, 3 = Other/Unknown.
+static unsigned getLoopSparsityRank(unsigned loop, ArrayRef<Value> allTensors,
+ ArrayRef<AffineMap> allMaps) {
+ // Start with highest rank.
+ unsigned minRank = 3;
+
+ for (auto [tensor, map] : llvm::zip(allTensors, allMaps)) {
+ // Check if this loop accesses this tensor.
+ bool loopAccessesTensor = false;
+ unsigned tensorDim = 0;
+ for (AffineExpr expr : map.getResults()) {
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+ if (dimExpr.getPosition() == loop) {
+ loopAccessesTensor = true;
+ break;
+ }
+ }
+ tensorDim++;
+ }
+
+ if (loopAccessesTensor) {
+ const auto enc = getSparseTensorEncoding(tensor.getType());
+ if (!enc) {
+ // Dense tensor - lowest rank.
+ return 0;
+ } else {
+ // Sparse tensor - check the level type for this dimension.
+ auto lvlTypes = enc.getLvlTypes();
+ if (tensorDim < lvlTypes.size()) {
+ auto lvlType = lvlTypes[tensorDim];
+ if (isDenseLT(lvlType)) {
+ return 0; // Dense level.
+ } else if (isCompressedLT(lvlType)) {
+ minRank = std::min(minRank, 1u); // Compressed level.
+ } else if (isSingletonLT(lvlType)) {
+ minRank = std::min(minRank, 2u); // Singleton level.
+ }
+ }
+ }
+ }
+ }
+
+ return minRank;
+}
+
AffineMap IterationGraphSorter::topoSort() {
// The sorted result will put the first Reduction iterator to the
// latest possible position.
@@ -107,10 +154,33 @@ AffineMap IterationGraphSorter::topoSort() {
case sparse_tensor::LoopOrderingStrategy::kDefault:
src = it.back();
break;
+ case sparse_tensor::LoopOrderingStrategy::kDenseOuter: {
+ // Prefer dense, then compressed, then singleton dimensions outermost.
+ // Create combined tensor and map lists for analysis.
+ SmallVector<Value> allTensors = ins;
+ allTensors.push_back(out);
+ SmallVector<AffineMap> allMaps = loop2InsLvl;
+ allMaps.push_back(loop2OutLvl);
+
+ // Find loop with minimum (lowest) sparsity rank.
+ unsigned minLoop = it[0];
+ unsigned minRank = getLoopSparsityRank(minLoop, allTensors, allMaps);
+
+ for (auto candidateLoop : it) {
+ unsigned rank = getLoopSparsityRank(candidateLoop, allTensors, allMaps);
+ if (rank < minRank || (rank == minRank && candidateLoop < minLoop)) {
+ minLoop = candidateLoop;
+ minRank = rank;
+ }
+ }
+ src = minLoop;
+ break;
+ }
}
loopOrder.push_back(src);
- it.pop_back();
+ // Remove the selected loop from the worklist.
+ it.erase(std::find(it.begin(), it.end(), src));
// Update in-degree, and push 0-degree node into worklist.
for (unsigned dst = 0; dst < numLoops; dst++) {
if (itGraph[src][dst] && --inDegree[dst] == 0) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 3636f3f..46378b9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -197,7 +197,7 @@ public:
// Sets the iterate to the specified position.
void seek(ValueRange vals) {
assert(vals.size() == cursorValsCnt);
- std::copy(vals.begin(), vals.end(), cursorValsStorageRef.begin());
+ llvm::copy(vals, cursorValsStorageRef.begin());
// Now that the iterator is re-positioned, the coordinate becomes invalid.
crd = nullptr;
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index 4ec13e1..686f6ee 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -77,6 +77,9 @@ namespace {
struct ReifyExpandShapeOp
: public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
ExpandShapeOp> {
+ using Base =
+ ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
+ ExpandShapeOp>;
LogicalResult
reifyResultShapes(Operation *op, OpBuilder &b,
ReifiedRankedShapedTypeDims &reifyResultShapes) const {
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 110bfdc..204e9bb 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -551,9 +551,7 @@ void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
auto tensorTypes =
- llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
- return llvm::cast<RankedTensorType>(type);
- }));
+ llvm::map_to_vector<4>(inputTypes, llvm::CastTo<RankedTensorType>);
int64_t concatRank = tensorTypes[0].getRank();
// The concatenation dim must be in the range [0, rank).
@@ -2293,9 +2291,9 @@ void ExtractSliceOp::getAsmResultNames(
/// An extract_slice result type can be inferred, when it is not
/// rank-reduced, from the source type and the static representation of
/// offsets, sizes and strides. Special sentinels encode the dynamic case.
-RankedTensorType ExtractSliceOp::inferResultType(
- RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
+RankedTensorType
+ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
+ ArrayRef<int64_t> staticSizes) {
// An extract_slice op may specify only a leading subset of offset/sizes/
// strides in which case we complete with offset=0, sizes from memref type
// and strides=1.
@@ -2307,11 +2305,12 @@ RankedTensorType ExtractSliceOp::inferResultType(
}
// TODO: This uses neither offsets nor strides!
-RankedTensorType ExtractSliceOp::inferResultType(
- RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
+RankedTensorType
+ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
+ ArrayRef<OpFoldResult> sizes) {
SmallVector<int64_t> staticSizes;
std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes);
+
assert(static_cast<int64_t>(staticSizes.size()) ==
sourceTensorType.getRank() &&
"unexpected staticSizes not equal to rank of source");
@@ -2329,11 +2328,10 @@ RankedTensorType ExtractSliceOp::inferResultType(
/// To disambiguate, this function always drops the first 1 sizes occurrences.
RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
- ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
- ArrayRef<int64_t> strides) {
+ ArrayRef<int64_t> sizes) {
// Type inferred in the absence of rank-reducing behavior.
auto inferredType = llvm::cast<RankedTensorType>(
- inferResultType(sourceRankedTensorType, offsets, sizes, strides));
+ inferResultType(sourceRankedTensorType, sizes));
int rankDiff = inferredType.getRank() - desiredResultRank;
if (rankDiff > 0) {
auto shape = inferredType.getShape();
@@ -2352,16 +2350,12 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
- ArrayRef<OpFoldResult> strides) {
- SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
- SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
- dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+ ArrayRef<OpFoldResult> sizes) {
+ SmallVector<int64_t> staticSizes;
+ SmallVector<Value> dynamicSizes;
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
- dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
return ExtractSliceOp::inferCanonicalRankReducedResultType(
- desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
- staticStrides);
+ desiredResultRank, sourceRankedTensorType, staticSizes);
}
/// Build an ExtractSliceOp with mixed static and dynamic entries and custom
@@ -2380,8 +2374,8 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
// Structuring implementation this way avoids duplication between builders.
if (!resultType) {
- resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
- sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
+ resultType = llvm::cast<RankedTensorType>(
+ ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes));
}
result.addAttributes(attrs);
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
@@ -2451,13 +2445,26 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
}
}
+/// Build an ExtractSliceOp with mixed static and dynamic sizes, inferred
+/// result type, offsets set to 0 and strides set to 1.
+void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
+ RankedTensorType resultType, Value source,
+ ArrayRef<OpFoldResult> sizes,
+ ArrayRef<NamedAttribute> attrs) {
+ Attribute zeroIdxAttr = b.getIndexAttr(0);
+ Attribute oneIdxAttr = b.getIndexAttr(1);
+ SmallVector<OpFoldResult> readStrides(sizes.size(), oneIdxAttr);
+ SmallVector<OpFoldResult> readOffsets(sizes.size(), zeroIdxAttr);
+ build(b, result, resultType, source, readOffsets, sizes, readStrides, attrs);
+}
+
/// Verifier for ExtractSliceOp.
LogicalResult ExtractSliceOp::verify() {
RankedTensorType sourceType = getSourceType();
// Verify result type against inferred type.
- RankedTensorType expectedType = ExtractSliceOp::inferResultType(
- sourceType, getMixedOffsets(), getMixedSizes(), getMixedStrides());
+ RankedTensorType expectedType =
+ ExtractSliceOp::inferResultType(sourceType, getMixedSizes());
SliceVerificationResult result = isRankReducedType(expectedType, getType());
if (result != SliceVerificationResult::Success)
return produceSliceErrorMsg(result, *this, expectedType);
@@ -2697,8 +2704,7 @@ struct SliceReturnTypeCanonicalizer {
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
return ExtractSliceOp::inferCanonicalRankReducedResultType(
- op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
- mixedStrides);
+ op.getType().getRank(), op.getSourceType(), mixedSizes);
}
};
@@ -2839,8 +2845,8 @@ static SliceVerificationResult verifyInsertSliceOp(
ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
// insert_slice is the inverse of extract_slice, use the same type
// inference.
- RankedTensorType expected = ExtractSliceOp::inferResultType(
- dstType, staticOffsets, staticSizes, staticStrides);
+ RankedTensorType expected =
+ ExtractSliceOp::inferResultType(dstType, staticSizes);
if (expectedType)
*expectedType = expected;
return isRankReducedType(expected, srcType);
@@ -2968,7 +2974,7 @@ public:
// Create the new op in canonical form.
auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
- mixedOffsets, mixedSizes, mixedStrides);
+ mixedSizes);
Value toInsert = insertSliceOp.getSource();
if (sourceType != insertSliceOp.getSourceType()) {
OpBuilder::InsertionGuard g(rewriter);
@@ -3896,6 +3902,18 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
}
+// Build an InsertSliceOp with mixed static and dynamic sizes, offsets set
+// to 0, strides set to 1 and inferred result type.
+void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
+ Value dest, ArrayRef<OpFoldResult> sizes,
+ ArrayRef<NamedAttribute> attrs) {
+ Attribute zeroIdxAttr = b.getIndexAttr(0);
+ Attribute oneIdxAttr = b.getIndexAttr(1);
+ SmallVector<OpFoldResult> writeStrides(sizes.size(), oneIdxAttr);
+ SmallVector<OpFoldResult> writeOffsets(sizes.size(), zeroIdxAttr);
+ build(b, result, source, dest, writeOffsets, sizes, writeStrides, attrs);
+}
+
LogicalResult ParallelInsertSliceOp::verify() {
if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
return this->emitError("expected InParallelOpInterface parent, got:")
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index c607ece..310e725 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1132,35 +1132,22 @@ struct ConcatOpInterface
// Extract the dimension for the concat op
uint64_t concatDim = concatOp.getDim();
- bool dynamicConcatDim = false;
SmallVector<OpFoldResult> offsets(tensorType.getRank(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(tensorType.getRank(),
rewriter.getIndexAttr(1));
- SmallVector<OpFoldResult> sizes;
-
- for (const auto &[dimIdx, dimSize] :
- llvm::enumerate(tensorType.getShape())) {
- if (dimSize == ShapedType::kDynamic) {
- auto dimOp = memref::DimOp::create(rewriter, loc, dstBuffer, dimIdx);
- sizes.push_back(dimOp.getResult());
- if (dimIdx == concatDim)
- dynamicConcatDim = true;
- } else {
- sizes.push_back(rewriter.getIndexAttr(dimSize));
- }
- }
-
- int64_t concatDimOffset = 0;
- std::optional<Value> dynamicOffset;
- std::optional<Value> dynamicSize;
- if (dynamicConcatDim) {
- // One or more operands have dynamic size, so we must accumulate the
- // offset with arith ops.
- dynamicOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
- }
+ SmallVector<OpFoldResult> sizes =
+ memref::getMixedSizes(rewriter, loc, dstBuffer);
+
+ AffineExpr s0, s1;
+ bindSymbols(rewriter.getContext(), s0, s1);
+ auto sum = [&](OpFoldResult v1, OpFoldResult v2) {
+ return affine::makeComposedFoldedAffineApply(rewriter, loc, s0 + s1,
+ {v1, v2});
+ };
+ OpFoldResult concatDimOffset = rewriter.getIndexAttr(0);
for (auto operand : concatOp.getInputs()) {
// Get the buffer for the operand.
FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options, state);
@@ -1171,18 +1158,10 @@ struct ConcatOpInterface
// so the offset on that axis must accumulate through the loop, and the
// size must change to the size of the current operand.
auto operandTensorType = cast<RankedTensorType>(operand.getType());
- int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);
-
- if (dynamicConcatDim) {
- offsets[concatDim] = dynamicOffset.value();
- dynamicSize =
- memref::DimOp::create(rewriter, loc, *srcBuffer, concatDim)
- .getResult();
- sizes[concatDim] = dynamicSize.value();
- } else {
- sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
- offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
- }
+ offsets[concatDim] = concatDimOffset;
+ OpFoldResult concatDimSize =
+ memref::getMixedSize(rewriter, loc, *srcBuffer, concatDim);
+ sizes[concatDim] = concatDimSize;
// Create a subview of the destination buffer.
auto dstMemrefType = cast<MemRefType>(memrefType);
@@ -1197,12 +1176,7 @@ struct ConcatOpInterface
if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
return failure();
- if (dynamicConcatDim) {
- dynamicOffset = arith::AddIOp::create(
- rewriter, loc, dynamicOffset.value(), dynamicSize.value());
- } else {
- concatDimOffset += operandConcatDimSize;
- }
+ concatDimOffset = sum(concatDimOffset, concatDimSize);
}
replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index 7ec61c7..a53af98 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -37,8 +37,7 @@ struct FoldExpandOfRankReducingExtract
// supported. Moreover, only simple cases where the resulting ExtractSliceOp
// has no rank-reduction anymore are supported at the moment.
RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
- srcType, extractSliceOp.getStaticOffsets(),
- extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
+ srcType, extractSliceOp.getStaticSizes());
if (nonReducingExtractType != resultType)
return failure();
@@ -533,8 +532,8 @@ LogicalResult mlir::tensor::getCollapsedExtractSliceInfo(
getMixedSizes(b, loc, sliceOp.getSource());
// Helper variables and function for accumulating the size values.
- AffineExpr d0, d1, d2;
- bindDims(b.getContext(), d0, d1, d2);
+ AffineExpr d0, d1;
+ bindDims(b.getContext(), d0, d1);
// Multiply two integers.
auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
auto mulMap = AffineMap::get(2, 0, {d0 * d1});
diff --git a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
index 753cb95..d35f458 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
@@ -155,13 +155,15 @@ struct ExtractSliceOpInterface
RankedTensorType sourceType = extractSliceOp.getSource().getType();
// For each dimension, assert that:
- // 0 <= offset < dim_size
- // 0 <= offset + (size - 1) * stride < dim_size
+ // For empty slices (size == 0) : 0 <= offset <= dim_size
+ // For non-empty slices (size > 0): 0 <= offset < dim_size
+ // 0 <= offset + (size - 1) * stride <
+ // dim_size
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
Value one = arith::ConstantIndexOp::create(builder, loc, 1);
for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
- // Reset insertion point to before the operation for each dimension
+
builder.setInsertionPoint(extractSliceOp);
Value offset = getValueOrCreateConstantIndexOp(
@@ -170,46 +172,63 @@ struct ExtractSliceOpInterface
builder, loc, extractSliceOp.getMixedSizes()[i]);
Value stride = getValueOrCreateConstantIndexOp(
builder, loc, extractSliceOp.getMixedStrides()[i]);
-
- // Verify that offset is in-bounds.
Value dimSize = builder.createOrFold<tensor::DimOp>(
loc, extractSliceOp.getSource(), i);
- Value offsetInBounds =
- generateInBoundsCheck(builder, loc, offset, zero, dimSize);
- cf::AssertOp::create(builder, loc, offsetInBounds,
+
+ // Verify that offset is in-bounds (conditional on slice size).
+ Value sizeIsZero = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::eq, size, zero);
+ auto offsetCheckIf = scf::IfOp::create(
+ builder, loc, sizeIsZero,
+ [&](OpBuilder &b, Location loc) {
+ // For empty slices, offset can be at the boundary: 0 <= offset <=
+ // dimSize.
+ Value offsetGEZero = arith::CmpIOp::create(
+ b, loc, arith::CmpIPredicate::sge, offset, zero);
+ Value offsetLEDimSize = arith::CmpIOp::create(
+ b, loc, arith::CmpIPredicate::sle, offset, dimSize);
+ Value emptyOffsetValid =
+ arith::AndIOp::create(b, loc, offsetGEZero, offsetLEDimSize);
+ scf::YieldOp::create(b, loc, emptyOffsetValid);
+ },
+ [&](OpBuilder &b, Location loc) {
+ // For non-empty slices, offset must be a valid index: 0 <= offset <
+ // dimSize.
+ Value offsetInBounds =
+ generateInBoundsCheck(b, loc, offset, zero, dimSize);
+ scf::YieldOp::create(b, loc, offsetInBounds);
+ });
+
+ Value offsetCondition = offsetCheckIf.getResult(0);
+ cf::AssertOp::create(builder, loc, offsetCondition,
generateErrorMessage(op, "offset " +
std::to_string(i) +
" is out-of-bounds"));
- // Only verify if size > 0
+ // Verify that the slice endpoint is in-bounds (only for non-empty
+ // slices).
Value sizeIsNonZero = arith::CmpIOp::create(
builder, loc, arith::CmpIPredicate::sgt, size, zero);
+ auto ifOp = scf::IfOp::create(
+ builder, loc, sizeIsNonZero,
+ [&](OpBuilder &b, Location loc) {
+ // Verify that slice does not run out-of-bounds.
+ Value sizeMinusOne = arith::SubIOp::create(b, loc, size, one);
+ Value sizeMinusOneTimesStride =
+ arith::MulIOp::create(b, loc, sizeMinusOne, stride);
+ Value lastPos =
+ arith::AddIOp::create(b, loc, offset, sizeMinusOneTimesStride);
+ Value lastPosInBounds =
+ generateInBoundsCheck(b, loc, lastPos, zero, dimSize);
+ scf::YieldOp::create(b, loc, lastPosInBounds);
+ },
+ [&](OpBuilder &b, Location loc) {
+ Value trueVal =
+ arith::ConstantOp::create(b, loc, b.getBoolAttr(true));
+ scf::YieldOp::create(b, loc, trueVal);
+ });
- auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(),
- sizeIsNonZero, /*withElseRegion=*/true);
-
- // Populate the "then" region (for size > 0).
- builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-
- // Verify that slice does not run out-of-bounds.
- Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
- Value sizeMinusOneTimesStride =
- arith::MulIOp::create(builder, loc, sizeMinusOne, stride);
- Value lastPos =
- arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
- Value lastPosInBounds =
- generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
- scf::YieldOp::create(builder, loc, lastPosInBounds);
-
- // Populate the "else" region (for size == 0).
- builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- Value trueVal =
- arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true));
- scf::YieldOp::create(builder, loc, trueVal);
-
- builder.setInsertionPointAfter(ifOp);
Value finalCondition = ifOp.getResult(0);
-
cf::AssertOp::create(
builder, loc, finalCondition,
generateErrorMessage(
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 293c6af..c420a4c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
@@ -539,7 +540,7 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
- inputEType = quantType.getStorageType();
+ inputEType = getStorageElementTypeFromQuantized(quantType);
}
Attribute newMinValAttr, newMaxValAttr;
@@ -1485,7 +1486,24 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
return {};
}
+static bool
+mayRequireBroadcast(ValueTypeRange<mlir::OperandRange> operandTypes) {
+ const auto isDynamic = [](Type ty) {
+ const auto shapedTy = llvm::dyn_cast<ShapedType>(ty);
+ return !shapedTy || !shapedTy.hasStaticShape();
+ };
+
+ return llvm::any_of(operandTypes, isDynamic) ||
+ failed(verifyCompatibleShapes(operandTypes));
+}
+
OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
+ // Select allows operand shapes to be broadcast to the output shape. For
+ // now, don't support folding when we cannot prove no broadcasting is
+ // involved.
+ if (mayRequireBroadcast(getOperandTypes()))
+ return {};
+
if (getOnTrue() == getOnFalse())
return getOnTrue();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index bf3810f..1c175f9ab 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -563,7 +563,7 @@ static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
static Type getStorageElementTypeOrSelf(Type type) {
auto srcType = getElementTypeOrSelf(type);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
- srcType = quantType.getStorageType();
+ srcType = getStorageElementTypeFromQuantized(quantType);
return srcType;
}
@@ -631,16 +631,16 @@ static LogicalResult verifyConvOp(T op) {
bool resultIsFloat = llvm::isa<FloatType>(resultEType);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
- inputEType = quantType.getStorageType();
+ inputEType = getStorageElementTypeFromQuantized(quantType);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
- weightEType = quantType.getStorageType();
+ weightEType = getStorageElementTypeFromQuantized(quantType);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
- biasEType = quantType.getStorageType();
+ biasEType = getStorageElementTypeFromQuantized(quantType);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
- resultEType = quantType.getStorageType();
+ resultEType = getStorageElementTypeFromQuantized(quantType);
if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
// for now, only enforce bias element type == result element type for
@@ -709,7 +709,7 @@ LogicalResult tosa::ConstOp::verify() {
if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
outputType.getElementType())) {
- if (result.getStorageType() == attrType.getElementType())
+ if (getStorageElementTypeFromQuantized(result) == attrType.getElementType())
return success();
}
@@ -727,7 +727,7 @@ static LogicalResult verifyConvOpModes(T op) {
llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
- inputEType = quantType.getStorageType();
+ inputEType = getStorageElementTypeFromQuantized(quantType);
auto accType = op.getAccType();
if (inputEType.isInteger(8) && !accType.isInteger(32))
@@ -752,7 +752,7 @@ static LogicalResult verifyConvOpModes(T op) {
llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
- resultEType = quantType.getStorageType();
+ resultEType = getStorageElementTypeFromQuantized(quantType);
return success();
}
@@ -1179,13 +1179,13 @@ LogicalResult tosa::ClampOp::verify() {
llvm::cast<ShapedType>(getInput().getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
- inputETy = quantType.getStorageType();
+ inputETy = getStorageElementTypeFromQuantized(quantType);
}
mlir::Type outputETy =
llvm::cast<ShapedType>(getOutput().getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
- outputETy = quantType.getStorageType();
+ outputETy = getStorageElementTypeFromQuantized(quantType);
}
if (inputETy != outputETy)
return emitOpError("input/output element types are incompatible.");
@@ -1761,6 +1761,11 @@ LogicalResult tosa::ConcatOp::verify() {
}
}
+ const ShapeAdaptor outputShape(outType);
+ if (outputShape.hasRank() && outputShape.getRank() != firstInputRank)
+ return emitOpError("expect output rank to match inputs rank, got ")
+ << outputShape.getRank() << " vs " << firstInputRank;
+
// ERROR_IF(axis_sum != shape[axis]);
int64_t axisSum = 0;
for (const auto &input : inputList) {
@@ -1772,7 +1777,7 @@ LogicalResult tosa::ConcatOp::verify() {
}
axisSum += inputShape.getDimSize(axis);
}
- const ShapeAdaptor outputShape(outType);
+
if (axisSum >= 0 && outputShape.hasRank() &&
!outputShape.isDynamicDim(axis) &&
axisSum != outputShape.getDimSize(axis))
@@ -2628,7 +2633,7 @@ static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
if (!zpElemType.isInteger(8) && zp != 0) {
// convert operand to lower case for error message
std::string lower = operand;
- std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
+ llvm::transform(lower, lower.begin(), ::tolower);
return op.emitOpError()
<< lower << " zero point must be zero for non-int8 integer types";
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 41b338d..091b481 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRTosaTransforms
TosaAttachTarget.cpp
+ TosaArithConstantToConst.cpp
TosaConvertIntegerTypeToSignless.cpp
TosaDecomposeTransposeConv.cpp
TosaDecomposeDepthwise.cpp
@@ -12,6 +13,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
TosaTypeConverters.cpp
TosaProfileCompliance.cpp
TosaValidation.cpp
+ TosaNarrowI64ToI32.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms
@@ -21,7 +23,9 @@ add_mlir_dialect_library(MLIRTosaTransforms
LINK_LIBS PUBLIC
MLIRFuncDialect
+ MLIRFuncTransformOps
MLIRPass
MLIRTosaDialect
MLIRTransformUtils
+ MLIRFuncTransforms
)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp
new file mode 100644
index 0000000..73e1e2b
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp
@@ -0,0 +1,111 @@
+//===- TosaArithConstantToConst.cpp ---------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass that converts tensor-valued arith.constant ops
+// into tosa.const so that TOSA pipelines operate on a uniform constant form.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSAARITHCONSTANTTOTOSACONSTPASS
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+// NOTE: TOSA pipelines already lower their constants through shared Arith
+// folding passes, so tensor literals often come back as `arith.constant` even
+// after the IR is otherwise TOSA-only. Keep this normalization with the rest of
+// the TOSA transforms so any client can re-establish a canonical `tosa.const`
+// representation without needing a full Arith->TOSA conversion library.
+
+/// Returns true when `elementType` is natively representable by tosa.const.
+static bool isSupportedElementType(Type elementType) {
+ if (isa<FloatType>(elementType))
+ return true;
+
+ if (auto intType = dyn_cast<IntegerType>(elementType))
+ return intType.isSignless() || intType.isUnsigned();
+
+ if (isa<quant::QuantizedType>(elementType))
+ return true;
+
+ if (isa<tosa::mxint8Type>(elementType))
+ return true;
+
+ return false;
+}
+
+class ArithConstantToTosaConst : public OpRewritePattern<arith::ConstantOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arith::ConstantOp constOp,
+ PatternRewriter &rewriter) const override {
+ // TOSA constant verification requires a ranked, statically shaped tensor.
+ auto resultType = dyn_cast<RankedTensorType>(constOp.getResult().getType());
+ if (!resultType || !resultType.hasStaticShape())
+ return failure();
+
+ if (!isSupportedElementType(resultType.getElementType()))
+ return failure();
+
+ Attribute attr = constOp.getValueAttr();
+ auto elementsAttr = dyn_cast<ElementsAttr>(attr);
+ if (!elementsAttr)
+ return failure();
+
+ auto attrType = dyn_cast<RankedTensorType>(elementsAttr.getType());
+ if (!attrType || !attrType.hasStaticShape())
+ return failure();
+ if (attrType != resultType)
+ return failure();
+
+ auto newConst = tosa::ConstOp::create(rewriter, constOp.getLoc(),
+ resultType, elementsAttr);
+ rewriter.replaceOp(constOp, newConst.getResult());
+ return success();
+ }
+};
+
+struct TosaArithConstantToTosaConstPass
+ : public tosa::impl::TosaArithConstantToTosaConstPassBase<
+ TosaArithConstantToTosaConstPass> {
+ using Base::Base;
+
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<arith::ArithDialect, tosa::TosaDialect>();
+ }
+
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ RewritePatternSet patterns(ctx);
+ patterns.add<ArithConstantToTosaConst>(ctx);
+
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index 0bec0da..022476a2 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -33,8 +33,13 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
ShapedType weightType = cast<ShapedType>(weight.getType());
ShapedType resultType = cast<ShapedType>(op.getOutput().getType());
- if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
- resultType.hasStaticShape())) {
+ // Any dimensions other than batchSize cannot be dynamic for input/output
+ for (unsigned int i = 1; i < 4; ++i) {
+ if (inputType.isDynamicDim(i) || resultType.isDynamicDim(i))
+ return failure();
+ }
+
+ if (!weightType.hasStaticShape()) {
return failure();
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index dc5c51b..8b23fd1 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -49,8 +49,13 @@ public:
if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
return failure();
- if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
- !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+ // Any dimensions other than batchSize cannot be dynamic for input/output
+ for (unsigned int i = 1; i < 4; ++i) {
+ if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i))
+ return failure();
+ }
+
+ if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
return failure();
int64_t kernelHeight = weightTy.getDimSize(1);
@@ -113,8 +118,13 @@ public:
if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
return rewriter.notifyMatchFailure(op, "non-one stride found.");
- if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
- !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+ // Any dimensions other than batchSize cannot be dynamic for input/output
+ for (unsigned int i = 1; i < 4; ++i) {
+ if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i))
+ return failure();
+ }
+
+ if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
return failure();
int64_t batch = inputTy.getDimSize(0);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp
new file mode 100644
index 0000000..ddaf7d8a
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp
@@ -0,0 +1,310 @@
+//===- TosaNarrowI64ToI32.cpp ---------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass narrows TOSA operations with 64-bit integer tensor types to
+// 32-bit integer tensor types. This can be useful for backends that do not
+// support the EXT-INT64 extension of TOSA. The pass has two options:
+//
+// - aggressive-rewrite - If enabled, all TOSA operations are rewritten,
+// regardless or whether the narrowing is safe. This option may lead to
+// data loss if not used carefully.
+// - convert-function-boundaries - If enabled, the pass will convert function
+// I/O types as well. Otherwise casts will be inserted at the I/O
+// boundaries.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+LogicalResult convertGenericOp(Operation *op, ValueRange operands,
+ ConversionPatternRewriter &rewriter,
+ const TypeConverter *typeConverter) {
+ // Convert types of results
+ SmallVector<Type, 4> newResults;
+ if (failed(typeConverter->convertTypes(op->getResultTypes(), newResults)))
+ return failure();
+
+ // Create a new operation state
+ OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
+ newResults, {}, op->getSuccessors());
+
+ for (const NamedAttribute &namedAttribute : op->getAttrs()) {
+ const Attribute attribute = namedAttribute.getValue();
+
+ // Convert integer attribute type
+ if (const auto intAttr = dyn_cast<IntegerAttr>(attribute)) {
+ const std::optional<Attribute> convertedAttribute =
+ typeConverter->convertTypeAttribute(intAttr.getType(), attribute);
+ state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
+ continue;
+ }
+
+ if (const auto typeAttr = dyn_cast<TypeAttr>(attribute)) {
+ Type type = typeAttr.getValue();
+ const std::optional<Attribute> convertedAttribute =
+ typeConverter->convertTypeAttribute(type, attribute);
+ if (!convertedAttribute)
+ return rewriter.notifyMatchFailure(op,
+ "Failed to convert type attribute.");
+ state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
+ continue;
+ }
+
+ if (const auto denseElementsAttr = dyn_cast<DenseElementsAttr>(attribute)) {
+ const Type type = denseElementsAttr.getType();
+ const std::optional<Attribute> convertedAttribute =
+ typeConverter->convertTypeAttribute(type, denseElementsAttr);
+ if (!convertedAttribute)
+ return rewriter.notifyMatchFailure(
+ op, "Failed to convert dense elements attribute.");
+ state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
+ continue;
+ }
+
+ state.addAttribute(namedAttribute.getName(), attribute);
+ }
+
+ for (Region &region : op->getRegions()) {
+ Region *newRegion = state.addRegion();
+ rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
+ if (failed(rewriter.convertRegionTypes(newRegion, *typeConverter)))
+ return failure();
+ }
+
+ Operation *newOp = rewriter.create(state);
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+}
+
+// ===========================
+// Aggressive rewrite patterns
+// ===========================
+
+class ConvertGenericOp : public ConversionPattern {
+public:
+ ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context)
+ : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ if (!isa<tosa::TosaOp>(op))
+ return rewriter.notifyMatchFailure(
+ op,
+ "Support for operations other than TOSA has not been implemented.");
+
+ return convertGenericOp(op, operands, rewriter, typeConverter);
+ }
+};
+
+// ===============================
+// Bounds checked rewrite patterns
+// ===============================
+
+class ConvertArgMaxOpWithBoundsChecking
+ : public OpConversionPattern<tosa::ArgMaxOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(tosa::ArgMaxOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ // Output type can be narrowed based on the size of the axis dimension
+ const int32_t axis = op.getAxis();
+ const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
+ if (!inputType || !inputType.isStaticDim(axis))
+ return rewriter.notifyMatchFailure(
+ op, "Requires a static axis dimension for bounds checking.");
+ const int64_t axisDim = inputType.getDimSize(axis);
+ if (axisDim >= std::numeric_limits<int32_t>::max())
+ return rewriter.notifyMatchFailure(
+ op, "Axis dimension is too large to narrow safely.");
+
+ const Type resultType = op.getOutput().getType();
+ const Type newResultType = typeConverter->convertType(resultType);
+ rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(op, newResultType,
+ adaptor.getInput(), axis);
+ return success();
+ }
+};
+
+class ConvertCastOpWithBoundsChecking
+ : public OpConversionPattern<tosa::CastOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(tosa::CastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
+ const auto resultType = dyn_cast<ShapedType>(op.getResult().getType());
+ if (!inputType || !resultType)
+ return failure();
+
+ const auto elementInputIntType =
+ dyn_cast<IntegerType>(inputType.getElementType());
+ const auto elementResultIntType =
+ dyn_cast<IntegerType>(resultType.getElementType());
+ if (elementInputIntType && elementResultIntType &&
+ elementInputIntType.getWidth() > elementResultIntType.getWidth())
+ return rewriter.notifyMatchFailure(
+ op, "Narrowing cast may lead to data loss.");
+
+ rewriter.replaceOpWithNewOp<tosa::CastOp>(
+ op, typeConverter->convertType(resultType), adaptor.getInput());
+ return success();
+ }
+};
+
+template <typename OpTy>
+class ConvertTypedOp : public OpConversionPattern<OpTy> {
+ using OpConversionPattern<OpTy>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ return convertGenericOp(op, adaptor.getOperands(), rewriter,
+ this->getTypeConverter());
+ }
+};
+
+struct TosaNarrowI64ToI32
+ : public tosa::impl::TosaNarrowI64ToI32PassBase<TosaNarrowI64ToI32> {
+public:
+ explicit TosaNarrowI64ToI32() = default;
+ explicit TosaNarrowI64ToI32(const TosaNarrowI64ToI32PassOptions &options)
+ : TosaNarrowI64ToI32() {
+ this->aggressiveRewrite = options.aggressiveRewrite;
+ this->convertFunctionBoundaries = options.convertFunctionBoundaries;
+ }
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+
+ TypeConverter typeConverter;
+ typeConverter.addConversion([](Type type) -> Type { return type; });
+ typeConverter.addConversion([](IntegerType type) -> Type {
+ if (!type.isInteger(64))
+ return type;
+ return IntegerType::get(type.getContext(), 32);
+ });
+ typeConverter.addConversion(
+ [&typeConverter](RankedTensorType type) -> Type {
+ const Type elementType = type.getElementType();
+ if (!elementType.isInteger(64))
+ return type;
+ return RankedTensorType::get(type.getShape(),
+ typeConverter.convertType(elementType));
+ });
+
+ const auto materializeCast = [](OpBuilder &builder, Type resultType,
+ ValueRange inputs, Location loc) -> Value {
+ if (inputs.size() != 1)
+ return Value();
+ return tosa::CastOp::create(builder, loc, resultType, inputs.front());
+ };
+ typeConverter.addSourceMaterialization(materializeCast);
+ typeConverter.addTargetMaterialization(materializeCast);
+
+ typeConverter.addTypeAttributeConversion(
+ [](IntegerType type, IntegerAttr attribute) -> Attribute {
+ const APInt value = attribute.getValue().truncSSat(32);
+ return IntegerAttr::get(IntegerType::get(type.getContext(), 32),
+ value);
+ });
+ typeConverter.addTypeAttributeConversion(
+ [&typeConverter](ShapedType type,
+ DenseIntElementsAttr attr) -> Attribute {
+ const ShapedType newType =
+ cast<ShapedType>(typeConverter.convertType(type));
+ const auto oldElementType = cast<IntegerType>(type.getElementType());
+ const auto newElementType =
+ cast<IntegerType>(newType.getElementType());
+ if (oldElementType.getWidth() == newElementType.getWidth())
+ return attr;
+
+ DenseElementsAttr mapped =
+ attr.mapValues(newElementType, [&](const APInt &v) {
+ return v.truncSSat(newElementType.getWidth());
+ });
+ return mapped;
+ });
+
+ ConversionTarget target(*context);
+ target.addDynamicallyLegalDialect<tosa::TosaDialect>(
+ [&typeConverter](Operation *op) {
+ return typeConverter.isLegal(op->getResultTypes()) &&
+ typeConverter.isLegal(op->getOperandTypes());
+ });
+ if (convertFunctionBoundaries) {
+ target.addDynamicallyLegalOp<func::FuncOp>(
+ [&typeConverter](func::FuncOp op) {
+ return typeConverter.isSignatureLegal(op.getFunctionType()) &&
+ typeConverter.isLegal(&op.getBody());
+ });
+ target.addDynamicallyLegalOp<func::ReturnOp>([](func::ReturnOp op) {
+ const FunctionType funcType =
+ op->getParentOfType<func::FuncOp>().getFunctionType();
+ return llvm::equal(op.getOperandTypes(), funcType.getResults());
+ });
+ } else {
+ target.addDynamicallyLegalOp<func::FuncOp>(
+ [](func::FuncOp op) { return true; });
+ target.addDynamicallyLegalOp<func::ReturnOp>(
+ [](func::ReturnOp op) { return true; });
+ }
+
+ RewritePatternSet patterns(context);
+ if (convertFunctionBoundaries) {
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
+ patterns, typeConverter);
+ populateReturnOpTypeConversionPattern(patterns, typeConverter);
+ }
+ if (aggressiveRewrite) {
+ patterns.add<ConvertGenericOp>(typeConverter, context);
+ } else {
+ // Tensor
+ patterns.add<ConvertArgMaxOpWithBoundsChecking>(typeConverter, context);
+ // Data layout
+ patterns.add<ConvertTypedOp<tosa::ConcatOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::PadOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::ReshapeOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::ReverseOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::SliceOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::TileOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::TransposeOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::IdentityOp>>(typeConverter, context);
+ // Type conversion
+ patterns.add<ConvertCastOpWithBoundsChecking>(typeConverter, context);
+ // Controlflow
+ patterns.add<ConvertTypedOp<tosa::IfOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::WhileOp>>(typeConverter, context);
+ }
+
+ if (failed(
+ applyFullConversion(getOperation(), target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index ac5d620..36e8940 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -70,6 +70,8 @@ namespace {
// If lower=[a], higher=[a, a], [a] reshaped into [1, a].
// If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
// If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
+// If lower=[c], higher=[?, ?, c], [c] reshaped into [1, 1, c].
+// If lower=[?], higher=[?, ?, ?], [?] reshaped into [1, 1, ?].
LogicalResult
computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
ArrayRef<int64_t> lowerRankShape,
@@ -87,7 +89,12 @@ computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
higherRankDim = higherRankShape[i + rankDiff];
lowerRankDim = lowerRankShape[i];
- if (lowerRankDim != 1 && higherRankDim != 1 &&
+ auto isStaticDimAndNotEqualToOne = [](int64_t dim) {
+ return dim != 1 && dim != ShapedType::kDynamic;
+ };
+
+ if (isStaticDimAndNotEqualToOne(lowerRankDim) &&
+ isStaticDimAndNotEqualToOne(higherRankDim) &&
lowerRankDim != higherRankDim)
return failure();
@@ -216,22 +223,23 @@ mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) {
bool mlir::tosa::hasUniqueConstantScatterIndices(
ShapedType indicesType, DenseIntElementsAttr indicesAttr) {
- llvm::ArrayRef<int64_t> const indicesShape = indicesType.getShape();
+ const llvm::ArrayRef<int64_t> indicesShape = indicesType.getShape();
const unsigned int indicesRank = indicesShape.size();
const unsigned int lastDimSize = indicesShape[indicesRank - 1];
// check each batch of indices from the flat indicesAttr values
// for duplicates
- auto const indicesValues = indicesAttr.getValues<int32_t>();
+ auto const indicesValues = indicesAttr.getValues<APInt>();
assert(
(indicesValues.size() % lastDimSize == 0) &&
"Constant indices data length should be a multiple of indicesShape[-1]");
- std::vector<uint64_t> indices(lastDimSize);
+ std::vector<APInt> indices(lastDimSize);
for (auto beg = indicesValues.begin(); beg < indicesValues.end();
beg += lastDimSize) {
std::copy(beg, beg + lastDimSize, indices.begin());
- std::sort(indices.begin(), indices.end());
+ std::sort(indices.begin(), indices.end(),
+ [](const APInt &a, const APInt &b) { return a.slt(b); });
if (std::adjacent_find(indices.begin(), indices.end()) != indices.end()) {
// found duplicate values in indices in batch
return false;
diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
index 02c86a0..c55b13d 100644
--- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
@@ -395,3 +395,16 @@ mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype,
maxAttr, quantBits, filterQuantDim,
isSigned, narrowRange));
}
+
+Type mlir::tosa::getStorageElementTypeFromQuantized(
+ quant::QuantizedType quantType) {
+ auto quantEty = quantType.getStorageType();
+ // StorageType doesn't capture the sign information
+ // Explicitly create unsigned type if needed
+ if (!quantType.isSigned()) {
+ quantEty = IntegerType::get(quantEty.getContext(),
+ quantEty.getIntOrFloatBitWidth(),
+ IntegerType::Unsigned);
+ }
+ return quantEty;
+}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 062606e..86233b0 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -2062,6 +2062,10 @@ transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
DiagnosedSilenceableFailure result = applySequenceBlock(
callee.getBody().front(), getFailurePropagationMode(), state, results);
+
+ if (!result.succeeded())
+ return result;
+
mappings.clear();
detail::prepareValueMappings(
mappings, callee.getBody().front().getTerminator()->getOperands(), state);
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index 4f4620a..24b0487 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -47,8 +47,6 @@ static bool happensBefore(Operation *a, Operation *b) {
// TransformState
//===----------------------------------------------------------------------===//
-constexpr const Value transform::TransformState::kTopLevelValue;
-
transform::TransformState::TransformState(
Region *region, Operation *payloadRoot,
const RaggedArray<MappedValue> &extraMappings,
@@ -1497,8 +1495,7 @@ transform::detail::checkApplyToOne(Operation *transformOp,
template <typename T>
static SmallVector<T> castVector(ArrayRef<transform::MappedValue> range) {
- return llvm::to_vector(llvm::map_range(
- range, [](transform::MappedValue value) { return cast<T>(value); }));
+ return llvm::map_to_vector(range, llvm::CastTo<T>);
}
void transform::detail::setApplyToOneResults(
diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
index f727118..2bd6205 100644
--- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
@@ -156,7 +156,7 @@ DiagnosedSilenceableFailure
transform::tune::AlternativesOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
- std::optional<size_t> selectedRegionIdx;
+ std::optional<int64_t> selectedRegionIdx;
if (auto selectedRegionAttr = getSelectedRegionAttr())
selectedRegionIdx = selectedRegionAttr->getSExtValue();
@@ -232,7 +232,7 @@ LogicalResult transform::tune::AlternativesOp::verify() {
}
if (auto selectedRegionAttr = getSelectedRegionAttr()) {
- size_t regionIdx = selectedRegionAttr->getSExtValue();
+ int64_t regionIdx = selectedRegionAttr->getSExtValue();
if (regionIdx < 0 || regionIdx >= getNumRegions())
return emitOpError()
<< "'selected_region' attribute specifies region at index "
diff --git a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
index a26edac..2986f4c 100644
--- a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
+++ b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
@@ -106,14 +106,12 @@ ScalableValueBoundsConstraintSet::computeScalableBound(
AffineMap bound = [&] {
if (boundType == BoundType::EQ && !invalidBound(lowerBound) &&
- lowerBound[0] == upperBound[0]) {
+ lowerBound[0] == upperBound[0])
return lowerBound[0];
- }
- if (boundType == BoundType::LB && !invalidBound(lowerBound)) {
+ if (boundType == BoundType::LB && !invalidBound(lowerBound))
return lowerBound[0];
- } else if (boundType == BoundType::UB && !invalidBound(upperBound)) {
+ if (boundType == BoundType::UB && !invalidBound(upperBound))
return upperBound[0];
- }
return AffineMap{};
}();
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ae3423c..2789f63 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -717,7 +717,15 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
case arith::AtomicRMWKind::ori:
return vector::ReductionOp::create(builder, vector.getLoc(),
CombiningKind::OR, vector);
- // TODO: Add remaining reduction operations.
+ case arith::AtomicRMWKind::minnumf:
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::MINNUMF, vector);
+ case arith::AtomicRMWKind::maxnumf:
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::MAXNUMF, vector);
+ case arith::AtomicRMWKind::xori:
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::XOR, vector);
default:
(void)emitOptionalError(loc, "Reduction operation type not supported");
break;
@@ -6058,19 +6066,21 @@ LogicalResult ScatterOp::verify() {
VectorType indVType = getIndexVectorType();
VectorType maskVType = getMaskVectorType();
VectorType valueVType = getVectorType();
- MemRefType memType = getMemRefType();
+ ShapedType baseType = getBaseType();
- if (valueVType.getElementType() != memType.getElementType())
+ if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
+ return emitOpError("requires base to be a memref or ranked tensor type");
+
+ if (valueVType.getElementType() != baseType.getElementType())
return emitOpError("base and valueToStore element type should match");
- if (llvm::size(getOffsets()) != memType.getRank())
- return emitOpError("requires ") << memType.getRank() << " indices";
+ if (llvm::size(getOffsets()) != baseType.getRank())
+ return emitOpError("requires ") << baseType.getRank() << " indices";
if (valueVType.getShape() != indVType.getShape())
return emitOpError("expected valueToStore dim to match indices dim");
if (valueVType.getShape() != maskVType.getShape())
return emitOpError("expected valueToStore dim to match mask dim");
return success();
}
-
namespace {
class ScatterFolder final : public OpRewritePattern<ScatterOp> {
public:
@@ -6233,6 +6243,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges.front());
}
+std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
+ return llvm::to_vector<4>(getResultVectorType().getShape());
+}
+
LogicalResult ShapeCastOp::verify() {
VectorType sourceType = getSourceVectorType();
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 546099c..352f477 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
using namespace mlir;
using namespace mlir::bufferization;
@@ -126,6 +127,54 @@ struct TransferWriteOpInterface
}
};
+/// Bufferization of vector.scatter. Replaced with a new vector.scatter that
+/// operates on a memref.
+struct ScatterOpInterface
+ : public BufferizableOpInterface::ExternalModel<ScatterOpInterface,
+ vector::ScatterOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ assert(isa<RankedTensorType>(opOperand.get().getType()) &&
+ "only tensor types expected");
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ assert(isa<RankedTensorType>(opOperand.get().getType()) &&
+ "only tensor types expected");
+ return true;
+ }
+
+ AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ assert(isa<RankedTensorType>(opOperand.get().getType()) &&
+ "only tensor types expected");
+ auto scatterOp = cast<vector::ScatterOp>(op);
+ if (&opOperand != &scatterOp.getBaseMutable())
+ return {};
+ return {{scatterOp.getResult(), BufferRelation::Equivalent}};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
+ auto scatterOp = cast<vector::ScatterOp>(op);
+ assert(isa<TensorType>(scatterOp.getBaseType()) &&
+ "only tensor types expected");
+ FailureOr<Value> buffer =
+ getBuffer(rewriter, scatterOp.getBase(), options, state);
+ if (failed(buffer))
+ return failure();
+ vector::ScatterOp::create(rewriter, scatterOp.getLoc(),
+ /*resultType=*/nullptr, *buffer,
+ scatterOp.getOffsets(), scatterOp.getIndices(),
+ scatterOp.getMask(), scatterOp.getValueToStore());
+ replaceOpWithBufferizedValues(rewriter, op, *buffer);
+ return success();
+ }
+};
+
/// Bufferization of vector.gather. Replaced with a new vector.gather that
/// operates on a memref.
struct GatherOpInterface
@@ -335,5 +384,6 @@ void mlir::vector::registerBufferizableOpInterfaceExternalModels(
GatherOp::attachInterface<GatherOpInterface>(*ctx);
MaskOp::attachInterface<MaskOpInterface>(*ctx);
YieldOp::attachInterface<YieldOpInterface>(*ctx);
+ ScatterOp::attachInterface<ScatterOpInterface>(*ctx);
});
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
index 258f2cb..1af5523 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
@@ -111,7 +111,7 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
if (!isValidKind(isInt, scanOp.getKind()))
return failure();
- VectorType resType = VectorType::get(destShape, elType);
+ VectorType resType = destType;
Value result = arith::ConstantOp::create(rewriter, loc, resType,
rewriter.getZeroAttr(resType));
int64_t reductionDim = scanOp.getReductionDim();
@@ -121,8 +121,18 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
int64_t initialValueRank = initialValueType.getRank();
SmallVector<int64_t> reductionShape(destShape);
+ SmallVector<bool> reductionScalableDims(destType.getScalableDims());
+
+ if (reductionScalableDims[reductionDim])
+ return rewriter.notifyMatchFailure(
+ scanOp, "Trying to reduce scalable dimension - not yet supported!");
+
+ // The reduction dimension, after reducing, becomes 1. It's a fixed-width
+ // dimension - no need to touch the scalability flag.
reductionShape[reductionDim] = 1;
- VectorType reductionType = VectorType::get(reductionShape, elType);
+ VectorType reductionType =
+ VectorType::get(reductionShape, elType, reductionScalableDims);
+
SmallVector<int64_t> offsets(destRank, 0);
SmallVector<int64_t> strides(destRank, 1);
SmallVector<int64_t> sizes(destShape);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 726da1e..ad16b80 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -453,6 +453,8 @@ struct ReorderCastOpsOnBroadcast
PatternRewriter &rewriter) const override {
if (op->getNumOperands() != 1)
return failure();
+ if (!isa<VectorType>(op->getResult(0).getType()))
+ return failure();
auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
if (!bcastOp)
return failure();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index fbae098..462bd8c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1003,6 +1003,286 @@ private:
vector::UnrollVectorOptions options;
};
+/// This pattern unrolls `vector.create_mask` operations into smaller mask
+/// operations based on the target unroll shape. Each unrolled slice computes
+/// its local mask size in each dimension (d) as:
+/// min(max(originalMaskSize[d] - offset[d], 0), unrolledDimSize[d]).
+/// Example:
+/// Given a create_mask operation:
+/// %0 = vector.create_mask %c6, %c10 : vector<8x16xi1> // mask first 6x10
+/// elements
+///
+/// and a target unroll shape of <4x8>, the pattern produces:
+///
+/// %false = arith.constant dense<false> : vector<8x16xi1>
+///
+/// Slice [0,0]:
+/// mask size = min(max(6-0, 0), 4) x min(max(10-0, 0), 8) = 4x8
+/// %mask00 = vector.create_mask %c4, %c8 : vector<4x8xi1>
+/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+/// Slice [0,8]:
+/// mask size = min(max(6-0, 0), 4) x min(max(10-8, 0), 8) = 4x2
+/// %mask01 = vector.create_mask %c4, %c2 : vector<4x8xi1>
+/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+/// Slice [4,0]:
+/// mask size = min(max(6-4, 0), 4) x min(max(10-0, 0), 8) = 2x8
+/// %mask10 = vector.create_mask %c2, %c8 : vector<4x8xi1>
+/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+/// Slice [4,8]:
+/// mask size = min(max(6-4, 0), 4) x min(max(10-8, 0), 8) = 2x2
+/// %mask11 = vector.create_mask %c2, %c2 : vector<4x8xi1>
+/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
+ UnrollCreateMaskPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::CreateMaskOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp,
+ PatternRewriter &rewriter) const override {
+ auto targetShape = getTargetShape(options, createMaskOp);
+ if (!targetShape)
+ return failure();
+
+ VectorType resultType = createMaskOp.getVectorType();
+ SmallVector<int64_t> originalSize = *createMaskOp.getShapeForUnroll();
+ Location loc = createMaskOp.getLoc();
+
+ Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+ rewriter.getZeroAttr(resultType));
+ VectorType targetVectorType =
+ VectorType::get(*targetShape, rewriter.getI1Type());
+ SmallVector<int64_t> strides(targetShape->size(), 1);
+
+ // In each dimension (d), each unrolled vector computes its mask size as:
+ // min(max(originalMaskOperands[d] - offset[d], 0), unrolledDimSize[d]).
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalSize, *targetShape)) {
+ SmallVector<Value> unrolledOperands;
+
+ for (auto [i, originalMaskOperand] :
+ llvm::enumerate(createMaskOp.getOperands())) {
+ Value offsetVal =
+ arith::ConstantIndexOp::create(rewriter, loc, offsets[i]);
+ Value adjustedMaskSize = rewriter.createOrFold<arith::SubIOp>(
+ loc, originalMaskOperand, offsetVal);
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value unrolledDimSize =
+ arith::ConstantIndexOp::create(rewriter, loc, (*targetShape)[i]);
+ Value nonNegative =
+ rewriter.createOrFold<arith::MaxSIOp>(loc, adjustedMaskSize, zero);
+ Value unrolledOperand = rewriter.createOrFold<arith::MinSIOp>(
+ loc, nonNegative, unrolledDimSize);
+ unrolledOperands.push_back(unrolledOperand);
+ }
+
+ auto unrolledMask = rewriter.createOrFold<vector::CreateMaskOp>(
+ loc, targetVectorType, unrolledOperands);
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, unrolledMask, result, offsets, strides);
+ }
+ rewriter.replaceOp(createMaskOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
+/// Checks whether extractShape is a contiguous slice of shape.
+/// For extractShape to be contiguous in shape:
+/// 1) All but the leading dimension of extractShape and shape must match
+/// exactly. 2) The total number of elements in shape must be evenly divisible
+/// by
+/// the total number of elements in extractShape.
+/// Examples:
+/// isContiguous([4, 4], [8, 4]) == true
+/// isContiguous([2, 4], [8, 4]) == true
+/// isContiguous([2, 2], [8, 4]) == false
+/// Removes leading unit dimensions to handle cases like:
+/// isContiguous([1, 16], [1, 32]) == true
+static bool isContiguous(ArrayRef<int64_t> extractShape,
+ ArrayRef<int64_t> shape) {
+
+ if (extractShape.size() > shape.size())
+ return false;
+
+ while (!extractShape.empty() && extractShape.front() == 1) {
+ extractShape = extractShape.drop_front();
+ }
+
+ while (!shape.empty() && shape.front() == 1) {
+ shape = shape.drop_front();
+ }
+
+ size_t rankDiff = shape.size() - extractShape.size();
+ if (!llvm::equal(extractShape.drop_front(), shape.drop_front(rankDiff + 1)))
+ return false;
+
+ int64_t extractElements = ShapedType::getNumElements(extractShape);
+ int64_t shapeElements = ShapedType::getNumElements(shape);
+ return shapeElements % extractElements == 0;
+}
+
+/// Determines what shape to use with `vector.extract_strided_slice` to extract
+/// a contiguous memory region from a source vector. The extraction must be
+/// contiguous and contain exactly the specified number of elements. If such an
+/// extraction shape cannot be determined, returns std::nullopt.
+/// EXAMPLE 1:
+/// sourceShape = [16], targetElements = 8
+/// Working right-to-left:
+/// - Take min(8, 16) = 8 from only dim → extractShape = [8],
+/// remaining = 8/8 = 1
+/// Result: [8]
+///
+/// EXAMPLE 2:
+/// sourceShape = [4, 4], targetElements = 8
+/// Working right-to-left:
+/// - Take min(8, 4) = 4 from last dim → extractShape = [4],
+/// remaining = 8/4 = 2
+/// - Take min(2, 4) = 2 from first dim → extractShape = [2, 4],
+/// remaining = 2/2 = 1
+/// Result: [2, 4]
+static std::optional<SmallVector<int64_t>>
+calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
+ int64_t targetElements) {
+ SmallVector<int64_t> extractShape;
+ int64_t remainingElements = targetElements;
+
+ // Build extract shape from innermost dimension outward to ensure contiguity.
+ for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
+ int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
+ extractShape.insert(extractShape.begin(), takeFromDim);
+
+ if (remainingElements % takeFromDim != 0)
+ return std::nullopt; // Not evenly divisible.
+ remainingElements /= takeFromDim;
+ }
+
+ // Fill remaining dimensions with 1.
+ while (extractShape.size() < sourceShape.size())
+ extractShape.insert(extractShape.begin(), 1);
+
+ if (ShapedType::getNumElements(extractShape) != targetElements)
+ return std::nullopt;
+
+ return extractShape;
+}
+
+// Convert result offsets to source offsets via linear position.
+static SmallVector<int64_t>
+calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
+ ArrayRef<int64_t> sourceShape,
+ ArrayRef<int64_t> resultShape) {
+ // Convert result offsets to linear position.
+ int64_t linearIndex = linearize(resultOffsets, computeStrides(resultShape));
+ // Convert linear position to source offsets.
+ return delinearize(linearIndex, computeStrides(sourceShape));
+}
+
+/// This pattern unrolls `vector.shape_cast` operations according to the
+/// provided target unroll shape. It unrolls a large shape cast into smaller
+/// shape casts by extracting contiguous slices from the source vector, casting
+/// each slice to the target shape, and assembling the result by inserting each
+/// computed segment into the appropriate offset of the result vector.
+///
+/// This pattern only applies when contiguous slices can be extracted from the
+/// source vector and inserted into the result vector such that each slice
+/// remains a valid vector (and not decompose to scalars). In these cases, the
+/// unrolling proceeds as:
+/// vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
+/// vector.insert_strided_slice.
+///
+/// Example:
+/// Given a shape cast operation:
+/// %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32>
+///
+/// and a target unroll shape of <2x4>, the pattern produces:
+///
+/// %zero = arith.constant dense<0.0> : vector<4x4xf32>
+/// %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1]
+/// : vector<8x2xf32> to vector<4x2xf32>
+/// %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32>
+/// %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1]
+/// : vector<2x4xf32> into vector<4x4xf32>
+/// %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1]
+/// : vector<8x2xf32> to vector<4x2xf32>
+/// %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32>
+/// %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1]
+/// : vector<2x4xf32> into vector<4x4xf32>
+///
+struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
+ UnrollShapeCastPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ std::optional<SmallVector<int64_t>> targetShape =
+ getTargetShape(options, shapeCastOp);
+ if (!targetShape)
+ return failure();
+
+ VectorType sourceType = shapeCastOp.getSourceVectorType();
+ VectorType resultType = shapeCastOp.getResultVectorType();
+ ArrayRef<int64_t> sourceShape = sourceType.getShape();
+ ArrayRef<int64_t> resultShape = resultType.getShape();
+
+ if (!isContiguous(*targetShape, resultShape))
+ return rewriter.notifyMatchFailure(
+ shapeCastOp, "Only supports cases where target shape is "
+ "contiguous in result vector shape");
+
+ int64_t targetElements = ShapedType::getNumElements(*targetShape);
+
+ // Calculate the shape to extract from source.
+ std::optional<SmallVector<int64_t>> extractShape =
+ calculateSourceExtractShape(sourceShape, targetElements);
+ if (!extractShape)
+ return rewriter.notifyMatchFailure(
+ shapeCastOp,
+ "cannot extract target number of elements contiguously from source");
+
+ Location loc = shapeCastOp.getLoc();
+
+ // Create result vector initialized to zero.
+ Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+ rewriter.getZeroAttr(resultType));
+
+ VectorType targetType =
+ VectorType::get(*targetShape, sourceType.getElementType());
+
+ SmallVector<int64_t> extractStrides(extractShape->size(), 1);
+ SmallVector<int64_t> insertStrides(targetShape->size(), 1);
+
+ for (SmallVector<int64_t> resultOffsets :
+ StaticTileOffsetRange(resultShape, *targetShape)) {
+ SmallVector<int64_t> sourceOffsets =
+ calculateSourceOffsets(resultOffsets, sourceShape, resultShape);
+ Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
+ extractStrides);
+ Value targetChunk = rewriter.createOrFold<vector::ShapeCastOp>(
+ loc, targetType, sourceChunk);
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, targetChunk, result, resultOffsets, insertStrides);
+ }
+
+ rewriter.replaceOp(shapeCastOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
} // namespace
void mlir::vector::populateVectorUnrollPatterns(
@@ -1013,8 +1293,9 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
- UnrollToElements, UnrollStepPattern>(patterns.getContext(),
- options, benefit);
+ UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
+ UnrollCreateMaskPattern>(patterns.getContext(), options,
+ benefit);
}
void mlir::vector::populateVectorToElementsUnrollPatterns(
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index c809c502..c307fb4 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -322,46 +322,61 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
std::optional<Value> padValue,
bool useInBoundsInsteadOfMasking,
ArrayRef<bool> inputScalableVecDims) {
- assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
+ VectorType vecToReadTy = VectorType::get(
+ inputVectorSizes, cast<ShapedType>(source.getType()).getElementType(),
+ inputScalableVecDims);
+
+ return createReadOrMaskedRead(builder, loc, source, vecToReadTy, padValue,
+ useInBoundsInsteadOfMasking);
+}
+
+Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
+ Value source,
+ const VectorType &vecToReadTy,
+ std::optional<Value> padValue,
+ bool useInBoundsInsteadOfMasking) {
+ assert(!llvm::is_contained(vecToReadTy.getScalableDims(),
+ ShapedType::kDynamic) &&
"invalid input vector sizes");
auto sourceShapedType = cast<ShapedType>(source.getType());
auto sourceShape = sourceShapedType.getShape();
- assert(sourceShape.size() == inputVectorSizes.size() &&
+
+ int64_t vecToReadRank = vecToReadTy.getRank();
+ auto vecToReadShape = vecToReadTy.getShape();
+
+ assert(sourceShape.size() == static_cast<size_t>(vecToReadRank) &&
"expected same ranks.");
- auto vectorType =
- VectorType::get(inputVectorSizes, sourceShapedType.getElementType(),
- inputScalableVecDims);
assert((!padValue.has_value() ||
padValue.value().getType() == sourceShapedType.getElementType()) &&
"expected same pad element type to match source element type");
- int64_t readRank = inputVectorSizes.size();
+
auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
- SmallVector<bool> inBoundsVal(readRank, true);
+ SmallVector<bool> inBoundsVal(vecToReadRank, true);
if (useInBoundsInsteadOfMasking) {
// Update the inBounds attribute.
// FIXME: This computation is too weak - it ignores the read indices.
- for (unsigned i = 0; i < readRank; i++)
- inBoundsVal[i] = (sourceShape[i] == inputVectorSizes[i]) &&
+ for (unsigned i = 0; i < vecToReadRank; i++)
+ inBoundsVal[i] = (sourceShape[i] == vecToReadShape[i]) &&
ShapedType::isStatic(sourceShape[i]);
}
auto transferReadOp = vector::TransferReadOp::create(
builder, loc,
- /*vectorType=*/vectorType,
+ /*vectorType=*/vecToReadTy,
/*source=*/source,
- /*indices=*/SmallVector<Value>(readRank, zero),
+ /*indices=*/SmallVector<Value>(vecToReadRank, zero),
/*padding=*/padValue,
/*inBounds=*/inBoundsVal);
- if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking)
+ if (llvm::equal(vecToReadTy.getShape(), sourceShape) ||
+ useInBoundsInsteadOfMasking)
return transferReadOp;
SmallVector<OpFoldResult> mixedSourceDims =
isa<MemRefType>(source.getType())
? memref::getMixedSizes(builder, loc, source)
: tensor::getMixedSizes(builder, loc, source);
- auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(),
- inputScalableVecDims);
+ auto maskType = vecToReadTy.cloneWith(/*shape=*/{}, builder.getI1Type());
Value mask =
vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims);
return mlir::vector::maskOperation(builder, transferReadOp, mask)
diff --git a/mlir/lib/Dialect/X86Vector/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
index 9f57627..cb1e9d0 100644
--- a/mlir/lib/Dialect/X86Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000..f4c9f8a
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRX86VectorTransformOps
+ X86VectorTransformOps.cpp
+
+ DEPENDS
+ MLIRX86VectorTransformOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRVectorDialect
+ MLIRSideEffectInterfaces
+ MLIRTransformDialect
+ MLIRTransformDialectUtils
+ MLIRX86VectorDialect
+ MLIRX86VectorTransforms
+ )
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
new file mode 100644
index 0000000..95db208
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -0,0 +1,64 @@
+//===- X86VectorTransformOps.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+
+using namespace mlir;
+using namespace mlir::x86vector;
+using namespace mlir::transform;
+
+void mlir::transform::ApplyVectorContractToFMAPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ x86vector::populateVectorContractToFMAPatterns(patterns);
+}
+
+void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp::
+ populatePatterns(RewritePatternSet &patterns) {
+ x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns);
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class X86VectorTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ X86VectorTransformDialectExtension> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ X86VectorTransformDialectExtension)
+
+ X86VectorTransformDialectExtension() {
+ declareGeneratedDialect<x86vector::X86VectorDialect>();
+ declareGeneratedDialect<LLVM::LLVMDialect>();
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
+ >();
+ }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
+
+void mlir::x86vector::registerTransformDialectExtension(
+ DialectRegistry &registry) {
+ registry.addExtensions<X86VectorTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index c51266a..2cab50f 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -1,11 +1,14 @@
add_mlir_dialect_library(MLIRX86VectorTransforms
AVXTranspose.cpp
LegalizeForLLVMExport.cpp
+ VectorContractToFMA.cpp
+ VectorContractToPackedTypeDotProduct.cpp
LINK_LIBS PUBLIC
MLIRArithDialect
MLIRX86VectorDialect
MLIRIR
+ MLIRLinalgDialect
MLIRLLVMCommonConversion
MLIRLLVMDialect
MLIRVectorDialect
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
new file mode 100644
index 0000000..f3af5ca
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
@@ -0,0 +1,143 @@
+//===- VectorContractToFMA.cpp --------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+namespace {
+
+// Implements outer product contraction as a sequence of broadcast and
+// FMA operations.
+//
+// For example - for F32 type:
+// ```
+// vector.contract <1x1xf32>, <1x16xf32> into <1x16xf32>
+// ```
+// to
+// ```
+// vector.broadcast %lhs to <16xf32>
+// vector.fma vector<16xf32>
+// ```
+struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind.");
+
+ VectorType lhsTy = contractOp.getLhsType();
+ if (!lhsTy.getElementType().isF32())
+ return rewriter.notifyMatchFailure(contractOp,
+ "Only F32 lowering is supported.");
+
+ ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimLhs;
+ llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
+ [](int64_t dim) { return dim != 1; });
+
+ VectorType rhsTy = contractOp.getRhsType();
+ ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimRhs;
+ llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
+ [](int64_t dim) { return dim != 1; });
+
+ if (nonUnitDimLhs.size() > 0 && nonUnitDimRhs.size() > 0)
+ return rewriter.notifyMatchFailure(
+ contractOp, "Excepts unit dimensions for either LHS or RHS shape.");
+
+ if (nonUnitDimLhs.size() != 1 && nonUnitDimRhs.size() != 1)
+ return rewriter.notifyMatchFailure(
+ contractOp,
+ "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
+
+ VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+ if (!accTy)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Accmulator is not a vector type");
+
+ if (!accTy.getElementType().isF32())
+ return rewriter.notifyMatchFailure(contractOp,
+ "Accmulator should be F32 type.");
+
+ ArrayRef<int64_t> accShape = accTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimAcc;
+ llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
+ [](int64_t dim) { return dim != 1; });
+ if (nonUnitDimAcc.size() != 1)
+ return rewriter.notifyMatchFailure(
+ contractOp, "A or B dimension should be non-unit.");
+
+ // Lowers vector.contract into a broadcast+FMA sequence.
+ auto loc = contractOp.getLoc();
+ auto castAcc = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
+ contractOp.getAcc());
+
+ vector::FMAOp fma;
+
+ // Broadcast the unit-dimension LHS or RHS to match the vector length of the
+ // corresponding non-unit dimension on the other operand. For example,
+ // if LHS has type vector<1x1xf32> and RHS has type vector<1x16xf32>, we
+ // broadcast the LHS to vector<1x16xf32>. In the opposite case (non-unit
+ // dimension on the LHS), we broadcast the RHS instead.
+ if (nonUnitDimRhs.size() > 0) {
+ auto castLhs = vector::ShapeCastOp::create(
+ rewriter, loc, VectorType::get(1, lhsTy.getElementType()),
+ contractOp.getLhs());
+ auto castRhs = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()),
+ contractOp.getRhs());
+ auto broadcastLhs = vector::BroadcastOp::create(
+ rewriter, loc, castRhs.getResult().getType(), castLhs);
+ fma =
+ vector::FMAOp::create(rewriter, loc, broadcastLhs, castRhs, castAcc);
+ } else {
+ auto castLhs = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()),
+ contractOp.getLhs());
+ auto castRhs = vector::ShapeCastOp::create(
+ rewriter, loc, VectorType::get(1, rhsTy.getElementType()),
+ contractOp.getRhs());
+ auto broadcastRhs = vector::BroadcastOp::create(
+ rewriter, loc, castLhs.getResult().getType(), castRhs);
+ fma =
+ vector::FMAOp::create(rewriter, loc, castLhs, broadcastRhs, castAcc);
+ }
+
+ auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma);
+ rewriter.replaceOp(contractOp, castFma);
+
+ return success();
+ }
+};
+
+} // namespace
+
+void x86vector::populateVectorContractToFMAPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<VectorContractToFMA>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
new file mode 100644
index 0000000..1e64811
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -0,0 +1,301 @@
+//===- VectorContractToPackedTypeDotProduct.cpp ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+namespace {
+
+static FailureOr<SmallVector<mlir::utils::IteratorType>>
+inferIteratorsFromOutMap(AffineMap map) {
+ if (!map.isProjectedPermutation())
+ return failure();
+ SmallVector<mlir::utils::IteratorType> iterators(
+ map.getNumDims(), mlir::utils::IteratorType::reduction);
+ for (auto expr : map.getResults())
+ if (auto dim = dyn_cast<AffineDimExpr>(expr))
+ iterators[dim.getPosition()] = mlir::utils::IteratorType::parallel;
+ return iterators;
+}
+
+// Returns true if the operation is in VNNI layout.
+// Optionally, the check can be constrained to a specific VNNI blocking factor.
+static bool isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps,
+ std::optional<unsigned> blockingFactor) {
+ // Narrow down type operations - VNNI only applies to contractions.
+ FailureOr<linalg::ContractionDimensions> dims =
+ linalg::inferContractionDims(indexingMaps);
+ if (failed(dims))
+ return false;
+
+ auto matA = op->getOperand(0);
+ auto matB = op->getOperand(1);
+ auto typeA = dyn_cast<ShapedType>(matA.getType());
+ auto typeB = dyn_cast<ShapedType>(matB.getType());
+ unsigned rankA = typeA.getRank();
+ unsigned rankB = typeB.getRank();
+ // VNNI format requires at least 1 parallel and 2 reduction dimensions.
+ if (rankA < 3 || rankB < 3)
+ return false;
+
+ // At least two reduction dimensions are expected:
+ // one for the VNNI factor and one for the K dimension
+ if (dims->k.size() < 2)
+ return false;
+
+ // Validate affine maps - VNNI computation should be defined by the two
+ // innermost reduction iterators.
+ // The input matrix dimensions layout must match the following:
+ // - matrix A - [...][K/vnniFactor][vnniFactor]
+ // - matrix B - [...][K/vnniFactor][N][vnniFactor]
+ auto maybeIters = inferIteratorsFromOutMap(indexingMaps[2]);
+ if (failed(maybeIters))
+ return false;
+ SmallVector<mlir::utils::IteratorType> iteratorTypes = *maybeIters;
+ AffineMap mapA = indexingMaps[0];
+ AffineMap mapB = indexingMaps[1];
+
+ auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 1));
+ auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 1));
+ if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
+ iteratorTypes[vnniDimA.getPosition()] !=
+ mlir::utils::IteratorType::reduction)
+ return false;
+ auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 2));
+ auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 3));
+ if (!redDimA || !redDimB || redDimA != redDimB ||
+ iteratorTypes[redDimA.getPosition()] !=
+ mlir::utils::IteratorType::reduction)
+ return false;
+ auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 2));
+ if (!parallelDimB || iteratorTypes[parallelDimB.getPosition()] !=
+ mlir::utils::IteratorType::parallel)
+ return false;
+
+ // VNNI factor must be:
+ // - the innermost inputs' dimension
+ // - statically known
+ // - multiple of 2 or equal to the specified factor
+ auto vnniDimSize = typeB.getShape().back();
+ if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 ||
+ vnniDimSize % 2 != 0)
+ return false;
+ if (typeA.getShape().back() != vnniDimSize)
+ return false;
+ if (blockingFactor && vnniDimSize != *blockingFactor)
+ return false;
+
+ // The split reduction dimension size should also match.
+ if (typeA.getShape().end()[-2] != typeB.getShape().end()[-3])
+ return false;
+
+ return true;
+}
+
+// Implements packed type outer product contraction as a sequence
+// of broadcast and packed dot-product operations.
+//
+// For example - for F32 type:
+// ```
+// vector.contract <1x1x2xbf16>, <1x16x2xbf16> into <1x16xf32>
+// ```
+// to
+// ```
+// vector.broadcast %lhs to <32xbf16>
+// x86vector.avx512.dot vector<32xbf16> -> vector<16xf32>
+// ```
+struct VectorContractToPackedTypeDotProduct
+ : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind.");
+
+ VectorType lhsTy = contractOp.getLhsType();
+ if (!lhsTy.getElementType().isBF16() &&
+ !lhsTy.getElementType().isSignlessInteger(8))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Only BF16/Int8 lowering is supported.");
+
+ unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4;
+ if (!isInVnniLayout(contractOp.getOperation(),
+ contractOp.getIndexingMapsArray(), blockingFactor))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Input matrices not in VNNI format.");
+
+ ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimLhs;
+ llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
+ [](int64_t dim) { return dim != 1; });
+
+ VectorType rhsTy = contractOp.getRhsType();
+ ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimRhs;
+ llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
+ [](int64_t dim) { return dim != 1; });
+
+ if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Excepts unit dimensions for either "
+ "LHS or RHS shape other than VNNI.");
+
+ if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
+ return rewriter.notifyMatchFailure(
+ contractOp,
+ "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
+
+ VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+ if (!accTy)
+ return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
+
+ if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
+ (lhsTy.getElementType().isSignlessInteger(8) &&
+ !accTy.getElementType().isSignlessInteger(32)))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Only F32 for BF16 or Int32 for Int8 "
+ "accumulation type is supported.");
+
+ ArrayRef<int64_t> accShape = accTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimAcc;
+ llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
+ [](int64_t dim) { return dim != 1; });
+ if (nonUnitDimAcc.size() != 1)
+ return rewriter.notifyMatchFailure(
+ contractOp, "A or B should be a non-unit dim in acc.");
+
+ // Non-unit dimensions should match the vector length of BF16 or Int8
+ // dot-product.
+ unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front()
+ : nonUnitDimRhs.front();
+ if (lhsTy.getElementType().isBF16() && nonUnitDim != 4 && nonUnitDim != 8 &&
+ nonUnitDim != 16 && nonUnitDimAcc.front() == nonUnitDim)
+ return rewriter.notifyMatchFailure(
+ contractOp, "BF16 dot-product operation expects non-unit (LHR or "
+ "RHS) dim and acc dim of size 4/8/16.");
+
+ if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDim != 4 &&
+ nonUnitDim != 8 && nonUnitDimAcc.front() == nonUnitDim)
+ return rewriter.notifyMatchFailure(
+ contractOp, "Int8 dot-product operation expects non-unit (LHR or "
+ "RHS) dim and acc dim of size 4/8.");
+
+ auto loc = contractOp.getLoc();
+ auto castAcc = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
+ contractOp.getAcc());
+
+ Value dp;
+
+ // Broadcast the unit-dimension LHS or RHS to match the vector length of the
+ // corresponding non-unit dimension on the other operand. For example,
+ // if LHS has type vector<1x1x2xbf16> and RHS has type vector<1x16x2xbf16>,
+ // we broadcast the LHS to vector<16x2xbf16>. In the opposite case (non-unit
+ // dimension on the LHS), we broadcast the RHS instead.
+ if ((nonUnitDimRhs.size() - 1) > 0) {
+ auto castRhs = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimRhs.front() * nonUnitDimRhs.back(),
+ rhsTy.getElementType()),
+ contractOp.getRhs());
+ auto castLhs = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()),
+ contractOp.getLhs());
+ auto bitcastLhs = vector::BitCastOp::create(
+ rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
+ castLhs);
+ auto broadcastLhs = vector::BroadcastOp::create(
+ rewriter, loc,
+ VectorType::get({nonUnitDimRhs.front()}, rewriter.getIntegerType(32)),
+ bitcastLhs);
+ auto bitcastLhsPkType = vector::BitCastOp::create(
+ rewriter, loc, castRhs.getResult().getType(), broadcastLhs);
+
+ if (lhsTy.getElementType().isBF16()) {
+ dp = x86vector::DotBF16Op::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimRhs.front(), rewriter.getF32Type()),
+ castAcc, bitcastLhsPkType, castRhs);
+ }
+
+ if (lhsTy.getElementType().isSignlessInteger(8)) {
+ dp = x86vector::DotInt8Op::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimRhs.front(), rewriter.getIntegerType(32)),
+ castAcc, bitcastLhsPkType, castRhs);
+ }
+ } else {
+ auto castLhs = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimLhs.front() * nonUnitDimLhs.back(),
+ lhsTy.getElementType()),
+ contractOp.getLhs());
+ auto castRhs = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()),
+ contractOp.getRhs());
+ auto bitcastRhs = vector::BitCastOp::create(
+ rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
+ castRhs);
+ auto broadcastRhs = vector::BroadcastOp::create(
+ rewriter, loc,
+ VectorType::get({nonUnitDimLhs.front()}, rewriter.getIntegerType(32)),
+ bitcastRhs);
+ auto bitcastRhsPkType = vector::BitCastOp::create(
+ rewriter, loc, castLhs.getResult().getType(), broadcastRhs);
+
+ if (lhsTy.getElementType().isBF16()) {
+ dp = x86vector::DotBF16Op::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimLhs.front(), rewriter.getF32Type()),
+ castAcc, castLhs, bitcastRhsPkType);
+ }
+
+ if (lhsTy.getElementType().isSignlessInteger(8)) {
+ dp = x86vector::DotInt8Op::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimLhs.front(), rewriter.getIntegerType(32)),
+ castAcc, castLhs, bitcastRhsPkType);
+ }
+ }
+
+ if (!dp)
+ return failure();
+
+ auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp);
+ rewriter.replaceOp(contractOp, castDp);
+ return success();
+ }
+};
+
+} // namespace
+
+void x86vector::populateVectorContractToPackedTypeDotProductPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<VectorContractToPackedTypeDotProduct>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
index 31167e6..46b8251 100644
--- a/mlir/lib/Dialect/XeGPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
@@ -1,3 +1,4 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(Utils)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 6b4c185..1a19ab5 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -8,10 +8,8 @@
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
@@ -38,55 +36,61 @@ void XeGPUDialect::initialize() {
>();
}
-/// Generates instructions to compute offsets for a subgroup identified by
-/// its multidimensional indices (sgId), using the specified subgroup layout
-/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
-/// dimensions (sizePerWg).
+// A `srcShape` consists of N distribution units, each being `subShapesLayout` x
+// `subShape`. A `delinearizedId` is used to identify a particular `subShape`
+// within each distribution unit.
+// Example:
+// WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a
+// distribution unit of shape 64x64, we have 2x4 such distribution units.
+// `delinearizedId` is used to identify a 16x32 of a subgroup in each
+// distribution unit.
static SmallVector<SmallVector<Value>>
-genOffsetsComputingInsts(OpBuilder &builder, Location loc,
- SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout,
- ArrayRef<int64_t> sizePerSg,
- ArrayRef<int64_t> sizePerWg) {
-
- SmallVector<SmallVector<Value>> offsets;
+genCoordinates(OpBuilder &builder, Location loc,
+ SmallVector<Value> delinearizedId,
+ ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
+ ArrayRef<int64_t> srcShape) {
+ SmallVector<SmallVector<Value>> coordinates;
+
+ // A distribution unit must be less than or equal to `srcShape`
+ SmallVector<int64_t> distUnitShape = llvm::map_to_vector(
+ llvm::zip_equal(srcShape,
+ computeElementwiseMul(subShapesLayout, subShape)),
+ [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
- // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i]
- SmallVector<Value> localOffsets = llvm::map_to_vector(
- llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value {
- return builder.createOrFold<index::MulOp>(
+ // Get the offset of `subShape` within a distribution unit.
+ SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector(
+ llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value {
+ return builder.createOrFold<arith::MulIOp>(
loc, std::get<0>(t),
builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
});
- // distUnit[i] is the minimum value between sizePerWg[i] and
- // sgLayout[i] * sizePerSg[i]
- SmallVector<int64_t> distUnit = llvm::map_to_vector(
- llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)),
- [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
-
+ // For each dist unit
for (SmallVector<int64_t> unitOffs :
- StaticTileOffsetRange(sizePerWg, distUnit)) {
+ StaticTileOffsetRange(srcShape, distUnitShape)) {
+ // Get dist unit offset within `srcShape`.
SmallVector<Value> base =
llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
return arith::ConstantIndexOp::create(builder, loc, d);
});
-
- SmallVector<Value> adds = llvm::map_to_vector(
- llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value {
- return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t),
- std::get<1>(t));
- });
-
+ // Calculate `subShape` offset within `srcShape`.
+ SmallVector<Value> adds =
+ llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset),
+ [&](const auto &t) -> Value {
+ return builder.createOrFold<arith::AddIOp>(
+ loc, std::get<0>(t), std::get<1>(t));
+ });
+ // Do not go beyond `srcShape` bounds.
SmallVector<Value> mods = llvm::map_to_vector(
- llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value {
- return builder.createOrFold<index::RemUOp>(
+ llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value {
+ return builder.createOrFold<arith::RemUIOp>(
loc, std::get<0>(t),
arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
});
- offsets.push_back(mods);
+ coordinates.push_back(mods);
}
- return offsets;
+ return coordinates;
}
// Checks if the given shape can be evenly distributed based on the layout
@@ -273,56 +277,197 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
}
FailureOr<SmallVector<Value>>
-LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
- Value linearId) {
- // delinearizeSubgroupId is only available for
- // workgroup-level layout attribute
- if (!isForWorkgroup())
+LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
+
+ SmallVector<int64_t> sgLayoutInt;
+ if (isForWorkgroup()) {
+ sgLayoutInt = getEffectiveSgLayoutAsInt();
+ } else if (isForSubgroup()) {
+ sgLayoutInt = getEffectiveLaneLayoutAsInt();
+ } else {
return failure();
+ }
- // TODO: handle order attribute
- auto hasDefaultOrder = [&]() {
- DenseI32ArrayAttr order = getOrder();
- return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>(
- llvm::reverse(order.asArrayRef())));
- };
- if (!hasDefaultOrder())
- return mlir::emitError(loc, "order attribute is currently not supported.");
+ DenseI32ArrayAttr orderAttr = getOrder();
- auto dims =
- llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value {
- return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
- });
+ // Handle order attribute
+ SmallVector<int64_t> order;
+ if (orderAttr && !orderAttr.empty()) {
+ order = llvm::to_vector(
+ llvm::map_range(orderAttr.asArrayRef(),
+ [](int32_t idx) { return static_cast<int64_t>(idx); }));
+ } else {
+ // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc.
+ order = llvm::to_vector(
+ llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size())));
+ }
- return affine::delinearizeIndex(builder, loc, linearId, dims);
+ if (order.size() != sgLayoutInt.size()) {
+ return failure();
+ }
+
+ SmallVector<Value> result(sgLayoutInt.size());
+ Value remaining = linearId;
+
+ /// Process dimensions in the order they appear in the order array
+ /// The first dimension in order is the fastest-changing
+ ///
+ /// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]:
+ ///
+ /// Initial: remaining=22, dimIdx = order[i], dimSize = sgLayout[dimIdx],
+ /// result=[?,?,?]
+ ///
+ /// i=0 (process columns, dimIdx=2, dimSize=4):
+ /// result[2] = 22 % 4 = 2 (column coordinate)
+ /// remaining = 22 / 4 = 5 (5 complete groups of 4 columns processed)
+ ///
+ /// i=1 (process rows, dimIdx=1, dimSize=4):
+ /// result[1] = 5 % 4 = 1 (row coordinate)
+ /// remaining = 5 / 4 = 1 (1 complete group of 4 rows processed)
+ ///
+ /// i=2 (process layers, dimIdx=0, dimSize=2):
+ /// result[0] = 1 % 2 = 1 (layer coordinate)
+ /// (no remaining update - last iteration)
+ ///
+ /// Final result: [1,1,2] = Layer 1, Row 1, Column 2
+ for (size_t i = 0; i < order.size(); ++i) {
+ int64_t dimIdx = order[i];
+ int64_t dimSize = sgLayoutInt[dimIdx];
+
+ Value dimSizeVal =
+ builder.createOrFold<arith::ConstantIndexOp>(loc, dimSize);
+
+ /// Extract the coordinate for this dimension using modulo operation
+ /// This gives us "how far within this dimension" we are
+ /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within
+ /// this dimension)
+ result[dimIdx] =
+ builder.createOrFold<arith::RemUIOp>(loc, remaining, dimSizeVal);
+
+ /// Update remaining for the next dimension by removing what we've already
+ /// processed. Division tells us "how many complete groups of this dimension
+ /// we've gone through" e.g., linearId=22, dimSize=4: 22 / 4 = 5 (we've
+ /// completed 5 groups of 4) Skip this for the last iteration since there's
+ /// no next dimension to process
+ if (i < order.size() - 1) {
+ remaining =
+ builder.createOrFold<arith::DivUIOp>(loc, remaining, dimSizeVal);
+ }
+ }
+ return result;
}
-/// Implements DistributeLayoutAttr::getOffsets to generate
+/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
/// instructions for computing multi-dimensional offsets when distributed by
/// LayoutAttr.
FailureOr<SmallVector<SmallVector<Value>>>
-LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
- ArrayRef<int64_t> shape) {
- if (!isForWorkgroup())
+LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
+ Value linearId, ArrayRef<int64_t> shape) {
+ SmallVector<int64_t> layout;
+ SmallVector<int64_t> subShape;
+ if (isForWorkgroup()) {
+ layout = getEffectiveSgLayoutAsInt();
+ subShape = getEffectiveSgDataAsInt();
+ } else if (isForSubgroup()) {
+ layout = getEffectiveLaneLayoutAsInt();
+ subShape = getEffectiveLaneDataAsInt();
+ } else {
return failure();
-
- SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
- if (sgShape.empty()) {
- if (auto derivedShape = computeShapeRatio(shape, sgLayout))
- sgShape = derivedShape.value();
+ }
+ if (subShape.empty()) {
+ if (auto derivedShape = computeShapeRatio(shape, layout))
+ subShape = derivedShape.value();
else
return failure();
}
// delinearize Ids
- auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+ auto maybeIds = delinearizeId(builder, loc, linearId);
if (failed(maybeIds))
return failure();
- SmallVector<Value> sgIds = *maybeIds;
+ SmallVector<Value> ids = *maybeIds;
+
+ return genCoordinates(builder, loc, ids, layout, subShape, shape);
+}
+
+bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
+ if (dyn_cast<xegpu::SliceAttr>(other))
+ return false;
+
+ return *this == dyn_cast<xegpu::LayoutAttr>(other);
+}
+
+// set the layout for unit dims: sg_data, inst_data and lane_data to 1
+DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) {
+ auto sgDataOpt = getSgData();
+ auto instDataOpt = getInstData();
+ auto laneDataOpt = getLaneData();
+
+ SmallVector<int32_t> sgData;
+ SmallVector<int32_t> instData;
+ SmallVector<int32_t> laneData;
+
+ if (sgDataOpt) {
+ sgData = llvm::to_vector(sgDataOpt.asArrayRef());
+ }
+ if (instDataOpt) {
+ instData = llvm::to_vector(instDataOpt.asArrayRef());
+ }
+ if (laneDataOpt) {
+ laneData = llvm::to_vector(laneDataOpt.asArrayRef());
+ }
+
+ for (auto dim : unitDims) {
+ if (dim < static_cast<int64_t>(sgData.size()))
+ sgData[dim] = 1;
+ if (dim < static_cast<int64_t>(instData.size()))
+ instData[dim] = 1;
+ if (dim < static_cast<int64_t>(laneData.size()))
+ laneData[dim] = 1;
+ }
+
+ return LayoutAttr::get(
+ getContext(), getSgLayout(),
+ sgData.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), sgData),
+ instData.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), instData),
+ getLaneLayout(),
+ laneData.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), laneData),
+ getOrder());
+}
+
+// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
+DistributeLayoutAttr LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
+ auto sgLayoutOpt = getSgLayout();
+ auto laneLayoutOpt = getLaneLayout();
+
+ SmallVector<int32_t> sgLayout;
+ SmallVector<int32_t> laneLayout;
+
+ if (sgLayoutOpt) {
+ sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef());
+ }
+ if (laneLayoutOpt) {
+ laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef());
+ }
+
+ for (auto dim : unitDims) {
+ if (dim < static_cast<int64_t>(sgLayout.size()))
+ sgLayout[dim] = 1;
+ if (dim < static_cast<int64_t>(laneLayout.size()))
+ laneLayout[dim] = 1;
+ }
- return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
- shape);
+ return LayoutAttr::get(
+ getContext(),
+ sgLayout.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), sgLayout),
+ getSgData(), getInstData(),
+ laneLayout.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), laneLayout),
+ getLaneData(), getOrder());
}
//===----------------------------------------------------------------------===//
@@ -376,34 +521,43 @@ SliceAttr SliceAttr::flatten() const {
}
FailureOr<SmallVector<Value>>
-SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
- Value linearId) {
+SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
SliceAttr attr = flatten();
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
- return parent.delinearizeSubgroupId(builder, loc, linearId);
+ return parent.delinearizeId(builder, loc, linearId);
}
-/// Implements DistributeLayoutAttr::getOffsets to generate
-/// instructions for computing multi-dimensional offsets when distributed by
-/// SliceAttr.
+// Implements DistributeLayoutAttr::computeDistributedCoords to generate
+// instructions for computing multi-dimensional offsets when distributed by
+// LayoutAttr.
FailureOr<SmallVector<SmallVector<Value>>>
-SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
- ArrayRef<int64_t> shape) {
+SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
+ Value linearId, ArrayRef<int64_t> shape) {
assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
if (!isForWorkgroup())
return failure();
- SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
- if (sgShape.empty()) {
- if (auto derivedShape = computeShapeRatio(shape, sgLayout))
- sgShape = derivedShape.value();
+ SmallVector<int64_t> layout;
+ SmallVector<int64_t> subShape;
+ if (isForWorkgroup()) {
+ layout = getEffectiveSgLayoutAsInt();
+ subShape = getEffectiveSgDataAsInt();
+ } else if (isForSubgroup()) {
+ layout = getEffectiveLaneLayoutAsInt();
+ subShape = getEffectiveLaneDataAsInt();
+ } else {
+ return failure();
+ }
+
+ if (subShape.empty()) {
+ if (auto derivedShape = computeShapeRatio(shape, layout))
+ subShape = derivedShape.value();
else
return failure();
}
// delinearize Ids
- auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+ auto maybeIds = delinearizeId(builder, loc, linearId);
if (failed(maybeIds))
return failure();
@@ -413,8 +567,7 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
SmallVector<Value> sgIds =
XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
- return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
- shape);
+ return genCoordinates(builder, loc, sgIds, layout, subShape, shape);
}
bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
@@ -437,6 +590,69 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
[&](int64_t dim) { return thisDims.contains(dim); });
}
+bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
+ if (dyn_cast<xegpu::LayoutAttr>(other))
+ return false;
+
+ auto flattenedThis = flatten();
+ auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
+
+ return ((flattenedThis.getParent() == flattenedOther.getParent()) &&
+ (flattenedThis.getDims() == flattenedOther.getDims()));
+}
+
+// Helper function to adjust unit dimensions from sliced space to parent space
+static SetVector<int64_t>
+adjustUnitDimsWithSliceDims(const SetVector<int64_t> &unitDims,
+ ArrayRef<int64_t> sliceDims) {
+ // Reconstruct parent's non-sliced dimensions
+
+ int64_t parentRank = sliceDims.size() + unitDims.size();
+ llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(),
+ sliceDims.end());
+ SmallVector<int64_t> nonSlicedDims;
+ for (int64_t i = 0; i < parentRank; ++i) {
+ if (!slicedDimsSet.contains(i))
+ nonSlicedDims.push_back(i);
+ }
+
+ // Map unit dims from sliced space to parent space
+ SetVector<int64_t> adjustUnitDims;
+ for (auto dim : unitDims) {
+ if (dim < static_cast<int64_t>(nonSlicedDims.size())) {
+ adjustUnitDims.insert(nonSlicedDims[dim]);
+ }
+ }
+
+ return adjustUnitDims;
+}
+
+// set the layout for unit dims: sg_data, inst_data and lane_data to 1
+DistributeLayoutAttr SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) {
+ SliceAttr attr = flatten();
+ ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+
+ SetVector<int64_t> adjustUnitDims =
+ adjustUnitDimsWithSliceDims(unitDims, sliceDims);
+
+ return SliceAttr::get(getContext(), parent.setUnitDimData(adjustUnitDims),
+ attr.getDims());
+}
+
+// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
+DistributeLayoutAttr SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
+ SliceAttr attr = flatten();
+ ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+
+ SetVector<int64_t> adjustUnitDims =
+ adjustUnitDimsWithSliceDims(unitDims, sliceDims);
+
+ return SliceAttr::get(getContext(), parent.setUnitDimLayout(adjustUnitDims),
+ attr.getDims());
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_RangeAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index abd12e2..91ba07a 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -175,13 +175,13 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
LogicalResult
IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
- UnitAttr subgroup_block_io,
+ UnitAttr subgroup_block_io, DistributeLayoutAttr layout,
function_ref<InFlightDiagnostic()> emitError) {
if (!dataTy) {
if (subgroup_block_io)
return emitError() << "subgroup_block_io "
- "are only allowed when result is a 1D VectorType.";
+ "are only allowed when result is a VectorType.";
else
return success();
}
@@ -192,15 +192,37 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
ArrayRef<int64_t> dataShape = dataTy.getShape();
ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+ SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
+ ArrayAttr strideAttr = mdescTy.getStrideAttr();
+ SmallVector<int64_t> strides;
+ for (Attribute attr : strideAttr.getValue()) {
+ strides.push_back(cast<IntegerAttr>(attr).getInt());
+ }
+ if (subgroup_block_io && layout) {
+ auto laneData = layout.getEffectiveLaneDataAsInt();
+ auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+ if (!laneData.empty()) {
+ bool isLaneDataContiguous =
+ std::all_of(laneData.begin(), std::prev(laneData.end()),
+ [](int x) { return x == 1; });
+ if (!isLaneDataContiguous)
+ return emitError() << "With subgroup_block_io, accessed data must be "
+ "contiguous and coalesced.";
+ for (size_t i = 0; i < laneData.size(); ++i) {
+ if (laneLayout[i] != blockShape[i])
+ return emitError() << "With subgroup_block_io, the block shape must "
+ "match the lane layout.";
+ if (laneLayout[i] != 1 && strides[i] != 1)
+ return emitError() << "With subgroup_block_io, the distributed "
+ "dimensions must be contiguous.";
+ }
+ }
+ }
if (dataShape.size() == 2) {
- if (subgroup_block_io)
- return emitError() << "subgroup_block_io "
- "are only allowed when result is a 1D VectorType.";
if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
return emitError() << "data shape must not exceed mem_desc shape.";
} else {
- SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
// if the subgroup_block_io attribute is set, mdescTy must have block
// attribute
if (subgroup_block_io && !blockShape.size())
@@ -258,8 +280,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
// if shape and strides are from Memref, we don't need attributes for them
- // to keep the IR print clean.
- if (staticShape == memrefShape && staticStrides == memrefStrides) {
+ // to keep the IR print clean (only do so for full-static case, otherwise
+ // printer would fail trying to print empty array-attr).
+ if (staticShape == memrefShape && staticStrides == memrefStrides &&
+ dynamicShape.empty() && dynamicStrides.empty()) {
staticShapeAttr = DenseI64ArrayAttr();
staticStridesAttr = DenseI64ArrayAttr();
}
@@ -320,8 +344,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
// if shape and strides are from Memref, we don't need attributes for them
- // to keep the IR print clean.
- if (staticShape == memrefShape && staticStrides == memrefStrides) {
+ // to keep the IR print clean (only do so for full-static case, otherwise
+ // printer would fail trying to print empty array-attr).
+ if (staticShape == memrefShape && staticStrides == memrefStrides &&
+ dynamicShape.empty() && dynamicStrides.empty()) {
staticShapeAttr = DenseI64ArrayAttr();
staticStridesAttr = DenseI64ArrayAttr();
}
@@ -439,14 +465,15 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
xegpu::CachePolicyAttr l3_hint) {
return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(),
- l1_hint, l2_hint, l3_hint);
+ l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr);
}
void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
- xegpu::CachePolicyAttr l3_hint) {
+ xegpu::CachePolicyAttr l3_hint,
+ xegpu::DistributeLayoutAttr layout) {
SmallVector<Value> dynamicOffsets;
SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -454,7 +481,7 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
- l2_hint, l3_hint);
+ l2_hint, l3_hint, /*anchor_layout=*/layout);
}
LogicalResult PrefetchNdOp::verify() {
@@ -472,11 +499,8 @@ LogicalResult PrefetchNdOp::verify() {
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
int64_t tDescRank = tdescTy.getRank();
- int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
- int64_t constOffsetSize =
- getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
- if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
- ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
+ int64_t offsetSize = getMixedOffsets().size();
+ if (offsetSize != 0 && offsetSize != tDescRank)
return emitOpError(
"Mismatched ranks between offsets and tensor descriptor");
@@ -496,7 +520,7 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
return build(builder, state, retType, tensorDesc, ValueRange(),
DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint,
- l3_hint);
+ l3_hint, /*anchor_layout=*/nullptr);
}
void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
@@ -504,7 +528,8 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
UnitAttr packed, DenseI64ArrayAttr transpose,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
- xegpu::CachePolicyAttr l3_hint) {
+ xegpu::CachePolicyAttr l3_hint,
+ xegpu::DistributeLayoutAttr layout) {
SmallVector<Value> dynamicOffsets;
SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -512,7 +537,8 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
- packed, transpose, l1_hint, l2_hint, l3_hint);
+ packed, transpose, l1_hint, l2_hint, l3_hint,
+ /*anchor_layout=*/layout);
}
LogicalResult LoadNdOp::verify() {
@@ -597,11 +623,8 @@ LogicalResult LoadNdOp::verify() {
<< tdescTy;
int64_t tDescRank = tdescTy.getRank();
- int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
- int64_t constOffsetSize =
- getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
- if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
- ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
+ int64_t offsetSize = getMixedOffsets().size();
+ if (offsetSize != 0 && offsetSize != tDescRank)
return emitOpError(
"Mismatched ranks between offsets and tensor descriptor");
@@ -618,14 +641,16 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
xegpu::CachePolicyAttr l3_hint) {
return build(builder, state, value, tensorDesc, ValueRange(),
- DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
+ DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint,
+ /*anchor_layout=*/nullptr);
}
void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
- xegpu::CachePolicyAttr l3_hint) {
+ xegpu::CachePolicyAttr l3_hint,
+ xegpu::DistributeLayoutAttr layout) {
SmallVector<Value> dynamicOffsets;
SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -633,7 +658,7 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
- l1_hint, l2_hint, l3_hint);
+ l1_hint, l2_hint, l3_hint, /*anchor_layout=*/layout);
}
LogicalResult StoreNdOp::verify() {
@@ -691,11 +716,8 @@ LogicalResult StoreNdOp::verify() {
<< dstTy;
int64_t tDescRank = dstTy.getRank();
- int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
- int64_t constOffsetSize =
- getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
- if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
- ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
+ int64_t offsetSize = getMixedOffsets().size();
+ if (offsetSize != 0 && offsetSize != tDescRank)
return emitOpError(
"Mismatched ranks between offsets and tensor descriptor");
@@ -809,7 +831,7 @@ void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint,
- IntegerAttr{});
+ IntegerAttr{}, /*anchor_layout=*/nullptr);
}
//===----------------------------------------------------------------------===//
@@ -859,7 +881,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
- l1_hint, l2_hint, l3_hint);
+ l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr);
}
void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
@@ -875,7 +897,24 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
auto offset = vector::FromElementsOp::create(builder, loc, type, values);
build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
- l2_hint, l3_hint);
+ l2_hint, l3_hint, /*anchor_layout=*/nullptr);
+}
+
+void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
+ Type valueType, Value source,
+ ArrayRef<OpFoldResult> offsets, Value mask,
+ IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint,
+ DistributeLayoutAttr layout) {
+ auto loc = source.getLoc();
+ int64_t size = static_cast<int64_t>(offsets.size());
+ auto type = VectorType::get(size, builder.getIndexType());
+ auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+ auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+ build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
+ l2_hint, l3_hint, layout);
}
//===----------------------------------------------------------------------===//
@@ -926,7 +965,7 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
- l2_hint, l3_hint);
+ l2_hint, l3_hint, /*anchor_layout=*/nullptr);
}
void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
@@ -944,7 +983,23 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
// Call the correct builder overload that does not expect result types.
build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
- l3_hint);
+ l3_hint, /*anchor_layout=*/nullptr);
+}
+
+void StoreScatterOp::build(
+ OpBuilder &builder, OperationState &state, Value value, Value dest,
+ ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
+ xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
+ auto loc = dest.getLoc();
+ int64_t size = static_cast<int64_t>(offsets.size());
+ auto type = VectorType::get(size, builder.getIndexType());
+ auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+ auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+ // Call the correct builder overload that does not expect result types.
+ build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
+ l3_hint, layout);
}
//===----------------------------------------------------------------------===//
@@ -1105,7 +1160,7 @@ LogicalResult LoadMatrixOp::verify() {
MemDescType mdescTy = getMemDesc().getType();
return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
- [&]() { return emitError(); });
+ getLayoutAttr(), [&]() { return emitError(); });
}
//===----------------------------------------------------------------------===//
@@ -1129,7 +1184,7 @@ LogicalResult StoreMatrixOp::verify() {
UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
MemDescType mdescTy = getMemDesc().getType();
return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
- [&]() { return emitError(); });
+ getLayoutAttr(), [&]() { return emitError(); });
}
namespace mlir {
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000..48fe841
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRXeGPUTransformOps
+ XeGPUTransformOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${PROJECT_SOURCE_DIR}/mlir/Dialect/XeGPU/TransformOps/
+
+ DEPENDS
+ MLIRXeGPUTransformOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRXeGPUDialect
+ MLIRXeGPUTransforms
+ MLIRIR
+ MLIRTransformDialect
+ MLIRFuncDialect
+ MLIRSCFDialect
+)
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
new file mode 100644
index 0000000..e6009d5
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -0,0 +1,695 @@
+//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+
+#include <optional>
+
+#include "llvm/Support/DebugLog.h"
+#define DEBUG_TYPE "xegpu-transforms"
+
+using namespace mlir;
+using namespace mlir::transform;
+
+/// Assuming that `ofr` is an index attr or a param of index type
+/// or a transform dialect handle mapped to exactly one op
+/// with one index result, get that value and cast it to int type.
+static DiagnosedSilenceableFailure convertMixedValuesToInt(
+ transform::TransformState &state, TransformOpInterface transformOp,
+ SmallVectorImpl<int32_t> &result, ArrayRef<OpFoldResult> ofrs) {
+ for (OpFoldResult ofr : ofrs) {
+ // Attribute case.
+ if (auto attr = dyn_cast<Attribute>(ofr)) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+ result.push_back(intAttr.getInt());
+ continue;
+ }
+ return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+ }
+
+ // Transform param case.
+ Value transformValue = cast<Value>(ofr);
+ if (isa<TransformParamTypeInterface>(transformValue.getType())) {
+ ArrayRef<Attribute> params = state.getParams(transformValue);
+ if (params.size() != 1)
+ return transformOp.emitDefiniteFailure()
+ << "requires exactly one parameter associated";
+ result.push_back(
+ cast<IntegerAttr>(params.front()).getValue().getSExtValue());
+ continue;
+ }
+
+ // Payload value case.
+ auto payloadOps = state.getPayloadOps(transformValue);
+ if (!llvm::hasSingleElement(payloadOps)) {
+ DiagnosedSilenceableFailure diag =
+ transformOp.emitSilenceableError()
+ << "handle must be mapped to exactly one payload op";
+ diag.attachNote(transformValue.getLoc())
+ << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
+ return diag;
+ }
+
+ Operation *op = *payloadOps.begin();
+ if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+ DiagnosedSilenceableFailure diag =
+ transformOp.emitSilenceableError()
+ << "payload op must have exactly 1 index result";
+ diag.attachNote(op->getLoc())
+ << "has " << op->getNumResults() << " results";
+ return diag;
+ }
+
+ IntegerAttr intAttr;
+ if (!matchPattern(op->getResult(0), m_Constant(&intAttr)))
+ return transformOp.emitSilenceableError()
+ << "requires param or handle to be the result of a constant like "
+ "op";
+
+ result.push_back(intAttr.getInt());
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
+/// Find producer operation of type T for the given value.
+/// It's assumed that producer ops are chained through their first operand.
+/// Producer chain is traced trough loop block arguments (init values).
+template <typename T>
+static std::optional<T> findProducerOfType(Value val) {
+ Value currentValue = val;
+ if (!currentValue.getDefiningOp()) {
+ // Value may be a block argument initialized outside a loop.
+ if (val.getNumUses() == 0) {
+ LDBG() << "Failed to find producer op, value has no uses.";
+ return std::nullopt;
+ }
+ auto userOp = val.getUsers().begin();
+ auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
+ if (!parentLoop) {
+ LDBG() << "Failed to find producer op, not in a loop.";
+ return std::nullopt;
+ }
+ int64_t iterArgIdx;
+ if (auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) {
+ auto numInductionVars = parentLoop.getLoopInductionVars()->size();
+ iterArgIdx = iterArg.getArgNumber() - numInductionVars;
+ currentValue = parentLoop.getInits()[iterArgIdx];
+ } else {
+ LDBG() << "Failed to find producer op, value not in init values.";
+ return std::nullopt;
+ }
+ }
+ Operation *producerOp = currentValue.getDefiningOp();
+
+ if (auto matchingOp = dyn_cast<T>(producerOp))
+ return matchingOp;
+
+ if (producerOp->getNumOperands() == 0)
+ return std::nullopt;
+
+ return findProducerOfType<T>(producerOp->getOperand(0));
+}
+
+/// Create a layout attribute from the given parameters.
+static xegpu::LayoutAttr
+createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
+ ArrayRef<int32_t> sgData,
+ std::optional<ArrayRef<int32_t>> instData) {
+ return xegpu::LayoutAttr::get(
+ ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
+ DenseI32ArrayAttr::get(ctx, sgData),
+ instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
+ /*lane_layout=*/nullptr,
+ /*lane_data=*/nullptr,
+ /*order=*/nullptr);
+}
+
+/// Generate `xegpu::LayoutAttr` from op mixed layout values.
+DiagnosedSilenceableFailure
+getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state,
+ TransformOpInterface transformOp,
+ ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
+ ArrayRef<::mlir::OpFoldResult> mixedSgData,
+ ArrayRef<::mlir::OpFoldResult> mixedInstData,
+ xegpu::LayoutAttr &layoutAttr) {
+ SmallVector<int32_t> sgLayout, sgData, instData;
+ auto status =
+ convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout);
+ if (!status.succeeded())
+ return status;
+
+ status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData);
+ if (!status.succeeded())
+ return status;
+
+ status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData);
+ if (!status.succeeded())
+ return status;
+ auto maybeInstData = instData.empty()
+ ? std::nullopt
+ : std::optional<ArrayRef<int32_t>>(instData);
+
+ layoutAttr = createLayoutAttr(ctx, sgLayout, sgData, maybeInstData);
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+/// Replace xegpu.create_nd_desc op with a new one with the given layout.
+static xegpu::CreateNdDescOp
+setDescLayout(transform::TransformRewriter &rewriter,
+ xegpu::CreateNdDescOp descOp,
+ xegpu::DistributeLayoutAttr layout) {
+ assert(descOp.getMixedOffsets().size() == 0 &&
+ "create desc op with offsets is not supported");
+ auto oldTensorDesc = descOp.getType();
+ auto descType = xegpu::TensorDescType::get(
+ oldTensorDesc.getShape(), oldTensorDesc.getElementType(),
+ /*array_length=*/oldTensorDesc.getArrayLength(),
+ /*boundary_check=*/oldTensorDesc.getBoundaryCheck(),
+ /*memory_space=*/oldTensorDesc.getMemorySpace(),
+ /*layout=*/layout);
+
+ rewriter.setInsertionPointAfter(descOp);
+ auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
+ descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
+ descOp.getMixedStrides());
+ return newDescOp;
+}
+
+DiagnosedSilenceableFailure
+transform::GetDescOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto targetValues = state.getPayloadValues(getTarget());
+ if (!llvm::hasSingleElement(targetValues)) {
+ return emitDefiniteFailure()
+ << "requires exactly one target value handle (got "
+ << llvm::range_size(targetValues) << ")";
+ }
+
+ auto maybeDescOp =
+ findProducerOfType<xegpu::CreateNdDescOp>(*targetValues.begin());
+ if (!maybeDescOp) {
+ return emitSilenceableFailure(getLoc())
+ << "Could not find a matching descriptor op when walking the "
+ "producer chain of the first operand.";
+ }
+
+ results.set(llvm::cast<OpResult>(getResult()), {*maybeDescOp});
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetDescLayoutOp::build(OpBuilder &builder,
+ OperationState &result, Value target,
+ ArrayRef<OpFoldResult> mixedSgLayout,
+ ArrayRef<OpFoldResult> mixedSgData,
+ ArrayRef<OpFoldResult> mixedInstData,
+ ArrayRef<int64_t> sliceDims) {
+ SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
+ SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
+ dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
+ dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
+ dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
+ build(builder, result, target.getType(),
+ /*target=*/target,
+ /*sg_layout=*/dynamicSgLayout,
+ /*sg_data=*/dynamicSgData,
+ /*inst_data=*/dynamicInstData,
+ /*static_sg_layout=*/staticSgLayout,
+ /*static_sg_data=*/staticSgData,
+ /*static_inst_data=*/staticInstData,
+ /*slice_dims=*/sliceDims);
+}
+
+DiagnosedSilenceableFailure
+transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto targetOps = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(targetOps)) {
+ return emitDefiniteFailure() << "requires exactly one targetOp handle (got "
+ << llvm::range_size(targetOps) << ")";
+ }
+ Operation *target = *targetOps.begin();
+
+ xegpu::LayoutAttr layoutAttr = nullptr;
+ auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
+ getMixedSgLayout(), getMixedSgData(),
+ getMixedInstData(), layoutAttr);
+ if (!status.succeeded())
+ return status;
+
+ xegpu::DistributeLayoutAttr layout = layoutAttr;
+ auto sliceDims = getSliceDims();
+ if (sliceDims.size() > 0) {
+ // Wrap layoutAttr in a slice attribute.
+ layout = xegpu::SliceAttr::get(
+ getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
+ }
+
+ // For now only create_nd_desc op is supported.
+ auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
+ if (!descOp) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Expected a xegpu.create_nd_desc op, but got: "
+ << target->getName();
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ // Set layout attr in desc op's return type. Replaces old desc op.
+ auto newdescOp = setDescLayout(rewriter, descOp, layout);
+
+ // Map result handles.
+ results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetDescLayoutOp::getEffects(
+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getSgLayoutMutable(), effects);
+ onlyReadsHandle(getSgDataMutable(), effects);
+ onlyReadsHandle(getInstDataMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
+}
+
+void transform::SetOpLayoutAttrOp::build(
+ OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
+ ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
+ ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int64_t> sliceDims,
+ bool result) {
+ SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
+ SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
+ dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
+ dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
+ dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
+ build(builder, ostate, target.getType(),
+ /*target=*/target,
+ /*index=*/index,
+ /*sg_layout=*/dynamicSgLayout,
+ /*sg_data=*/dynamicSgData,
+ /*inst_data=*/dynamicInstData,
+ /*static_sg_layout=*/staticSgLayout,
+ /*static_sg_data=*/staticSgData,
+ /*static_inst_data=*/staticInstData,
+ /*slice_dims=*/sliceDims,
+ /*result=*/result);
+}
+
+DiagnosedSilenceableFailure
+transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto targetOps = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(targetOps)) {
+ return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
+ << llvm::range_size(targetOps) << ")";
+ }
+ Operation *target = *targetOps.begin();
+
+ bool resultTarget = getResult();
+
+ int64_t index = getIndex();
+ if (resultTarget && index >= target->getNumResults()) {
+ return emitSilenceableFailure(getLoc())
+ << "Index exceeds the number of op results";
+ }
+ if (!resultTarget && index >= target->getNumOperands()) {
+ return emitSilenceableFailure(getLoc())
+ << "Index exceeds the number of op operands";
+ }
+
+ xegpu::LayoutAttr layoutAttr = nullptr;
+ auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
+ getMixedSgLayout(), getMixedSgData(),
+ getMixedInstData(), layoutAttr);
+ if (!status.succeeded())
+ return status;
+
+ xegpu::DistributeLayoutAttr layout = layoutAttr;
+ auto sliceDims = getSliceDims();
+ if (sliceDims.size() > 0) {
+ // Wrap layoutAttr in a slice attribute.
+ layout = xegpu::SliceAttr::get(
+ getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
+ }
+
+ // Set layout attribute for the op result or operand
+ if (resultTarget)
+ xegpu::setDistributeLayoutAttr(target->getResult(index), layout);
+ else
+ xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layout);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetOpLayoutAttrOp::getEffects(
+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getSgLayoutMutable(), effects);
+ onlyReadsHandle(getSgDataMutable(), effects);
+ onlyReadsHandle(getInstDataMutable(), effects);
+ modifiesPayload(effects);
+}
+
+void transform::SetGPULaunchThreadsOp::build(
+ OpBuilder &builder, OperationState &ostate, Value target,
+ ArrayRef<OpFoldResult> mixedThreads) {
+ SmallVector<int64_t> staticThreads;
+ SmallVector<Value> dynamicThreads;
+ dispatchIndexOpFoldResults(mixedThreads, dynamicThreads, staticThreads);
+ build(builder, ostate, target.getType(),
+ /*target=*/target,
+ /*threads=*/dynamicThreads,
+ /*static_threads=*/staticThreads);
+}
+
+DiagnosedSilenceableFailure
+transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto targetOps = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(targetOps)) {
+ return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
+ << llvm::range_size(targetOps) << ")";
+ }
+ Operation *target = *targetOps.begin();
+
+ auto launchOp = dyn_cast<gpu::LaunchOp>(target);
+ if (!launchOp) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Expected a gpu.launch op, but got: " << target->getName();
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ SmallVector<int32_t> threads;
+ DiagnosedSilenceableFailure status =
+ convertMixedValuesToInt(state, (*this), threads, getMixedThreads());
+ if (!status.succeeded())
+ return status;
+
+ if (threads.size() != 3) {
+ return emitSilenceableFailure(getLoc())
+ << "Expected threads argument to consist of three values (got "
+ << threads.size() << ")";
+ }
+
+ rewriter.setInsertionPoint(launchOp);
+ auto createConstValue = [&](int value) {
+ return arith::ConstantIndexOp::create(rewriter, launchOp.getLoc(), value);
+ };
+
+ // Replace threads in-place.
+ launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0]));
+ launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1]));
+ launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2]));
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetGPULaunchThreadsOp::getEffects(
+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getThreadsMutable(), effects);
+ modifiesPayload(effects);
+}
+
+DiagnosedSilenceableFailure
+transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto targetValues = state.getPayloadValues(getTarget());
+ if (!llvm::hasSingleElement(targetValues))
+ return emitDefiniteFailure()
+ << "requires exactly one target value handle (got "
+ << llvm::range_size(targetValues) << ")";
+ auto value = *targetValues.begin();
+
+ int64_t nbPrefetch = getStaticNbPrefetch();
+ if (getDynamicNbPrefetch()) {
+ // Get dynamic prefetch count from transform param or handle.
+ SmallVector<int32_t> dynamicNbPrefetch;
+ auto status = convertMixedValuesToInt(state, (*this), dynamicNbPrefetch,
+ {getDynamicNbPrefetch()});
+ if (!status.succeeded())
+ return status;
+ if (dynamicNbPrefetch.size() != 1)
+ return emitDefiniteFailure()
+ << "requires exactly one value for dynamic_nb_prefetch";
+ nbPrefetch = dynamicNbPrefetch[0];
+ }
+ if (nbPrefetch <= 0)
+ return emitSilenceableFailure(getLoc())
+ << "nb_prefetch must be a positive integer.";
+
+ // Find load operation of the operand.
+ auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
+ if (!maybeLoadOp)
+ return emitSilenceableFailure(getLoc()) << "Could not find load op.";
+ auto loadOp = *maybeLoadOp;
+ if (loadOp.getMixedOffsets().size() == 0) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Load op must have offsets.";
+ diag.attachNote(loadOp.getLoc()) << "load op";
+ return diag;
+ }
+
+ // Find the parent scf.for loop.
+ auto forOp = loadOp->getParentOfType<scf::ForOp>();
+ if (!forOp) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Load op is not contained in a scf.for loop.";
+ diag.attachNote(loadOp.getLoc()) << "load op";
+ return diag;
+ }
+
+ // Find descriptor op.
+ auto maybeDescOp = findProducerOfType<xegpu::CreateNdDescOp>(value);
+ if (!maybeDescOp)
+ return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
+ auto descOp = *maybeDescOp;
+ if (descOp.getMixedOffsets().size() > 0) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "desc op with offsets is not supported.";
+ diag.attachNote(descOp.getLoc()) << "desc op";
+ }
+
+ // Clone desc op outside the loop.
+ rewriter.setInsertionPoint(forOp);
+ auto newDescOp =
+ cast<xegpu::CreateNdDescOp>(rewriter.clone(*descOp.getOperation()));
+
+ // Clone reduction loop to emit initial prefetches.
+ // Compute upper bound of the init loop: start + nbPrefetch * step.
+ auto nbPrefetchCst =
+ arith::ConstantIndexOp::create(rewriter, forOp.getLoc(), nbPrefetch);
+ auto nbStep = rewriter.createOrFold<arith::MulIOp>(
+ forOp.getLoc(), nbPrefetchCst, forOp.getStep());
+ auto initUpBound = rewriter.createOrFold<arith::AddIOp>(
+ forOp.getLoc(), forOp.getLowerBound(), nbStep);
+ auto initForOp =
+ scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
+ initUpBound, forOp.getStep());
+
+ auto ctx = rewriter.getContext();
+ auto readCacheHint =
+ xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED);
+
+ // Modify loadOp mixedOffsets by replacing the for loop induction variable
+ // with the given value.
+ auto getPrefetchOffsets =
+ [&](Value replacementVal) -> SmallVector<OpFoldResult> {
+ IRMapping mapping;
+ mapping.map(forOp.getInductionVar(), replacementVal);
+ SmallVector<Value> dynamicOffsets =
+ llvm::to_vector(llvm::map_range(loadOp.getOffsets(), [&](Value v) {
+ return mapping.lookupOrDefault(v);
+ }));
+ auto constOffsets = loadOp.getConstOffsets().value();
+ return getMixedValues(constOffsets, dynamicOffsets, ctx);
+ };
+
+ // Insert prefetch op in init loop.
+ // Replace induction var with the init loop induction var.
+ rewriter.setInsertionPointToStart(initForOp.getBody());
+ xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
+ newDescOp.getResult(),
+ getPrefetchOffsets(initForOp.getInductionVar()),
+ readCacheHint, readCacheHint, readCacheHint,
+ /*layout=*/nullptr);
+
+ // Insert prefetch op in main loop.
+ // Calculate prefetch offset after the init prefetches have been issued.
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(),
+ forOp.getInductionVar(), nbStep);
+ // Replace induction var with correct offset.
+ xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
+ newDescOp.getResult(),
+ getPrefetchOffsets(prefetchOffset), readCacheHint,
+ readCacheHint, readCacheHint, /*layout=*/nullptr);
+
+ // Unroll the init loop.
+ if (failed(loopUnrollFull(initForOp)))
+ return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop";
+
+ results.set(llvm::cast<OpResult>(getResult()), {newDescOp});
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::InsertPrefetchOp::getEffects(
+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getDynamicNbPrefetchMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
+}
+
+void transform::ConvertLayoutOp::build(
+ OpBuilder &builder, OperationState &ostate, Value target,
+ ArrayRef<OpFoldResult> mixedInputSgLayout,
+ ArrayRef<OpFoldResult> mixedInputSgData,
+ ArrayRef<OpFoldResult> mixedInputInstData,
+ ArrayRef<OpFoldResult> mixedTargetSgLayout,
+ ArrayRef<OpFoldResult> mixedTargetSgData,
+ ArrayRef<OpFoldResult> mixedTargetInstData) {
+ SmallVector<int64_t> staticInputSgLayout, staticInputSgData,
+ staticInputInstData;
+ SmallVector<Value> dynamicInputSgLayout, dynamicInputSgData,
+ dynamicInputInstData;
+ dispatchIndexOpFoldResults(mixedInputSgLayout, dynamicInputSgLayout,
+ staticInputSgLayout);
+ dispatchIndexOpFoldResults(mixedInputSgData, dynamicInputSgData,
+ staticInputSgData);
+ dispatchIndexOpFoldResults(mixedInputInstData, dynamicInputInstData,
+ staticInputInstData);
+ SmallVector<int64_t> staticTargetSgLayout, staticTargetSgData,
+ staticTargetInstData;
+ SmallVector<Value> dynamicTargetSgLayout, dynamicTargetSgData,
+ dynamicTargetInstData;
+ dispatchIndexOpFoldResults(mixedTargetSgLayout, dynamicTargetSgLayout,
+ staticTargetSgLayout);
+ dispatchIndexOpFoldResults(mixedTargetSgData, dynamicTargetSgData,
+ staticTargetSgData);
+ dispatchIndexOpFoldResults(mixedTargetInstData, dynamicTargetInstData,
+ staticTargetInstData);
+ build(builder, ostate, target.getType(),
+ /*target=*/target,
+ /*input_sg_layout=*/dynamicInputSgLayout,
+ /*input_sg_data=*/dynamicInputSgData,
+ /*input_inst_data=*/dynamicInputInstData,
+ /*target_sg_layout=*/dynamicTargetSgLayout,
+ /*target_sg_data=*/dynamicTargetSgData,
+ /*target_inst_data=*/dynamicTargetInstData,
+ /*static_input_sg_layout=*/staticInputSgLayout,
+ /*static_input_sg_data=*/staticInputSgData,
+ /*static_input_inst_data=*/staticInputInstData,
+ /*static_target_sg_layout=*/staticTargetSgLayout,
+ /*static_target_sg_data=*/staticTargetSgData,
+ /*static_target_inst_data=*/staticTargetInstData);
+}
+
+DiagnosedSilenceableFailure
+transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto targetValues = state.getPayloadValues(getTarget());
+ if (!llvm::hasSingleElement(targetValues))
+ return emitDefiniteFailure()
+ << "requires exactly one target value handle (got "
+ << llvm::range_size(targetValues) << ")";
+ auto value = *targetValues.begin();
+
+ // Construct layout attributes.
+ xegpu::LayoutAttr inputLayoutAttr = nullptr;
+ auto status = getLayoutAttrFromOperands(
+ getContext(), state, (*this), getMixedInputSgLayout(),
+ getMixedInputSgData(), getMixedInputInstData(), inputLayoutAttr);
+ if (!status.succeeded())
+ return status;
+
+ xegpu::LayoutAttr targetLayoutAttr = nullptr;
+ status = getLayoutAttrFromOperands(
+ getContext(), state, (*this), getMixedTargetSgLayout(),
+ getMixedTargetSgData(), getMixedTargetInstData(), targetLayoutAttr);
+ if (!status.succeeded())
+ return status;
+
+ // Find first user op to define insertion point for layout conversion.
+ if (value.use_empty())
+ return emitSilenceableFailure(getLoc())
+ << "Value has no users to insert layout conversion.";
+ Operation *userOp = *value.getUsers().begin();
+
+ // Emit convert_layout op.
+ rewriter.setInsertionPoint(userOp);
+ auto convLayoutOp =
+ xegpu::ConvertLayoutOp::create(rewriter, value.getLoc(), value.getType(),
+ value, inputLayoutAttr, targetLayoutAttr);
+ // Replace load op result with the converted layout.
+ rewriter.replaceUsesWithIf(
+ value, convLayoutOp.getResult(), [&](OpOperand &use) {
+ return use.getOwner() != convLayoutOp.getOperation();
+ });
+
+ results.set(llvm::cast<OpResult>(getResult()), {convLayoutOp});
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::ConvertLayoutOp::getEffects(
+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getInputSgLayoutMutable(), effects);
+ onlyReadsHandle(getInputSgDataMutable(), effects);
+ onlyReadsHandle(getInputInstDataMutable(), effects);
+ onlyReadsHandle(getTargetSgLayoutMutable(), effects);
+ onlyReadsHandle(getTargetSgDataMutable(), effects);
+ onlyReadsHandle(getTargetInstDataMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
+}
+
+namespace {
+class XeGPUTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ XeGPUTransformDialectExtension> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension)
+
+ using Base::Base;
+
+ void init();
+};
+
+void XeGPUTransformDialectExtension::init() {
+ declareGeneratedDialect<scf::SCFDialect>();
+ declareGeneratedDialect<arith::ArithDialect>();
+ declareGeneratedDialect<xegpu::XeGPUDialect>();
+
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
+ >();
+}
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
+
+void mlir::xegpu::registerTransformDialectExtension(DialectRegistry &registry) {
+ registry.addExtensions<XeGPUTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index e6f7606..29b645f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
XeGPUWgToSgDistribute.cpp
XeGPUPropagateLayout.cpp
XeGPUVectorLinearize.cpp
+ XeGPUOptimizeBlockLoads.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
new file mode 100644
index 0000000..ab41fe4
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
@@ -0,0 +1,490 @@
+//===- XeGPUOptimizeBlockLoads.cpp - XeGPU optimize block loads -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
+#include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include <optional>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUOPTIMIZEBLOCKLOADS
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-optimize-block-loads"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+using namespace mlir;
+
+namespace {
+
+/// Get the 2D lane data from a tensor desc type if it exists.
+static std::optional<SmallVector<int64_t>>
+getMaybeLaneData(xegpu::TensorDescType tdescType) {
+ auto layout = tdescType.getLayoutAttr();
+ if (!layout)
+ return std::nullopt;
+ auto laneData = layout.getEffectiveLaneDataAsInt();
+ if (laneData.size() != 2)
+ return std::nullopt;
+ return laneData;
+}
+
+/// Get the 2D lane layout from a tensor desc type if it exists.
+static std::optional<SmallVector<int64_t>>
+getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
+ auto layout = tdescType.getLayoutAttr();
+ if (!layout)
+ return std::nullopt;
+ auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+ if (laneLayout.size() != 2)
+ return std::nullopt;
+ return laneLayout;
+}
+
+/// A layout can be optimized if its lane layout is transposed (lane[0] != 1 &&
+/// lane[1] == 1), but inner lane data is not equal to [1, 1].
+/// Example:
+/// !xegpu.tensor_desc<16x16xf16,
+/// #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
+/// In this case, lane layout is transposed (from the usual [1, SG_SIZE] form)
+/// indicating that this is a load that requires transpose effect. However,
+/// lane data is [1, 2], meaning that each lane must grab 2 f16 elements from
+/// the inner dimension. We convert this to a optimized form by converting the
+/// tensor_desc to i32 type such that lane data becomes [1, 1]. This makes the
+/// later lowering easily use the load with transpose instruction.
+static bool canBeOptimizedForTranspose(ArrayRef<int64_t> laneLayout,
+ ArrayRef<int64_t> laneData) {
+ if (laneLayout.size() != 2 || laneData.size() != 2)
+ return false;
+ if (laneLayout[0] == 1 || laneLayout[1] != 1)
+ return false;
+ if (laneData[0] != 1 || laneData[1] == 1)
+ return false;
+ return true;
+}
+
+/// A tensor desc type can be optimized if its element type is less than 32 bits
+/// and its layout can be optimized.
+static bool canBeOptimizedForTranspose(xegpu::TensorDescType tdescType) {
+ // If the dtype is greater or equal to 32 bits, layout must be valid.
+ int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
+ if (elementTyBitwidth >= 32)
+ return false;
+ auto maybeLaneLayout = getMaybeLaneLayout(tdescType);
+ auto maybeLaneData = getMaybeLaneData(tdescType);
+ if (!maybeLaneData || !maybeLaneLayout)
+ return false;
+ return canBeOptimizedForTranspose(*maybeLaneLayout, *maybeLaneData);
+}
+
+/// Check if a tensor desc type can be optimized for transpose, if so return the
+/// new optimized tensor desc type with a valid transpose layout.
+static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType,
+ const uArch *targetuArch) {
+ if (!canBeOptimizedForTranspose(tdescType))
+ return tdescType;
+ auto laneData = getMaybeLaneData(tdescType)
+ .value(); // Lane data must exist if we reach here.
+ int64_t innerLaneData = laneData[1];
+ int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
+ // Required shape is total shape of the vector result that this tensor desc
+ // must eventually load after adjusting for the new bitwidth and array
+ // length.
+ SmallVector<int64_t> requiredShape(tdescType.getShape());
+ requiredShape.back() =
+ requiredShape.back() * tdescType.getArrayLength() / innerLaneData;
+ int newBitWidth = elementTyBitwidth * innerLaneData;
+ Type newElemTy = IntegerType::get(tdescType.getContext(), newBitWidth);
+ // Supported shape is the max transpose shape that can be supported by
+ // hardware that is less than or equal to required shape.
+ auto *blockLoadTarget = dyn_cast<Subgroup2DBlockLoadInstruction>(
+ targetuArch->getInstruction(InstructionKind::Subgroup2DBlockLoad));
+ auto maybeHWParams = blockLoadTarget->getBlockWidthHeightCount(
+ newElemTy, /** has transform */ false, /** has transpose */ true);
+ // If no HW params found, return the original type.
+ if (!maybeHWParams)
+ return tdescType;
+ auto [widths, heights, counts] = maybeHWParams.value();
+ // TODO: Currently we expect array length to be 1 for transpose case.
+ if (counts.size() != 1 || counts[0] != 1)
+ return tdescType;
+ int arrayLen = counts[0];
+ int supportedHeight =
+ xegpu::getLargestDivisor(static_cast<int>(requiredShape[0]), heights);
+ int supportedWidth =
+ xegpu::getLargestDivisor(static_cast<int>(requiredShape[1]), widths);
+ // If no supported height or width found, return the original type.
+ if (supportedHeight == -1 || supportedWidth == -1)
+ return tdescType;
+
+ SmallVector<int64_t> supportedShape = {supportedHeight, supportedWidth};
+ xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
+ tdescType.getContext(),
+ tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1});
+ // Array length can not be larger than 1 for transpose case.
+ return xegpu::TensorDescType::get(supportedShape, newElemTy, arrayLen,
+ tdescType.getBoundaryCheck(),
+ tdescType.getMemorySpace(), newLayout);
+}
+
+/// Helper to convert an OpFoldResult to Value.
+static Value convertToValue(ConversionPatternRewriter &rewriter, Location loc,
+ OpFoldResult ofr) {
+ std::optional<int64_t> mayBeInt = getConstantIntValue(ofr);
+ if (mayBeInt)
+ return arith::ConstantIndexOp::create(rewriter, loc, *mayBeInt).getResult();
+ return llvm::cast<Value>(ofr);
+}
+
+/// Helper to divide a Value by a constant integer.
+static Value divideByConstant(ConversionPatternRewriter &rewriter, Location loc,
+ Value val, int64_t constant) {
+ // If the constant is a power of 2, use right shift for division.
+ if (llvm::isPowerOf2_64(constant)) {
+ int64_t shiftAmount = llvm::Log2_64(constant);
+ return arith::ShRUIOp::create(
+ rewriter, loc, val,
+ arith::ConstantIndexOp::create(rewriter, loc, shiftAmount)
+ .getResult())
+ .getResult();
+ }
+ auto constantOp =
+ arith::ConstantIndexOp::create(rewriter, loc, constant).getResult();
+ return arith::DivUIOp::create(rewriter, loc, val, constantOp).getResult();
+}
+
+/// This function takes a larger register block `data` and generates multiple
+/// smaller loads (size given by `newTensorDesc`) to fill in the `data` block
+/// starting from `offsets`.
+static Value generateLoads(ConversionPatternRewriter &rewriter,
+ TypedValue<VectorType> data,
+ SmallVector<OpFoldResult> offsets,
+ TypedValue<xegpu::TensorDescType> newTensorDesc,
+ xegpu::LoadNdOp origLoadOp) {
+ Location loc = data.getLoc();
+ assert(offsets.size() >= 2 && "Expecting at least 2 offsets for 2D LoadNdOp");
+ Value offsetDim0 = convertToValue(rewriter, loc, offsets[offsets.size() - 2]);
+ Value offsetDim1 = convertToValue(rewriter, loc, offsets[offsets.size() - 1]);
+ SmallVector<int64_t> supportedShape(newTensorDesc.getType().getShape());
+ // Compute the ratio between original shape and supported shape. We need to
+ // generate loads in this ratio arrangement.
+ auto shapeRatio = computeShapeRatio(data.getType().getShape(),
+ supportedShape)
+ .value(); // `ratio` must be defined if we reach here.
+ for (int64_t h = 0; h < shapeRatio[0]; ++h) {
+ for (int64_t w = 0; w < shapeRatio[1]; ++w) {
+ int64_t localOffsetDim0 = h * supportedShape[0];
+ int64_t localOffsetDim1 = w * supportedShape[1];
+ Value loadOffsetX = arith::AddIOp::create(
+ rewriter, loc, offsetDim0,
+ arith::ConstantIndexOp::create(rewriter, loc, localOffsetDim0)
+ .getResult());
+ Value loadOffsetY = arith::AddIOp::create(
+ rewriter, loc, offsetDim1,
+ arith::ConstantIndexOp::create(rewriter, loc, localOffsetDim1)
+ .getResult());
+ auto loadOp = xegpu::LoadNdOp::create(
+ rewriter, loc,
+ VectorType::get(supportedShape, data.getType().getElementType()),
+ newTensorDesc, ArrayRef<OpFoldResult>{loadOffsetX, loadOffsetY},
+ origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
+ origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
+ origLoadOp.getL3HintAttr(), origLoadOp.getLayoutAttr());
+ // Set the layout for the loadOp.
+ auto layoutAttr = newTensorDesc.getType().getLayoutAttr();
+ xegpu::setDistributeLayoutAttr(loadOp->getOpResult(0), layoutAttr);
+ // Insert the loaded block into the right position in data.
+ auto insertOp = vector::InsertStridedSliceOp::create(
+ rewriter, loc, loadOp.getResult(), data,
+ ArrayRef<int64_t>{localOffsetDim0, localOffsetDim1},
+ ArrayRef<int64_t>{1, 1});
+ // InsertOp must have the same layout as newTensorDesc.
+ xegpu::setDistributeLayoutAttr(insertOp->getOpResult(0), layoutAttr);
+ data = insertOp.getResult();
+ }
+ }
+ return data;
+}
+
+/// Checks if a CreateNdDescOp can be optimized for transpose, if so creates a
+/// new CreateNdDescOp with optimized tensor desc type. This involves extracting
+/// the base pointer from the original memory source and adjusting the shape and
+/// strides of the tensor desc to fit with the new optimized transpose layout.
+class XeGPUCreateNdDescOpPattern final
+ : public OpConversionPattern<xegpu::CreateNdDescOp> {
+public:
+ using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto tdescTy = createNdOp.getType();
+ // Get the target uArch info.
+ auto chipStr = xegpu::getChipStr(createNdOp);
+ // Check if the chip is supported.
+ assert(
+ chipStr && (chipStr.value() == "pvc" || chipStr.value() == "bmg") &&
+ "Expecting target chip to be pvc or bmg for transpose optimization.");
+ const uArch *targetuArch = xegpu::uArch::getUArch(chipStr.value());
+
+ auto convertType = tryOptimize(tdescTy, targetuArch);
+ if (convertType == tdescTy)
+ return failure();
+ auto strides = createNdOp.getMixedStrides();
+ auto maybeConstInnerStride = getConstantIntValue(strides.back());
+ // Only row-major memrefs are expected for now.
+ if (!maybeConstInnerStride || *maybeConstInnerStride != 1)
+ return rewriter.notifyMatchFailure(
+ createNdOp, "Expecting row-major memref for transpose optimization.");
+ Value source = createNdOp.getSource();
+ auto optionalLaneData = getMaybeLaneData(tdescTy);
+ assert(optionalLaneData && "Expected 2D lane data");
+ auto laneData = optionalLaneData.value();
+ int64_t innerLaneData = laneData[1];
+ auto memrefType = dyn_cast<MemRefType>(source.getType());
+ // Inner dimension of the shape must be adjusted based on innerLaneData.
+ SmallVector<OpFoldResult> modifiedShape(createNdOp.getMixedSizes());
+ modifiedShape.back() = divideByConstant(
+ rewriter, createNdOp.getLoc(),
+ convertToValue(rewriter, createNdOp.getLoc(), modifiedShape.back()),
+ innerLaneData);
+ // Similarly, second to last stride must be adjusted.
+ assert(strides.size() >= 2 &&
+ "Expected at least 2 strides for CreateNdDescOp");
+ SmallVector<OpFoldResult> modifiedStrides(strides);
+ modifiedStrides[modifiedStrides.size() - 2] = divideByConstant(
+ rewriter, createNdOp.getLoc(),
+ convertToValue(rewriter, createNdOp.getLoc(),
+ modifiedStrides[modifiedStrides.size() - 2]),
+ innerLaneData);
+
+ // If the source is a static memref, we need to extract the pointer to
+ // base address.
+ if (memrefType && memrefType.hasStaticShape()) {
+ auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, createNdOp.getLoc(), source);
+ source = arith::IndexCastOp::create(rewriter, createNdOp.getLoc(),
+ rewriter.getI64Type(),
+ extractOp.getResult())
+ .getResult();
+ }
+ // Create a new CreateNdDescOp with the modified shape and converted type.
+ auto newCreateNdDescOp = xegpu::CreateNdDescOp::create(
+ rewriter, createNdOp.getLoc(), convertType, source, modifiedShape,
+ modifiedStrides);
+ rewriter.replaceOp(createNdOp, newCreateNdDescOp.getResult());
+ return success();
+ }
+};
+
+/// Checks if a LoadNdOp consumes a tensor desc type that was rewritten for
+/// tranpose optimization. If so, rewrites the LoadNdOp to to align with the
+/// adjusted tensor desc type. This can result in multiple LoadNdOps being
+/// generated to fill in the original load shape.
+class XeGPULoadNdDescOpPattern final
+ : public OpConversionPattern<xegpu::LoadNdOp> {
+public:
+ using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::LoadNdOp loadNdOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto origTensorDescType = loadNdOp.getTensorDescType();
+ auto adaptorType =
+ cast<xegpu::TensorDescType>(adaptor.getTensorDesc().getType());
+ if (adaptorType == origTensorDescType)
+ return failure();
+ // Offsets must be adjusted based on innerLaneData.
+ auto laneData = getMaybeLaneData(loadNdOp.getTensorDescType()).value();
+ int64_t innerLaneData = laneData[1];
+ auto offsets = loadNdOp.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(loadNdOp,
+ "Expecting offsets in LoadNd");
+ SmallVector<OpFoldResult> modifiedOffsets(offsets);
+ modifiedOffsets.back() = divideByConstant(
+ rewriter, loadNdOp.getLoc(),
+ convertToValue(rewriter, loadNdOp.getLoc(), modifiedOffsets.back()),
+ innerLaneData);
+ // Get the 2D data shape of this loadNdOp in its original type including
+ // array length.
+ SmallVector<int64_t> origDataShape(origTensorDescType.getShape());
+ // Adjust the data shape based on innerLaneData.
+ origDataShape.back() /= innerLaneData;
+ // HW supported shape is the new tensor desc shape after conversion.
+ SmallVector<int64_t> hwSupportedShape(adaptorType.getShape());
+ VectorType origVectorType =
+ VectorType::get(origDataShape, adaptorType.getElementType());
+ Value data;
+ // Orig data shape is 3D for the array length case.
+ if (origTensorDescType.getArrayLength() > 1) {
+ SmallVector<Value> arraySlices;
+ for (int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) {
+ Value slice = arith::ConstantOp::create(
+ rewriter, loadNdOp->getLoc(), origVectorType,
+ rewriter.getZeroAttr(origVectorType));
+ // Increase the Y offset for each array slice.
+ Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(),
+ modifiedOffsets.back());
+ modifiedOffsets.back() =
+ arith::AddIOp::create(
+ rewriter, loadNdOp->getLoc(), offsetY,
+ arith::ConstantIndexOp::create(rewriter, loadNdOp->getLoc(),
+ i * origDataShape[1])
+ .getResult())
+ .getResult();
+ slice = generateLoads(
+ rewriter, cast<TypedValue<VectorType>>(slice), modifiedOffsets,
+ cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
+ loadNdOp);
+ // BitCast back to original load shape without array length.
+ auto bitcastType = VectorType::get(origTensorDescType.getShape(),
+ origTensorDescType.getElementType());
+ auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
+ bitcastType, slice);
+ // BitCastOp must have the same layout as the original loadNdOp.
+ xegpu::setDistributeLayoutAttr(bitCastOp->getOpResult(0),
+ origTensorDescType.getLayoutAttr());
+ arraySlices.push_back(bitCastOp.getResult());
+ }
+ rewriter.replaceOpWithMultiple(loadNdOp, {arraySlices});
+ return success();
+ }
+ data = arith::ConstantOp::create(
+ rewriter, loadNdOp->getLoc(),
+ VectorType::get(origDataShape, adaptorType.getElementType()),
+ rewriter.getZeroAttr(origVectorType));
+ data = generateLoads(
+ rewriter, cast<TypedValue<VectorType>>(data), modifiedOffsets,
+ cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
+ loadNdOp);
+ auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
+ loadNdOp.getType(), data);
+ // BitCastOp must have the same layout as the original loadNdOp.
+ xegpu::setDistributeLayoutAttr(bitCastOp->getOpResult(0),
+ origTensorDescType.getLayoutAttr());
+ rewriter.replaceOp(loadNdOp, bitCastOp);
+ return success();
+ }
+};
+
+/// Vector ExtractOp must be processed if the original tensor desc type has
+/// array length greater than 1. In this case, the LoadNdOp is replaced with
+/// multiple LoadNdOps for each array slice making the extraction unnecessary.
+/// In this case, we simply remove the ExtractOp.
+class VectorExtractOpPattern final
+ : public OpConversionPattern<vector::ExtractOp> {
+public:
+ using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(vector::ExtractOp extractOp, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Check if the source of the extraction is split to multiple values.
+ if (adaptor.getSource().size() == 1)
+ return failure();
+ auto mixedPos = extractOp.getMixedPosition();
+ if (mixedPos.size() != 1)
+ return failure();
+ auto mayBeInt = getConstantIntValue(mixedPos[0]);
+ if (!mayBeInt)
+ return failure();
+ rewriter.replaceOp(extractOp, adaptor.getSource()[*mayBeInt]);
+ return success();
+ }
+};
+
+} // namespace
+
+void xegpu::populateXeGPUOptimizeBlockLoadsPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern,
+ VectorExtractOpPattern>(patterns.getContext());
+}
+
+namespace {
+
+struct XeGPUOptimizeBlockLoadsPass final
+ : public xegpu::impl::XeGPUOptimizeBlockLoadsBase<
+ XeGPUOptimizeBlockLoadsPass> {
+ void runOnOperation() override {
+ MLIRContext &context = getContext();
+ TypeConverter converter;
+ RewritePatternSet patterns(&context);
+ ConversionTarget target(context);
+
+ // This pass is only meant for PVC and BMG targets. If unsupported target
+ // is found, exit early.
+ bool isTargetSupported = false;
+ getOperation()->walk([&](gpu::GPUFuncOp funcOp) {
+ auto chipStr = xegpu::getChipStr(funcOp);
+ if (chipStr && (chipStr.value() == "pvc" || chipStr.value() == "bmg"))
+ isTargetSupported = true;
+ });
+
+ if (!isTargetSupported) {
+ DBGS() << "XeGPUOptimizeBlockLoadsPass only supports PVC and BMG targets."
+ << "\n";
+ return;
+ }
+
+ // CreateNdDescOp and LoadNdOp with optimizable tensor desc types must be
+ // converted.
+ target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
+ [&](xegpu::CreateNdDescOp createNdOp) {
+ return !canBeOptimizedForTranspose(createNdOp.getType());
+ });
+ target.addDynamicallyLegalOp<xegpu::LoadNdOp>(
+ [&](xegpu::LoadNdOp loadNdOp) {
+ return !canBeOptimizedForTranspose(loadNdOp.getTensorDescType());
+ });
+ // Vector ExtractOps can have optimizable layouts if they extract from
+ // LoadNdOps with array length greater than 1. These ExtractOps must be
+ // converted.
+ target.addDynamicallyLegalOp<vector::ExtractOp>(
+ [&](vector::ExtractOp extractOp) {
+ auto layout = xegpu::getDistributeLayoutAttr(extractOp.getResult());
+ if (!layout)
+ return true;
+ auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+ auto laneData = layout.getEffectiveLaneDataAsInt();
+ return !canBeOptimizedForTranspose(laneLayout, laneData);
+ });
+ converter.addConversion([](Type type) { return type; });
+
+ target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
+ vector::VectorDialect>();
+ scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
+ target);
+ xegpu::populateXeGPUOptimizeBlockLoadsPatterns(patterns);
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns)))) {
+ DBGS() << "Optimize block loads pass failed.\n";
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 90eae87..dc9eb96 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -53,6 +53,8 @@ using namespace mlir::dataflow;
namespace {
+enum class LayoutKind { Lane, InstData };
+
//===----------------------------------------------------------------------===//
// LayoutInfo
//===----------------------------------------------------------------------===//
@@ -166,7 +168,8 @@ LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
llvm_unreachable("Join should not be triggered by layout propagation.");
}
-/// Construct a new layout with the transposed lane layout and lane data.
+/// Construct a new layout with the transposed inst_data or lane_layout,
+/// lane_data.
LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
if (!isAssigned())
return {};
@@ -186,12 +189,20 @@ LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
SmallVector<int32_t> laneData;
SmallVector<int32_t> instData;
for (int64_t idx : permutation) {
- laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
- laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
- instData.push_back(static_cast<int32_t>(getInstData()[idx]));
+ if (getLaneLayout().size()) {
+ laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
+ laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
+ }
+ if (getInstData().size())
+ instData.push_back(static_cast<int32_t>(getInstData()[idx]));
}
- return LayoutInfo(xegpu::LayoutAttr::get(storage.getContext(), instData,
- laneLayout, laneData));
+ xegpu::LayoutAttr layoutAttr;
+ if (getLaneLayout().size())
+ layoutAttr =
+ xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
+ if (getInstData().size())
+ layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
+ return LayoutInfo(layoutAttr);
}
//===----------------------------------------------------------------------===//
@@ -204,28 +215,6 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
using Lattice::Lattice;
};
-/// Helper Function to find a proper instruction multiple for the user-supplied
-/// sg-level data shape. `candidates` are uArch allowed shapes.
-/// `candidateMultiples` are uArch multiples of such shapes (e.g., block count).
-template <typename T>
-int getLargestDivisor(T dim, ArrayRef<T> candidates,
- ArrayRef<T> candidateMultiples = {}) {
- static_assert(std::is_integral<T>::value, "T must be an integer type");
- int largest = -1;
- SmallVector<T> multiples = {1};
- if (!candidateMultiples.empty())
- multiples =
- SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end());
- for (T candidate : candidates) {
- for (T multiple : multiples) {
- int value = static_cast<int>(candidate * multiple);
- if (value != 0 && dim % value == 0 && value > largest)
- largest = value;
- }
- }
- return largest;
-}
-
/// Helper Functions to get default layouts. A `default layout` is a layout that
/// is assigned to a value when the layout is not fixed by some anchor operation
/// (like DPAS).
@@ -235,15 +224,14 @@ int getLargestDivisor(T dim, ArrayRef<T> candidates,
/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
unsigned rank,
- const xegpu::uArch::uArch *uArch,
- ArrayRef<int> instData) {
+ const xegpu::uArch::uArch *uArch) {
assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
if (rank == 1) {
return LayoutInfo(
- xegpu::LayoutAttr::get(ctx, instData, {uArch->getSubgroupSize()}, {1}));
+ xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1}));
}
- return LayoutInfo(xegpu::LayoutAttr::get(
- ctx, instData, {1, uArch->getSubgroupSize()}, {1, 1}));
+ return LayoutInfo(
+ xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1}));
}
static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
@@ -258,7 +246,6 @@ static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
/// Helper to get the default layout for a vector type.
static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
const xegpu::uArch::uArch *uArch,
- ArrayRef<int> instData,
unsigned packingSize,
bool isScattered = false) {
// Expecting a 1D or 2D vector.
@@ -269,16 +256,16 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
"Expected int or float element type.");
// If the rank is 1, then return default layout for 1D vector.
if (vectorTy.getRank() == 1)
- return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch, instData);
+ return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
// Packing factor is determined by the element type bitwidth.
unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
if (isScattered) {
- return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData,
+ return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
{uArch->getSubgroupSize(), 1},
{1, packingFactor}));
}
- return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData,
+ return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
{1, uArch->getSubgroupSize()},
{1, packingFactor}));
}
@@ -286,7 +273,6 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
/// Helper to get the default layout for a vector type.
static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
const xegpu::uArch::uArch *uArch,
- ArrayRef<int> instData,
unsigned packingSize,
bool isScattered = false) {
// Expecting a 1D or 2D vector.
@@ -297,18 +283,18 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
"Expected int or float element type.");
// If the rank is 1, then return default layout for 1D vector.
if (tdescTy.getRank() == 1)
- return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch, instData);
+ return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch);
// Packing factor is determined by the element type bitwidth.
unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
int subgroupSize = uArch->getSubgroupSize();
int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
if (isScattered) {
return LayoutInfo(xegpu::LayoutAttr::get(
- tdescTy.getContext(), instData, {subgroupSize, 1}, {1, packingFactor}));
+ tdescTy.getContext(), {subgroupSize, 1}, {1, packingFactor}));
}
return LayoutInfo(xegpu::LayoutAttr::get(
- tdescTy.getContext(), instData, {1, subgroupSize}, {1, packingFactor}));
+ tdescTy.getContext(), {1, subgroupSize}, {1, packingFactor}));
}
/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
@@ -320,7 +306,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
static LayoutInfo
getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
const xegpu::uArch::uArch *uArch,
- ArrayRef<int> instData, unsigned packingSize) {
+ unsigned packingSize) {
Type elementTy = vectorTy.getElementType();
assert(elementTy.isIntOrFloat() &&
"Expected int or float type in DPAS operands");
@@ -332,10 +318,10 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
{static_cast<int32_t>(packingSize / elementTy.getIntOrFloatBitWidth()),
1});
return LayoutInfo(
- xegpu::LayoutAttr::get(vectorTy.getContext(), instData, layout, data));
+ xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
}
// Otherwise, return the default layout for the vector type.
- return getDefaultSIMTLayoutInfo(vectorTy, uArch, instData, packingSize);
+ return getDefaultSIMTLayoutInfo(vectorTy, uArch, packingSize);
}
//===----------------------------------------------------------------------===//
@@ -350,6 +336,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
class LayoutInfoPropagation
: public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
private:
+ LayoutKind layoutKind;
void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
@@ -400,10 +387,14 @@ private:
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
+ bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
+
public:
LayoutInfoPropagation(DataFlowSolver &solver,
- SymbolTableCollection &symbolTable)
- : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
+ SymbolTableCollection &symbolTable,
+ LayoutKind layoutKind)
+ : SparseBackwardDataFlowAnalysis(solver, symbolTable),
+ layoutKind(layoutKind) {}
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
LogicalResult
@@ -486,43 +477,71 @@ LogicalResult LayoutInfoPropagation::visitOperation(
return success();
}
+bool LayoutInfoPropagation::hasParamsOfLayoutKind(
+ xegpu::DistributeLayoutAttr anchorLayout) {
+ if (anchorLayout == nullptr) {
+ return false;
+ }
+ if (layoutKind == LayoutKind::InstData) {
+ return !(anchorLayout.getEffectiveInstDataAsInt().empty());
+ } else if (layoutKind == LayoutKind::Lane) {
+ return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
+ anchorLayout.getEffectiveLaneDataAsInt().empty());
+ }
+ return false;
+}
+
void LayoutInfoPropagation::visitPrefetchNdOp(
xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- // Here we assign the default layout to the tensor descriptor operand of
- // prefetch.
- auto tdescTy = prefetch.getTensorDescType();
-
- auto uArch = getUArch(getChipStr(prefetch).value_or(""));
- const auto *uArchInstruction =
- dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
- uArch->getInstruction(
- xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
-
- auto blockWHC =
- uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
- if (!blockWHC)
- prefetch.emitWarning("No known block params found for the element type.");
- auto [bWidth, bHeight, bCount] = blockWHC.value();
- SmallVector<int> instData;
- int instWidth = getLargestDivisor(
- static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth,
- bCount);
- if (instWidth == -1)
- prefetch.emitWarning(
- "No suitable instruction multiple found for the given shape.");
- if (tdescTy.getRank() == 1)
- instData = {instWidth};
- else {
- int instHeight = getLargestDivisor(
- static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
- if (instHeight == -1)
+
+ LayoutInfo prefetchLayout;
+ xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr();
+ if (hasParamsOfLayoutKind(anchorLayout)) {
+ prefetchLayout = LayoutInfo(anchorLayout);
+ } else {
+ // Here we assign the default layout to the tensor descriptor operand of
+ // prefetch.
+ auto tdescTy = prefetch.getTensorDescType();
+
+ auto uArch = getUArch(getChipStr(prefetch).value_or(""));
+ const auto *uArchInstruction =
+ dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
+ uArch->getInstruction(
+ xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
+
+ auto blockWHC =
+ uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
+ if (!blockWHC)
+ prefetch.emitWarning("No known block params found for the element type.");
+ auto [bWidth, bHeight, bCount] = blockWHC.value();
+ SmallVector<int> instData;
+ int instWidth = xegpu::getLargestDivisor(
+ static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
+ if (instWidth == -1)
prefetch.emitWarning(
"No suitable instruction multiple found for the given shape.");
- instData = {instHeight, instWidth};
+ if (tdescTy.getRank() == 1)
+ instData = {instWidth};
+ else {
+ int instHeight = xegpu::getLargestDivisor(
+ static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
+ if (instHeight == -1)
+ prefetch.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ instData = {instHeight, instWidth};
+ }
+
+ if (layoutKind == LayoutKind::InstData)
+ prefetchLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
+ else
+ prefetchLayout = getDefaultSIMTLayoutInfo(
+ tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
+
+ prefetch.setLayoutAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get()));
}
- auto prefetchLayout = getDefaultSIMTLayoutInfo(
- tdescTy, uArch, instData, uArchInstruction->getPackedFormatBitSize());
// Propagate the layout to the source tensor descriptor.
propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
}
@@ -561,23 +580,39 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
// Only consider vector to vector broadcasts for now.
VectorType resultTy = broadcast.getResultVectorType();
VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
- if (!sourceTy) {
- broadcast.emitWarning("Expecting source type to be a vector type.");
+ // skip layout propagation for non-vector source operand.
+ if (!sourceTy)
return;
- }
- // Only consider nD -> nD broadcast.
+ // Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
if (sourceTy.getRank() != resultTy.getRank()) {
- broadcast.emitWarning("Expecting source and result to have same rank.");
+ auto sourceDims = sourceTy.getShape();
+ auto resultDims = resultTy.getShape();
+ SmallVector<int64_t> bcastDims;
+ auto dimDiff = resultTy.getRank() - sourceTy.getRank();
+ // adding the missing leading dims
+ for (int i = 0; i < dimDiff; i++)
+ bcastDims.push_back(i);
+
+ // for the rest dims in the resultTy, if sourceTy dim is 1, then it's
+ // broadcasted dim
+ for (size_t i = 0; i < sourceDims.size(); i++)
+ if ((sourceDims[i] == 1) && (resultDims[i + dimDiff] != 1))
+ bcastDims.push_back(i + dimDiff);
+
+ // create a slice layout for the source
+ xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
+ broadcast->getContext(),
+ cast<xegpu::DistributeLayoutAttr>(resultLayout.get()),
+ DenseI64ArrayAttr::get(broadcast->getContext(), bcastDims));
+
+ propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
return;
}
+
SetVector<int64_t> broadcastUnitDims = broadcast.computeBroadcastedUnitDims();
- if (broadcastUnitDims.size() != 1) {
- broadcast.emitWarning("Expecting source type to be nD vector only with "
- "one broadcasted dimension.");
- return;
- }
- // Propagate the result layout to the source operand.
+ resultLayout = cast<xegpu::DistributeLayoutAttr>(resultLayout.get())
+ .setUnitDimData(broadcastUnitDims);
propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
}
@@ -622,55 +657,97 @@ void LayoutInfoPropagation::visitUpdateNdOffsetOp(
void LayoutInfoPropagation::visitDpasOp(
xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- VectorType aTy = dpas.getLhsType();
- VectorType bTy = dpas.getRhsType();
-
- auto uArch = getUArch(getChipStr(dpas).value_or(""));
- const int subgroupSize = uArch->getSubgroupSize();
- const auto *uArchInstruction =
- dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
- xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
-
- const unsigned dataALen = aTy.getShape().front();
- auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
- const int maxALen =
- getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
- if (maxALen == -1)
- dpas.emitWarning(
- "No suitable instruction multiple found for the given shape.");
-
- const unsigned dataBLen = bTy.getShape().back();
- auto supportedBLen = uArchInstruction->getSupportedK(bTy.getElementType());
- const int maxBLen =
- getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
- if (maxBLen == -1)
- dpas.emitWarning(
- "No suitable instruction multiple found for the given shape.");
- SmallVector<int> instDataA = {maxALen, subgroupSize};
- SmallVector<int> instDataB = {subgroupSize, maxBLen};
-
- propagateIfChanged(operands[0],
- operands[0]->meet(getSIMTLayoutInfoForDPASOperand(
- aTy, 0, uArch, instDataA,
- uArchInstruction->getPackedFormatBitSizeA())));
- propagateIfChanged(operands[1],
- operands[1]->meet(getSIMTLayoutInfoForDPASOperand(
- bTy, 1, uArch, instDataB,
- uArchInstruction->getPackedFormatBitSizeB())));
- if (operands.size() > 2) {
- VectorType cTy = dpas.getAccType();
- const unsigned dataCLen = bTy.getShape().back();
- auto supportedCLen = uArchInstruction->getSupportedN(bTy.getElementType());
- const int maxCLen =
- getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
- if (maxCLen == -1)
+
+ LayoutInfo dpasALayout;
+ LayoutInfo dpasBLayout;
+ LayoutInfo dpasCDLayout;
+
+ xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr();
+ if (hasParamsOfLayoutKind(anchorLayoutCD)) {
+ xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr();
+ xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr();
+ assert(hasParamsOfLayoutKind(anchorLayoutA) &&
+ "Expected anchor layout for DPAS A operand.");
+ assert(hasParamsOfLayoutKind(anchorLayoutB) &&
+ "Expected anchor layout for DPAS B operand.");
+ dpasALayout = LayoutInfo(anchorLayoutA);
+ dpasBLayout = LayoutInfo(anchorLayoutB);
+ dpasCDLayout = LayoutInfo(anchorLayoutCD);
+
+ } else {
+
+ VectorType aTy = dpas.getLhsType();
+ VectorType bTy = dpas.getRhsType();
+
+ auto uArch = getUArch(getChipStr(dpas).value_or(""));
+ const int subgroupSize = uArch->getSubgroupSize();
+ const auto *uArchInstruction =
+ dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
+ xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
+
+ const unsigned dataALen = aTy.getShape().front();
+ auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
+ const int maxALen =
+ xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
+ if (maxALen == -1)
dpas.emitWarning(
"No suitable instruction multiple found for the given shape.");
- SmallVector<int> instDataC = {maxALen, maxCLen};
- propagateIfChanged(operands[2],
- operands[2]->meet(getSIMTLayoutInfoForDPASOperand(
- cTy, 2, uArch, instDataC,
- uArchInstruction->getPackedFormatBitSizeB())));
+
+ const unsigned dataBLen = bTy.getShape().back();
+ auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
+
+ const int maxBLen =
+ xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
+
+ if (maxBLen == -1)
+ dpas.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ SmallVector<int> instDataA = {maxALen, subgroupSize};
+ SmallVector<int> instDataB = {subgroupSize, maxBLen};
+
+ if (layoutKind == LayoutKind::InstData) {
+ dpasALayout =
+ LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
+ dpasBLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
+ } else {
+ dpasALayout = getSIMTLayoutInfoForDPASOperand(
+ aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
+ dpasBLayout = getSIMTLayoutInfoForDPASOperand(
+ bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
+ }
+
+ if (operands.size() > 2) {
+ VectorType cTy = dpas.getAccType();
+ if (layoutKind == LayoutKind::InstData) {
+ const unsigned dataCLen = bTy.getShape().back();
+ auto supportedCLen =
+ uArchInstruction->getSupportedN(bTy.getElementType());
+ const int maxCLen = xegpu::getLargestDivisor(
+ dataCLen, ArrayRef<unsigned>(supportedCLen));
+ if (maxCLen == -1)
+ dpas.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ SmallVector<int> instDataC = {maxALen, maxCLen};
+ dpasCDLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
+ } else
+ dpasCDLayout = getSIMTLayoutInfoForDPASOperand(
+ cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
+
+ dpas.setLayoutCdAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get()));
+ }
+ dpas.setLayoutAAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(dpasALayout.get()));
+ dpas.setLayoutBAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(dpasBLayout.get()));
+ }
+
+ propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
+ propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
+ if (operands.size() > 2) {
+ propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
}
}
@@ -679,37 +756,50 @@ void LayoutInfoPropagation::visitStoreNdOp(
xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- auto uArch = getUArch(getChipStr(store).value_or(""));
- const auto *uArchInstruction =
- dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
- uArch->getInstruction(
- xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
- VectorType dataTy = store.getValueType();
- auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
- store.getValueType().getElementType());
- if (!blockWHC)
- store.emitWarning("No known block params found for the element type.");
- auto [bWidth, bHeight, bCount] = blockWHC.value();
- SmallVector<int> instData;
- int instWidth = getLargestDivisor(
- static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth,
- bCount);
- if (instWidth == -1)
- store.emitWarning(
- "No suitable instruction multiple found for the given shape.");
- if (dataTy.getRank() == 1)
- instData = {instWidth};
- else {
- int instHeight = getLargestDivisor(
- static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
- if (instHeight == -1)
+ LayoutInfo storeLayout;
+ xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
+ if (hasParamsOfLayoutKind(anchorLayout)) {
+ storeLayout = LayoutInfo(anchorLayout);
+ } else {
+ auto uArch = getUArch(getChipStr(store).value_or(""));
+ const auto *uArchInstruction =
+ dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
+ uArch->getInstruction(
+ xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
+ VectorType dataTy = store.getValueType();
+ auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
+ store.getValueType().getElementType());
+ if (!blockWHC)
+ store.emitWarning("No known block params found for the element type.");
+ auto [bWidth, bHeight, bCount] = blockWHC.value();
+ SmallVector<int> instData;
+ int instWidth = xegpu::getLargestDivisor(
+ static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
+ if (instWidth == -1)
store.emitWarning(
"No suitable instruction multiple found for the given shape.");
- instData = {instHeight, instWidth};
+ if (dataTy.getRank() == 1)
+ instData = {instWidth};
+ else {
+ int instHeight = xegpu::getLargestDivisor(
+ static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
+ if (instHeight == -1)
+ store.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ instData = {instHeight, instWidth};
+ }
+
+ if (layoutKind == LayoutKind::InstData)
+ storeLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
+ else
+ storeLayout =
+ getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
+ uArchInstruction->getPackedFormatBitSize());
+ store.setLayoutAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
}
- LayoutInfo storeLayout =
- getDefaultSIMTLayoutInfo(store.getValueType(), uArch, instData,
- uArchInstruction->getPackedFormatBitSize());
+ // Propagate the layout to the value operand.
// Both operands should have the same layout
for (LayoutInfoLattice *operand : operands)
propagateIfChanged(operand, operand->meet(storeLayout));
@@ -720,21 +810,30 @@ void LayoutInfoPropagation::visitStoreNdOp(
void LayoutInfoPropagation::visitLoadNdOp(
xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- LayoutInfo valueLayout = results[0]->getValue();
- // Need the layout of the value to propagate to the tensor descriptor.
- if (!valueLayout.isAssigned())
- return;
- LayoutInfo tensorDescLayout = valueLayout;
- // LoadNdOp has the transpose effect. However, at the stage of this analysis
- // this effect is not expected and should be abstracted away. Emit a
- // warning.
- if (auto transpose = load.getTranspose()) {
- load.emitWarning("Transpose effect is not expected for LoadNdOp at "
- "LayoutInfoPropagation stage.");
- tensorDescLayout = valueLayout.transpose(transpose.value());
+
+ LayoutInfo loadLayout;
+ xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
+ if (hasParamsOfLayoutKind(anchorLayout)) {
+ loadLayout = LayoutInfo(anchorLayout);
+ } else {
+
+ LayoutInfo valueLayout = results[0]->getValue();
+ // Need the layout of the value to propagate to the tensor descriptor.
+ if (!valueLayout.isAssigned())
+ return;
+ loadLayout = valueLayout;
+ // LoadNdOp has the transpose effect. However, at the stage of this analysis
+ // this effect is not expected and should be abstracted away. Emit a
+ // warning.
+ if (auto transpose = load.getTranspose()) {
+ load.emitWarning("Transpose effect is not expected for LoadNdOp at "
+ "LayoutInfoPropagation stage.");
+ loadLayout = valueLayout.transpose(transpose.value());
+ }
+ load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
}
// Propagate the new layout to the tensor descriptor operand.
- propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
+ propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
}
/// For vector::TransposeOp, the layout of the result is transposed and
@@ -824,33 +923,48 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
void LayoutInfoPropagation::visitLoadGatherOp(
xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- // The layout is strictly determined by the payload type.
- auto payloadTy = dyn_cast<VectorType>(load.getValueType());
- if (!payloadTy) {
- load.emitWarning("Not propagating, non-vector payload supplied.");
- return;
- }
- auto uArch = getUArch(getChipStr(load).value_or(""));
- const int subgroupSize = uArch->getSubgroupSize();
- SmallVector<int> instData{subgroupSize};
- if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1)
- instData.push_back(chunkSize);
- else if (auto srcTdescTy =
- dyn_cast<xegpu::TensorDescType>(load.getSourceType())) {
- if (srcTdescTy.getChunkSizeAsInt() > 1)
+
+ LayoutInfo loadLayout;
+ LayoutInfo maskLayout;
+ xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
+ if (hasParamsOfLayoutKind(anchorLayout)) {
+ loadLayout = LayoutInfo(anchorLayout);
+ maskLayout = loadLayout;
+ } else {
+
+ // The layout is strictly determined by the payload type.
+ VectorType payloadTy = load.getValueType();
+ if (!payloadTy) {
+ load.emitWarning("Not propagating, non-vector payload supplied.");
+ return;
+ }
+ auto uArch = getUArch(getChipStr(load).value_or(""));
+ const int subgroupSize = uArch->getSubgroupSize();
+ SmallVector<int> instData{subgroupSize};
+ if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1)
instData.push_back(chunkSize);
- }
- LayoutInfo layout = getDefaultSIMTLayoutInfo(
- payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(),
- /*scattered*/ true);
+ else if (auto srcTdescTy =
+ dyn_cast<xegpu::TensorDescType>(load.getSourceType())) {
+ if (srcTdescTy.getChunkSizeAsInt() > 1)
+ instData.push_back(chunkSize);
+ }
- // Mask operand should have 1D default layout.
- LayoutInfo maskLayout =
- getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
+ if (layoutKind == LayoutKind::InstData)
+ loadLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
+ else
+ loadLayout = getDefaultSIMTLayoutInfo(
+ payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
+ /*scattered*/ true);
+ // Mask operand should have 1D default layout.
+ maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
+
+ load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
+ }
// Propagate the new layout to the tensor descriptor operand.
if (isa<xegpu::TensorDescType>(load.getSourceType()))
- propagateIfChanged(operands[0], operands[0]->meet(layout));
+ propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
// Propagate the new layout to the mask and optional offset operand.
propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
if (load.getOffsets())
@@ -878,38 +992,56 @@ void LayoutInfoPropagation::visitCreateDescOp(
void LayoutInfoPropagation::visitStoreScatterOp(
xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- // Currently, for 2D StoreScatterOp we expect that the height dimension of
- // the tensor descriptor is equal to the subgroup size. This is ensured by
- // the op verifier.
- auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
- if (!payloadTy) {
- storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
- return;
- }
- auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
- const int subgroupSize = uArch->getSubgroupSize();
-
- auto payloadShape = payloadTy.getShape();
- if (payloadShape.size() > 1)
- assert(
- payloadShape[0] == subgroupSize &&
- "Expected the first dimension of 2D tensor descriptor to be equal to "
- "subgroup size.");
-
- SmallVector<int> instData{subgroupSize};
- if (auto chunkSize = storeScatter.getChunkSize().value_or(0); chunkSize > 1)
- instData.push_back(chunkSize);
- else if (auto dstTdescTy =
- dyn_cast<xegpu::TensorDescType>(storeScatter.getDestType())) {
- if (dstTdescTy.getChunkSizeAsInt() > 1)
- instData.push_back(chunkSize);
- }
- LayoutInfo payloadLayout = getDefaultSIMTLayoutInfo(
- payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(),
- /*scattered=*/true);
- LayoutInfo maskLayout =
- getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
+ LayoutInfo payloadLayout;
+ LayoutInfo maskLayout;
+ xegpu::DistributeLayoutAttr anchorLayout = storeScatter.getLayoutAttr();
+ if (hasParamsOfLayoutKind(anchorLayout)) {
+ payloadLayout = LayoutInfo(anchorLayout);
+ maskLayout = payloadLayout;
+ } else {
+ // Currently, for 2D StoreScatterOp we expect that the height dimension of
+ // the tensor descriptor is equal to the subgroup size. This is ensured by
+ // the op verifier.
+ VectorType payloadTy = storeScatter.getValueType();
+ if (!payloadTy) {
+ storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
+ return;
+ }
+
+ auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
+ const int subgroupSize = uArch->getSubgroupSize();
+
+ if (layoutKind == LayoutKind::InstData) {
+ SmallVector<int> instData{subgroupSize};
+ if (auto chunkSize = storeScatter.getChunkSize().value_or(0);
+ chunkSize > 1)
+ instData.push_back(chunkSize);
+ else if (auto dstTdescTy = dyn_cast<xegpu::TensorDescType>(
+ storeScatter.getDestType())) {
+ if (dstTdescTy.getChunkSizeAsInt() > 1)
+ instData.push_back(chunkSize);
+ }
+ payloadLayout = LayoutInfo(
+ xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
+ } else {
+ auto payloadShape = payloadTy.getShape();
+ if (payloadShape.size() > 1)
+ assert(payloadShape[0] == subgroupSize &&
+ "Expected the first dimension of 2D tensor descriptor to be "
+ "equal to "
+ "subgroup size.");
+ payloadLayout = getDefaultSIMTLayoutInfo(
+ payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
+ /*scattered=*/true);
+ }
+
+ maskLayout =
+ getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
+
+ storeScatter.setLayoutAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(payloadLayout.get()));
+ }
// Propagate the payload operand layout
propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
// Propagate the destination (if tdesc) operand layout
@@ -931,10 +1063,10 @@ class RunLayoutInfoPropagation {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
- RunLayoutInfoPropagation(Operation *op) : target(op) {
+ RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) : target(op) {
SymbolTableCollection symbolTable;
loadBaselineAnalyses(solver);
- solver.load<LayoutInfoPropagation>(symbolTable);
+ solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
(void)solver.initializeAndRun(op);
}
@@ -1041,7 +1173,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
}
// If the result is a vector type, add a temporary layout attribute to the
// op.
- xegpu::setDistributeLayoutAttr(result, layout);
+ xegpu::setDistributeLayoutAttr(result, layout, /*respectPermLayout*/ true);
}
return success();
}
@@ -1174,7 +1306,18 @@ struct XeGPUPropagateLayoutPass final
} // namespace
void XeGPUPropagateLayoutPass::runOnOperation() {
- auto &analysis = getAnalysis<RunLayoutInfoPropagation>();
+ LayoutKind layoutKind;
+ if (this->layoutKind == "lane") {
+ layoutKind = LayoutKind::Lane;
+ } else if (this->layoutKind == "inst") {
+ layoutKind = LayoutKind::InstData;
+ } else {
+ getOperation()->emitError("Unsupported layout kind option: " +
+ this->layoutKind);
+ signalPassFailure();
+ return;
+ }
+ RunLayoutInfoPropagation analysis(getOperation(), layoutKind);
// Print the analysis result and exit. (for debugging purposes)
if (printOnly) {
auto &os = llvm::outs();
@@ -1188,8 +1331,6 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
return {};
xegpu::DistributeLayoutAttr layoutAttr =
cast<xegpu::DistributeLayoutAttr>(layout.get());
- if (this->layoutKind == "lane")
- layoutAttr = layoutAttr.dropInstData();
if (layout.isSliceLayout())
return cast<xegpu::SliceAttr>(layoutAttr);
return cast<xegpu::LayoutAttr>(layoutAttr);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 5a3b27e..ca81c3c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
@@ -98,7 +99,6 @@ getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
if (i < distributionStart)
continue;
-
// Check if the dimension can be distributed evenly.
if (dim % effectiveLaneLayout[i - distributionStart] != 0)
return failure();
@@ -173,6 +173,21 @@ static bool requireTranspose(const xegpu::LayoutAttr layout,
return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1;
}
+/// Given a vector type and its distributed vector type, return the list of
+/// dimensions that are distributed.
+static SmallVector<int64_t> getDistributedDims(VectorType originalType,
+ VectorType distributedType) {
+ assert(originalType.getRank() == distributedType.getRank() &&
+ "sequential and distributed vector types must have the same rank");
+ SmallVector<int64_t> distributedDims;
+ for (int64_t i = 0; i < originalType.getRank(); ++i) {
+ if (distributedType.getDimSize(i) != originalType.getDimSize(i)) {
+ distributedDims.push_back(i);
+ }
+ }
+ return distributedDims;
+}
+
/// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
/// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
/// contained within a WarpExecuteOnLane0Op.
@@ -912,6 +927,183 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
}
};
+static SmallVector<Value> computeDistributedCoordinatesForMatrixOp(
+ PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout,
+ Value laneId, ArrayRef<int64_t> payloadShape, ValueRange origOffsets) {
+ SmallVector<Value> newCoods;
+ auto maybeCoords =
+ layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
+ if (failed(maybeCoords))
+ return {};
+ assert(maybeCoords.value().size() == 1 &&
+ "Expected one set of distributed offsets");
+ SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
+ rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
+ getAsOpFoldResult(origOffsets));
+ newCoods = llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
+ return newCoods;
+}
+
+/// Pattern for distributing xegpu::LoadMatrixOp.
+struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ gpu::YieldOp yield = warpOp.getTerminator();
+ Operation *lastNode = yield->getPrevNode();
+ auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode);
+ if (!matrixOp)
+ return failure();
+
+ OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
+ return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
+ });
+ if (!producedByLastLoad)
+ return rewriter.notifyMatchFailure(
+ warpOp, "The last op is not xegpu::LoadMatrixOp");
+ const int operandIdx = producedByLastLoad->getOperandNumber();
+
+ VectorType sgPayloadTy =
+ dyn_cast<VectorType>(matrixOp.getResult().getType());
+ VectorType warpResultTy =
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
+ if (!sgPayloadTy)
+ return rewriter.notifyMatchFailure(
+ matrixOp, "the matrix op payload must be a vector type");
+
+ auto loc = matrixOp.getLoc();
+ auto offsets = matrixOp.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(matrixOp,
+ "the load op must have offsets");
+ SmallVector<Value> offsetsAsValues =
+ vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
+
+ auto layout = matrixOp.getLayoutAttr();
+ if (!layout)
+ return rewriter.notifyMatchFailure(
+ matrixOp, "the matrix operation lacks layout attribute");
+
+ FailureOr<VectorType> distPayloadByWarpOpOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
+ if (failed(distPayloadByWarpOpOrFailure))
+ return rewriter.notifyMatchFailure(
+ matrixOp, "Failed to distribute matrix op payload based on layout.");
+
+ SmallVector<Value> operands = {matrixOp.getMemDesc()};
+ const unsigned offsetsStartIdx = operands.size();
+ operands.append(offsetsAsValues);
+
+ SmallVector<Type> operandTypes = llvm::to_vector(
+ llvm::map_range(operands, [](Value v) { return v.getType(); }));
+
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, operands, operandTypes, newRetIndices);
+ SmallVector<Value> newOperands = llvm::map_to_vector(
+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+ SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
+ ShapedType::kDynamic);
+ DenseI64ArrayAttr newConstOffsetsAttr =
+ rewriter.getDenseI64ArrayAttr(newConstOffsets);
+ ValueRange currentOffsets =
+ ValueRange(newOperands).drop_front(offsetsStartIdx);
+
+ SmallVector<Value> newCoords = currentOffsets;
+ rewriter.setInsertionPointAfter(newWarpOp);
+
+ if (!matrixOp.getSubgroupBlockIoAttr()) {
+ newCoords = computeDistributedCoordinatesForMatrixOp(
+ rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
+ currentOffsets);
+ }
+ xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
+ rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
+ newOperands[0], ValueRange(newCoords), newConstOffsetsAttr,
+ matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+ // Resolve the output type and replace all uses.
+ rewriter.replaceAllUsesWith(
+ newWarpOp.getResult(operandIdx),
+ resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
+ return success();
+ }
+};
+
+/// Pattern for distributing xegpu::StoreMatrixOp.
+struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ gpu::YieldOp yield = warpOp.getTerminator();
+ Operation *lastNode = yield->getPrevNode();
+ auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode);
+ if (!matrixOp)
+ return failure();
+
+ VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
+ if (!sgPayloadTy)
+ return rewriter.notifyMatchFailure(
+ matrixOp, "the matrix op payload must be a vector type");
+
+ auto loc = matrixOp.getLoc();
+ auto offsets = matrixOp.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(matrixOp,
+ "the store op must have offsets");
+ SmallVector<Value> offsetsAsValues =
+ vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
+
+ auto layout = matrixOp.getLayoutAttr();
+ if (!layout)
+ return rewriter.notifyMatchFailure(
+ matrixOp, "the matrix operation lacks layout attribute");
+
+ FailureOr<VectorType> distPayloadByWarpOpOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
+ if (failed(distPayloadByWarpOpOrFailure))
+ return rewriter.notifyMatchFailure(
+ matrixOp, "Failed to distribute matrix op payload based on layout.");
+
+ SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()};
+ const unsigned offsetsStartIdx = operands.size();
+ operands.append(offsetsAsValues);
+
+ SmallVector<Type> operandTypes = llvm::to_vector(
+ llvm::map_range(operands, [](Value v) { return v.getType(); }));
+ operandTypes[0] = *distPayloadByWarpOpOrFailure;
+
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, operands, operandTypes, newRetIndices);
+ SmallVector<Value> newOperands = llvm::map_to_vector(
+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+ SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
+ ShapedType::kDynamic);
+ DenseI64ArrayAttr newConstOffsetsAttr =
+ rewriter.getDenseI64ArrayAttr(newConstOffsets);
+ ValueRange currentOffsets =
+ ValueRange(newOperands).drop_front(offsetsStartIdx);
+
+ SmallVector<Value> newCoords = currentOffsets;
+ rewriter.setInsertionPointAfter(newWarpOp);
+
+ if (!matrixOp.getSubgroupBlockIoAttr()) {
+ newCoords = computeDistributedCoordinatesForMatrixOp(
+ rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
+ currentOffsets);
+ }
+
+ xegpu::StoreMatrixOp::create(
+ rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
+ ValueRange(newCoords), newConstOffsetsAttr,
+ matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+ rewriter.eraseOp(matrixOp);
+ return success();
+ }
+};
+
/// Distribute a scattered load op. The logic and requirements are the same as
/// for the scattered store distribution. The warpOp's payload vector is
/// expected to be distributed by the load's result consumer.
@@ -1231,6 +1423,166 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
}
};
+/// This pattern distributes the `vector.broadcast` operation across lanes in a
+/// warp. The pattern supports three use cases:
+///
+/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input
+/// vector
+/// must have a slice layout of the result. If the distributed source and
+/// target vector types are identical, this lowers to a no-op; otherwise, it
+/// remains a broadcast but operates on distributed vectors.
+///
+/// 2) Broadcast a same-rank vector with identical layouts for source and
+/// target:
+/// The source vector must have unit dimensions, and lane_data must be unit
+/// size for those unit dims. This always lowers to a no-op.
+///
+/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast from
+/// scalar to distributed result type.
+///
+/// Example 1 (lowering to a broadcast with distributed types):
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
+/// %0 = "some_def"() {layout_result_0 =
+/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+/// dims = [0]> } : () -> (vector<32xf32>)
+/// %2 = vector.broadcast %0 {layout_result_0 =
+/// #xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>}
+/// : vector<32xf32> to vector<8x32xf32>
+/// gpu.yield %1 : vector<8x32xf32>
+/// }
+/// ```
+/// is lowered to:
+/// ```
+/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
+/// %0 = "some_def"() {layout_result_0 =
+/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+/// dims = [0]> } : () -> (vector<32xf32>)
+/// gpu.yield %0 : vector<32xf32>
+/// }
+/// %2 = vector.broadcast %r#0 : vector<1xf32> to vector<8x1xf32>
+///
+/// Example 2 (no-op):
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x32xf32>) {
+/// %0 = "some_def"() {layout_result_0 =
+/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+/// dims = [1]> } : () -> (vector<8xf32>)
+/// %1 = vector.shape_cast %0
+/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
+/// 1]>}: vector<8xf32> to vector<8x1xf32>
+/// %2 = vector.broadcast %1
+/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
+/// 1]>}: vector<8x1xf32> to vector<8x32xf32>
+/// gpu.yield %1 : vector<8x32xf32>
+/// }
+/// ```
+/// is lowered to:
+/// ```
+/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
+/// %0 = "some_def"() {layout_result_0 =
+/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+/// dims = [1]> } : () -> (vector<8xf32>)
+/// %1 = vector.shape_cast %0
+/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
+/// 1]>}: vector<8xf32> to vector<8x1xf32>
+/// gpu.yield %1 : vector<8x1xf32>
+/// }
+/// // The broadcast is implicit through layout transformation (no-op)
+/// "some_use"(%r#0)
+/// ```
+struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *yieldOperand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
+ if (!yieldOperand)
+ return failure();
+ auto broadcastOp =
+ cast<vector::BroadcastOp>(yieldOperand->get().getDefiningOp());
+ unsigned operandIdx = yieldOperand->getOperandNumber();
+
+ VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+ VectorType destType =
+ dyn_cast<VectorType>(broadcastOp.getResult().getType());
+
+ xegpu::DistributeLayoutAttr sourceLayout =
+ xegpu::getDistributeLayoutAttr(broadcastOp->getOpOperand(0));
+ xegpu::DistributeLayoutAttr resultLayout =
+ xegpu::getDistributeLayoutAttr(broadcastOp.getResult());
+
+ FailureOr<VectorType> sourceDistType;
+ Type sourceElemOrDistType;
+ if (sourceType) {
+
+ // Case 1 and 2: source is a vector type.
+ int64_t rankDiff = destType.getRank() - sourceType.getRank();
+ if (rankDiff > 0) {
+ // Case 1: source is lower-rank than result.
+ bool isSliceOf = sourceLayout.isSliceOf(resultLayout);
+ if (!isSliceOf)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "Broadcast input layout must be a slice of result layout.");
+ }
+ // case 2: source and result have same rank
+ if (rankDiff == 0) {
+ SetVector<int64_t> broadcastUnitDims =
+ broadcastOp.computeBroadcastedUnitDims();
+ resultLayout = resultLayout.setUnitDimData(broadcastUnitDims);
+ bool isEqualTo = sourceLayout.isEqualTo(resultLayout);
+ if (!isEqualTo)
+ return rewriter.notifyMatchFailure(
+ warpOp, "For same-rank broadcast, source must be identical to "
+ "adjusted result layouts with unit dims.");
+ sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
+ }
+
+ sourceDistType =
+ getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
+ if (failed(sourceDistType)) {
+ return rewriter.notifyMatchFailure(
+ warpOp, "Failed to distribute the source vector type.");
+ }
+ sourceElemOrDistType = sourceDistType.value();
+
+ } else {
+ // Case 3: source is a scalar type.
+ if (sourceLayout) {
+ return rewriter.notifyMatchFailure(
+ warpOp, "Broadcast from scalar must not have a layout attribute.");
+ }
+ sourceElemOrDistType = broadcastOp.getSourceType();
+ }
+ FailureOr<VectorType> destDistType =
+ getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
+ if (failed(destDistType)) {
+ return rewriter.notifyMatchFailure(
+ warpOp, "Failed to distribute the dest vector type.");
+ }
+
+ SmallVector<size_t> newRetIndices;
+ auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {broadcastOp.getSource()}, sourceElemOrDistType,
+ newRetIndices);
+
+ Value distributedSource = newWarpOp.getResult(newRetIndices[0]);
+
+ Value newBroadcast = distributedSource;
+
+ if (sourceElemOrDistType != destDistType.value()) {
+ rewriter.setInsertionPointAfter(newWarpOp);
+ newBroadcast =
+ vector::BroadcastOp::create(rewriter, newWarpOp.getLoc(),
+ destDistType.value(), distributedSource);
+ }
+
+ rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newBroadcast);
+ return success();
+ }
+};
+
/// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing
/// `gpu.warp_execute_on_lane_0` region.
struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
@@ -1291,6 +1643,226 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
}
};
+// Distribute a `vector.extract_strided_slice` op feeding into yield op of an
+// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
+// advanced cases where the distributed dimension is partially extracted and
+// currently not supported by the generic vector distribution patterns.
+struct VectorExtractStridedSliceDistribution
+ : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
+ if (!operand)
+ return failure();
+ auto extractOp =
+ cast<vector::ExtractStridedSliceOp>(operand->get().getDefiningOp());
+ unsigned operandIdx = operand->getOperandNumber();
+ auto distributedType =
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
+ // Find the distributed dimensions.
+ auto extractResultType = cast<VectorType>(operand->get().getType());
+ auto distributedDims =
+ getDistributedDims(extractResultType, distributedType);
+ // Collect updated source type, sizes and offsets. They may be adjusted
+ // later if the data is distributed to lanes (as opposed to being owned by
+ // all lanes uniformly).
+ VectorType updatedSourceType = extractOp.getSourceVectorType();
+ SmallVector<Attribute> updatedSizes = llvm::map_to_vector(
+ extractOp.getSizes(), [](Attribute attr) { return attr; });
+ SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
+ extractOp.getOffsets(), [](Attribute attr) { return attr; });
+ // If the result is distributed, it must be distributed in exactly one
+ // dimension. In this case, we adjust the sourceDistType, distributedSizes
+ // and distributedOffsets accordingly.
+ if (distributedDims.size() > 0) {
+ if (distributedDims.size() != 1)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Source can not be distributed in multiple dimensions.");
+ int64_t distributedDim = distributedDims[0];
+ int sourceDistrDimSize =
+ extractOp.getSourceVectorType().getShape()[distributedDim];
+ auto sourceLayout =
+ xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0));
+ if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
+ return rewriter.notifyMatchFailure(
+ warpOp, "the source of extract_strided_slice op lacks distribution "
+ "layout");
+ auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
+ // Because only single dimension distribution is supported, lane layout
+ // size at the distributed dim must be the subgroup size.
+ int subgroupSize = sourceLaneLayout[distributedDim];
+ // Check if the source size in the distributed dimension is a multiple of
+ // subgroup size.
+ if (sourceDistrDimSize % subgroupSize != 0)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "Source size along distributed dimension is not a multiple of "
+ "subgroup size.");
+ auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
+ // We expect lane data to be all ones in this case.
+ if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
+ return rewriter.notifyMatchFailure(
+ warpOp, "Expecting unit lane data in source layout");
+ // The offsets in the distributed dimention must be a multiple of subgroup
+ // size.
+ int64_t distrDimOffset =
+ cast<IntegerAttr>(extractOp.getOffsets()[distributedDim]).getInt();
+ if (distrDimOffset % subgroupSize != 0)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Offset along distributed dimension "
+ "is not a multiple of subgroup size.");
+ updatedSourceType = getDistVecTypeBasedOnLaneLayout(
+ sourceLayout, extractOp.getSourceVectorType())
+ .value();
+ // Update the distributed sizes to match the distributed type.
+ updatedSizes[distributedDim] = rewriter.getI64IntegerAttr(
+ distributedType.getDimSize(distributedDim));
+ // Update the distributed offsets to match round robin distribution (i.e.
+ // each lane owns data at `subgroupSize` stride given unit lane data).
+ updatedOffsets[distributedDim] =
+ rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
+ }
+ // Do the distribution by yielding the source of the extract op from
+ // the warp op and creating a new extract op outside the warp op.
+ SmallVector<size_t> newRetIndices;
+ auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType},
+ newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ Value source = newWarpOp.getResult(newRetIndices[0]);
+ // Create a new extract op outside the warp op.
+ Value newExtractOp = vector::ExtractStridedSliceOp::create(
+ rewriter, extractOp.getLoc(), distributedType, source,
+ ArrayAttr::get(rewriter.getContext(), updatedOffsets),
+ ArrayAttr::get(rewriter.getContext(), updatedSizes),
+ extractOp.getStrides());
+ rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newExtractOp);
+ return success();
+ }
+};
+
+/// Distribute a `vector.insert_strided_slice` op feeding into yield op of an
+/// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
+/// advanced cases where the distributed dimension is partially inserted and
+/// currently not supported by the generic vector distribution patterns.
+struct VectorInsertStridedSliceDistribution
+ : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
+ if (!operand)
+ return failure();
+ unsigned int operandNumber = operand->getOperandNumber();
+ auto insertOp =
+ operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
+ auto distributedType =
+ cast<VectorType>(warpOp.getResult(operandNumber).getType());
+ // Find the distributed dimensions of the dest vector.
+ auto insertResultType = cast<VectorType>(operand->get().getType());
+ auto destDistributedDims =
+ getDistributedDims(insertResultType, distributedType);
+ // Collect updated offsets, source type and dest type. They may be adjusted
+ // later if the data is distributed to lanes (as opposed to being owned by
+ // all lanes uniformly).
+ SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
+ insertOp.getOffsets(), [](Attribute attr) { return attr; });
+ VectorType updatedSourceType = insertOp.getSourceVectorType();
+ VectorType updatedDestType = insertOp.getDestVectorType();
+ if (destDistributedDims.size() > 0) {
+ // Only single dimension distribution is supported.
+ if (destDistributedDims.size() != 1)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "Expecting source to be distributed in a single dimension.");
+ int64_t destDistributedDim = destDistributedDims[0];
+
+ VectorType srcType = insertOp.getSourceVectorType();
+ VectorType destType = insertOp.getDestVectorType();
+ // Currently we require that both source (kD) and dest (nD) vectors are
+ // distributed. This requires that distributedDim (d) is contained in the
+ // last k dims of the dest vector (d >= n - k).
+ int64_t sourceDistributedDim =
+ destDistributedDim - (destType.getRank() - srcType.getRank());
+ if (sourceDistributedDim < 0)
+ return rewriter.notifyMatchFailure(
+ insertOp,
+ "distributed dimension must be in the last k (i.e. source "
+ "rank) dims of dest vector");
+ int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
+ // Obtain the source and dest layouts.
+ auto destLayout =
+ xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(1));
+ auto sourceLayout =
+ xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(0));
+ if (!destLayout || !sourceLayout ||
+ destLayout.getEffectiveLaneLayoutAsInt().empty() ||
+ sourceLayout.getEffectiveLaneLayoutAsInt().empty())
+ return rewriter.notifyMatchFailure(
+ warpOp, "the source or dest of insert_strided_slice op lacks "
+ "distribution layout");
+ // Because only single dimension distribution is supported, lane layout
+ // size at the distributed dim must be the subgroup size.
+ int subgroupSize =
+ destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
+ // We require that source and dest lane data are all ones to ensure
+ // uniform round robin distribution.
+ auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
+ auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
+ if (!llvm::all_of(destLaneData, [](int64_t v) { return v == 1; }) ||
+ !llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
+ return rewriter.notifyMatchFailure(
+ warpOp, "Expecting unit lane data in source and dest layouts");
+ // Source distributed dim size must be multiples of subgroup size.
+ if (srcDistrDimSize % subgroupSize != 0)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Distributed dimension size in source is not a multiple of "
+ "subgroup size.");
+ // Offsets in the distributed dimension must be multiples of subgroup
+ // size.
+ int64_t destDistrDimOffset =
+ cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
+ if (destDistrDimOffset % subgroupSize != 0)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "Offset along distributed dimension in dest is not a multiple of "
+ "subgroup size.");
+ // Update the source and dest types based on their layouts.
+ updatedSourceType = getDistVecTypeBasedOnLaneLayout(
+ sourceLayout, insertOp.getSourceVectorType())
+ .value();
+ updatedDestType = getDistVecTypeBasedOnLaneLayout(
+ destLayout, insertOp.getDestVectorType())
+ .value();
+ // Update the distributed offsets to match round robin distribution (i.e.
+ // each lane owns data at `subgroupSize` stride given unit lane data).
+ updatedOffsets[destDistributedDim] =
+ rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
+ }
+ // Do the distribution by yielding the source and dest of the insert op
+ // from the warp op and creating a new insert op outside the warp op.
+ SmallVector<size_t> newRetIndices;
+ auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
+ {updatedSourceType, updatedDestType}, newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+
+ Value valueToStore = newWarpOp.getResult(newRetIndices[0]);
+ Value dest = newWarpOp.getResult(newRetIndices[1]);
+ // Create a new insert op outside the warp op.
+ Value newInsertOp = vector::InsertStridedSliceOp::create(
+ rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest,
+ ArrayAttr::get(rewriter.getContext(), updatedOffsets),
+ insertOp.getStrides());
+ rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
+ newInsertOp);
+ return success();
+ }
+};
+
/// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
/// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
/// outside of the warp op.
@@ -1443,13 +2015,18 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
GpuBarrierDistribution, VectorMultiReductionDistribution,
LoadDistribution, StoreDistribution, VectorTransposeDistribution,
- VectorBitcastDistribution,
+ VectorBitcastDistribution, LoadMatrixDistribution,
+ StoreMatrixDistribution,
MemrefExtractAlignedPointerAsIndexDistribution>(
patterns.getContext(),
/*pattern benefit=*/regularPatternBenefit);
- patterns.add<VectorShapeCastDistribution>(
- patterns.getContext(),
- /*pattern benefit=*/highPatternBenefit);
+ // For following patterns, we need to override the regular vector distribution
+ // patterns. Therefore, assign higher benefit.
+ patterns
+ .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
+ VectorInsertStridedSliceDistribution, VectorBroadcastDistribution>(
+ patterns.getContext(),
+ /*pattern benefit=*/highPatternBenefit);
}
void xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(
@@ -1468,6 +2045,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
// Layouts are needed for vector type only.
if (!isa<VectorType>(operand.get().getType()))
continue;
+ if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>(op))
+ continue;
auto layout = xegpu::getDistributeLayoutAttr(operand.get());
if (!layout) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index e6e71cc..af63f09 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -238,6 +238,9 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
if (!targetShape)
return failure();
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ if (layout)
+ layout = layout.dropInstData();
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
@@ -255,7 +258,7 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value {
xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
op.getL1HintAttr(), op.getL2HintAttr(),
- op.getL3HintAttr());
+ op.getL3HintAttr(), layout);
// return dummy Value to satisfy function's signature
return nullptr;
};
@@ -282,6 +285,9 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
if (!targetShape)
return failure();
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ if (layout)
+ layout = layout.dropInstData();
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
@@ -306,7 +312,7 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
return xegpu::LoadNdOp::create(
rewriter, loc, newValueTy, convertedTdescs[0], offsets,
op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
- op.getL2HintAttr(), op.getL3HintAttr());
+ op.getL2HintAttr(), op.getL3HintAttr(), layout);
};
newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
*targetShape, createLoad, loc, rewriter);
@@ -331,6 +337,9 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
if (!targetShape)
return failure();
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ if (layout)
+ layout = layout.dropInstData();
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
@@ -354,7 +363,7 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
convertedTdescs[0], offsets,
op.getL1HintAttr(), op.getL2HintAttr(),
- op.getL3HintAttr());
+ op.getL3HintAttr(), layout);
// return dummy Value to satisfy function's signature
return nullptr;
};
@@ -678,12 +687,16 @@ struct UnrollLoadGatherOpWithOffset
pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
}
+ auto layout = op.getLayoutAttr();
+ if (layout)
+ layout = layout.dropInstData();
+
SmallVector<Value> newOps;
for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
auto newOp = xegpu::LoadGatherOp::create(
rewriter, loc, newValueTy, op.getSource(), o, m,
rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(),
- op.getL2HintAttr(), op.getL3HintAttr());
+ op.getL2HintAttr(), op.getL3HintAttr(), layout);
newOps.push_back(newOp);
}
@@ -774,12 +787,16 @@ struct UnrollStoreScatterOpWithOffsets
SmallVector<Value> convertedValues =
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+ auto layout = op.getLayoutAttr();
+ if (layout)
+ layout = layout.dropInstData();
+
for (auto [v, o, m] :
llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
rewriter.getI64IntegerAttr(chunkSize),
op.getL1HintAttr(), op.getL2HintAttr(),
- op.getL3HintAttr());
+ op.getL3HintAttr(), layout);
}
rewriter.eraseOp(op);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9fc5ad9..be82cda 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -86,8 +86,16 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
if (origOffsets.empty())
return failure();
+ // if op is xegpu::CreateNdDescOp, call op.getDescLayoutAttr()
+ xegpu::DistributeLayoutAttr layout;
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> ||
+ std::is_same_v<OpType, xegpu::StoreMatrixOp>) {
+ layout = op.getLayoutAttr();
+ } else {
+ layout = op.getDescLayoutAttr();
+ }
+
// not applicable to ops without workgroup layout attributes
- xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
if (!layout || !layout.isForWorkgroup())
return failure();
@@ -114,7 +122,8 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
// Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
// descriptors to be accessed, based on the layout information.
ArrayRef<int64_t> wgShape = op.getDataShape();
- auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ auto maybeDescOffsets =
+ layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
if (failed(maybeDescOffsets))
return failure();
@@ -189,7 +198,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
xegpu::TensorDescType tdescTy = op.getType();
ArrayRef<int64_t> wgShape = tdescTy.getShape();
Type elemTy = tdescTy.getElementType();
- xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
auto newTdescTy =
xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
@@ -308,6 +317,9 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
if (failed(genOffsetsList(rewriter, op, offsetsList)))
return failure();
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ if (layout)
+ layout = layout.dropSgLayoutAndData();
SmallVector<Value> newOps;
for (auto [tdesc, offsets] :
llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
@@ -317,7 +329,7 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
auto newOp = xegpu::LoadNdOp::create(
rewriter, op.getLoc(), newResTy, tdesc, offsets,
/*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
- op.getL2HintAttr(), op.getL3HintAttr());
+ op.getL2HintAttr(), op.getL3HintAttr(), layout);
newOps.push_back(newOp);
}
rewriter.replaceOpWithMultiple(op, {newOps});
@@ -338,11 +350,14 @@ struct WgToSgStoreNdOpWithOffset
if (failed(genOffsetsList(rewriter, op, offsetsList)))
return failure();
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ if (layout)
+ layout = layout.dropSgLayoutAndData();
for (auto [v, tdesc, offsets] :
llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
op.getL1HintAttr(), op.getL2HintAttr(),
- op.getL3HintAttr());
+ op.getL3HintAttr(), layout);
}
rewriter.eraseOp(op);
@@ -362,11 +377,14 @@ struct WgToSgPrefetchNdOpWithOffset
if (failed(genOffsetsList(rewriter, op, offsetsList)))
return failure();
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ if (layout)
+ layout = layout.dropSgLayoutAndData();
for (auto [tdesc, offsets] :
llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
op.getL1HintAttr(), op.getL2HintAttr(),
- op.getL3HintAttr());
+ op.getL3HintAttr(), layout);
}
rewriter.eraseOp(op);
@@ -488,10 +506,8 @@ struct WgToSgVectorBroadcastOp
for (auto operand : adaptor.getOperands().front()) {
auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
newResultType, operand);
- if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
- !layout.getEffectiveInstDataAsInt().empty())
- xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
- layout.dropSgLayoutAndData());
+ xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
+ layout.dropSgLayoutAndData());
newBroadcastOps.push_back(newBroadcast.getResult());
}
@@ -737,12 +753,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
Location loc = op.getLoc();
auto eltType = vecType.getElementType();
- auto setLayoutIfNeeded = [&](Value val) {
- if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
- !layout.getEffectiveInstDataAsInt().empty()) {
- xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
- layout.dropSgLayoutAndData());
- }
+ auto setLayout = [&](Value val) {
+ xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
+ layout.dropSgLayoutAndData());
};
if (vecAttr.isSplat()) {
@@ -750,14 +763,14 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
Attribute singleVal = vecAttr.getSplatValue<Attribute>();
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
- setLayoutIfNeeded(cstOp->getResult(0));
+ setLayout(cstOp->getResult(0));
rewriter.replaceOp(op, cstOp);
return success();
} else if (sgShape == wgShape) { // if the entire vector is shared by all
// subgroups, don't distribute
auto newConstOp =
arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
- setLayoutIfNeeded(newConstOp->getResult(0));
+ setLayout(newConstOp->getResult(0));
rewriter.replaceOp(op, newConstOp);
return success();
} else {
@@ -830,8 +843,8 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
// Get subgroup id
Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-
- auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ auto sgOffsets =
+ layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
if (failed(sgOffsets))
return failure();
@@ -859,9 +872,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
rewriter, loc, baseConstVec.getType(), mulOffset);
auto finalConst =
arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
- setLayoutIfNeeded(baseConstVec);
- setLayoutIfNeeded(bcastOffset);
- setLayoutIfNeeded(finalConst);
+ setLayout(baseConstVec);
+ setLayout(bcastOffset);
+ setLayout(finalConst);
newConstOps.push_back(finalConst);
}
rewriter.replaceOpWithMultiple(op, {newConstOps});
@@ -912,11 +925,12 @@ struct WgToSgLoadGatherOpWithOffset
VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
for (auto [offsets, mask] :
llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
+ auto newLayout = layout.dropSgLayoutAndData();
auto newLoadOp = xegpu::LoadGatherOp::create(
rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
- op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
- xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0),
- layout.dropSgLayoutAndData());
+ op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
+ newLayout);
+ xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0), newLayout);
newLoadOps.push_back(newLoadOp);
}
rewriter.replaceOpWithMultiple(op, {newLoadOps});
@@ -964,16 +978,14 @@ struct WgToSgStoreScatterOpWithOffset
adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
auto store = xegpu::StoreScatterOp::create(
rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
- op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+ op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
+ layout.dropSgLayoutAndData());
// Update the layout attribute to drop sg_layout and sg_data.
- if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
- !layout.getEffectiveInstDataAsInt().empty()) {
- for (OpOperand &operand : store->getOpOperands()) {
- // Skip for operand one (memref)
- if (operand.getOperandNumber() == 1)
- continue;
- xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData());
- }
+ for (OpOperand &operand : store->getOpOperands()) {
+ // Skip for operand one (memref)
+ if (operand.getOperandNumber() == 1)
+ continue;
+ xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData());
}
}
rewriter.eraseOp(op);
@@ -1052,7 +1064,8 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
- auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ auto sgOffsets =
+ layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
if (failed(sgOffsets))
return failure();
@@ -1065,15 +1078,12 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
auto finalSteps =
arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
- if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
- !layout.getEffectiveInstDataAsInt().empty()) {
- xegpu::setDistributeLayoutAttr(steps->getResult(0),
- layout.dropSgLayoutAndData());
- xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0),
- layout.dropSgLayoutAndData());
- xegpu::setDistributeLayoutAttr(finalSteps->getResult(0),
- layout.dropSgLayoutAndData());
- }
+ xegpu::setDistributeLayoutAttr(steps->getResult(0),
+ layout.dropSgLayoutAndData());
+ xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0),
+ layout.dropSgLayoutAndData());
+ xegpu::setDistributeLayoutAttr(finalSteps->getResult(0),
+ layout.dropSgLayoutAndData());
newOps.push_back(finalSteps);
}
@@ -1141,10 +1151,8 @@ struct WgToSgVectorShapeCastOp
for (auto src : adaptor.getSource()) {
auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
newResultType, src);
- if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
- !layout.getEffectiveInstDataAsInt().empty())
- xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
- layout.dropSgLayoutAndData());
+ xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
+ layout.dropSgLayoutAndData());
newShapeCastOps.push_back(newShapeCast.getResult());
}
@@ -1205,10 +1213,8 @@ struct WgToSgMultiDimReductionOp
auto newOp = vector::MultiDimReductionOp::create(
rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc,
adaptor.getAcc()[0], op.getReductionDims());
- if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
- !layout.getEffectiveInstDataAsInt().empty())
- xegpu::setDistributeLayoutAttr(newOp->getResult(0),
- layout.dropSgLayoutAndData());
+ xegpu::setDistributeLayoutAttr(newOp->getResult(0),
+ layout.dropSgLayoutAndData());
newReductions.push_back(newOp.getResult());
}
@@ -1217,6 +1223,142 @@ struct WgToSgMultiDimReductionOp
}
};
+// This pattern transforms vector.transpose ops to work at subgroup level.
+struct WgToSgVectorTransposeOp
+ : public OpConversionPattern<vector::TransposeOp> {
+ using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType resultType = op.getResultVectorType();
+
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ xegpu::DistributeLayoutAttr sourceLayout =
+ xegpu::getDistributeLayoutAttr(op.getVector());
+ if (!sourceLayout || !sourceLayout.isForWorkgroup())
+ return failure();
+
+ SmallVector<int64_t> sourceSgLayout =
+ sourceLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
+ DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder();
+ DenseI32ArrayAttr resultOrder = layout.getOrder();
+
+ if (!sourceOrder || !resultOrder) {
+ return rewriter.notifyMatchFailure(
+ op, "Both source and result must have order attributes");
+ }
+
+ ArrayRef<int64_t> permutation = op.getPermutation();
+ size_t permutationSize = permutation.size();
+ if (sourceSgLayout.size() != permutationSize ||
+ resultSgLayout.size() != permutationSize) {
+ return rewriter.notifyMatchFailure(
+ op, "Layouts and permutation must have the same rank");
+ }
+
+ // Check that sgLayout, sgData & order are properly transposed for source
+ // and result
+ if (!layout.isTransposeOf(sourceLayout, permutation))
+ return rewriter.notifyMatchFailure(
+ op, "Result layout is not a valid transpose of source layout "
+ "according to permutation");
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ VectorType newResultType =
+ VectorType::get(sgShape, resultType.getElementType());
+ SmallVector<Value> newTransposeOps;
+ for (auto src : adaptor.getVector()) {
+ auto newTranspose = vector::TransposeOp::create(
+ rewriter, op.getLoc(), newResultType, src, permutation);
+ xegpu::setDistributeLayoutAttr(newTranspose->getResult(0),
+ layout.dropSgLayoutAndData());
+ newTransposeOps.push_back(newTranspose.getResult());
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newTransposeOps});
+ return success();
+ }
+};
+
+// Distribute vector mask ops to work at subgroup level.
+template <typename MaskOpType>
+struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
+ using OpConversionPattern<MaskOpType>::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ MaskOpType op,
+ typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ Location loc = op.getLoc();
+ VectorType type = op.getResult().getType();
+ auto wgShape = type.getShape();
+
+ SmallVector<Value> wgMaskDimSizes;
+ if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
+ for (int64_t maskSize : op.getMaskDimSizes()) {
+ wgMaskDimSizes.push_back(
+ arith::ConstantIndexOp::create(rewriter, loc, maskSize));
+ }
+ } else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
+ wgMaskDimSizes = llvm::to_vector(op.getOperands());
+ }
+
+ Value sgId =
+ gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+ auto sgOffsets =
+ layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
+ if (failed(sgOffsets))
+ return failure();
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ VectorType resultType = VectorType::get(sgShape, type.getElementType());
+
+ // In each dimension, each subgroup computes its local mask size as:
+ // min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d])
+ SmallVector<Value> newCreateMaskOps;
+ for (auto offsetSet : *sgOffsets) {
+ SmallVector<Value> maskOperands;
+
+ for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
+ Value dimSizeVal =
+ arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
+ Value offset = offsetSet[i];
+ Value adjustedMaskSize =
+ arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value nonNegative =
+ arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
+ Value sgMaskSize =
+ arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
+ maskOperands.push_back(sgMaskSize);
+ }
+
+ auto newCreateMaskOp =
+ vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
+ xegpu::setDistributeLayoutAttr(newCreateMaskOp->getResult(0),
+ layout.dropSgLayoutAndData());
+ newCreateMaskOps.push_back(newCreateMaskOp.getResult());
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
+ return success();
+ }
+};
+
+using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
+using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
} // namespace
namespace mlir {
@@ -1231,7 +1373,9 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
- WgToSgMultiDimReductionOp>(patterns.getContext());
+ WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
+ WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
+ patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -1358,7 +1502,10 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});
- target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>(
+ target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
+ vector::TransposeOp, vector::BroadcastOp,
+ vector::MultiDimReductionOp,
+ vector::ConstantMaskOp, vector::CreateMaskOp>(
[=](Operation *op) -> bool {
// Check for either a SliceAttr or LayoutAttr on the result.
auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
@@ -1377,16 +1524,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});
- target.addDynamicallyLegalOp<vector::BroadcastOp>(
- [=](vector::BroadcastOp op) -> bool {
- return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
- });
-
- target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
- [=](vector::MultiDimReductionOp op) -> bool {
- return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
- });
-
target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
[=](xegpu::ConvertLayoutOp op) -> bool {
return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index a38993e..9f126fe 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -12,7 +12,6 @@
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -140,10 +139,14 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
// for StoreMatrixOp, the layout is attached to the property of the op
if (auto storeOp = dyn_cast<xegpu::StoreMatrixOp>(defOp))
return storeOp.getLayoutAttr();
-
std::string layoutName = getLayoutName(result);
if (defOp->hasAttr(layoutName))
return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+
+ // check for "permament" layout only after "temporary" layout name lookup
+ // for backward compatibility
+ if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(defOp))
+ return loadGatherOp.getLayoutAttr();
}
if (auto arg = dyn_cast<BlockArgument>(value)) {
@@ -171,27 +174,77 @@ xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
std::string layoutName = xegpu::getLayoutName(opr);
if (op->hasAttr(layoutName))
return op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+
+ // check for "permament" layout only after "temporary" layout name lookup
+ if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op))
+ if (auto layout = storeScatterOp.getLayoutAttr())
+ return layout;
+
return getDistributeLayoutAttr(opr.get());
}
+// Returns the permanent layout attribute for the given result if it's
+// available on the defining op. Otherwise returns the provided layout.
+xegpu::DistributeLayoutAttr
+maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout,
+ const OpResult &result, mlir::Operation *owner,
+ const std::string &name) {
+ xegpu::DistributeLayoutAttr candidate = layout;
+
+ if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(owner)) {
+ if (auto perm = loadOp.getLayoutAttr())
+ candidate = perm;
+ }
+
+ return candidate;
+}
+
+// Returns the permanent layout attribute for the given operand if it's
+// available on the defining op. Otherwise returns the provided layout.
+xegpu::DistributeLayoutAttr
+maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout,
+ const OpOperand &operand, mlir::Operation *owner,
+ const std::string &name) {
+ xegpu::DistributeLayoutAttr candidate = layout;
+ unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
+
+ if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(owner)) {
+ if (idx == 0) {
+ if (auto perm = storeOp.getLayoutAttr())
+ candidate = perm;
+ }
+ }
+
+ return candidate;
+}
+
template <typename T, typename>
void xegpu::setDistributeLayoutAttr(const T &operandOrResult,
- const DistributeLayoutAttr layout) {
+ const DistributeLayoutAttr layout,
+ bool respectPermLayout) {
Operation *owner = operandOrResult.getOwner();
std::string name = xegpu::getLayoutName(operandOrResult);
- if (layout && !owner->hasAttrOfType<DistributeLayoutAttr>(name))
- owner->setAttr(name, layout);
+
+ if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
+ return;
+
+ DistributeLayoutAttr candidate = layout;
+ if (respectPermLayout)
+ candidate = maybePickPermanentLayout(layout, operandOrResult, owner, name);
+
+ if (candidate)
+ owner->setAttr(name, candidate);
}
// Explicit instantiation for OpResult
template void xegpu::setDistributeLayoutAttr<mlir::OpResult>(
const mlir::OpResult &result,
- const mlir::xegpu::DistributeLayoutAttr layout);
+ const mlir::xegpu::DistributeLayoutAttr layout, bool respectPermLayout);
// Explicit instantiation for OpOperand
template void xegpu::setDistributeLayoutAttr<mlir::OpOperand>(
const mlir::OpOperand &operand,
- const mlir::xegpu::DistributeLayoutAttr layout);
+ const mlir::xegpu::DistributeLayoutAttr layout, bool respectPermLayout);
void xegpu::setDistributeLayoutAttrs(
Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) {
@@ -253,7 +306,7 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
int64_t rankDiff = srcShapeRank - targetShapeRank;
std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff,
1);
- std::copy(shape.begin(), shape.end(), adjustedTargetShape.begin() + rankDiff);
+ llvm::copy(shape, adjustedTargetShape.begin() + rankDiff);
SmallVector<Value> result;
for (SmallVector<int64_t> offsets :
@@ -473,7 +526,7 @@ SmallVector<OpFoldResult> xegpu::addElementwise(OpBuilder &builder,
for (auto [l, r] : llvm::zip_equal(lhs, rhs)) {
auto lval = getValueOrCreateConstantIndexOp(builder, loc, l);
auto rval = getValueOrCreateConstantIndexOp(builder, loc, r);
- results.push_back(builder.createOrFold<index::AddOp>(loc, lval, rval));
+ results.push_back(builder.createOrFold<arith::AddIOp>(loc, lval, rval));
}
return results;
}
@@ -500,3 +553,29 @@ xegpu::addWithRightAligned(OpBuilder &builder, Location loc,
results.append(addElementwise(builder, loc, a, b));
return results;
}
+
+template <typename T>
+int xegpu::getLargestDivisor(T dim, ArrayRef<T> candidates,
+ ArrayRef<T> candidateMultiples) {
+ static_assert(std::is_integral<T>::value, "T must be an integer type");
+ int largest = -1;
+ SmallVector<T> multiples = {1};
+ if (!candidateMultiples.empty())
+ multiples =
+ SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end());
+ for (T candidate : candidates) {
+ for (T multiple : multiples) {
+ int value = static_cast<int>(candidate * multiple);
+ if (value != 0 && dim % value == 0 && value > largest)
+ largest = value;
+ }
+ }
+ return largest;
+}
+
+/// Explicit instantiations
+template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates,
+ ArrayRef<int> candidateMultiples);
+template int
+xegpu::getLargestDivisor<unsigned>(unsigned dim, ArrayRef<unsigned> candidates,
+ ArrayRef<unsigned> candidateMultiples);