aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp90
-rw-r--r--mlir/lib/Analysis/Presburger/Barvinok.cpp14
-rw-r--r--mlir/lib/Analysis/Presburger/IntegerRelation.cpp87
-rw-r--r--mlir/lib/Analysis/Presburger/Matrix.cpp41
-rw-r--r--mlir/lib/Bindings/Python/DialectLLVM.cpp7
-rw-r--r--mlir/lib/Bindings/Python/DialectLinalg.cpp22
-rw-r--r--mlir/lib/Bindings/Python/ExecutionEngineModule.cpp10
-rw-r--r--mlir/lib/Bindings/Python/IRCore.cpp1842
-rw-r--r--mlir/lib/Bindings/Python/MainModule.cpp14
-rw-r--r--mlir/lib/Bindings/Python/NanobindUtils.h13
-rw-r--r--mlir/lib/Bytecode/Reader/BytecodeReader.cpp260
-rw-r--r--mlir/lib/CAPI/Dialect/LLVM.cpp23
-rw-r--r--mlir/lib/CAPI/Dialect/Linalg.cpp38
-rw-r--r--mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp15
-rw-r--r--mlir/lib/CAPI/IR/BuiltinTypes.cpp2
-rw-r--r--mlir/lib/CAPI/IR/IR.cpp5
-rw-r--r--mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp758
-rw-r--r--mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp665
-rw-r--r--mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt19
-rw-r--r--mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp6
-rw-r--r--mlir/lib/Conversion/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp3
-rw-r--r--mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp9
-rw-r--r--mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp5
-rw-r--r--mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp5
-rw-r--r--mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp52
-rw-r--r--mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp38
-rw-r--r--mlir/lib/Conversion/LLVMCommon/Pattern.cpp21
-rw-r--r--mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp42
-rw-r--r--mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp13
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp14
-rw-r--r--mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp20
-rw-r--r--mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp3
-rw-r--r--mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp7
-rw-r--r--mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp18
-rw-r--r--mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp3
-rw-r--r--mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp26
-rw-r--r--mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp35
-rw-r--r--mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp14
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp17
-rw-r--r--mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp19
-rw-r--r--mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp428
-rw-r--r--mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp27
-rw-r--r--mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp210
-rw-r--r--mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp34
-rw-r--r--mlir/lib/Dialect/Affine/IR/AffineOps.cpp2
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp2
-rw-r--r--mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp6
-rw-r--r--mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td4
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp2
-rw-r--r--mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp6
-rw-r--r--mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp12
-rw-r--r--mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp3
-rw-r--r--mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp8
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp10
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp2
-rw-r--r--mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp34
-rw-r--r--mlir/lib/Dialect/EmitC/IR/EmitC.cpp29
-rw-r--r--mlir/lib/Dialect/Func/Utils/Utils.cpp25
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp4
-rw-r--r--mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp1
-rw-r--r--mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp12
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp4
-rw-r--r--mlir/lib/Dialect/IRDL/IRDLLoading.cpp4
-rw-r--r--mlir/lib/Dialect/Index/IR/IndexDialect.cpp14
-rw-r--r--mlir/lib/Dialect/LLVMIR/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp11
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp19
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp28
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp4
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp1833
-rw-r--r--mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp45
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp23
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp11
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp301
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp8
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp68
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp28
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp31
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp34
-rw-r--r--mlir/lib/Dialect/Linalg/Utils/Utils.cpp725
-rw-r--r--mlir/lib/Dialect/MemRef/IR/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp1
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp19
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp7
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp6
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp32
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp33
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp89
-rw-r--r--mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp7
-rw-r--r--mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp6
-rw-r--r--mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp7
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp560
-rw-r--r--mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp146
-rw-r--r--mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp431
-rw-r--r--mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp237
-rw-r--r--mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp117
-rw-r--r--mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt3
-rw-r--r--mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp111
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp32
-rw-r--r--mlir/lib/Dialect/SCF/IR/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp139
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp216
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp83
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp77
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp5
-rw-r--r--mlir/lib/Dialect/Shard/IR/ShardOps.cpp82
-rw-r--r--mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp8
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp4
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp72
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h2
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp3
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorOps.cpp76
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp56
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp7
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp85
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp20
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp20
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt4
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp111
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp9
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp18
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp310
-rw-r--r--mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp18
-rw-r--r--mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp13
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformOps.cpp4
-rw-r--r--mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp3
-rw-r--r--mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp4
-rw-r--r--mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp8
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp16
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp50
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp14
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp285
-rw-r--r--mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp43
-rw-r--r--mlir/lib/Dialect/X86Vector/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt17
-rw-r--r--mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp64
-rw-r--r--mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt3
-rw-r--r--mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp143
-rw-r--r--mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp301
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp152
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp37
-rw-r--r--mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp514
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp2
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp588
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp419
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp19
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp182
-rw-r--r--mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp6
-rw-r--r--mlir/lib/ExecutionEngine/APFloatWrappers.cpp174
-rw-r--r--mlir/lib/ExecutionEngine/ArmRunnerUtils.cpp4
-rw-r--r--mlir/lib/ExecutionEngine/CMakeLists.txt32
-rw-r--r--mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp4
-rw-r--r--mlir/lib/ExecutionEngine/ExecutionEngine.cpp6
-rw-r--r--mlir/lib/ExecutionEngine/VulkanRuntimeWrappers.cpp4
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp2
-rw-r--r--mlir/lib/IR/Remarks.cpp5
-rw-r--r--mlir/lib/IR/TypeUtilities.cpp3
-rw-r--r--mlir/lib/Interfaces/InferTypeOpInterface.cpp16
-rw-r--r--mlir/lib/Interfaces/ValueBoundsOpInterface.cpp2
-rw-r--r--mlir/lib/Pass/Pass.cpp6
-rw-r--r--mlir/lib/Query/Matcher/Parser.cpp8
-rw-r--r--mlir/lib/Reducer/ReductionTreePass.cpp18
-rw-r--r--mlir/lib/RegisterAllExtensions.cpp2
-rw-r--r--mlir/lib/Rewrite/ByteCode.cpp60
-rw-r--r--mlir/lib/Rewrite/ByteCode.h7
-rw-r--r--mlir/lib/TableGen/Interfaces.cpp8
-rw-r--r--mlir/lib/TableGen/Pattern.cpp21
-rw-r--r--mlir/lib/Target/Cpp/TranslateToCpp.cpp233
-rw-r--r--mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp14
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp19
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp42
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp39
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp35
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp794
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleImport.cpp296
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleTranslation.cpp60
-rw-r--r--mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp2
-rw-r--r--mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp67
-rw-r--r--mlir/lib/Target/SPIRV/Deserialization/Deserializer.h17
-rw-r--r--mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp21
-rw-r--r--mlir/lib/Target/SPIRV/Serialization/Serializer.cpp17
-rw-r--r--mlir/lib/Target/SPIRV/Serialization/Serializer.h2
-rw-r--r--mlir/lib/Tools/PDLL/AST/Context.cpp2
-rw-r--r--mlir/lib/Tools/PDLL/AST/Nodes.cpp2
-rw-r--r--mlir/lib/Tools/PDLL/AST/TypeDetail.h141
-rw-r--r--mlir/lib/Tools/PDLL/AST/Types.cpp1
-rw-r--r--mlir/lib/Tools/mlir-opt/MlirOptMain.cpp33
-rw-r--r--mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp9
-rw-r--r--mlir/lib/Transforms/CMakeLists.txt2
-rw-r--r--mlir/lib/Transforms/RemoveDeadValues.cpp141
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp107
-rw-r--r--mlir/lib/Transforms/Utils/Inliner.cpp4
-rw-r--r--mlir/lib/Transforms/Utils/RegionUtils.cpp17
202 files changed, 13767 insertions, 3259 deletions
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index 70b56ca..a93e605 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -180,23 +180,20 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
return;
}
- /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
- /// on a LoopLikeInterface return the lower/upper bound for that result if
- /// possible.
- auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound,
- Type boundType, Block *block, bool getUpper) {
+ /// Given a lower bound, upper bound, or step from a LoopLikeInterface return
+ /// the lower/upper bound for that result if possible.
+ auto getLoopBoundFromFold = [&](OpFoldResult loopBound, Type boundType,
+ Block *block, bool getUpper) {
unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
- if (loopBound.has_value()) {
- if (auto attr = dyn_cast<Attribute>(*loopBound)) {
- if (auto bound = dyn_cast_or_null<IntegerAttr>(attr))
- return bound.getValue();
- } else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
- const IntegerValueRangeLattice *lattice =
- getLatticeElementFor(getProgramPointBefore(block), value);
- if (lattice != nullptr && !lattice->getValue().isUninitialized())
- return getUpper ? lattice->getValue().getValue().smax()
- : lattice->getValue().getValue().smin();
- }
+ if (auto attr = dyn_cast<Attribute>(loopBound)) {
+ if (auto bound = dyn_cast<IntegerAttr>(attr))
+ return bound.getValue();
+ } else if (auto value = llvm::dyn_cast<Value>(loopBound)) {
+ const IntegerValueRangeLattice *lattice =
+ getLatticeElementFor(getProgramPointBefore(block), value);
+ if (lattice != nullptr && !lattice->getValue().isUninitialized())
+ return getUpper ? lattice->getValue().getValue().smax()
+ : lattice->getValue().getValue().smin();
}
// Given the results of getConstant{Lower,Upper}Bound()
// or getConstantStep() on a LoopLikeInterface return the lower/upper
@@ -207,38 +204,43 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
// Infer bounds for loop arguments that have static bounds
if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
- std::optional<Value> iv = loop.getSingleInductionVar();
- if (!iv) {
+ std::optional<llvm::SmallVector<Value>> maybeIvs =
+ loop.getLoopInductionVars();
+ if (!maybeIvs) {
return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments(
op, successor, argLattices, firstIndex);
}
- Block *block = iv->getParentBlock();
- std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
- std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
- std::optional<OpFoldResult> step = loop.getSingleStep();
- APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), block,
- /*getUpper=*/false);
- APInt max = getLoopBoundFromFold(upperBound, iv->getType(), block,
- /*getUpper=*/true);
- // Assume positivity for uniscoverable steps by way of getUpper = true.
- APInt stepVal =
- getLoopBoundFromFold(step, iv->getType(), block, /*getUpper=*/true);
-
- if (stepVal.isNegative()) {
- std::swap(min, max);
- } else {
- // Correct the upper bound by subtracting 1 so that it becomes a <=
- // bound, because loops do not generally include their upper bound.
- max -= 1;
- }
+ // This shouldn't be returning nullopt if there are indunction variables.
+ SmallVector<OpFoldResult> lowerBounds = *loop.getLoopLowerBounds();
+ SmallVector<OpFoldResult> upperBounds = *loop.getLoopUpperBounds();
+ SmallVector<OpFoldResult> steps = *loop.getLoopSteps();
+ for (auto [iv, lowerBound, upperBound, step] :
+ llvm::zip_equal(*maybeIvs, lowerBounds, upperBounds, steps)) {
+ Block *block = iv.getParentBlock();
+ APInt min = getLoopBoundFromFold(lowerBound, iv.getType(), block,
+ /*getUpper=*/false);
+ APInt max = getLoopBoundFromFold(upperBound, iv.getType(), block,
+ /*getUpper=*/true);
+ // Assume positivity for uniscoverable steps by way of getUpper = true.
+ APInt stepVal =
+ getLoopBoundFromFold(step, iv.getType(), block, /*getUpper=*/true);
+
+ if (stepVal.isNegative()) {
+ std::swap(min, max);
+ } else {
+ // Correct the upper bound by subtracting 1 so that it becomes a <=
+ // bound, because loops do not generally include their upper bound.
+ max -= 1;
+ }
- // If we infer the lower bound to be larger than the upper bound, the
- // resulting range is meaningless and should not be used in further
- // inferences.
- if (max.sge(min)) {
- IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
- auto ivRange = ConstantIntRanges::fromSigned(min, max);
- propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
+ // If we infer the lower bound to be larger than the upper bound, the
+ // resulting range is meaningless and should not be used in further
+ // inferences.
+ if (max.sge(min)) {
+ IntegerValueRangeLattice *ivEntry = getLatticeElement(iv);
+ auto ivRange = ConstantIntRanges::fromSigned(min, max);
+ propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
+ }
}
return;
}
diff --git a/mlir/lib/Analysis/Presburger/Barvinok.cpp b/mlir/lib/Analysis/Presburger/Barvinok.cpp
index 75d592e..c31b277 100644
--- a/mlir/lib/Analysis/Presburger/Barvinok.cpp
+++ b/mlir/lib/Analysis/Presburger/Barvinok.cpp
@@ -178,13 +178,13 @@ mlir::presburger::detail::solveParametricEquations(FracMatrix equations) {
for (unsigned i = 0; i < d; ++i) {
// First ensure that the diagonal element is nonzero, by swapping
// it with a row that is non-zero at column i.
- if (equations(i, i) != 0)
- continue;
- for (unsigned j = i + 1; j < d; ++j) {
- if (equations(j, i) == 0)
- continue;
- equations.swapRows(j, i);
- break;
+ if (equations(i, i) == 0) {
+ for (unsigned j = i + 1; j < d; ++j) {
+ if (equations(j, i) == 0)
+ continue;
+ equations.swapRows(j, i);
+ break;
+ }
}
Fraction diagElement = equations(i, i);
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 812043d..26197ce 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -21,6 +21,7 @@
#include "mlir/Analysis/Presburger/Simplex.h"
#include "mlir/Analysis/Presburger/Utils.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallBitVector.h"
@@ -442,6 +443,14 @@ void IntegerRelation::removeInequality(unsigned pos) {
inequalities.removeRow(pos);
}
+void IntegerRelation::removeConstraint(unsigned pos) {
+ if (pos >= getNumInequalities()) {
+ removeEquality(pos - getNumInequalities());
+ } else {
+ removeInequality(pos);
+ }
+}
+
void IntegerRelation::removeEqualityRange(unsigned start, unsigned end) {
if (start >= end)
return;
@@ -1112,15 +1121,29 @@ unsigned IntegerRelation::gaussianEliminateVars(unsigned posStart,
return posLimit - posStart;
}
+static std::optional<unsigned>
+findEqualityWithNonZeroAfterRow(IntegerRelation &rel, unsigned fromRow,
+ unsigned colIdx) {
+ assert(fromRow < rel.getNumEqualities() && colIdx < rel.getNumCols() &&
+ "position out of bounds");
+ for (unsigned rowIdx = fromRow, e = rel.getNumEqualities(); rowIdx < e;
+ ++rowIdx) {
+ if (rel.atEq(rowIdx, colIdx) != 0)
+ return rowIdx;
+ }
+ return std::nullopt;
+}
+
bool IntegerRelation::gaussianEliminate() {
gcdTightenInequalities();
unsigned firstVar = 0, vars = getNumVars();
unsigned nowDone, eqs;
std::optional<unsigned> pivotRow;
for (nowDone = 0, eqs = getNumEqualities(); nowDone < eqs; ++nowDone) {
- // Finds the first non-empty column.
+ // Finds the first non-empty column that we haven't dealt with.
for (; firstVar < vars; ++firstVar) {
- if ((pivotRow = findConstraintWithNonZeroAt(firstVar, /*isEq=*/true)))
+ if ((pivotRow =
+ findEqualityWithNonZeroAfterRow(*this, nowDone, firstVar)))
break;
}
// The matrix has been normalized to row echelon form.
@@ -1143,6 +1166,10 @@ bool IntegerRelation::gaussianEliminate() {
inequalities.normalizeRow(i);
}
gcdTightenInequalities();
+
+ // The column is finished. Tell the next iteration to start at the next
+ // column.
+ firstVar++;
}
// No redundant rows.
@@ -1724,12 +1751,64 @@ std::optional<DynamicAPInt> IntegerRelation::getConstantBoundOnDimSize(
return minDiff;
}
+void IntegerRelation::pruneOrthogonalConstraints(unsigned pos) {
+ llvm::DenseSet<unsigned> relatedCols({pos}), relatedRows;
+
+ // Early exit if constraints is empty.
+ unsigned numConstraints = getNumConstraints();
+ if (numConstraints == 0)
+ return;
+
+ llvm::SmallVector<unsigned> rowStack, colStack({pos});
+ // The following code performs a graph traversal, starting from the target
+ // variable, to identify all variables(recorded in relatedCols) and
+ // constraints (recorded in relatedRows) belonging to the same connected
+ // component.
+ while (!rowStack.empty() || !colStack.empty()) {
+ if (!rowStack.empty()) {
+ unsigned currentRow = rowStack.pop_back_val();
+ // Push all variable that accociated to this constraints to relatedCols
+ // and colStack.
+ for (unsigned colIndex = 0; colIndex < getNumVars(); ++colIndex) {
+ if (atConstraint(currentRow, colIndex) != 0 &&
+ relatedCols.insert(colIndex).second) {
+ colStack.push_back(colIndex);
+ }
+ }
+ } else {
+ unsigned currentCol = colStack.pop_back_val();
+ // Push all constraints that are associated with this variable to related
+ // rows and the row stack.
+ for (unsigned rowIndex = 0; rowIndex < numConstraints; ++rowIndex) {
+ if (atConstraint(rowIndex, currentCol) != 0 &&
+ relatedRows.insert(rowIndex).second) {
+ rowStack.push_back(rowIndex);
+ }
+ }
+ }
+ }
+
+ // Prune all constraints not related to target variable.
+ for (int constraintId = numConstraints - 1; constraintId >= 0;
+ --constraintId) {
+ if (!relatedRows.contains(constraintId))
+ removeConstraint((unsigned)constraintId);
+ }
+}
+
template <bool isLower>
std::optional<DynamicAPInt>
IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) {
assert(pos < getNumVars() && "invalid position");
// Project to 'pos'.
+ // Prune orthogonal constraints to reduce unnecessary computations and
+ // accelerate the bound computation.
+ pruneOrthogonalConstraints(pos);
projectOut(0, pos);
+
+ // After projecting out values, more orthogonal constraints may be exposed.
+ // Prune these orthogonal constraints again.
+ pruneOrthogonalConstraints(0);
projectOut(1, getNumVars() - 1);
// Check if there's an equality equating the '0'^th variable to a constant.
int eqRowIdx = findEqualityToConstant(/*pos=*/0, /*symbolic=*/false);
@@ -2265,11 +2344,11 @@ IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
newLb[d] = lbFloorDivisor;
newUb[d] = -lbFloorDivisor;
// Copy over the symbolic part + constant term.
- std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimVars());
+ llvm::copy(minLb, newLb.begin() + getNumDimVars());
std::transform(newLb.begin() + getNumDimVars(), newLb.end(),
newLb.begin() + getNumDimVars(),
std::negate<DynamicAPInt>());
- std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimVars());
+ llvm::copy(maxUb, newUb.begin() + getNumDimVars());
boundingLbs.emplace_back(newLb);
boundingUbs.emplace_back(newUb);
diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp
index bb60564..83a2c28 100644
--- a/mlir/lib/Analysis/Presburger/Matrix.cpp
+++ b/mlir/lib/Analysis/Presburger/Matrix.cpp
@@ -255,20 +255,13 @@ void Matrix<T>::fillRow(unsigned row, const T &value) {
}
// moveColumns is implemented by moving the columns adjacent to the source range
-// to their final position. When moving right (i.e. dstPos > srcPos), the range
-// of the adjacent columns is [srcPos + num, dstPos + num). When moving left
-// (i.e. dstPos < srcPos) the range of the adjacent columns is [dstPos, srcPos).
-// First, zeroed out columns are inserted in the final positions of the adjacent
-// columns. Then, the adjacent columns are moved to their final positions by
-// swapping them with the zeroed columns. Finally, the now zeroed adjacent
-// columns are deleted.
+// to their final position.
template <typename T>
void Matrix<T>::moveColumns(unsigned srcPos, unsigned num, unsigned dstPos) {
if (num == 0)
return;
- int offset = dstPos - srcPos;
- if (offset == 0)
+ if (dstPos == srcPos)
return;
assert(srcPos + num <= getNumColumns() &&
@@ -276,23 +269,19 @@ void Matrix<T>::moveColumns(unsigned srcPos, unsigned num, unsigned dstPos) {
assert(dstPos + num <= getNumColumns() &&
"move destination range exceeds matrix columns");
- unsigned insertCount = offset > 0 ? offset : -offset;
- unsigned finalAdjStart = offset > 0 ? srcPos : srcPos + num;
- unsigned curAdjStart = offset > 0 ? srcPos + num : dstPos;
- // TODO: This can be done using std::rotate.
- // Insert new zero columns in the positions where the adjacent columns are to
- // be moved.
- insertColumns(finalAdjStart, insertCount);
- // Update curAdjStart if insertion of new columns invalidates it.
- if (finalAdjStart < curAdjStart)
- curAdjStart += insertCount;
-
- // Swap the adjacent columns with inserted zero columns.
- for (unsigned i = 0; i < insertCount; ++i)
- swapColumns(finalAdjStart + i, curAdjStart + i);
-
- // Delete the now redundant zero columns.
- removeColumns(curAdjStart, insertCount);
+ unsigned numRows = getNumRows();
+ // std::rotate(start, middle, end) permutes the elements of [start, end] to
+ // [middle, end) + [start, middle). NOTE: &at(i, srcPos + num) will trigger an
+ // assert.
+ if (dstPos > srcPos) {
+ for (unsigned i = 0; i < numRows; ++i) {
+ std::rotate(&at(i, srcPos), &at(i, srcPos) + num, &at(i, dstPos) + num);
+ }
+ return;
+ }
+ for (unsigned i = 0; i < numRows; ++i) {
+ std::rotate(&at(i, dstPos), &at(i, srcPos), &at(i, srcPos) + num);
+ }
}
template <typename T>
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 870a713..05681ce 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -31,8 +31,8 @@ static void populateDialectLLVMSubmodule(nanobind::module_ &m) {
// StructType
//===--------------------------------------------------------------------===//
- auto llvmStructType =
- mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
+ auto llvmStructType = mlir_type_subclass(
+ m, "StructType", mlirTypeIsALLVMStructType, mlirLLVMStructTypeGetTypeID);
llvmStructType
.def_classmethod(
@@ -137,7 +137,8 @@ static void populateDialectLLVMSubmodule(nanobind::module_ &m) {
// PointerType
//===--------------------------------------------------------------------===//
- mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType)
+ mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType,
+ mlirLLVMPointerTypeGetTypeID)
.def_classmethod(
"get",
[](const nb::object &cls, std::optional<unsigned> addressSpace,
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index 0155023..0b079b4 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -80,6 +80,28 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
"op.",
nb::arg("op"));
+ m.def(
+ "infer_contraction_dimensions_from_maps",
+ [](std::vector<MlirAffineMap> indexingMaps)
+ -> std::optional<MlirLinalgContractionDimensions> {
+ if (indexingMaps.empty())
+ return std::nullopt;
+
+ MlirLinalgContractionDimensions dims =
+ mlirLinalgInferContractionDimensionsFromMaps(indexingMaps.data(),
+ indexingMaps.size());
+
+ // Detect "empty" result from invalid input or failed inference.
+ if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
+ mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
+ return std::nullopt;
+ }
+ return dims;
+ },
+ "Infers contraction dimensions (batch/m/n/k) from a list of affine "
+ "maps.",
+ nb::arg("indexing_maps"));
+
m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp,
"Checks if the given operation is a Linalg convolution operation.",
nb::arg("op"));
diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
index 8bb493e..be0785b1 100644
--- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
+++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
@@ -75,13 +75,13 @@ NB_MODULE(_mlirExecutionEngine, m) {
"__init__",
[](PyExecutionEngine &self, MlirModule module, int optLevel,
const std::vector<std::string> &sharedLibPaths,
- bool enableObjectDump) {
+ bool enableObjectDump, bool enablePIC) {
llvm::SmallVector<MlirStringRef, 4> libPaths;
for (const std::string &path : sharedLibPaths)
libPaths.push_back({path.c_str(), path.length()});
- MlirExecutionEngine executionEngine =
- mlirExecutionEngineCreate(module, optLevel, libPaths.size(),
- libPaths.data(), enableObjectDump);
+ MlirExecutionEngine executionEngine = mlirExecutionEngineCreate(
+ module, optLevel, libPaths.size(), libPaths.data(),
+ enableObjectDump, enablePIC);
if (mlirExecutionEngineIsNull(executionEngine))
throw std::runtime_error(
"Failure while creating the ExecutionEngine.");
@@ -89,7 +89,7 @@ NB_MODULE(_mlirExecutionEngine, m) {
},
nb::arg("module"), nb::arg("opt_level") = 2,
nb::arg("shared_libs") = nb::list(),
- nb::arg("enable_object_dump") = true,
+ nb::arg("enable_object_dump") = true, nb::arg("enable_pic") = false,
"Create a new ExecutionEngine instance for the given Module. The "
"module must contain only dialects that can be translated to LLVM. "
"Perform transformations and code generation at the optimization "
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index cda4fe1..2e0c2b8 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -18,6 +18,7 @@
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"
+#include "nanobind/typing.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
@@ -32,33 +33,6 @@ using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;
-//------------------------------------------------------------------------------
-// Docstrings (trivial, non-duplicated docstrings are included inline).
-//------------------------------------------------------------------------------
-
-static const char kContextParseTypeDocstring[] =
- R"(Parses the assembly form of a type.
-
-Returns a Type object or raises an MLIRError if the type cannot be parsed.
-
-See also: https://mlir.llvm.org/docs/LangRef/#type-system
-)";
-
-static const char kContextGetCallSiteLocationDocstring[] =
- R"(Gets a Location representing a caller and callsite)";
-
-static const char kContextGetFileLocationDocstring[] =
- R"(Gets a Location representing a file, line and column)";
-
-static const char kContextGetFileRangeDocstring[] =
- R"(Gets a Location representing a file, line and column range)";
-
-static const char kContextGetFusedLocationDocstring[] =
- R"(Gets a Location representing a fused location with optional metadata)";
-
-static const char kContextGetNameLocationDocString[] =
- R"(Gets a Location representing a named location with optional child location)";
-
static const char kModuleParseDocstring[] =
R"(Parses a module's assembly format from a string.
@@ -67,132 +41,12 @@ Returns a new MlirModule or raises an MLIRError if the parsing fails.
See also: https://mlir.llvm.org/docs/LangRef/
)";
-static const char kModuleCAPICreate[] =
- R"(Creates a Module from a MlirModule wrapped by a capsule (i.e. module._CAPIPtr).
-Note this returns a new object BUT _clear_mlir_module(module) must be called to
-prevent double-frees (of the underlying mlir::Module).
-)";
-
-static const char kOperationCreateDocstring[] =
- R"(Creates a new operation.
-
-Args:
- name: Operation name (e.g. "dialect.operation").
- results: Sequence of Type representing op result types.
- attributes: Dict of str:Attribute.
- successors: List of Block for the operation's successors.
- regions: Number of regions to create.
- location: A Location object (defaults to resolve from context manager).
- ip: An InsertionPoint (defaults to resolve from context manager or set to
- False to disable insertion, even with an insertion point set in the
- context manager).
- infer_type: Whether to infer result types.
-Returns:
- A new "detached" Operation object. Detached operations can be added
- to blocks, which causes them to become "attached."
-)";
-
-static const char kOperationPrintDocstring[] =
- R"(Prints the assembly form of the operation to a file like object.
-
-Args:
- file: The file like object to write to. Defaults to sys.stdout.
- binary: Whether to write bytes (True) or str (False). Defaults to False.
- large_elements_limit: Whether to elide elements attributes above this
- number of elements. Defaults to None (no limit).
- large_resource_limit: Whether to elide resource attributes above this
- number of characters. Defaults to None (no limit). If large_elements_limit
- is set and this is None, the behavior will be to use large_elements_limit
- as large_resource_limit.
- enable_debug_info: Whether to print debug/location information. Defaults
- to False.
- pretty_debug_info: Whether to format debug information for easier reading
- by a human (warning: the result is unparseable).
- print_generic_op_form: Whether to print the generic assembly forms of all
- ops. Defaults to False.
- use_local_Scope: Whether to print in a way that is more optimized for
- multi-threaded access but may not be consistent with how the overall
- module prints.
- assume_verified: By default, if not printing generic form, the verifier
- will be run and if it fails, generic form will be printed with a comment
- about failed verification. While a reasonable default for interactive use,
- for systematic use, it is often better for the caller to verify explicitly
- and report failures in a more robust fashion. Set this to True if doing this
- in order to avoid running a redundant verification. If the IR is actually
- invalid, behavior is undefined.
- skip_regions: Whether to skip printing regions. Defaults to False.
-)";
-
-static const char kOperationPrintStateDocstring[] =
- R"(Prints the assembly form of the operation to a file like object.
-
-Args:
- file: The file like object to write to. Defaults to sys.stdout.
- binary: Whether to write bytes (True) or str (False). Defaults to False.
- state: AsmState capturing the operation numbering and flags.
-)";
-
-static const char kOperationGetAsmDocstring[] =
- R"(Gets the assembly form of the operation with all options available.
-
-Args:
- binary: Whether to return a bytes (True) or str (False) object. Defaults to
- False.
- ... others ...: See the print() method for common keyword arguments for
- configuring the printout.
-Returns:
- Either a bytes or str object, depending on the setting of the 'binary'
- argument.
-)";
-
-static const char kOperationPrintBytecodeDocstring[] =
- R"(Write the bytecode form of the operation to a file like object.
-
-Args:
- file: The file like object to write to.
- desired_version: The version of bytecode to emit.
-Returns:
- The bytecode writer status.
-)";
-
-static const char kOperationStrDunderDocstring[] =
- R"(Gets the assembly form of the operation with default options.
-
-If more advanced control over the assembly formatting or I/O options is needed,
-use the dedicated print or get_asm method, which supports keyword arguments to
-customize behavior.
-)";
-
static const char kDumpDocstring[] =
- R"(Dumps a debug representation of the object to stderr.)";
-
-static const char kAppendBlockDocstring[] =
- R"(Appends a new block, with argument types as positional args.
-
-Returns:
- The created block.
-)";
-
-static const char kValueDunderStrDocstring[] =
- R"(Returns the string form of the value.
-
-If the value is a block argument, this is the assembly form of its type and the
-position in the argument list. If the value is an operation result, this is
-equivalent to printing the operation that produced it.
-)";
-
-static const char kGetNameAsOperand[] =
- R"(Returns the string form of value as an operand (i.e., the ValueID).
-)";
-
-static const char kValueReplaceAllUsesWithDocstring[] =
- R"(Replace all uses of value with the new value, updating anything in
-the IR that uses 'self' to use the other value instead.
-)";
+ "Dumps a debug representation of the object to stderr.";
static const char kValueReplaceAllUsesExceptDocstring[] =
- R"("Replace all uses of this value with the 'with' value, except for those
-in 'exceptions'. 'exceptions' can be either a single operation or a list of
+ R"(Replace all uses of this value with the `with` value, except for those
+in `exceptions`. `exceptions` can be either a single operation or a list of
operations.
)";
@@ -274,22 +128,26 @@ struct PyGlobalDebugFlag {
// Debug flags.
nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
.def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
- &PyGlobalDebugFlag::set, "LLVM-wide debug flag")
+ &PyGlobalDebugFlag::set, "LLVM-wide debug flag.")
.def_static(
"set_types",
[](const std::string &type) {
nb::ft_lock_guard lock(mutex);
mlirSetGlobalDebugType(type.c_str());
},
- "types"_a, "Sets specific debug types to be produced by LLVM")
- .def_static("set_types", [](const std::vector<std::string> &types) {
- std::vector<const char *> pointers;
- pointers.reserve(types.size());
- for (const std::string &str : types)
- pointers.push_back(str.c_str());
- nb::ft_lock_guard lock(mutex);
- mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
- });
+ "types"_a, "Sets specific debug types to be produced by LLVM.")
+ .def_static(
+ "set_types",
+ [](const std::vector<std::string> &types) {
+ std::vector<const char *> pointers;
+ pointers.reserve(types.size());
+ for (const std::string &str : types)
+ pointers.push_back(str.c_str());
+ nb::ft_lock_guard lock(mutex);
+ mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
+ },
+ "types"_a,
+ "Sets multiple specific debug types to be produced by LLVM.");
}
private:
@@ -316,12 +174,18 @@ struct PyAttrBuilderMap {
static void bind(nb::module_ &m) {
nb::class_<PyAttrBuilderMap>(m, "AttrBuilder")
- .def_static("contains", &PyAttrBuilderMap::dunderContains)
- .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed)
+ .def_static("contains", &PyAttrBuilderMap::dunderContains,
+ "attribute_kind"_a,
+ "Checks whether an attribute builder is registered for the "
+ "given attribute kind.")
+ .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed,
+ "attribute_kind"_a,
+ "Gets the registered attribute builder for the given "
+ "attribute kind.")
.def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
"attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
"Register an attribute builder for building MLIR "
- "attributes from python values.");
+ "attributes from Python values.");
}
};
@@ -341,8 +205,8 @@ namespace {
class PyRegionIterator {
public:
- PyRegionIterator(PyOperationRef operation)
- : operation(std::move(operation)) {}
+ PyRegionIterator(PyOperationRef operation, int nextIndex)
+ : operation(std::move(operation)), nextIndex(nextIndex) {}
PyRegionIterator &dunderIter() { return *this; }
@@ -357,13 +221,15 @@ public:
static void bind(nb::module_ &m) {
nb::class_<PyRegionIterator>(m, "RegionIterator")
- .def("__iter__", &PyRegionIterator::dunderIter)
- .def("__next__", &PyRegionIterator::dunderNext);
+ .def("__iter__", &PyRegionIterator::dunderIter,
+ "Returns an iterator over the regions in the operation.")
+ .def("__next__", &PyRegionIterator::dunderNext,
+ "Returns the next region in the iteration.");
}
private:
PyOperationRef operation;
- int nextIndex = 0;
+ intptr_t nextIndex = 0;
};
/// Regions of an op are fixed length and indexed numerically so are represented
@@ -382,11 +248,12 @@ public:
PyRegionIterator dunderIter() {
operation->checkValid();
- return PyRegionIterator(operation);
+ return PyRegionIterator(operation, startIndex);
}
static void bindDerived(ClassTy &c) {
- c.def("__iter__", &PyRegionList::dunderIter);
+ c.def("__iter__", &PyRegionList::dunderIter,
+ "Returns an iterator over the regions in the sequence.");
}
private:
@@ -430,8 +297,10 @@ public:
static void bind(nb::module_ &m) {
nb::class_<PyBlockIterator>(m, "BlockIterator")
- .def("__iter__", &PyBlockIterator::dunderIter)
- .def("__next__", &PyBlockIterator::dunderNext);
+ .def("__iter__", &PyBlockIterator::dunderIter,
+ "Returns an iterator over the blocks in the operation's region.")
+ .def("__next__", &PyBlockIterator::dunderNext,
+ "Returns the next block in the iteration.");
}
private:
@@ -493,10 +362,19 @@ public:
static void bind(nb::module_ &m) {
nb::class_<PyBlockList>(m, "BlockList")
- .def("__getitem__", &PyBlockList::dunderGetItem)
- .def("__iter__", &PyBlockList::dunderIter)
- .def("__len__", &PyBlockList::dunderLen)
- .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring,
+ .def("__getitem__", &PyBlockList::dunderGetItem,
+ "Returns the block at the specified index.")
+ .def("__iter__", &PyBlockList::dunderIter,
+ "Returns an iterator over blocks in the operation's region.")
+ .def("__len__", &PyBlockList::dunderLen,
+ "Returns the number of blocks in the operation's region.")
+ .def("append", &PyBlockList::appendBlock,
+ R"(
+ Appends a new block, with argument types as positional args.
+
+ Returns:
+ The created block.
+ )",
nb::arg("args"), nb::kw_only(),
nb::arg("arg_locs") = std::nullopt);
}
@@ -527,8 +405,10 @@ public:
static void bind(nb::module_ &m) {
nb::class_<PyOperationIterator>(m, "OperationIterator")
- .def("__iter__", &PyOperationIterator::dunderIter)
- .def("__next__", &PyOperationIterator::dunderNext);
+ .def("__iter__", &PyOperationIterator::dunderIter,
+ "Returns an iterator over the operations in an operation's block.")
+ .def("__next__", &PyOperationIterator::dunderNext,
+ "Returns the next operation in the iteration.");
}
private:
@@ -584,9 +464,12 @@ public:
static void bind(nb::module_ &m) {
nb::class_<PyOperationList>(m, "OperationList")
- .def("__getitem__", &PyOperationList::dunderGetItem)
- .def("__iter__", &PyOperationList::dunderIter)
- .def("__len__", &PyOperationList::dunderLen);
+ .def("__getitem__", &PyOperationList::dunderGetItem,
+ "Returns the operation at the specified index.")
+ .def("__iter__", &PyOperationList::dunderIter,
+ "Returns an iterator over operations in the list.")
+ .def("__len__", &PyOperationList::dunderLen,
+ "Returns the number of operations in the list.");
}
private:
@@ -609,8 +492,10 @@ public:
static void bind(nb::module_ &m) {
nb::class_<PyOpOperand>(m, "OpOperand")
- .def_prop_ro("owner", &PyOpOperand::getOwner)
- .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber);
+ .def_prop_ro("owner", &PyOpOperand::getOwner,
+ "Returns the operation that owns this operand.")
+ .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber,
+ "Returns the operand number in the owning operation.");
}
private:
@@ -634,8 +519,10 @@ public:
static void bind(nb::module_ &m) {
nb::class_<PyOpOperandIterator>(m, "OpOperandIterator")
- .def("__iter__", &PyOpOperandIterator::dunderIter)
- .def("__next__", &PyOpOperandIterator::dunderNext);
+ .def("__iter__", &PyOpOperandIterator::dunderIter,
+ "Returns an iterator over operands.")
+ .def("__next__", &PyOpOperandIterator::dunderNext,
+ "Returns the next operand in the iteration.");
}
private:
@@ -1524,9 +1411,10 @@ nb::object PyOperation::create(std::string_view name,
}
// Construct the operation.
+ PyMlirContext::ErrorCapture errors(location.getContext());
MlirOperation operation = mlirOperationCreate(&state);
if (!operation.ptr)
- throw nb::value_error("Operation creation failed");
+ throw MLIRError("Operation creation failed", errors.take());
PyOperationRef created =
PyOperation::createDetached(location.getContext(), operation);
maybeInsertOperation(created, maybeIp);
@@ -1596,7 +1484,11 @@ public:
/// Binds the Python module objects to functions of this class.
static void bind(nb::module_ &m) {
- auto cls = ClassTy(m, DerivedTy::pyClassName);
+ auto cls = ClassTy(
+ m, DerivedTy::pyClassName, nb::is_generic(),
+ nb::sig((Twine("class ") + DerivedTy::pyClassName + "(Value[_T])")
+ .str()
+ .c_str()));
cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
cls.def_static(
"isinstance",
@@ -1626,16 +1518,21 @@ public:
static void bindDerived(ClassTy &c) {
c.def_prop_ro(
- "owner", [](PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
+ "owner",
+ [](PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
assert(mlirOperationEqual(self.getParentOperation()->get(),
mlirOpResultGetOwner(self.get())) &&
"expected the owner of the value in Python to match that in "
"the IR");
return self.getParentOperation().getObject();
- });
- c.def_prop_ro("result_number", [](PyOpResult &self) {
- return mlirOpResultGetResultNumber(self.get());
- });
+ },
+ "Returns the operation that produces this result.");
+ c.def_prop_ro(
+ "result_number",
+ [](PyOpResult &self) {
+ return mlirOpResultGetResultNumber(self.get());
+ },
+ "Returns the position of this result in the operation's result list.");
}
};
@@ -1671,13 +1568,18 @@ public:
operation(std::move(operation)) {}
static void bindDerived(ClassTy &c) {
- c.def_prop_ro("types", [](PyOpResultList &self) {
- return getValueTypes(self, self.operation->getContext());
- });
- c.def_prop_ro("owner",
- [](PyOpResultList &self) -> nb::typed<nb::object, PyOpView> {
- return self.operation->createOpView();
- });
+ c.def_prop_ro(
+ "types",
+ [](PyOpResultList &self) {
+ return getValueTypes(self, self.operation->getContext());
+ },
+ "Returns a list of types for all results in this result list.");
+ c.def_prop_ro(
+ "owner",
+ [](PyOpResultList &self) -> nb::typed<nb::object, PyOpView> {
+ return self.operation->createOpView();
+ },
+ "Returns the operation that owns this result list.");
}
PyOperationRef &getOperation() { return operation; }
@@ -2427,19 +2329,31 @@ public:
using PyConcreteValue::PyConcreteValue;
static void bindDerived(ClassTy &c) {
- c.def_prop_ro("owner", [](PyBlockArgument &self) {
- return PyBlock(self.getParentOperation(),
- mlirBlockArgumentGetOwner(self.get()));
- });
- c.def_prop_ro("arg_number", [](PyBlockArgument &self) {
- return mlirBlockArgumentGetArgNumber(self.get());
- });
+ c.def_prop_ro(
+ "owner",
+ [](PyBlockArgument &self) {
+ return PyBlock(self.getParentOperation(),
+ mlirBlockArgumentGetOwner(self.get()));
+ },
+ "Returns the block that owns this argument.");
+ c.def_prop_ro(
+ "arg_number",
+ [](PyBlockArgument &self) {
+ return mlirBlockArgumentGetArgNumber(self.get());
+ },
+ "Returns the position of this argument in the block's argument list.");
c.def(
"set_type",
[](PyBlockArgument &self, PyType type) {
return mlirBlockArgumentSetType(self.get(), type);
},
- nb::arg("type"));
+ nb::arg("type"), "Sets the type of this block argument.");
+ c.def(
+ "set_location",
+ [](PyBlockArgument &self, PyLocation loc) {
+ return mlirBlockArgumentSetLocation(self.get(), loc);
+ },
+ nb::arg("loc"), "Sets the location of this block argument.");
}
};
@@ -2462,9 +2376,12 @@ public:
operation(std::move(operation)), block(block) {}
static void bindDerived(ClassTy &c) {
- c.def_prop_ro("types", [](PyBlockArgumentList &self) {
- return getValueTypes(self, self.operation->getContext());
- });
+ c.def_prop_ro(
+ "types",
+ [](PyBlockArgumentList &self) {
+ return getValueTypes(self, self.operation->getContext());
+ },
+ "Returns a list of types for all arguments in this argument list.");
}
private:
@@ -2516,7 +2433,9 @@ public:
}
static void bindDerived(ClassTy &c) {
- c.def("__setitem__", &PyOpOperandList::dunderSetItem);
+ c.def("__setitem__", &PyOpOperandList::dunderSetItem, nb::arg("index"),
+ nb::arg("value"),
+ "Sets the operand at the specified index to a new value.");
}
private:
@@ -2571,7 +2490,8 @@ public:
}
static void bindDerived(ClassTy &c) {
- c.def("__setitem__", &PyOpSuccessors::dunderSetItem);
+ c.def("__setitem__", &PyOpSuccessors::dunderSetItem, nb::arg("index"),
+ nb::arg("block"), "Sets the successor block at the specified index.");
}
private:
@@ -2743,55 +2663,70 @@ public:
static void bind(nb::module_ &m) {
nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
- .def("__contains__", &PyOpAttributeMap::dunderContains)
- .def("__len__", &PyOpAttributeMap::dunderLen)
- .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
- .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
- .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
- .def("__delitem__", &PyOpAttributeMap::dunderDelItem)
- .def("__iter__",
- [](PyOpAttributeMap &self) {
- nb::list keys;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef name, MlirAttribute) {
- keys.append(nb::str(name.data, name.length));
- });
- return nb::iter(keys);
- })
- .def("keys",
- [](PyOpAttributeMap &self) {
- nb::list out;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef name, MlirAttribute) {
- out.append(nb::str(name.data, name.length));
- });
- return out;
- })
- .def("values",
- [](PyOpAttributeMap &self) {
- nb::list out;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef, MlirAttribute attr) {
- out.append(PyAttribute(self.operation->getContext(), attr)
- .maybeDownCast());
- });
- return out;
- })
- .def("items", [](PyOpAttributeMap &self) {
- nb::list out;
- PyOpAttributeMap::forEachAttr(
- self.operation->get(),
- [&](MlirStringRef name, MlirAttribute attr) {
- out.append(nb::make_tuple(
- nb::str(name.data, name.length),
- PyAttribute(self.operation->getContext(), attr)
- .maybeDownCast()));
- });
- return out;
- });
+ .def("__contains__", &PyOpAttributeMap::dunderContains, nb::arg("name"),
+ "Checks if an attribute with the given name exists in the map.")
+ .def("__len__", &PyOpAttributeMap::dunderLen,
+ "Returns the number of attributes in the map.")
+ .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed,
+ nb::arg("name"), "Gets an attribute by name.")
+ .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed,
+ nb::arg("index"), "Gets a named attribute by index.")
+ .def("__setitem__", &PyOpAttributeMap::dunderSetItem, nb::arg("name"),
+ nb::arg("attr"), "Sets an attribute with the given name.")
+ .def("__delitem__", &PyOpAttributeMap::dunderDelItem, nb::arg("name"),
+ "Deletes an attribute with the given name.")
+ .def(
+ "__iter__",
+ [](PyOpAttributeMap &self) {
+ nb::list keys;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef name, MlirAttribute) {
+ keys.append(nb::str(name.data, name.length));
+ });
+ return nb::iter(keys);
+ },
+ "Iterates over attribute names.")
+ .def(
+ "keys",
+ [](PyOpAttributeMap &self) {
+ nb::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef name, MlirAttribute) {
+ out.append(nb::str(name.data, name.length));
+ });
+ return out;
+ },
+ "Returns a list of attribute names.")
+ .def(
+ "values",
+ [](PyOpAttributeMap &self) {
+ nb::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef, MlirAttribute attr) {
+ out.append(PyAttribute(self.operation->getContext(), attr)
+ .maybeDownCast());
+ });
+ return out;
+ },
+ "Returns a list of attribute values.")
+ .def(
+ "items",
+ [](PyOpAttributeMap &self) {
+ nb::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef name, MlirAttribute attr) {
+ out.append(nb::make_tuple(
+ nb::str(name.data, name.length),
+ PyAttribute(self.operation->getContext(), attr)
+ .maybeDownCast()));
+ });
+ return out;
+ },
+ "Returns a list of `(name, attribute)` tuples.");
}
private:
@@ -2979,62 +2914,103 @@ void mlir::python::populateIRCore(nb::module_ &m) {
// Mapping of Diagnostics.
//----------------------------------------------------------------------------
nb::class_<PyDiagnostic>(m, "Diagnostic")
- .def_prop_ro("severity", &PyDiagnostic::getSeverity)
- .def_prop_ro("location", &PyDiagnostic::getLocation)
- .def_prop_ro("message", &PyDiagnostic::getMessage)
- .def_prop_ro("notes", &PyDiagnostic::getNotes)
- .def("__str__", [](PyDiagnostic &self) -> nb::str {
- if (!self.isValid())
- return nb::str("<Invalid Diagnostic>");
- return self.getMessage();
- });
+ .def_prop_ro("severity", &PyDiagnostic::getSeverity,
+ "Returns the severity of the diagnostic.")
+ .def_prop_ro("location", &PyDiagnostic::getLocation,
+ "Returns the location associated with the diagnostic.")
+ .def_prop_ro("message", &PyDiagnostic::getMessage,
+ "Returns the message text of the diagnostic.")
+ .def_prop_ro("notes", &PyDiagnostic::getNotes,
+ "Returns a tuple of attached note diagnostics.")
+ .def(
+ "__str__",
+ [](PyDiagnostic &self) -> nb::str {
+ if (!self.isValid())
+ return nb::str("<Invalid Diagnostic>");
+ return self.getMessage();
+ },
+ "Returns the diagnostic message as a string.");
nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
- .def("__init__",
- [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) {
- new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
- })
- .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity)
- .def_ro("location", &PyDiagnostic::DiagnosticInfo::location)
- .def_ro("message", &PyDiagnostic::DiagnosticInfo::message)
- .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes)
- .def("__str__",
- [](PyDiagnostic::DiagnosticInfo &self) { return self.message; });
+ .def(
+ "__init__",
+ [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) {
+ new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
+ },
+ "diag"_a, "Creates a DiagnosticInfo from a Diagnostic.")
+ .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity,
+ "The severity level of the diagnostic.")
+ .def_ro("location", &PyDiagnostic::DiagnosticInfo::location,
+ "The location associated with the diagnostic.")
+ .def_ro("message", &PyDiagnostic::DiagnosticInfo::message,
+ "The message text of the diagnostic.")
+ .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes,
+ "List of attached note diagnostics.")
+ .def(
+ "__str__",
+ [](PyDiagnostic::DiagnosticInfo &self) { return self.message; },
+ "Returns the diagnostic message as a string.");
nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
- .def("detach", &PyDiagnosticHandler::detach)
- .def_prop_ro("attached", &PyDiagnosticHandler::isAttached)
- .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError)
- .def("__enter__", &PyDiagnosticHandler::contextEnter)
+ .def("detach", &PyDiagnosticHandler::detach,
+ "Detaches the diagnostic handler from the context.")
+ .def_prop_ro("attached", &PyDiagnosticHandler::isAttached,
+ "Returns True if the handler is attached to a context.")
+ .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError,
+ "Returns True if an error was encountered during diagnostic "
+ "handling.")
+ .def("__enter__", &PyDiagnosticHandler::contextEnter,
+ "Enters the diagnostic handler as a context manager.")
.def("__exit__", &PyDiagnosticHandler::contextExit,
nb::arg("exc_type").none(), nb::arg("exc_value").none(),
- nb::arg("traceback").none());
+ nb::arg("traceback").none(),
+ "Exits the diagnostic handler context manager.");
// Expose DefaultThreadPool to python
nb::class_<PyThreadPool>(m, "ThreadPool")
- .def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); })
- .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency)
- .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr);
+ .def(
+ "__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); },
+ "Creates a new thread pool with default concurrency.")
+ .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency,
+ "Returns the maximum number of threads in the pool.")
+ .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr,
+ "Returns the raw pointer to the LLVM thread pool as a string.");
nb::class_<PyMlirContext>(m, "Context")
- .def("__init__",
- [](PyMlirContext &self) {
- MlirContext context = mlirContextCreateWithThreading(false);
- new (&self) PyMlirContext(context);
- })
- .def_static("_get_live_count", &PyMlirContext::getLiveCount)
- .def("_get_context_again",
- [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> {
- PyMlirContextRef ref = PyMlirContext::forContext(self.get());
- return ref.releaseObject();
- })
- .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
+ .def(
+ "__init__",
+ [](PyMlirContext &self) {
+ MlirContext context = mlirContextCreateWithThreading(false);
+ new (&self) PyMlirContext(context);
+ },
+ R"(
+ Creates a new MLIR context.
+
+ The context is the top-level container for all MLIR objects. It owns the storage
+ for types, attributes, locations, and other core IR objects. A context can be
+ configured to allow or disallow unregistered dialects and can have dialects
+ loaded on-demand.)")
+ .def_static("_get_live_count", &PyMlirContext::getLiveCount,
+ "Gets the number of live Context objects.")
+ .def(
+ "_get_context_again",
+ [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> {
+ PyMlirContextRef ref = PyMlirContext::forContext(self.get());
+ return ref.releaseObject();
+ },
+ "Gets another reference to the same context.")
+ .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount,
+ "Gets the number of live modules owned by this context.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule,
+ "Gets a capsule wrapping the MlirContext.")
.def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
- &PyMlirContext::createFromCapsule)
- .def("__enter__", &PyMlirContext::contextEnter)
+ &PyMlirContext::createFromCapsule,
+ "Creates a Context from a capsule wrapping MlirContext.")
+ .def("__enter__", &PyMlirContext::contextEnter,
+ "Enters the context as a context manager.")
.def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
- nb::arg("exc_value").none(), nb::arg("traceback").none())
+ nb::arg("exc_value").none(), nb::arg("traceback").none(),
+ "Exits the context manager.")
.def_prop_ro_static(
"current",
[](nb::object & /*class*/)
@@ -3045,14 +3021,15 @@ void mlir::python::populateIRCore(nb::module_ &m) {
return nb::cast(context);
},
nb::sig("def current(/) -> Context | None"),
- "Gets the Context bound to the current thread or raises ValueError")
+ "Gets the Context bound to the current thread or returns None if no "
+ "context is set.")
.def_prop_ro(
"dialects",
[](PyMlirContext &self) { return PyDialects(self.getRef()); },
- "Gets a container for accessing dialects by name")
+ "Gets a container for accessing dialects by name.")
.def_prop_ro(
"d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
- "Alias for 'dialect'")
+ "Alias for `dialects`.")
.def(
"get_dialect_descriptor",
[=](PyMlirContext &self, std::string &name) {
@@ -3065,7 +3042,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
return PyDialectDescriptor(self.getRef(), dialect);
},
nb::arg("dialect_name"),
- "Gets or loads a dialect by name, returning its descriptor object")
+ "Gets or loads a dialect by name, returning its descriptor object.")
.def_prop_rw(
"allow_unregistered_dialects",
[](PyMlirContext &self) -> bool {
@@ -3073,67 +3050,110 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
[](PyMlirContext &self, bool value) {
mlirContextSetAllowUnregisteredDialects(self.get(), value);
- })
+ },
+ "Controls whether unregistered dialects are allowed in this context.")
.def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
nb::arg("callback"),
- "Attaches a diagnostic handler that will receive callbacks")
+ "Attaches a diagnostic handler that will receive callbacks.")
.def(
"enable_multithreading",
[](PyMlirContext &self, bool enable) {
mlirContextEnableMultithreading(self.get(), enable);
},
- nb::arg("enable"))
- .def("set_thread_pool",
- [](PyMlirContext &self, PyThreadPool &pool) {
- // we should disable multi-threading first before setting
- // new thread pool otherwise the assert in
- // MLIRContext::setThreadPool will be raised.
- mlirContextEnableMultithreading(self.get(), false);
- mlirContextSetThreadPool(self.get(), pool.get());
- })
- .def("get_num_threads",
- [](PyMlirContext &self) {
- return mlirContextGetNumThreads(self.get());
- })
- .def("_mlir_thread_pool_ptr",
- [](PyMlirContext &self) {
- MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
- std::stringstream ss;
- ss << pool.ptr;
- return ss.str();
- })
+ nb::arg("enable"),
+ R"(
+ Enables or disables multi-threading support in the context.
+
+ Args:
+ enable: Whether to enable (True) or disable (False) multi-threading.
+ )")
+ .def(
+ "set_thread_pool",
+ [](PyMlirContext &self, PyThreadPool &pool) {
+ // we should disable multi-threading first before setting
+ // new thread pool otherwise the assert in
+ // MLIRContext::setThreadPool will be raised.
+ mlirContextEnableMultithreading(self.get(), false);
+ mlirContextSetThreadPool(self.get(), pool.get());
+ },
+ R"(
+ Sets a custom thread pool for the context to use.
+
+ Args:
+ pool: A ThreadPool object to use for parallel operations.
+
+ Note:
+ Multi-threading is automatically disabled before setting the thread pool.)")
+ .def(
+ "get_num_threads",
+ [](PyMlirContext &self) {
+ return mlirContextGetNumThreads(self.get());
+ },
+ "Gets the number of threads in the context's thread pool.")
+ .def(
+ "_mlir_thread_pool_ptr",
+ [](PyMlirContext &self) {
+ MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
+ std::stringstream ss;
+ ss << pool.ptr;
+ return ss.str();
+ },
+ "Gets the raw pointer to the LLVM thread pool as a string.")
.def(
"is_registered_operation",
[](PyMlirContext &self, std::string &name) {
return mlirContextIsRegisteredOperation(
self.get(), MlirStringRef{name.data(), name.size()});
},
- nb::arg("operation_name"))
+ nb::arg("operation_name"),
+ R"(
+ Checks whether an operation with the given name is registered.
+
+ Args:
+ operation_name: The fully qualified name of the operation (e.g., `arith.addf`).
+
+ Returns:
+ True if the operation is registered, False otherwise.)")
.def(
"append_dialect_registry",
[](PyMlirContext &self, PyDialectRegistry &registry) {
mlirContextAppendDialectRegistry(self.get(), registry);
},
- nb::arg("registry"))
+ nb::arg("registry"),
+ R"(
+ Appends the contents of a dialect registry to the context.
+
+ Args:
+ registry: A DialectRegistry containing dialects to append.)")
.def_prop_rw("emit_error_diagnostics",
&PyMlirContext::getEmitErrorDiagnostics,
&PyMlirContext::setEmitErrorDiagnostics,
- "Emit error diagnostics to diagnostic handlers. By default "
- "error diagnostics are captured and reported through "
- "MLIRError exceptions.")
- .def("load_all_available_dialects", [](PyMlirContext &self) {
- mlirContextLoadAllAvailableDialects(self.get());
- });
+ R"(
+ Controls whether error diagnostics are emitted to diagnostic handlers.
+
+ By default, error diagnostics are captured and reported through MLIRError exceptions.)")
+ .def(
+ "load_all_available_dialects",
+ [](PyMlirContext &self) {
+ mlirContextLoadAllAvailableDialects(self.get());
+ },
+ R"(
+ Loads all dialects available in the registry into the context.
+
+ This eagerly loads all dialects that have been registered, making them
+ immediately available for use.)");
//----------------------------------------------------------------------------
// Mapping of PyDialectDescriptor
//----------------------------------------------------------------------------
nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
- .def_prop_ro("namespace",
- [](PyDialectDescriptor &self) {
- MlirStringRef ns = mlirDialectGetNamespace(self.get());
- return nb::str(ns.data, ns.length);
- })
+ .def_prop_ro(
+ "namespace",
+ [](PyDialectDescriptor &self) {
+ MlirStringRef ns = mlirDialectGetNamespace(self.get());
+ return nb::str(ns.data, ns.length);
+ },
+ "Returns the namespace of the dialect.")
.def(
"__repr__",
[](PyDialectDescriptor &self) {
@@ -3143,35 +3163,43 @@ void mlir::python::populateIRCore(nb::module_ &m) {
repr.append(">");
return repr;
},
- nb::sig("def __repr__(self) -> str"));
+ nb::sig("def __repr__(self) -> str"),
+ "Returns a string representation of the dialect descriptor.");
//----------------------------------------------------------------------------
// Mapping of PyDialects
//----------------------------------------------------------------------------
nb::class_<PyDialects>(m, "Dialects")
- .def("__getitem__",
- [=](PyDialects &self, std::string keyName) {
- MlirDialect dialect =
- self.getDialectForKey(keyName, /*attrError=*/false);
- nb::object descriptor =
- nb::cast(PyDialectDescriptor{self.getContext(), dialect});
- return createCustomDialectWrapper(keyName, std::move(descriptor));
- })
- .def("__getattr__", [=](PyDialects &self, std::string attrName) {
- MlirDialect dialect =
- self.getDialectForKey(attrName, /*attrError=*/true);
- nb::object descriptor =
- nb::cast(PyDialectDescriptor{self.getContext(), dialect});
- return createCustomDialectWrapper(attrName, std::move(descriptor));
- });
+ .def(
+ "__getitem__",
+ [=](PyDialects &self, std::string keyName) {
+ MlirDialect dialect =
+ self.getDialectForKey(keyName, /*attrError=*/false);
+ nb::object descriptor =
+ nb::cast(PyDialectDescriptor{self.getContext(), dialect});
+ return createCustomDialectWrapper(keyName, std::move(descriptor));
+ },
+ "Gets a dialect by name using subscript notation.")
+ .def(
+ "__getattr__",
+ [=](PyDialects &self, std::string attrName) {
+ MlirDialect dialect =
+ self.getDialectForKey(attrName, /*attrError=*/true);
+ nb::object descriptor =
+ nb::cast(PyDialectDescriptor{self.getContext(), dialect});
+ return createCustomDialectWrapper(attrName, std::move(descriptor));
+ },
+ "Gets a dialect by name using attribute notation.");
//----------------------------------------------------------------------------
// Mapping of PyDialect
//----------------------------------------------------------------------------
nb::class_<PyDialect>(m, "Dialect")
- .def(nb::init<nb::object>(), nb::arg("descriptor"))
- .def_prop_ro("descriptor",
- [](PyDialect &self) { return self.getDescriptor(); })
+ .def(nb::init<nb::object>(), nb::arg("descriptor"),
+ "Creates a Dialect from a DialectDescriptor.")
+ .def_prop_ro(
+ "descriptor", [](PyDialect &self) { return self.getDescriptor(); },
+ "Returns the DialectDescriptor for this dialect.")
.def(
"__repr__",
[](const nb::object &self) {
@@ -3181,31 +3209,43 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::str(" (class ") + clazz.attr("__module__") +
nb::str(".") + clazz.attr("__name__") + nb::str(")>");
},
- nb::sig("def __repr__(self) -> str"));
+ nb::sig("def __repr__(self) -> str"),
+ "Returns a string representation of the dialect.");
//----------------------------------------------------------------------------
// Mapping of PyDialectRegistry
//----------------------------------------------------------------------------
nb::class_<PyDialectRegistry>(m, "DialectRegistry")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule)
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule,
+ "Gets a capsule wrapping the MlirDialectRegistry.")
.def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
- &PyDialectRegistry::createFromCapsule)
- .def(nb::init<>());
+ &PyDialectRegistry::createFromCapsule,
+ "Creates a DialectRegistry from a capsule wrapping "
+ "`MlirDialectRegistry`.")
+ .def(nb::init<>(), "Creates a new empty dialect registry.");
//----------------------------------------------------------------------------
// Mapping of Location
//----------------------------------------------------------------------------
nb::class_<PyLocation>(m, "Location")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
- .def("__enter__", &PyLocation::contextEnter)
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule,
+ "Gets a capsule wrapping the MlirLocation.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule,
+ "Creates a Location from a capsule wrapping MlirLocation.")
+ .def("__enter__", &PyLocation::contextEnter,
+ "Enters the location as a context manager.")
.def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(),
- nb::arg("exc_value").none(), nb::arg("traceback").none())
- .def("__eq__",
- [](PyLocation &self, PyLocation &other) -> bool {
- return mlirLocationEqual(self, other);
- })
- .def("__eq__", [](PyLocation &self, nb::object other) { return false; })
+ nb::arg("exc_value").none(), nb::arg("traceback").none(),
+ "Exits the location context manager.")
+ .def(
+ "__eq__",
+ [](PyLocation &self, PyLocation &other) -> bool {
+ return mlirLocationEqual(self, other);
+ },
+ "Compares two locations for equality.")
+ .def(
+ "__eq__", [](PyLocation &self, nb::object other) { return false; },
+ "Compares location with non-location object (always returns False).")
.def_prop_ro_static(
"current",
[](nb::object & /*class*/) -> std::optional<PyLocation *> {
@@ -3217,7 +3257,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
// clang-format off
nb::sig("def current(/) -> Location | None"),
// clang-format on
- "Gets the Location bound to the current thread or raises ValueError")
+ "Gets the Location bound to the current thread or raises ValueError.")
.def_static(
"unknown",
[](DefaultingPyMlirContext context) {
@@ -3225,13 +3265,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
mlirLocationUnknownGet(context->get()));
},
nb::arg("context") = nb::none(),
- "Gets a Location representing an unknown location")
+ "Gets a Location representing an unknown location.")
.def_static(
"callsite",
[](PyLocation callee, const std::vector<PyLocation> &frames,
DefaultingPyMlirContext context) {
if (frames.empty())
- throw nb::value_error("No caller frames provided");
+ throw nb::value_error("No caller frames provided.");
MlirLocation caller = frames.back().get();
for (const PyLocation &frame :
llvm::reverse(llvm::ArrayRef(frames).drop_back()))
@@ -3240,18 +3280,23 @@ void mlir::python::populateIRCore(nb::module_ &m) {
mlirLocationCallSiteGet(callee.get(), caller));
},
nb::arg("callee"), nb::arg("frames"), nb::arg("context") = nb::none(),
- kContextGetCallSiteLocationDocstring)
- .def("is_a_callsite", mlirLocationIsACallSite)
- .def_prop_ro("callee",
- [](PyLocation &self) {
- return PyLocation(self.getContext(),
- mlirLocationCallSiteGetCallee(self));
- })
- .def_prop_ro("caller",
- [](PyLocation &self) {
- return PyLocation(self.getContext(),
- mlirLocationCallSiteGetCaller(self));
- })
+ "Gets a Location representing a caller and callsite.")
+ .def("is_a_callsite", mlirLocationIsACallSite,
+ "Returns True if this location is a CallSiteLoc.")
+ .def_prop_ro(
+ "callee",
+ [](PyLocation &self) {
+ return PyLocation(self.getContext(),
+ mlirLocationCallSiteGetCallee(self));
+ },
+ "Gets the callee location from a CallSiteLoc.")
+ .def_prop_ro(
+ "caller",
+ [](PyLocation &self) {
+ return PyLocation(self.getContext(),
+ mlirLocationCallSiteGetCaller(self));
+ },
+ "Gets the caller location from a CallSiteLoc.")
.def_static(
"file",
[](std::string filename, int line, int col,
@@ -3262,7 +3307,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
context->get(), toMlirStringRef(filename), line, col));
},
nb::arg("filename"), nb::arg("line"), nb::arg("col"),
- nb::arg("context") = nb::none(), kContextGetFileLocationDocstring)
+ nb::arg("context") = nb::none(),
+ "Gets a Location representing a file, line and column.")
.def_static(
"file",
[](std::string filename, int startLine, int startCol, int endLine,
@@ -3274,17 +3320,25 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"),
nb::arg("end_line"), nb::arg("end_col"),
- nb::arg("context") = nb::none(), kContextGetFileRangeDocstring)
- .def("is_a_file", mlirLocationIsAFileLineColRange)
- .def_prop_ro("filename",
- [](MlirLocation loc) {
- return mlirIdentifierStr(
- mlirLocationFileLineColRangeGetFilename(loc));
- })
- .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine)
- .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn)
- .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine)
- .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn)
+ nb::arg("context") = nb::none(),
+ "Gets a Location representing a file, line and column range.")
+ .def("is_a_file", mlirLocationIsAFileLineColRange,
+ "Returns True if this location is a FileLineColLoc.")
+ .def_prop_ro(
+ "filename",
+ [](MlirLocation loc) {
+ return mlirIdentifierStr(
+ mlirLocationFileLineColRangeGetFilename(loc));
+ },
+ "Gets the filename from a FileLineColLoc.")
+ .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine,
+ "Gets the start line number from a `FileLineColLoc`.")
+ .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn,
+ "Gets the start column number from a `FileLineColLoc`.")
+ .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine,
+ "Gets the end line number from a `FileLineColLoc`.")
+ .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn,
+ "Gets the end column number from a `FileLineColLoc`.")
.def_static(
"fused",
[](const std::vector<PyLocation> &pyLocations,
@@ -3300,8 +3354,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
return PyLocation(context->getRef(), location);
},
nb::arg("locations"), nb::arg("metadata") = nb::none(),
- nb::arg("context") = nb::none(), kContextGetFusedLocationDocstring)
- .def("is_a_fused", mlirLocationIsAFused)
+ nb::arg("context") = nb::none(),
+ "Gets a Location representing a fused location with optional "
+ "metadata.")
+ .def("is_a_fused", mlirLocationIsAFused,
+ "Returns True if this location is a `FusedLoc`.")
.def_prop_ro(
"locations",
[](PyLocation &self) {
@@ -3314,7 +3371,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
for (unsigned i = 0; i < numLocations; ++i)
pyLocations.emplace_back(self.getContext(), locations[i]);
return pyLocations;
- })
+ },
+ "Gets the list of locations from a `FusedLoc`.")
.def_static(
"name",
[](std::string name, std::optional<PyLocation> childLoc,
@@ -3327,17 +3385,24 @@ void mlir::python::populateIRCore(nb::module_ &m) {
: mlirLocationUnknownGet(context->get())));
},
nb::arg("name"), nb::arg("childLoc") = nb::none(),
- nb::arg("context") = nb::none(), kContextGetNameLocationDocString)
- .def("is_a_name", mlirLocationIsAName)
- .def_prop_ro("name_str",
- [](MlirLocation loc) {
- return mlirIdentifierStr(mlirLocationNameGetName(loc));
- })
- .def_prop_ro("child_loc",
- [](PyLocation &self) {
- return PyLocation(self.getContext(),
- mlirLocationNameGetChildLoc(self));
- })
+ nb::arg("context") = nb::none(),
+ "Gets a Location representing a named location with optional child "
+ "location.")
+ .def("is_a_name", mlirLocationIsAName,
+ "Returns True if this location is a `NameLoc`.")
+ .def_prop_ro(
+ "name_str",
+ [](MlirLocation loc) {
+ return mlirIdentifierStr(mlirLocationNameGetName(loc));
+ },
+ "Gets the name string from a `NameLoc`.")
+ .def_prop_ro(
+ "child_loc",
+ [](PyLocation &self) {
+ return PyLocation(self.getContext(),
+ mlirLocationNameGetChildLoc(self));
+ },
+ "Gets the child location from a `NameLoc`.")
.def_static(
"from_attr",
[](PyAttribute &attribute, DefaultingPyMlirContext context) {
@@ -3345,41 +3410,59 @@ void mlir::python::populateIRCore(nb::module_ &m) {
mlirLocationFromAttribute(attribute));
},
nb::arg("attribute"), nb::arg("context") = nb::none(),
- "Gets a Location from a LocationAttr")
+ "Gets a Location from a `LocationAttr`.")
.def_prop_ro(
"context",
[](PyLocation &self) -> nb::typed<nb::object, PyMlirContext> {
return self.getContext().getObject();
},
- "Context that owns the Location")
+ "Context that owns the `Location`.")
.def_prop_ro(
"attr",
[](PyLocation &self) {
return PyAttribute(self.getContext(),
mlirLocationGetAttribute(self));
},
- "Get the underlying LocationAttr")
+ "Get the underlying `LocationAttr`.")
.def(
"emit_error",
[](PyLocation &self, std::string message) {
mlirEmitError(self, message.c_str());
},
- nb::arg("message"), "Emits an error at this location")
- .def("__repr__", [](PyLocation &self) {
- PyPrintAccumulator printAccum;
- mlirLocationPrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- });
+ nb::arg("message"),
+ R"(
+ Emits an error diagnostic at this location.
+
+ Args:
+ message: The error message to emit.)")
+ .def(
+ "__repr__",
+ [](PyLocation &self) {
+ PyPrintAccumulator printAccum;
+ mlirLocationPrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ "Returns the assembly representation of the location.");
//----------------------------------------------------------------------------
// Mapping of Module
//----------------------------------------------------------------------------
nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule,
+ "Gets a capsule wrapping the MlirModule.")
.def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
- kModuleCAPICreate)
- .def("_clear_mlir_module", &PyModule::clearMlirModule)
+ R"(
+ Creates a Module from a `MlirModule` wrapped by a capsule (i.e. `module._CAPIPtr`).
+
+ This returns a new object **BUT** `_clear_mlir_module(module)` must be called to
+ prevent double-frees (of the underlying `mlir::Module`).)")
+ .def("_clear_mlir_module", &PyModule::clearMlirModule,
+ R"(
+ Clears the internal MLIR module reference.
+
+ This is used internally to prevent double-free when ownership is transferred
+ via the C API capsule mechanism. Not intended for normal use.)")
.def_static(
"parse",
[](const std::string &moduleAsm, DefaultingPyMlirContext context)
@@ -3427,13 +3510,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
return PyModule::forModule(module).releaseObject();
},
- nb::arg("loc") = nb::none(), "Creates an empty module")
+ nb::arg("loc") = nb::none(), "Creates an empty module.")
.def_prop_ro(
"context",
[](PyModule &self) -> nb::typed<nb::object, PyMlirContext> {
return self.getContext().getObject();
},
- "Context that created the Module")
+ "Context that created the `Module`.")
.def_prop_ro(
"operation",
[](PyModule &self) -> nb::typed<nb::object, PyOperation> {
@@ -3442,7 +3525,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
self.getRef().releaseObject())
.releaseObject();
},
- "Accesses the module as an operation")
+ "Accesses the module as an operation.")
.def_prop_ro(
"body",
[](PyModule &self) {
@@ -3452,7 +3535,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
return returnBlock;
},
- "Return the block for this module")
+ "Return the block for this module.")
.def(
"dump",
[](PyModule &self) {
@@ -3465,39 +3548,59 @@ void mlir::python::populateIRCore(nb::module_ &m) {
// Defer to the operation's __str__.
return self.attr("operation").attr("__str__")();
},
- nb::sig("def __str__(self) -> str"), kOperationStrDunderDocstring)
+ nb::sig("def __str__(self) -> str"),
+ R"(
+ Gets the assembly form of the operation with default options.
+
+ If more advanced control over the assembly formatting or I/O options is needed,
+ use the dedicated print or get_asm method, which supports keyword arguments to
+ customize behavior.
+ )")
.def(
"__eq__",
[](PyModule &self, PyModule &other) {
return mlirModuleEqual(self.get(), other.get());
},
- "other"_a)
- .def("__hash__",
- [](PyModule &self) { return mlirModuleHashValue(self.get()); });
+ "other"_a, "Compares two modules for equality.")
+ .def(
+ "__hash__",
+ [](PyModule &self) { return mlirModuleHashValue(self.get()); },
+ "Returns the hash value of the module.");
//----------------------------------------------------------------------------
// Mapping of Operation.
//----------------------------------------------------------------------------
nb::class_<PyOperationBase>(m, "_OperationBase")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
- [](PyOperationBase &self) {
- return self.getOperation().getCapsule();
- })
- .def("__eq__",
- [](PyOperationBase &self, PyOperationBase &other) {
- return mlirOperationEqual(self.getOperation().get(),
- other.getOperation().get());
- })
- .def("__eq__",
- [](PyOperationBase &self, nb::object other) { return false; })
- .def("__hash__",
- [](PyOperationBase &self) {
- return mlirOperationHashValue(self.getOperation().get());
- })
- .def_prop_ro("attributes",
- [](PyOperationBase &self) {
- return PyOpAttributeMap(self.getOperation().getRef());
- })
+ .def_prop_ro(
+ MLIR_PYTHON_CAPI_PTR_ATTR,
+ [](PyOperationBase &self) {
+ return self.getOperation().getCapsule();
+ },
+ "Gets a capsule wrapping the `MlirOperation`.")
+ .def(
+ "__eq__",
+ [](PyOperationBase &self, PyOperationBase &other) {
+ return mlirOperationEqual(self.getOperation().get(),
+ other.getOperation().get());
+ },
+ "Compares two operations for equality.")
+ .def(
+ "__eq__",
+ [](PyOperationBase &self, nb::object other) { return false; },
+ "Compares operation with non-operation object (always returns "
+ "False).")
+ .def(
+ "__hash__",
+ [](PyOperationBase &self) {
+ return mlirOperationHashValue(self.getOperation().get());
+ },
+ "Returns the hash value of the operation.")
+ .def_prop_ro(
+ "attributes",
+ [](PyOperationBase &self) {
+ return PyOpAttributeMap(self.getOperation().getRef());
+ },
+ "Returns a dictionary-like map of operation attributes.")
.def_prop_ro(
"context",
[](PyOperationBase &self) -> nb::typed<nb::object, PyMlirContext> {
@@ -3505,22 +3608,28 @@ void mlir::python::populateIRCore(nb::module_ &m) {
concreteOperation.checkValid();
return concreteOperation.getContext().getObject();
},
- "Context that owns the Operation")
- .def_prop_ro("name",
- [](PyOperationBase &self) {
- auto &concreteOperation = self.getOperation();
- concreteOperation.checkValid();
- MlirOperation operation = concreteOperation.get();
- return mlirIdentifierStr(mlirOperationGetName(operation));
- })
- .def_prop_ro("operands",
- [](PyOperationBase &self) {
- return PyOpOperandList(self.getOperation().getRef());
- })
- .def_prop_ro("regions",
- [](PyOperationBase &self) {
- return PyRegionList(self.getOperation().getRef());
- })
+ "Context that owns the operation.")
+ .def_prop_ro(
+ "name",
+ [](PyOperationBase &self) {
+ auto &concreteOperation = self.getOperation();
+ concreteOperation.checkValid();
+ MlirOperation operation = concreteOperation.get();
+ return mlirIdentifierStr(mlirOperationGetName(operation));
+ },
+ "Returns the fully qualified name of the operation.")
+ .def_prop_ro(
+ "operands",
+ [](PyOperationBase &self) {
+ return PyOpOperandList(self.getOperation().getRef());
+ },
+ "Returns the list of operation operands.")
+ .def_prop_ro(
+ "regions",
+ [](PyOperationBase &self) {
+ return PyRegionList(self.getOperation().getRef());
+ },
+ "Returns the list of operation regions.")
.def_prop_ro(
"results",
[](PyOperationBase &self) {
@@ -3551,14 +3660,16 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"defined or derived from."),
nb::for_setter("Sets the source location the operation was defined "
"or derived from."))
- .def_prop_ro("parent",
- [](PyOperationBase &self)
- -> std::optional<nb::typed<nb::object, PyOperation>> {
- auto parent = self.getOperation().getParentOperation();
- if (parent)
- return parent->getObject();
- return {};
- })
+ .def_prop_ro(
+ "parent",
+ [](PyOperationBase &self)
+ -> std::optional<nb::typed<nb::object, PyOperation>> {
+ auto parent = self.getOperation().getParentOperation();
+ if (parent)
+ return parent->getObject();
+ return {};
+ },
+ "Returns the parent operation, or `None` if at top level.")
.def(
"__str__",
[](PyOperationBase &self) {
@@ -3579,7 +3690,14 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::overload_cast<PyAsmState &, nb::object, bool>(
&PyOperationBase::print),
nb::arg("state"), nb::arg("file") = nb::none(),
- nb::arg("binary") = false, kOperationPrintStateDocstring)
+ nb::arg("binary") = false,
+ R"(
+ Prints the assembly form of the operation to a file like object.
+
+ Args:
+ state: `AsmState` capturing the operation numbering and flags.
+ file: Optional file like object to write to. Defaults to sys.stdout.
+ binary: Whether to write `bytes` (True) or `str` (False). Defaults to False.)")
.def("print",
nb::overload_cast<std::optional<int64_t>, std::optional<int64_t>,
bool, bool, bool, bool, bool, bool, nb::object,
@@ -3594,10 +3712,47 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("use_name_loc_as_prefix") = false,
nb::arg("assume_verified") = false, nb::arg("file") = nb::none(),
nb::arg("binary") = false, nb::arg("skip_regions") = false,
- kOperationPrintDocstring)
+ R"(
+ Prints the assembly form of the operation to a file like object.
+
+ Args:
+ large_elements_limit: Whether to elide elements attributes above this
+ number of elements. Defaults to None (no limit).
+ large_resource_limit: Whether to elide resource attributes above this
+ number of characters. Defaults to None (no limit). If large_elements_limit
+ is set and this is None, the behavior will be to use large_elements_limit
+ as large_resource_limit.
+ enable_debug_info: Whether to print debug/location information. Defaults
+ to False.
+ pretty_debug_info: Whether to format debug information for easier reading
+ by a human (warning: the result is unparseable). Defaults to False.
+ print_generic_op_form: Whether to print the generic assembly forms of all
+ ops. Defaults to False.
+ use_local_scope: Whether to print in a way that is more optimized for
+ multi-threaded access but may not be consistent with how the overall
+ module prints.
+ use_name_loc_as_prefix: Whether to use location attributes (NameLoc) as
+ prefixes for the SSA identifiers. Defaults to False.
+ assume_verified: By default, if not printing generic form, the verifier
+ will be run and if it fails, generic form will be printed with a comment
+ about failed verification. While a reasonable default for interactive use,
+ for systematic use, it is often better for the caller to verify explicitly
+ and report failures in a more robust fashion. Set this to True if doing this
+ in order to avoid running a redundant verification. If the IR is actually
+ invalid, behavior is undefined.
+ file: The file like object to write to. Defaults to sys.stdout.
+ binary: Whether to write bytes (True) or str (False). Defaults to False.
+ skip_regions: Whether to skip printing regions. Defaults to False.)")
.def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"),
nb::arg("desired_version") = nb::none(),
- kOperationPrintBytecodeDocstring)
+ R"(
+ Write the bytecode form of the operation to a file like object.
+
+ Args:
+ file: The file like object to write to.
+ desired_version: Optional version of bytecode to emit.
+ Returns:
+ The bytecode writer status.)")
.def("get_asm", &PyOperationBase::getAsm,
// Careful: Lots of arguments must match up with get_asm method.
nb::arg("binary") = false,
@@ -3609,7 +3764,17 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("use_local_scope") = false,
nb::arg("use_name_loc_as_prefix") = false,
nb::arg("assume_verified") = false, nb::arg("skip_regions") = false,
- kOperationGetAsmDocstring)
+ R"(
+ Gets the assembly form of the operation with all options available.
+
+ Args:
+ binary: Whether to return a bytes (True) or str (False) object. Defaults to
+ False.
+ ... others ...: See the print() method for common keyword arguments for
+ configuring the printout.
+ Returns:
+ Either a bytes or str object, depending on the setting of the `binary`
+ argument.)")
.def("verify", &PyOperationBase::verify,
"Verify the operation. Raises MLIRError if verification fails, and "
"returns true otherwise.")
@@ -3621,18 +3786,31 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"block.")
.def("is_before_in_block", &PyOperationBase::isBeforeInBlock,
nb::arg("other"),
- "Given an operation 'other' that is within the same parent block, "
- "return"
- "whether the current operation is before 'other' in the operation "
- "list"
- "of the parent block.")
+ R"(
+ Checks if this operation is before another in the same block.
+
+ Args:
+ other: Another operation in the same parent block.
+
+ Returns:
+ True if this operation is before `other` in the operation list of the parent block.)")
.def(
"clone",
[](PyOperationBase &self,
const nb::object &ip) -> nb::typed<nb::object, PyOperation> {
return self.getOperation().clone(ip);
},
- nb::arg("ip") = nb::none())
+ nb::arg("ip") = nb::none(),
+ R"(
+ Creates a deep copy of the operation.
+
+ Args:
+ ip: Optional insertion point where the cloned operation should be inserted.
+ If None, the current insertion point is used. If False, the operation
+ remains detached.
+
+ Returns:
+ A new Operation that is a clone of this operation.)")
.def(
"detach_from_parent",
[](PyOperationBase &self) -> nb::typed<nb::object, PyOpView> {
@@ -3653,13 +3831,24 @@ void mlir::python::populateIRCore(nb::module_ &m) {
return operation.isAttached();
},
"Reports if the operation is attached to its parent block.")
- .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
+ .def(
+ "erase", [](PyOperationBase &self) { self.getOperation().erase(); },
+ R"(
+ Erases the operation and frees its memory.
+
+ Note:
+ After erasing, any Python references to the operation become invalid.)")
.def("walk", &PyOperationBase::walk, nb::arg("callback"),
nb::arg("walk_order") = MlirWalkPostOrder,
// clang-format off
- nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None")
+ nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
// clang-format on
- );
+ R"(
+ Walks the operation tree with a callback function.
+
+ Args:
+ callback: A callable that takes an Operation and returns a WalkResult.
+ walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)");
nb::class_<PyOperation, PyOperationBase>(m, "Operation")
.def_static(
@@ -3692,7 +3881,22 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(),
nb::arg("successors") = nb::none(), nb::arg("regions") = 0,
nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(),
- nb::arg("infer_type") = false, kOperationCreateDocstring)
+ nb::arg("infer_type") = false,
+ R"(
+ Creates a new operation.
+
+ Args:
+ name: Operation name (e.g. `dialect.operation`).
+ results: Optional sequence of Type representing op result types.
+ operands: Optional operands of the operation.
+ attributes: Optional Dict of {str: Attribute}.
+ successors: Optional List of Block for the operation's successors.
+ regions: Number of regions to create (default = 0).
+ location: Optional Location object (defaults to resolve from context manager).
+ ip: Optional InsertionPoint (defaults to resolve from context manager or set to False to disable insertion, even with an insertion point set in the context manager).
+ infer_type: Whether to infer result types (default = False).
+ Returns:
+ A new detached Operation object. Detached operations can be added to blocks, which causes them to become attached.)")
.def_static(
"parse",
[](const std::string &sourceStr, const std::string &sourceName,
@@ -3705,18 +3909,30 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("context") = nb::none(),
"Parses an operation. Supports both text assembly format and binary "
"bytecode format.")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule)
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule,
+ "Gets a capsule wrapping the MlirOperation.")
.def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
- &PyOperation::createFromCapsule)
- .def_prop_ro("operation",
- [](nb::object self) -> nb::typed<nb::object, PyOperation> {
- return self;
- })
- .def_prop_ro("opview",
- [](PyOperation &self) -> nb::typed<nb::object, PyOpView> {
- return self.createOpView();
- })
- .def_prop_ro("block", &PyOperation::getBlock)
+ &PyOperation::createFromCapsule,
+ "Creates an Operation from a capsule wrapping MlirOperation.")
+ .def_prop_ro(
+ "operation",
+ [](nb::object self) -> nb::typed<nb::object, PyOperation> {
+ return self;
+ },
+ "Returns self (the operation).")
+ .def_prop_ro(
+ "opview",
+ [](PyOperation &self) -> nb::typed<nb::object, PyOpView> {
+ return self.createOpView();
+ },
+ R"(
+ Returns an OpView of this operation.
+
+ Note:
+ If the operation has a registered and loaded dialect then this OpView will
+ be concrete wrapper class.)")
+ .def_prop_ro("block", &PyOperation::getBlock,
+ "Returns the block containing this operation.")
.def_prop_ro(
"successors",
[](PyOperationBase &self) {
@@ -3830,7 +4046,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
nb::arg("cls"), nb::arg("source"), nb::kw_only(),
nb::arg("source_name") = "", nb::arg("context") = nb::none(),
- "Parses a specific, generated OpView based on class level attributes");
+ "Parses a specific, generated OpView based on class level attributes.");
//----------------------------------------------------------------------------
// Mapping of PyRegion.
@@ -3856,17 +4072,22 @@ void mlir::python::populateIRCore(nb::module_ &m) {
return PyBlockIterator(self.getParentOperation(), firstBlock);
},
"Iterates over blocks in the region.")
- .def("__eq__",
- [](PyRegion &self, PyRegion &other) {
- return self.get().ptr == other.get().ptr;
- })
- .def("__eq__", [](PyRegion &self, nb::object &other) { return false; });
+ .def(
+ "__eq__",
+ [](PyRegion &self, PyRegion &other) {
+ return self.get().ptr == other.get().ptr;
+ },
+ "Compares two regions for pointer equality.")
+ .def(
+ "__eq__", [](PyRegion &self, nb::object &other) { return false; },
+ "Compares region with non-region object (always returns False).");
//----------------------------------------------------------------------------
// Mapping of PyBlock.
//----------------------------------------------------------------------------
nb::class_<PyBlock>(m, "Block")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule)
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule,
+ "Gets a capsule wrapping the MlirBlock.")
.def_prop_ro(
"owner",
[](PyBlock &self) -> nb::typed<nb::object, PyOpView> {
@@ -3893,14 +4114,26 @@ void mlir::python::populateIRCore(nb::module_ &m) {
mlirBlockAddArgument(self.get(), type, loc));
},
"type"_a, "loc"_a,
- "Append an argument of the specified type to the block and returns "
- "the newly added argument.")
+ R"(
+ Appends an argument of the specified type to the block.
+
+ Args:
+ type: The type of the argument to add.
+ loc: The source location for the argument.
+
+ Returns:
+ The newly added block argument.)")
.def(
"erase_argument",
[](PyBlock &self, unsigned index) {
return mlirBlockEraseArgument(self.get(), index);
},
- "Erase the argument at 'index' and remove it from the argument list.")
+ nb::arg("index"),
+ R"(
+ Erases the argument at the specified index.
+
+ Args:
+ index: The index of the argument to erase.)")
.def_prop_ro(
"operations",
[](PyBlock &self) {
@@ -3928,7 +4161,14 @@ void mlir::python::populateIRCore(nb::module_ &m) {
mlirBlockDetach(b);
mlirRegionAppendOwnedBlock(region.get(), b);
},
- "Append this block to a region, transferring ownership if necessary")
+ nb::arg("region"),
+ R"(
+ Appends this block to a region.
+
+ Transfers ownership if the block is currently owned by another region.
+
+ Args:
+ region: The region to append the block to.)")
.def(
"create_before",
[](PyBlock &self, const nb::args &pyArgTypes,
@@ -3969,15 +4209,21 @@ void mlir::python::populateIRCore(nb::module_ &m) {
firstOperation);
},
"Iterates over operations in the block.")
- .def("__eq__",
- [](PyBlock &self, PyBlock &other) {
- return self.get().ptr == other.get().ptr;
- })
- .def("__eq__", [](PyBlock &self, nb::object &other) { return false; })
- .def("__hash__",
- [](PyBlock &self) {
- return static_cast<size_t>(llvm::hash_value(self.get().ptr));
- })
+ .def(
+ "__eq__",
+ [](PyBlock &self, PyBlock &other) {
+ return self.get().ptr == other.get().ptr;
+ },
+ "Compares two blocks for pointer equality.")
+ .def(
+ "__eq__", [](PyBlock &self, nb::object &other) { return false; },
+ "Compares block with non-block object (always returns False).")
+ .def(
+ "__hash__",
+ [](PyBlock &self) {
+ return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+ },
+ "Returns the hash value of the block.")
.def(
"__str__",
[](PyBlock &self) {
@@ -4000,8 +4246,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
self.getParentOperation().getObject());
},
nb::arg("operation"),
- "Appends an operation to this block. If the operation is currently "
- "in another block, it will be moved.")
+ R"(
+ Appends an operation to this block.
+
+ If the operation is currently in another block, it will be moved.
+
+ Args:
+ operation: The operation to append to the block.)")
.def_prop_ro(
"successors",
[](PyBlock &self) {
@@ -4022,10 +4273,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::class_<PyInsertionPoint>(m, "InsertionPoint")
.def(nb::init<PyBlock &>(), nb::arg("block"),
"Inserts after the last operation but still inside the block.")
- .def("__enter__", &PyInsertionPoint::contextEnter)
+ .def("__enter__", &PyInsertionPoint::contextEnter,
+ "Enters the insertion point as a context manager.")
.def("__exit__", &PyInsertionPoint::contextExit,
nb::arg("exc_type").none(), nb::arg("exc_value").none(),
- nb::arg("traceback").none())
+ nb::arg("traceback").none(),
+ "Exits the insertion point context manager.")
.def_prop_ro_static(
"current",
[](nb::object & /*class*/) {
@@ -4036,20 +4289,50 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
nb::sig("def current(/) -> InsertionPoint"),
"Gets the InsertionPoint bound to the current thread or raises "
- "ValueError if none has been set")
+ "ValueError if none has been set.")
.def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"),
"Inserts before a referenced operation.")
.def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
- nb::arg("block"), "Inserts at the beginning of the block.")
+ nb::arg("block"),
+ R"(
+ Creates an insertion point at the beginning of a block.
+
+ Args:
+ block: The block at whose beginning operations should be inserted.
+
+ Returns:
+ An InsertionPoint at the block's beginning.)")
.def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
- nb::arg("block"), "Inserts before the block terminator.")
+ nb::arg("block"),
+ R"(
+ Creates an insertion point before a block's terminator.
+
+ Args:
+ block: The block whose terminator to insert before.
+
+ Returns:
+ An InsertionPoint before the terminator.
+
+ Raises:
+ ValueError: If the block has no terminator.)")
.def_static("after", &PyInsertionPoint::after, nb::arg("operation"),
- "Inserts after the operation.")
+ R"(
+ Creates an insertion point immediately after an operation.
+
+ Args:
+ operation: The operation after which to insert.
+
+ Returns:
+ An InsertionPoint after the operation.)")
.def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
- "Inserts an operation.")
+ R"(
+ Inserts an operation at this insertion point.
+
+ Args:
+ operation: The operation to insert.)")
.def_prop_ro(
"block", [](PyInsertionPoint &self) { return self.getBlock(); },
- "Returns the block that this InsertionPoint points to.")
+ "Returns the block that this `InsertionPoint` points to.")
.def_prop_ro(
"ref_operation",
[](PyInsertionPoint &self)
@@ -4061,7 +4344,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
"The reference operation before which new operations are "
"inserted, or None if the insertion point is at the end of "
- "the block");
+ "the block.");
//----------------------------------------------------------------------------
// Mapping of PyAttribute.
@@ -4070,10 +4353,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
// Delegate to the PyAttribute copy constructor, which will also lifetime
// extend the backing context which owns the MlirAttribute.
.def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"),
- "Casts the passed attribute to the generic Attribute")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule)
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
- &PyAttribute::createFromCapsule)
+ "Casts the passed attribute to the generic `Attribute`.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule,
+ "Gets a capsule wrapping the MlirAttribute.")
+ .def_static(
+ MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule,
+ "Creates an Attribute from a capsule wrapping `MlirAttribute`.")
.def_static(
"parse",
[](const std::string &attrSpec, DefaultingPyMlirContext context)
@@ -4086,33 +4371,49 @@ void mlir::python::populateIRCore(nb::module_ &m) {
return PyAttribute(context.get()->getRef(), attr).maybeDownCast();
},
nb::arg("asm"), nb::arg("context") = nb::none(),
- "Parses an attribute from an assembly form. Raises an MLIRError on "
+ "Parses an attribute from an assembly form. Raises an `MLIRError` on "
"failure.")
.def_prop_ro(
"context",
[](PyAttribute &self) -> nb::typed<nb::object, PyMlirContext> {
return self.getContext().getObject();
},
- "Context that owns the Attribute")
- .def_prop_ro("type",
- [](PyAttribute &self) -> nb::typed<nb::object, PyType> {
- return PyType(self.getContext(),
- mlirAttributeGetType(self))
- .maybeDownCast();
- })
+ "Context that owns the `Attribute`.")
+ .def_prop_ro(
+ "type",
+ [](PyAttribute &self) -> nb::typed<nb::object, PyType> {
+ return PyType(self.getContext(), mlirAttributeGetType(self))
+ .maybeDownCast();
+ },
+ "Returns the type of the `Attribute`.")
.def(
"get_named",
[](PyAttribute &self, std::string name) {
return PyNamedAttribute(self, std::move(name));
},
- nb::keep_alive<0, 1>(), "Binds a name to the attribute")
- .def("__eq__",
- [](PyAttribute &self, PyAttribute &other) { return self == other; })
- .def("__eq__", [](PyAttribute &self, nb::object &other) { return false; })
- .def("__hash__",
- [](PyAttribute &self) {
- return static_cast<size_t>(llvm::hash_value(self.get().ptr));
- })
+ nb::keep_alive<0, 1>(),
+ R"(
+ Binds a name to the attribute, creating a `NamedAttribute`.
+
+ Args:
+ name: The name to bind to the `Attribute`.
+
+ Returns:
+ A `NamedAttribute` with the given name and this attribute.)")
+ .def(
+ "__eq__",
+ [](PyAttribute &self, PyAttribute &other) { return self == other; },
+ "Compares two attributes for equality.")
+ .def(
+ "__eq__", [](PyAttribute &self, nb::object &other) { return false; },
+ "Compares attribute with non-attribute object (always returns "
+ "False).")
+ .def(
+ "__hash__",
+ [](PyAttribute &self) {
+ return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+ },
+ "Returns the hash value of the attribute.")
.def(
"dump", [](PyAttribute &self) { mlirAttributeDump(self); },
kDumpDocstring)
@@ -4125,61 +4426,69 @@ void mlir::python::populateIRCore(nb::module_ &m) {
return printAccum.join();
},
"Returns the assembly form of the Attribute.")
- .def("__repr__",
- [](PyAttribute &self) {
- // Generally, assembly formats are not printed for __repr__ because
- // this can cause exceptionally long debug output and exceptions.
- // However, attribute values are generally considered useful and
- // are printed. This may need to be re-evaluated if debug dumps end
- // up being excessive.
- PyPrintAccumulator printAccum;
- printAccum.parts.append("Attribute(");
- mlirAttributePrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- printAccum.parts.append(")");
- return printAccum.join();
- })
- .def_prop_ro("typeid",
- [](PyAttribute &self) {
- MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
- assert(!mlirTypeIDIsNull(mlirTypeID) &&
- "mlirTypeID was expected to be non-null.");
- return PyTypeID(mlirTypeID);
- })
- .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> {
- return self.maybeDownCast();
- });
+ .def(
+ "__repr__",
+ [](PyAttribute &self) {
+ // Generally, assembly formats are not printed for __repr__ because
+ // this can cause exceptionally long debug output and exceptions.
+ // However, attribute values are generally considered useful and
+ // are printed. This may need to be re-evaluated if debug dumps end
+ // up being excessive.
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("Attribute(");
+ mlirAttributePrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ },
+ "Returns a string representation of the attribute.")
+ .def_prop_ro(
+ "typeid",
+ [](PyAttribute &self) {
+ MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
+ assert(!mlirTypeIDIsNull(mlirTypeID) &&
+ "mlirTypeID was expected to be non-null.");
+ return PyTypeID(mlirTypeID);
+ },
+ "Returns the `TypeID` of the attribute.")
+ .def(
+ MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> {
+ return self.maybeDownCast();
+ },
+ "Downcasts the attribute to a more specific attribute if possible.");
//----------------------------------------------------------------------------
// Mapping of PyNamedAttribute
//----------------------------------------------------------------------------
nb::class_<PyNamedAttribute>(m, "NamedAttribute")
- .def("__repr__",
- [](PyNamedAttribute &self) {
- PyPrintAccumulator printAccum;
- printAccum.parts.append("NamedAttribute(");
- printAccum.parts.append(
- nb::str(mlirIdentifierStr(self.namedAttr.name).data,
- mlirIdentifierStr(self.namedAttr.name).length));
- printAccum.parts.append("=");
- mlirAttributePrint(self.namedAttr.attribute,
- printAccum.getCallback(),
- printAccum.getUserData());
- printAccum.parts.append(")");
- return printAccum.join();
- })
+ .def(
+ "__repr__",
+ [](PyNamedAttribute &self) {
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("NamedAttribute(");
+ printAccum.parts.append(
+ nb::str(mlirIdentifierStr(self.namedAttr.name).data,
+ mlirIdentifierStr(self.namedAttr.name).length));
+ printAccum.parts.append("=");
+ mlirAttributePrint(self.namedAttr.attribute,
+ printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ },
+ "Returns a string representation of the named attribute.")
.def_prop_ro(
"name",
[](PyNamedAttribute &self) {
return mlirIdentifierStr(self.namedAttr.name);
},
- "The name of the NamedAttribute binding")
+ "The name of the `NamedAttribute` binding.")
.def_prop_ro(
"attr",
[](PyNamedAttribute &self) { return self.namedAttr.attribute; },
nb::keep_alive<0, 1>(), nb::sig("def attr(self) -> Attribute"),
- "The underlying generic attribute of the NamedAttribute binding");
+ "The underlying generic attribute of the `NamedAttribute` binding.");
//----------------------------------------------------------------------------
// Mapping of PyType.
@@ -4188,9 +4497,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
// Delegate to the PyType copy constructor, which will also lifetime
// extend the backing context which owns the MlirType.
.def(nb::init<PyType &>(), nb::arg("cast_from_type"),
- "Casts the passed type to the generic Type")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
+ "Casts the passed type to the generic `Type`.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule,
+ "Gets a capsule wrapping the `MlirType`.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule,
+ "Creates a Type from a capsule wrapping `MlirType`.")
.def_static(
"parse",
[](std::string typeSpec,
@@ -4203,21 +4514,31 @@ void mlir::python::populateIRCore(nb::module_ &m) {
return PyType(context.get()->getRef(), type).maybeDownCast();
},
nb::arg("asm"), nb::arg("context") = nb::none(),
- kContextParseTypeDocstring)
+ R"(
+ Parses the assembly form of a type.
+
+ Returns a Type object or raises an `MLIRError` if the type cannot be parsed.
+
+ See also: https://mlir.llvm.org/docs/LangRef/#type-system)")
.def_prop_ro(
"context",
[](PyType &self) -> nb::typed<nb::object, PyMlirContext> {
return self.getContext().getObject();
},
- "Context that owns the Type")
- .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
+ "Context that owns the `Type`.")
+ .def(
+ "__eq__", [](PyType &self, PyType &other) { return self == other; },
+ "Compares two types for equality.")
.def(
"__eq__", [](PyType &self, nb::object &other) { return false; },
- nb::arg("other").none())
- .def("__hash__",
- [](PyType &self) {
- return static_cast<size_t>(llvm::hash_value(self.get().ptr));
- })
+ nb::arg("other").none(),
+ "Compares type with non-type object (always returns False).")
+ .def(
+ "__hash__",
+ [](PyType &self) {
+ return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+ },
+ "Returns the hash value of the `Type`.")
.def(
"dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
.def(
@@ -4228,60 +4549,84 @@ void mlir::python::populateIRCore(nb::module_ &m) {
printAccum.getUserData());
return printAccum.join();
},
- "Returns the assembly form of the type.")
- .def("__repr__",
- [](PyType &self) {
- // Generally, assembly formats are not printed for __repr__ because
- // this can cause exceptionally long debug output and exceptions.
- // However, types are an exception as they typically have compact
- // assembly forms and printing them is useful.
- PyPrintAccumulator printAccum;
- printAccum.parts.append("Type(");
- mlirTypePrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- printAccum.parts.append(")");
- return printAccum.join();
- })
- .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](PyType &self) -> nb::typed<nb::object, PyType> {
- return self.maybeDownCast();
- })
- .def_prop_ro("typeid", [](PyType &self) {
- MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
- if (!mlirTypeIDIsNull(mlirTypeID))
- return PyTypeID(mlirTypeID);
- auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
- throw nb::value_error(
- (origRepr + llvm::Twine(" has no typeid.")).str().c_str());
- });
+ "Returns the assembly form of the `Type`.")
+ .def(
+ "__repr__",
+ [](PyType &self) {
+ // Generally, assembly formats are not printed for __repr__ because
+ // this can cause exceptionally long debug output and exceptions.
+ // However, types are an exception as they typically have compact
+ // assembly forms and printing them is useful.
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("Type(");
+ mlirTypePrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ },
+ "Returns a string representation of the `Type`.")
+ .def(
+ MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyType &self) -> nb::typed<nb::object, PyType> {
+ return self.maybeDownCast();
+ },
+ "Downcasts the Type to a more specific `Type` if possible.")
+ .def_prop_ro(
+ "typeid",
+ [](PyType &self) {
+ MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
+ if (!mlirTypeIDIsNull(mlirTypeID))
+ return PyTypeID(mlirTypeID);
+ auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
+ throw nb::value_error(
+ (origRepr + llvm::Twine(" has no typeid.")).str().c_str());
+ },
+ "Returns the `TypeID` of the `Type`, or raises `ValueError` if "
+ "`Type` has no "
+ "`TypeID`.");
//----------------------------------------------------------------------------
// Mapping of PyTypeID.
//----------------------------------------------------------------------------
nb::class_<PyTypeID>(m, "TypeID")
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule)
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule)
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule,
+ "Gets a capsule wrapping the `MlirTypeID`.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule,
+ "Creates a `TypeID` from a capsule wrapping `MlirTypeID`.")
// Note, this tests whether the underlying TypeIDs are the same,
// not whether the wrapper MlirTypeIDs are the same, nor whether
// the Python objects are the same (i.e., PyTypeID is a value type).
- .def("__eq__",
- [](PyTypeID &self, PyTypeID &other) { return self == other; })
- .def("__eq__",
- [](PyTypeID &self, const nb::object &other) { return false; })
+ .def(
+ "__eq__",
+ [](PyTypeID &self, PyTypeID &other) { return self == other; },
+ "Compares two `TypeID`s for equality.")
+ .def(
+ "__eq__",
+ [](PyTypeID &self, const nb::object &other) { return false; },
+ "Compares TypeID with non-TypeID object (always returns False).")
// Note, this gives the hash value of the underlying TypeID, not the
// hash value of the Python object, nor the hash value of the
// MlirTypeID wrapper.
- .def("__hash__", [](PyTypeID &self) {
- return static_cast<size_t>(mlirTypeIDHashValue(self));
- });
+ .def(
+ "__hash__",
+ [](PyTypeID &self) {
+ return static_cast<size_t>(mlirTypeIDHashValue(self));
+ },
+ "Returns the hash value of the `TypeID`.");
//----------------------------------------------------------------------------
// Mapping of Value.
//----------------------------------------------------------------------------
- nb::class_<PyValue>(m, "Value")
- .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"))
- .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
- .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
+ m.attr("_T") = nb::type_var("_T", nb::arg("bound") = m.attr("Type"));
+
+ nb::class_<PyValue>(m, "Value", nb::is_generic(),
+ nb::sig("class Value(Generic[_T])"))
+ .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"),
+ "Creates a Value reference from another `Value`.")
+ .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule,
+ "Gets a capsule wrapping the `MlirValue`.")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule,
+ "Creates a `Value` from a capsule wrapping `MlirValue`.")
.def_prop_ro(
"context",
[](PyValue &self) -> nb::typed<nb::object, PyMlirContext> {
@@ -4312,23 +4657,30 @@ void mlir::python::populateIRCore(nb::module_ &m) {
assert(false && "Value must be a block argument or an op result");
return nb::none();
},
- // clang-format off
- nb::sig("def owner(self) -> Operation | Block | None"))
- // clang-format on
- .def_prop_ro("uses",
- [](PyValue &self) {
- return PyOpOperandIterator(
- mlirValueGetFirstUse(self.get()));
- })
- .def("__eq__",
- [](PyValue &self, PyValue &other) {
- return self.get().ptr == other.get().ptr;
- })
- .def("__eq__", [](PyValue &self, nb::object other) { return false; })
- .def("__hash__",
- [](PyValue &self) {
- return static_cast<size_t>(llvm::hash_value(self.get().ptr));
- })
+ "Returns the owner of the value (`Operation` for results, `Block` "
+ "for "
+ "arguments).")
+ .def_prop_ro(
+ "uses",
+ [](PyValue &self) {
+ return PyOpOperandIterator(mlirValueGetFirstUse(self.get()));
+ },
+ "Returns an iterator over uses of this value.")
+ .def(
+ "__eq__",
+ [](PyValue &self, PyValue &other) {
+ return self.get().ptr == other.get().ptr;
+ },
+ "Compares two values for pointer equality.")
+ .def(
+ "__eq__", [](PyValue &self, nb::object other) { return false; },
+ "Compares value with non-value object (always returns False).")
+ .def(
+ "__hash__",
+ [](PyValue &self) {
+ return static_cast<size_t>(llvm::hash_value(self.get().ptr));
+ },
+ "Returns the hash value of the value.")
.def(
"__str__",
[](PyValue &self) {
@@ -4339,7 +4691,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
printAccum.parts.append(")");
return printAccum.join();
},
- kValueDunderStrDocstring)
+ R"(
+ Returns the string form of the value.
+
+ If the value is a block argument, this is the assembly form of its type and the
+ position in the argument list. If the value is an operation result, this is
+ equivalent to printing the operation that produced it.
+ )")
.def(
"get_name",
[](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) {
@@ -4359,7 +4717,16 @@ void mlir::python::populateIRCore(nb::module_ &m) {
return printAccum.join();
},
nb::arg("use_local_scope") = false,
- nb::arg("use_name_loc_as_prefix") = false)
+ nb::arg("use_name_loc_as_prefix") = false,
+ R"(
+ Returns the string form of value as an operand.
+
+ Args:
+ use_local_scope: Whether to use local scope for naming.
+ use_name_loc_as_prefix: Whether to use the location attribute (NameLoc) as prefix.
+
+ Returns:
+ The value's name as it appears in IR (e.g., `%0`, `%arg0`).)")
.def(
"get_name",
[](PyValue &self, PyAsmState &state) {
@@ -4370,25 +4737,30 @@ void mlir::python::populateIRCore(nb::module_ &m) {
printAccum.getUserData());
return printAccum.join();
},
- nb::arg("state"), kGetNameAsOperand)
- .def_prop_ro("type",
- [](PyValue &self) -> nb::typed<nb::object, PyType> {
- return PyType(self.getParentOperation()->getContext(),
- mlirValueGetType(self.get()))
- .maybeDownCast();
- })
+ nb::arg("state"),
+ "Returns the string form of value as an operand (i.e., the ValueID).")
+ .def_prop_ro(
+ "type",
+ [](PyValue &self) -> nb::typed<nb::object, PyType> {
+ return PyType(self.getParentOperation()->getContext(),
+ mlirValueGetType(self.get()))
+ .maybeDownCast();
+ },
+ "Returns the type of the value.")
.def(
"set_type",
[](PyValue &self, const PyType &type) {
- return mlirValueSetType(self.get(), type);
+ mlirValueSetType(self.get(), type);
},
- nb::arg("type"))
+ nb::arg("type"), "Sets the type of the value.",
+ nb::sig("def set_type(self, type: _T)"))
.def(
"replace_all_uses_with",
[](PyValue &self, PyValue &with) {
mlirValueReplaceAllUsesOfWith(self.get(), with.get());
},
- kValueReplaceAllUsesWithDocstring)
+ "Replace all uses of value with the new value, updating anything in "
+ "the IR that uses `self` to use the other value instead.")
.def(
"replace_all_uses_except",
[](PyValue &self, PyValue &with, PyOperation &exception) {
@@ -4434,10 +4806,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
nb::arg("with_"), nb::arg("exceptions"),
kValueReplaceAllUsesExceptDocstring)
- .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](PyValue &self) -> nb::typed<nb::object, PyValue> {
- return self.maybeDownCast();
- })
+ .def(
+ MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyValue &self) -> nb::typed<nb::object, PyValue> {
+ return self.maybeDownCast();
+ },
+ "Downcasts the `Value` to a more specific kind if possible.")
.def_prop_ro(
"location",
[](MlirValue self) {
@@ -4445,7 +4819,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
PyMlirContext::forContext(mlirValueGetContext(self)),
mlirValueGetLocation(self));
},
- "Returns the source location the value");
+ "Returns the source location of the value.");
PyBlockArgument::bind(m);
PyOpResult::bind(m);
@@ -4453,43 +4827,105 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::class_<PyAsmState>(m, "AsmState")
.def(nb::init<PyValue &, bool>(), nb::arg("value"),
- nb::arg("use_local_scope") = false)
+ nb::arg("use_local_scope") = false,
+ R"(
+ Creates an `AsmState` for consistent SSA value naming.
+
+ Args:
+ value: The value to create state for.
+ use_local_scope: Whether to use local scope for naming.)")
.def(nb::init<PyOperationBase &, bool>(), nb::arg("op"),
- nb::arg("use_local_scope") = false);
+ nb::arg("use_local_scope") = false,
+ R"(
+ Creates an AsmState for consistent SSA value naming.
+
+ Args:
+ op: The operation to create state for.
+ use_local_scope: Whether to use local scope for naming.)");
//----------------------------------------------------------------------------
// Mapping of SymbolTable.
//----------------------------------------------------------------------------
nb::class_<PySymbolTable>(m, "SymbolTable")
- .def(nb::init<PyOperationBase &>())
- .def("__getitem__",
- [](PySymbolTable &self,
- const std::string &name) -> nb::typed<nb::object, PyOpView> {
- return self.dunderGetItem(name);
- })
- .def("insert", &PySymbolTable::insert, nb::arg("operation"))
- .def("erase", &PySymbolTable::erase, nb::arg("operation"))
- .def("__delitem__", &PySymbolTable::dunderDel)
- .def("__contains__",
- [](PySymbolTable &table, const std::string &name) {
- return !mlirOperationIsNull(mlirSymbolTableLookup(
- table, mlirStringRefCreate(name.data(), name.length())));
- })
+ .def(nb::init<PyOperationBase &>(),
+ R"(
+ Creates a symbol table for an operation.
+
+ Args:
+ operation: The `Operation` that defines a symbol table (e.g., a `ModuleOp`).
+
+ Raises:
+ TypeError: If the operation is not a symbol table.)")
+ .def(
+ "__getitem__",
+ [](PySymbolTable &self,
+ const std::string &name) -> nb::typed<nb::object, PyOpView> {
+ return self.dunderGetItem(name);
+ },
+ R"(
+ Looks up a symbol by name in the symbol table.
+
+ Args:
+ name: The name of the symbol to look up.
+
+ Returns:
+ The operation defining the symbol.
+
+ Raises:
+ KeyError: If the symbol is not found.)")
+ .def("insert", &PySymbolTable::insert, nb::arg("operation"),
+ R"(
+ Inserts a symbol operation into the symbol table.
+
+ Args:
+ operation: An operation with a symbol name to insert.
+
+ Returns:
+ The symbol name attribute of the inserted operation.
+
+ Raises:
+ ValueError: If the operation does not have a symbol name.)")
+ .def("erase", &PySymbolTable::erase, nb::arg("operation"),
+ R"(
+ Erases a symbol operation from the symbol table.
+
+ Args:
+ operation: The symbol operation to erase.
+
+ Note:
+ The operation is also erased from the IR and invalidated.)")
+ .def("__delitem__", &PySymbolTable::dunderDel,
+ "Deletes a symbol by name from the symbol table.")
+ .def(
+ "__contains__",
+ [](PySymbolTable &table, const std::string &name) {
+ return !mlirOperationIsNull(mlirSymbolTableLookup(
+ table, mlirStringRefCreate(name.data(), name.length())));
+ },
+ "Checks if a symbol with the given name exists in the table.")
// Static helpers.
.def_static("set_symbol_name", &PySymbolTable::setSymbolName,
- nb::arg("symbol"), nb::arg("name"))
+ nb::arg("symbol"), nb::arg("name"),
+ "Sets the symbol name for a symbol operation.")
.def_static("get_symbol_name", &PySymbolTable::getSymbolName,
- nb::arg("symbol"))
+ nb::arg("symbol"),
+ "Gets the symbol name from a symbol operation.")
.def_static("get_visibility", &PySymbolTable::getVisibility,
- nb::arg("symbol"))
+ nb::arg("symbol"),
+ "Gets the visibility attribute of a symbol operation.")
.def_static("set_visibility", &PySymbolTable::setVisibility,
- nb::arg("symbol"), nb::arg("visibility"))
+ nb::arg("symbol"), nb::arg("visibility"),
+ "Sets the visibility attribute of a symbol operation.")
.def_static("replace_all_symbol_uses",
&PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"),
- nb::arg("new_symbol"), nb::arg("from_op"))
+ nb::arg("new_symbol"), nb::arg("from_op"),
+ "Replaces all uses of a symbol with a new symbol name within "
+ "the given operation.")
.def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
nb::arg("from_op"), nb::arg("all_sym_uses_visible"),
- nb::arg("callback"));
+ nb::arg("callback"),
+ "Walks symbol tables starting from an operation with a "
+ "callback function.");
// Container bindings.
PyBlockArgumentList::bind(m);
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index a14f09f..ba767ad 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -24,6 +24,8 @@ using namespace mlir::python;
NB_MODULE(_mlir, m) {
m.doc() = "MLIR Python Native Extension";
+ m.attr("T") = nb::type_var("T");
+ m.attr("U") = nb::type_var("U");
nb::class_<PyGlobals>(m, "_Globals")
.def_prop_rw("dialect_search_modules",
@@ -102,6 +104,10 @@ NB_MODULE(_mlir, m) {
return opClass;
});
},
+ // clang-format off
+ nb::sig("def register_operation(dialect_class: type, *, replace: bool = False) "
+ "-> typing.Callable[[type[T]], type[T]]"),
+ // clang-format on
"dialect_class"_a, nb::kw_only(), "replace"_a = false,
"Produce a class decorator for registering an Operation class as part of "
"a dialect");
@@ -114,6 +120,10 @@ NB_MODULE(_mlir, m) {
return typeCaster;
});
},
+ // clang-format off
+ nb::sig("def register_type_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) "
+ "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"),
+ // clang-format on
"typeid"_a, nb::kw_only(), "replace"_a = false,
"Register a type caster for casting MLIR types to custom user types.");
m.def(
@@ -126,6 +136,10 @@ NB_MODULE(_mlir, m) {
return valueCaster;
});
},
+ // clang-format off
+ nb::sig("def register_value_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) "
+ "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"),
+ // clang-format on
"typeid"_a, nb::kw_only(), "replace"_a = false,
"Register a value caster for casting MLIR values to custom user values.");
diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/lib/Bindings/Python/NanobindUtils.h
index 64ea4329..aea195f 100644
--- a/mlir/lib/Bindings/Python/NanobindUtils.h
+++ b/mlir/lib/Bindings/Python/NanobindUtils.h
@@ -19,6 +19,7 @@
#include "llvm/Support/raw_ostream.h"
#include <string>
+#include <typeinfo>
#include <variant>
template <>
@@ -344,7 +345,16 @@ public:
/// Binds the indexing and length methods in the Python class.
static void bind(nanobind::module_ &m) {
- auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName)
+ const std::type_info &elemTy = typeid(ElementTy);
+ PyObject *elemTyInfo = nanobind::detail::nb_type_lookup(&elemTy);
+ assert(elemTyInfo &&
+ "expected nb_type_lookup to succeed for Sliceable elemTy");
+ nanobind::handle elemTyName = nanobind::detail::nb_type_name(elemTyInfo);
+ std::string sig = std::string("class ") + Derived::pyClassName +
+ "(collections.abc.Sequence[" +
+ nanobind::cast<std::string>(elemTyName) + "])";
+ auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName,
+ nanobind::sig(sig.c_str()))
.def("__add__", &Sliceable::dunderAdd);
Derived::bindDerived(clazz);
@@ -395,7 +405,6 @@ public:
/// Hook for derived classes willing to bind more methods.
static void bindDerived(ClassTy &) {}
-private:
intptr_t startIndex;
intptr_t length;
intptr_t step;
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 1659437..0ac5fc5 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -27,6 +27,7 @@
#include <cstddef>
#include <cstdint>
+#include <deque>
#include <list>
#include <memory>
#include <numeric>
@@ -830,6 +831,23 @@ namespace {
/// This class provides support for reading attribute and type entries from the
/// bytecode. Attribute and Type entries are read lazily on demand, so we use
/// this reader to manage when to actually parse them from the bytecode.
+///
+/// The parsing of attributes & types are generally recursive, this can lead to
+/// stack overflows for deeply nested structures, so we track a few extra pieces
+/// of information to avoid this:
+///
+/// - `depth`: The current depth while parsing nested attributes. We defer on
+/// parsing deeply nested attributes to avoid potential stack overflows. The
+/// deferred parsing is achieved by reporting a failure when parsing a nested
+/// attribute/type and registering the index of the encountered attribute/type
+/// in the deferred parsing worklist. Hence, a failure with deffered entry
+/// does not constitute a failure, it also requires that folks return on
+/// first failure rather than attempting additional parses.
+/// - `deferredWorklist`: A list of attribute/type indices that we could not
+/// parse due to hitting the depth limit. The worklist is used to capture the
+/// indices of attributes/types that need to be parsed/reparsed when we hit
+/// the depth limit. This enables moving the tracking of what needs to be
+/// parsed to the heap.
class AttrTypeReader {
/// This class represents a single attribute or type entry.
template <typename T>
@@ -863,12 +881,34 @@ public:
ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData);
+ LogicalResult readAttribute(uint64_t index, Attribute &result,
+ uint64_t depth = 0) {
+ return readEntry(attributes, index, result, "attribute", depth);
+ }
+
+ LogicalResult readType(uint64_t index, Type &result, uint64_t depth = 0) {
+ return readEntry(types, index, result, "type", depth);
+ }
+
/// Resolve the attribute or type at the given index. Returns nullptr on
/// failure.
- Attribute resolveAttribute(size_t index) {
- return resolveEntry(attributes, index, "Attribute");
+ Attribute resolveAttribute(size_t index, uint64_t depth = 0) {
+ return resolveEntry(attributes, index, "Attribute", depth);
+ }
+ Type resolveType(size_t index, uint64_t depth = 0) {
+ return resolveEntry(types, index, "Type", depth);
+ }
+
+ Attribute getAttributeOrSentinel(size_t index) {
+ if (index >= attributes.size())
+ return nullptr;
+ return attributes[index].entry;
+ }
+ Type getTypeOrSentinel(size_t index) {
+ if (index >= types.size())
+ return nullptr;
+ return types[index].entry;
}
- Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); }
/// Parse a reference to an attribute or type using the given reader.
LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) {
@@ -909,23 +949,33 @@ public:
llvm::getTypeName<T>(), ", but got: ", baseResult);
}
+ /// Add an index to the deferred worklist for re-parsing.
+ void addDeferredParsing(uint64_t index) { deferredWorklist.push_back(index); }
+
private:
/// Resolve the given entry at `index`.
template <typename T>
- T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
- StringRef entryType);
+ T resolveEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index,
+ StringRef entryType, uint64_t depth = 0);
- /// Parse an entry using the given reader that was encoded using the textual
- /// assembly format.
+ /// Read the entry at the given index, returning failure if the entry is not
+ /// yet resolved.
template <typename T>
- LogicalResult parseAsmEntry(T &result, EncodingReader &reader,
- StringRef entryType);
+ LogicalResult readEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index,
+ T &result, StringRef entryType, uint64_t depth);
/// Parse an entry using the given reader that was encoded using a custom
/// bytecode format.
template <typename T>
LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
- StringRef entryType);
+ StringRef entryType, uint64_t index,
+ uint64_t depth);
+
+ /// Parse an entry using the given reader that was encoded using the textual
+ /// assembly format.
+ template <typename T>
+ LogicalResult parseAsmEntry(T &result, EncodingReader &reader,
+ StringRef entryType);
/// The string section reader used to resolve string references when parsing
/// custom encoded attribute/type entries.
@@ -951,6 +1001,10 @@ private:
/// Reference to the parser configuration.
const ParserConfig &parserConfig;
+
+ /// Worklist for deferred attribute/type parsing. This is used to handle
+ /// deeply nested structures like CallSiteLoc iteratively.
+ std::vector<uint64_t> deferredWorklist;
};
class DialectReader : public DialectBytecodeReader {
@@ -959,10 +1013,11 @@ public:
const StringSectionReader &stringReader,
const ResourceSectionReader &resourceReader,
const llvm::StringMap<BytecodeDialect *> &dialectsMap,
- EncodingReader &reader, uint64_t &bytecodeVersion)
+ EncodingReader &reader, uint64_t &bytecodeVersion,
+ uint64_t depth = 0)
: attrTypeReader(attrTypeReader), stringReader(stringReader),
resourceReader(resourceReader), dialectsMap(dialectsMap),
- reader(reader), bytecodeVersion(bytecodeVersion) {}
+ reader(reader), bytecodeVersion(bytecodeVersion), depth(depth) {}
InFlightDiagnostic emitError(const Twine &msg) const override {
return reader.emitError(msg);
@@ -998,14 +1053,40 @@ public:
// IR
//===--------------------------------------------------------------------===//
+ /// The maximum depth to eagerly parse nested attributes/types before
+ /// deferring.
+ static constexpr uint64_t maxAttrTypeDepth = 5;
+
LogicalResult readAttribute(Attribute &result) override {
- return attrTypeReader.parseAttribute(reader, result);
+ uint64_t index;
+ if (failed(reader.parseVarInt(index)))
+ return failure();
+ if (depth > maxAttrTypeDepth) {
+ if (Attribute attr = attrTypeReader.getAttributeOrSentinel(index)) {
+ result = attr;
+ return success();
+ }
+ attrTypeReader.addDeferredParsing(index);
+ return failure();
+ }
+ return attrTypeReader.readAttribute(index, result, depth + 1);
}
LogicalResult readOptionalAttribute(Attribute &result) override {
return attrTypeReader.parseOptionalAttribute(reader, result);
}
LogicalResult readType(Type &result) override {
- return attrTypeReader.parseType(reader, result);
+ uint64_t index;
+ if (failed(reader.parseVarInt(index)))
+ return failure();
+ if (depth > maxAttrTypeDepth) {
+ if (Type type = attrTypeReader.getTypeOrSentinel(index)) {
+ result = type;
+ return success();
+ }
+ attrTypeReader.addDeferredParsing(index);
+ return failure();
+ }
+ return attrTypeReader.readType(index, result, depth + 1);
}
FailureOr<AsmDialectResourceHandle> readResourceHandle() override {
@@ -1095,6 +1176,7 @@ private:
const llvm::StringMap<BytecodeDialect *> &dialectsMap;
EncodingReader &reader;
uint64_t &bytecodeVersion;
+ uint64_t depth;
};
/// Wraps the properties section and handles reading properties out of it.
@@ -1238,69 +1320,112 @@ LogicalResult AttrTypeReader::initialize(
}
template <typename T>
-T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
- StringRef entryType) {
+T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries,
+ uint64_t index, StringRef entryType,
+ uint64_t depth) {
if (index >= entries.size()) {
emitError(fileLoc) << "invalid " << entryType << " index: " << index;
return {};
}
- // If the entry has already been resolved, there is nothing left to do.
- Entry<T> &entry = entries[index];
- if (entry.entry)
- return entry.entry;
+ // Fast path: Try direct parsing without worklist overhead. This handles the
+ // common case where there are no deferred dependencies.
+ assert(deferredWorklist.empty());
+ T result;
+ if (succeeded(readEntry(entries, index, result, entryType, depth))) {
+ assert(deferredWorklist.empty());
+ return result;
+ }
+ if (deferredWorklist.empty()) {
+ // Failed with no deferred entries is error.
+ return T();
+ }
- // Parse the entry.
- EncodingReader reader(entry.data, fileLoc);
+ // Slow path: Use worklist to handle deferred dependencies. Use a deque to
+ // iteratively resolve entries with dependencies.
+ // - Pop from front to process
+ // - Push new dependencies to front (depth-first)
+ // - Move failed entries to back (retry after dependencies)
+ std::deque<size_t> worklist;
+ llvm::DenseSet<size_t> inWorklist;
- // Parse based on how the entry was encoded.
- if (entry.hasCustomEncoding) {
- if (failed(parseCustomEntry(entry, reader, entryType)))
- return T();
- } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) {
- return T();
+ // Add the original index and any dependencies from the fast path attempt.
+ worklist.push_back(index);
+ inWorklist.insert(index);
+ for (uint64_t idx : llvm::reverse(deferredWorklist)) {
+ if (inWorklist.insert(idx).second)
+ worklist.push_front(idx);
}
- if (!reader.empty()) {
- reader.emitError("unexpected trailing bytes after " + entryType + " entry");
- return T();
+ while (!worklist.empty()) {
+ size_t currentIndex = worklist.front();
+ worklist.pop_front();
+
+ // Clear the deferred worklist before parsing to capture any new entries.
+ deferredWorklist.clear();
+
+ T result;
+ if (succeeded(readEntry(entries, currentIndex, result, entryType, depth))) {
+ inWorklist.erase(currentIndex);
+ continue;
+ }
+
+ if (deferredWorklist.empty()) {
+ // Parsing failed with no deferred entries which implies an error.
+ return T();
+ }
+
+ // Move this entry to the back to retry after dependencies.
+ worklist.push_back(currentIndex);
+
+ // Add dependencies to the front (in reverse so they maintain order).
+ for (uint64_t idx : llvm::reverse(deferredWorklist)) {
+ if (inWorklist.insert(idx).second)
+ worklist.push_front(idx);
+ }
+ deferredWorklist.clear();
}
- return entry.entry;
+ return entries[index].entry;
}
template <typename T>
-LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
- StringRef entryType) {
- StringRef asmStr;
- if (failed(reader.parseNullTerminatedString(asmStr)))
- return failure();
+LogicalResult AttrTypeReader::readEntry(SmallVectorImpl<Entry<T>> &entries,
+ uint64_t index, T &result,
+ StringRef entryType, uint64_t depth) {
+ if (index >= entries.size())
+ return emitError(fileLoc) << "invalid " << entryType << " index: " << index;
- // Invoke the MLIR assembly parser to parse the entry text.
- size_t numRead = 0;
- MLIRContext *context = fileLoc->getContext();
- if constexpr (std::is_same_v<T, Type>)
- result =
- ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true);
- else
- result = ::parseAttribute(asmStr, context, Type(), &numRead,
- /*isKnownNullTerminated=*/true);
- if (!result)
+ // If the entry has already been resolved, return it.
+ Entry<T> &entry = entries[index];
+ if (entry.entry) {
+ result = entry.entry;
+ return success();
+ }
+
+ // If the entry hasn't been resolved, try to parse it.
+ EncodingReader reader(entry.data, fileLoc);
+ LogicalResult parseResult =
+ entry.hasCustomEncoding
+ ? parseCustomEntry(entry, reader, entryType, index, depth)
+ : parseAsmEntry(entry.entry, reader, entryType);
+ if (failed(parseResult))
return failure();
- // Ensure there weren't dangling characters after the entry.
- if (numRead != asmStr.size()) {
- return reader.emitError("trailing characters found after ", entryType,
- " assembly format: ", asmStr.drop_front(numRead));
- }
+ if (!reader.empty())
+ return reader.emitError("unexpected trailing bytes after " + entryType +
+ " entry");
+
+ result = entry.entry;
return success();
}
template <typename T>
LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
EncodingReader &reader,
- StringRef entryType) {
+ StringRef entryType,
+ uint64_t index, uint64_t depth) {
DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap,
- reader, bytecodeVersion);
+ reader, bytecodeVersion, depth);
if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
return failure();
@@ -1350,6 +1475,33 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
return success(!!entry.entry);
}
+template <typename T>
+LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
+ StringRef entryType) {
+ StringRef asmStr;
+ if (failed(reader.parseNullTerminatedString(asmStr)))
+ return failure();
+
+ // Invoke the MLIR assembly parser to parse the entry text.
+ size_t numRead = 0;
+ MLIRContext *context = fileLoc->getContext();
+ if constexpr (std::is_same_v<T, Type>)
+ result =
+ ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true);
+ else
+ result = ::parseAttribute(asmStr, context, Type(), &numRead,
+ /*isKnownNullTerminated=*/true);
+ if (!result)
+ return failure();
+
+ // Ensure there weren't dangling characters after the entry.
+ if (numRead != asmStr.size()) {
+ return reader.emitError("trailing characters found after ", entryType,
+ " assembly format: ", asmStr.drop_front(numRead));
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Bytecode Reader
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp
index eaad8a8..bf23176 100644
--- a/mlir/lib/CAPI/Dialect/LLVM.cpp
+++ b/mlir/lib/CAPI/Dialect/LLVM.cpp
@@ -27,6 +27,10 @@ MlirType mlirLLVMPointerTypeGet(MlirContext ctx, unsigned addressSpace) {
return wrap(LLVMPointerType::get(unwrap(ctx), addressSpace));
}
+MlirTypeID mlirLLVMPointerTypeGetTypeID() {
+ return wrap(LLVM::LLVMPointerType::getTypeID());
+}
+
bool mlirTypeIsALLVMPointerType(MlirType type) {
return isa<LLVM::LLVMPointerType>(unwrap(type));
}
@@ -73,6 +77,10 @@ bool mlirTypeIsALLVMStructType(MlirType type) {
return isa<LLVM::LLVMStructType>(unwrap(type));
}
+MlirTypeID mlirLLVMStructTypeGetTypeID() {
+ return wrap(LLVM::LLVMStructType::getTypeID());
+}
+
bool mlirLLVMStructTypeIsLiteral(MlirType type) {
return !cast<LLVM::LLVMStructType>(unwrap(type)).isIdentified();
}
@@ -159,9 +167,8 @@ MlirAttribute mlirLLVMDIExpressionAttrGet(MlirContext ctx, intptr_t nOperations,
return wrap(DIExpressionAttr::get(
unwrap(ctx),
- llvm::map_to_vector(
- unwrapList(nOperations, operations, attrStorage),
- [](Attribute a) { return cast<DIExpressionElemAttr>(a); })));
+ llvm::map_to_vector(unwrapList(nOperations, operations, attrStorage),
+ llvm::CastTo<DIExpressionElemAttr>)));
}
MlirAttribute mlirLLVMDINullTypeAttrGet(MlirContext ctx) {
@@ -202,7 +209,7 @@ MlirAttribute mlirLLVMDICompositeTypeAttrGet(
cast<DIExpressionAttr>(unwrap(allocated)),
cast<DIExpressionAttr>(unwrap(associated)),
llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage),
- [](Attribute a) { return cast<DINodeAttr>(a); })));
+ llvm::CastTo<DINodeAttr>)));
}
MlirAttribute mlirLLVMDIDerivedTypeAttrGet(
@@ -308,7 +315,7 @@ MlirAttribute mlirLLVMDISubroutineTypeAttrGet(MlirContext ctx,
return wrap(DISubroutineTypeAttr::get(
unwrap(ctx), callingConvention,
llvm::map_to_vector(unwrapList(nTypes, types, attrStorage),
- [](Attribute a) { return cast<DITypeAttr>(a); })));
+ llvm::CastTo<DITypeAttr>)));
}
MlirAttribute mlirLLVMDISubprogramAttrGetRecSelf(MlirAttribute recId) {
@@ -338,10 +345,10 @@ MlirAttribute mlirLLVMDISubprogramAttrGet(
cast<DISubroutineTypeAttr>(unwrap(type)),
llvm::map_to_vector(
unwrapList(nRetainedNodes, retainedNodes, nodesStorage),
- [](Attribute a) { return cast<DINodeAttr>(a); }),
+ llvm::CastTo<DINodeAttr>),
llvm::map_to_vector(
unwrapList(nAnnotations, annotations, annotationsStorage),
- [](Attribute a) { return cast<DINodeAttr>(a); })));
+ llvm::CastTo<DINodeAttr>)));
}
MlirAttribute mlirLLVMDISubprogramAttrGetScope(MlirAttribute diSubprogram) {
@@ -398,7 +405,7 @@ MlirAttribute mlirLLVMDIImportedEntityAttrGet(
cast<DINodeAttr>(unwrap(entity)), cast<DIFileAttr>(unwrap(file)), line,
cast<StringAttr>(unwrap(name)),
llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage),
- [](Attribute a) { return cast<DINodeAttr>(a); })));
+ llvm::CastTo<DINodeAttr>)));
}
MlirAttribute mlirLLVMDIAnnotationAttrGet(MlirContext ctx, MlirAttribute name,
diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index 5c2a65d..75c811a 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir-c/Dialect/Linalg.h"
+#include "mlir/CAPI/AffineMap.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -62,9 +63,8 @@ mlirLinalgInferContractionDimensions(MlirOperation op) {
const linalg::ContractionDimensions &contractionDims = *maybeDims;
MLIRContext *ctx = linalgOp.getContext();
- auto toAttr = [&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
- return wrap(
- DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t, 2>(vals)));
+ auto toAttr = [ctx](ArrayRef<unsigned> vals) -> MlirAttribute {
+ return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t>(vals)));
};
result.batch = toAttr(contractionDims.batch);
@@ -75,6 +75,38 @@ mlirLinalgInferContractionDimensions(MlirOperation op) {
return result;
}
+MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
+mlirLinalgInferContractionDimensionsFromMaps(const MlirAffineMap *indexingMaps,
+ size_t numMaps) {
+ MlirLinalgContractionDimensions result{};
+ if (!indexingMaps || numMaps == 0)
+ return result;
+
+ SmallVector<AffineMap, 3> maps;
+ maps.reserve(numMaps);
+ for (size_t i = 0; i < numMaps; ++i) {
+ maps.push_back(unwrap(indexingMaps[i]));
+ }
+
+ FailureOr<linalg::ContractionDimensions> maybeDims =
+ linalg::inferContractionDims(maps);
+ if (failed(maybeDims))
+ return result;
+
+ MLIRContext *ctx = maps[0].getContext();
+
+ auto toAttr = [ctx](ArrayRef<unsigned> vals) -> MlirAttribute {
+ return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t>(vals)));
+ };
+
+ result.batch = toAttr(maybeDims->batch);
+ result.m = toAttr(maybeDims->m);
+ result.n = toAttr(maybeDims->n);
+ result.k = toAttr(maybeDims->k);
+
+ return result;
+}
+
MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op) {
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
if (!linalgOp)
diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
index 2dbb993..81d86ad 100644
--- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
+++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
@@ -22,7 +22,7 @@ using namespace mlir;
extern "C" MlirExecutionEngine
mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths,
const MlirStringRef *sharedLibPaths,
- bool enableObjectDump) {
+ bool enableObjectDump, bool enablePIC) {
static bool initOnce = [] {
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmParser(); // needed for inline_asm
@@ -38,12 +38,17 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths,
auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
if (!tmBuilderOrError) {
- llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
+ llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host "
+ "because: \n";
+ consumeError(tmBuilderOrError.takeError());
return MlirExecutionEngine{nullptr};
}
+ if (enablePIC)
+ tmBuilderOrError->setRelocationModel(llvm::Reloc::PIC_);
auto tmOrError = tmBuilderOrError->createTargetMachine();
if (!tmOrError) {
- llvm::errs() << "Failed to create a TargetMachine for the host\n";
+ llvm::errs() << "Failed to create a TargetMachine for the host because: \n";
+ consumeError(tmOrError.takeError());
return MlirExecutionEngine{nullptr};
}
@@ -60,8 +65,10 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths,
jitOptions.jitCodeGenOptLevel = static_cast<llvm::CodeGenOptLevel>(optLevel);
jitOptions.sharedLibPaths = libPaths;
jitOptions.enableObjectDump = enableObjectDump;
- auto jitOrError = ExecutionEngine::create(unwrap(op), jitOptions);
+ auto jitOrError = ExecutionEngine::create(unwrap(op), jitOptions,
+ std::move(tmOrError.get()));
if (!jitOrError) {
+ llvm::errs() << "Failed to create an ExecutionEngine because: \n";
consumeError(jitOrError.takeError());
return MlirExecutionEngine{nullptr};
}
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index f5f4ed3..e2e236a 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -536,7 +536,7 @@ MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type,
if (failed(memrefType.getStridesAndOffset(strides_, *offset)))
return mlirLogicalResultFailure();
- (void)std::copy(strides_.begin(), strides_.end(), strides);
+ (void)llvm::copy(strides_, strides);
return mlirLogicalResultSuccess();
}
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 1881865..ffcbed8 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -1129,6 +1129,11 @@ void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
blockArg.setType(unwrap(type));
}
+void mlirBlockArgumentSetLocation(MlirValue value, MlirLocation loc) {
+ if (auto blockArg = llvm::dyn_cast<BlockArgument>(unwrap(value)))
+ blockArg.setLoc(unwrap(loc));
+}
+
MlirOperation mlirOpResultGetOwner(MlirValue value) {
return wrap(llvm::dyn_cast<OpResult>(unwrap(value)).getOwner());
}
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3a307a0..7584b17 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -16,8 +16,10 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
@@ -42,6 +44,7 @@ constexpr Chipset kGfx908 = Chipset(9, 0, 8);
constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
constexpr Chipset kGfx942 = Chipset(9, 4, 2);
constexpr Chipset kGfx950 = Chipset(9, 5, 0);
+constexpr Chipset kGfx1250 = Chipset(12, 5, 0);
/// Convert an unsigned number `val` to i32.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
@@ -79,12 +82,6 @@ static Value createI64Constant(ConversionPatternRewriter &rewriter,
return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
}
-static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
- bool value) {
- Type llvmI1 = rewriter.getI1Type();
- return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value);
-}
-
/// Returns the linear index used to access an element in the memref.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
Location loc, MemRefDescriptor &memRefDescriptor,
@@ -509,10 +506,16 @@ struct MemoryCounterWaitOpLowering
if (std::optional<int> exp = adaptor.getExp())
ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
+ if (std::optional<int> tensor = adaptor.getTensor())
+ ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor);
+
rewriter.eraseOp(op);
return success();
}
+ if (adaptor.getTensor())
+ return op.emitOpError("unsupported chipset");
+
auto getVal = [](Attribute attr) -> unsigned {
if (attr)
return cast<IntegerAttr>(attr).getInt();
@@ -684,12 +687,11 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
/// intrinsics having been defined before the AMD backend supported bfloat. We
/// similarly need to pack 8-bit float types into integers as if they were i8
/// (which they are for the backend's purposes).
-static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
- Location loc,
- const TypeConverter *typeConverter,
- bool isUnsigned, Value llvmInput,
- Value mlirInput,
- SmallVector<Value, 4> &operands) {
+static void wmmaPushInputOperand(
+ ConversionPatternRewriter &rewriter, Location loc,
+ const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
+ Value mlirInput, SmallVectorImpl<Value> &operands,
+ SmallVectorImpl<NamedAttribute> &attrs, StringRef attrName) {
Type inputType = llvmInput.getType();
auto vectorType = dyn_cast<VectorType>(inputType);
if (!vectorType) {
@@ -697,10 +699,6 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
return;
}
Type elemType = vectorType.getElementType();
-
- if (elemType.isBF16())
- llvmInput = LLVM::BitcastOp::create(
- rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
if (elemType.getIntOrFloatBitWidth() > 8) {
operands.push_back(llvmInput);
return;
@@ -719,8 +717,8 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
} else if (elemType.isSignedInteger()) {
localIsUnsigned = false;
}
- Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
- operands.push_back(sign);
+ attrs.push_back(
+ NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
}
int64_t numBits =
@@ -751,18 +749,17 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
Location loc,
const TypeConverter *typeConverter,
Value output, int32_t subwordOffset,
- bool clamp, SmallVector<Value, 4> &operands) {
+ bool clamp, SmallVectorImpl<Value> &operands,
+ SmallVectorImpl<NamedAttribute> &attrs) {
Type inputType = output.getType();
auto vectorType = dyn_cast<VectorType>(inputType);
Type elemType = vectorType.getElementType();
- if (elemType.isBF16())
- output = LLVM::BitcastOp::create(
- rewriter, loc, vectorType.clone(rewriter.getI16Type()), output);
operands.push_back(output);
if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
- operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
+ attrs.push_back(
+ NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset)));
} else if (elemType.isInteger(32)) {
- operands.push_back(createI1Constant(rewriter, loc, clamp));
+ attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp)));
}
}
@@ -1160,7 +1157,7 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
k, isRDNA3);
// Handle gfx1250.
- if (chipset == Chipset{12, 5, 0})
+ if (chipset == kGfx1250)
return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
elemDestType, k);
@@ -1311,11 +1308,33 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
return op->emitOpError("WMMA only supported on gfx11 and gfx12");
- // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
- // need to bitcast bfloats to i16 and then bitcast them back.
+ bool isGFX1250 = chipset >= kGfx1250;
+
+ // The WMMA operations represent vectors of bf16s as vectors of i16s
+ // (except on gfx1250), so we need to bitcast bfloats to i16 and then
+ // bitcast them back.
+ auto aType = cast<VectorType>(adaptor.getSourceA().getType());
+ auto bType = cast<VectorType>(adaptor.getSourceB().getType());
+ auto destCType = cast<VectorType>(adaptor.getDestC().getType());
+ bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
+ bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
+ bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
+ bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250;
VectorType rawOutType = outType;
- if (outType.getElementType().isBF16())
+ if (castOutToI16)
rawOutType = outType.clone(rewriter.getI16Type());
+ Value a = adaptor.getSourceA();
+ if (castAToI16)
+ a = LLVM::BitcastOp::create(rewriter, loc,
+ aType.clone(rewriter.getI16Type()), a);
+ Value b = adaptor.getSourceB();
+ if (castBToI16)
+ b = LLVM::BitcastOp::create(rewriter, loc,
+ bType.clone(rewriter.getI16Type()), b);
+ Value destC = adaptor.getDestC();
+ if (castDestCToI16)
+ destC = LLVM::BitcastOp::create(
+ rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
@@ -1325,18 +1344,20 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
return op.emitOpError("subwordOffset not supported on gfx12+");
- OperationState loweredOp(loc, *maybeIntrinsic);
- loweredOp.addTypes(rawOutType);
-
SmallVector<Value, 4> operands;
- wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
- adaptor.getSourceA(), op.getSourceA(), operands);
- wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
- adaptor.getSourceB(), op.getSourceB(), operands);
- wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
- op.getSubwordOffset(), op.getClamp(), operands);
+ SmallVector<NamedAttribute, 4> attrs;
+ wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a,
+ op.getSourceA(), operands, attrs, "signA");
+ wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b,
+ op.getSourceB(), operands, attrs, "signB");
+ wmmaPushOutputOperand(rewriter, loc, typeConverter, destC,
+ op.getSubwordOffset(), op.getClamp(), operands,
+ attrs);
+ OperationState loweredOp(loc, *maybeIntrinsic);
+ loweredOp.addTypes(rawOutType);
loweredOp.addOperands(operands);
+ loweredOp.addAttributes(attrs);
Operation *lowered = rewriter.create(loweredOp);
Operation *maybeCastBack = lowered;
@@ -1492,6 +1513,20 @@ struct ExtPackedFp8OpLowering final
ConversionPatternRewriter &rewriter) const override;
};
+struct ScaledExtPackedMatrixOpLowering final
+ : public ConvertOpToLLVMPattern<ScaledExtPackedMatrixOp> {
+ ScaledExtPackedMatrixOpLowering(const LLVMTypeConverter &converter,
+ Chipset chipset)
+ : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
+ chipset(chipset) {}
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(ScaledExtPackedMatrixOp op,
+ ScaledExtPackedMatrixOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
struct PackedTrunc2xFp8OpLowering final
: public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
@@ -1600,6 +1635,173 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
return success();
}
+int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, int32_t scaleWaveHalf,
+ int32_t firstScaleByte) {
+ // When lowering amdgpu.scaled_ext_packed_matrix to rocdl.cvt.scale.pk*.f*.f*
+ // operations, the attributes blockSize, sourceType, scaleWaveHalf, and
+ // firstScaleByte are merged into a single attribute scaleSel. This is how
+ // those values are merged together. (Note: scaleWaveHalf isn't a high-level
+ // attribute but is derifed from firstScaleLane).
+ assert(llvm::is_contained({16, 32}, blockSize));
+ assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
+
+ const bool isFp8 = bitWidth == 8;
+ const bool isBlock16 = blockSize == 16;
+
+ if (!isFp8) {
+ int32_t bit0 = isBlock16;
+ assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
+ int32_t bit1 = (firstScaleByte == 2) << 1;
+ assert(llvm::is_contained({0, 1}, scaleWaveHalf));
+ int32_t bit2 = scaleWaveHalf << 2;
+ return bit2 | bit1 | bit0;
+ }
+
+ int32_t bit0 = isBlock16;
+ // firstScaleByte is guaranteed to be defined by two bits.
+ assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
+ int32_t bits2and1 = firstScaleByte << 1;
+ assert(llvm::is_contained({0, 1}, scaleWaveHalf));
+ int32_t bit3 = scaleWaveHalf << 3;
+ int32_t bits = bit3 | bits2and1 | bit0;
+ // These are invalid cases.
+ assert(!llvm::is_contained(
+ {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
+ return bits;
+}
+
+static std::optional<StringRef>
+scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
+ using fp4 = Float4E2M1FNType;
+ using fp8 = Float8E4M3FNType;
+ using bf8 = Float8E5M2Type;
+ using fp6 = Float6E2M3FNType;
+ using bf6 = Float6E3M2FNType;
+ if (isa<fp4>(srcElemType)) {
+ if (destElemType.isF16())
+ return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
+ if (destElemType.isBF16())
+ return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
+ if (destElemType.isF32())
+ return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
+ return std::nullopt;
+ }
+ if (isa<fp8>(srcElemType)) {
+ if (destElemType.isF16())
+ return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
+ if (destElemType.isBF16())
+ return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
+ if (destElemType.isF32())
+ return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
+ return std::nullopt;
+ }
+ if (isa<bf8>(srcElemType)) {
+ if (destElemType.isF16())
+ return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
+ if (destElemType.isBF16())
+ return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
+ if (destElemType.isF32())
+ return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
+ return std::nullopt;
+ }
+ if (isa<fp6>(srcElemType)) {
+ if (destElemType.isF16())
+ return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
+ if (destElemType.isBF16())
+ return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
+ if (destElemType.isF32())
+ return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
+ return std::nullopt;
+ }
+ if (isa<bf6>(srcElemType)) {
+ if (destElemType.isF16())
+ return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
+ if (destElemType.isBF16())
+ return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
+ if (destElemType.isF32())
+ return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
+ return std::nullopt;
+ }
+ llvm_unreachable("invalid combination of element types for packed conversion "
+ "instructions");
+}
+
+LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
+ ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ using fp4 = Float4E2M1FNType;
+ using fp8 = Float8E4M3FNType;
+ using bf8 = Float8E5M2Type;
+ using fp6 = Float6E2M3FNType;
+ using bf6 = Float6E3M2FNType;
+ Location loc = op.getLoc();
+ if (chipset != kGfx1250) {
+ return rewriter.notifyMatchFailure(
+ loc,
+ "Scaled fp packed conversion instructions are not available on target "
+ "architecture and their emulation is not implemented");
+ }
+ // Convert user-facing firstScaleLane (0 or 16) to the half of the wave that
+ // is being selected.
+ int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
+ int32_t firstScaleByte = op.getFirstScaleByte();
+ int32_t blockSize = op.getBlockSize();
+ auto sourceType = cast<VectorType>(op.getSource().getType());
+ auto srcElemType = cast<FloatType>(sourceType.getElementType());
+ unsigned bitWidth = srcElemType.getWidth();
+
+ auto targetType = cast<VectorType>(op.getResult().getType());
+ auto destElemType = cast<FloatType>(targetType.getElementType());
+
+ IntegerType i32 = rewriter.getI32Type();
+ Value source = adaptor.getSource();
+ Type llvmResultType = typeConverter->convertType(op.getResult().getType());
+ Type packedType = nullptr;
+ if (isa<fp4>(srcElemType)) {
+ packedType = i32;
+ packedType = getTypeConverter()->convertType(packedType);
+ } else if (isa<fp8, bf8>(srcElemType)) {
+ packedType = VectorType::get(2, i32);
+ packedType = getTypeConverter()->convertType(packedType);
+ } else if (isa<fp6, bf6>(srcElemType)) {
+ packedType = VectorType::get(3, i32);
+ packedType = getTypeConverter()->convertType(packedType);
+ } else {
+ llvm_unreachable("invalid element type for packed scaled ext");
+ }
+
+ if (!packedType || !llvmResultType) {
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+ }
+
+ std::optional<StringRef> maybeIntrinsic =
+ scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
+ if (!maybeIntrinsic.has_value())
+ return op.emitOpError(
+ "no intrinsic matching packed scaled conversion on the given chipset");
+
+ int32_t scaleSel =
+ getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
+ Value castedScale =
+ LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
+ Value castedSource =
+ LLVM::BitcastOp::create(rewriter, loc, packedType, source);
+
+ OperationState loweredOp(loc, *maybeIntrinsic);
+ loweredOp.addTypes({llvmResultType});
+ loweredOp.addOperands({castedSource, castedScale});
+
+ SmallVector<NamedAttribute, 1> attrs;
+ attrs.push_back(
+ NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
+
+ loweredOp.addAttributes(attrs);
+ Operation *lowered = rewriter.create(loweredOp);
+ rewriter.replaceOp(op, lowered);
+
+ return success();
+}
+
LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -2073,6 +2275,441 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
}
};
+struct AMDGPUMakeDmaBaseLowering
+ : public ConvertOpToLLVMPattern<MakeDmaBaseOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<MakeDmaBaseOp>(converter), chipset(chipset) {}
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(MakeDmaBaseOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (chipset < kGfx1250)
+ return op->emitOpError("make_dma_base is only supported on gfx1250");
+
+ Location loc = op.getLoc();
+
+ ValueRange ldsIndices = adaptor.getLdsIndices();
+ Value lds = adaptor.getLds();
+ auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
+
+ Value ldsPtr =
+ getStridedElementPtr(rewriter, loc, ldsMemRefType, lds, ldsIndices);
+
+ ValueRange globalIndices = adaptor.getGlobalIndices();
+ Value global = adaptor.getGlobal();
+ auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
+
+ Value globalPtr = getStridedElementPtr(rewriter, loc, globalMemRefType,
+ global, globalIndices);
+
+ Type i32 = rewriter.getI32Type();
+ Type i64 = rewriter.getI64Type();
+
+ Value castForLdsAddr = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
+ Value castForGlobalAddr =
+ LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
+
+ Value lowHalf =
+ LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
+
+ Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
+ createI64Constant(rewriter, loc, 32));
+
+ Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
+
+ Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1);
+ Value validHighHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
+
+ Value typeField = createI32Constant(rewriter, loc, 2 << 30);
+ Value highHalfPlusType =
+ LLVM::OrOp::create(rewriter, loc, validHighHalf, typeField);
+
+ Value c0 = createI32Constant(rewriter, loc, 0);
+ Value c1 = createI32Constant(rewriter, loc, 1);
+ Value c2 = createI32Constant(rewriter, loc, 2);
+ Value c3 = createI32Constant(rewriter, loc, 3);
+
+ Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
+ assert(v4i32 && "expected type conversion to succeed");
+ Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
+ result = LLVM::InsertElementOp::create(rewriter, loc, result, c1, c0);
+ result = LLVM::InsertElementOp::create(rewriter, loc, result,
+ castForLdsAddr, c1);
+ result = LLVM::InsertElementOp::create(rewriter, loc, result, lowHalf, c2);
+ result = LLVM::InsertElementOp::create(rewriter, loc, result,
+ highHalfPlusType, c3);
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+struct AMDGPUMakeDmaDescriptorLowering
+ : public ConvertOpToLLVMPattern<MakeDmaDescriptorOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ AMDGPUMakeDmaDescriptorLowering(const LLVMTypeConverter &converter,
+ Chipset chipset)
+ : ConvertOpToLLVMPattern<MakeDmaDescriptorOp>(converter),
+ chipset(chipset) {}
+ Chipset chipset;
+
+ Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); }
+
+ Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
+ Value accumulator, Value value, int64_t shift) const {
+ shift = shift % 32;
+ Value shiftAmount;
+ if (shift != 0) {
+ shiftAmount = createI32Constant(rewriter, loc, shift % 32);
+ value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
+ }
+
+ if (matchPattern(accumulator, mlir::m_Zero()))
+ return value;
+
+ return LLVM::OrOp::create(rewriter, loc, accumulator, value);
+ }
+
+ Value setWorkgroupMask(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0) const {
+ Value mask = op.getWorkgroupMask();
+ if (!mask)
+ return sgpr0;
+
+ Type i32 = rewriter.getI32Type();
+ Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask);
+ return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
+ }
+
+ Value setDataSize(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0, ArrayRef<Value> consts) const {
+ // Compute data_size.
+ unsigned elementTypeWidthInBits = op.getElementTypeWidth();
+ assert(
+ llvm::is_contained<unsigned>({8, 16, 32, 64}, elementTypeWidthInBits) &&
+ "expected type width to be 8, 16, 32, or 64.");
+ int64_t dataSize = llvm::Log2_32(elementTypeWidthInBits / 8);
+ Value size = createI32Constant(rewriter, loc, dataSize);
+ return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
+ }
+
+ Value setAtomicBarrier(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0, ArrayRef<Value> consts) const {
+ bool atomic_barrier_enable = adaptor.getAtomicBarrierAddress() != nullptr;
+ if (!atomic_barrier_enable)
+ return sgpr0;
+
+ return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
+ }
+
+ Value setIterateEnable(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0, ArrayRef<Value> consts) const {
+ bool iterate_enable = adaptor.getGlobalIncrement() != nullptr;
+ if (!iterate_enable)
+ return sgpr0;
+
+ // TODO: In future PR, add other required fields for iteration.
+ return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
+ }
+
+ Value setPadEnable(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0, ArrayRef<Value> consts) const {
+ bool pad_enable = op.getPadAmount() != nullptr;
+ if (!pad_enable)
+ return sgpr0;
+
+ return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
+ }
+
+ Value setEarlyTimeout(MakeDmaDescriptorOp op, OpAdaptor adaptorm,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0, ArrayRef<Value> consts) const {
+ if (!op.getWorkgroupMask())
+ return sgpr0;
+
+ return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
+ }
+
+ Value setPadInterval(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0, ArrayRef<Value> consts) const {
+ bool pad_enable = op.getPadAmount() != nullptr;
+ if (!pad_enable)
+ return sgpr0;
+
+ IntegerType i32 = rewriter.getI32Type();
+ Value padInterval = adaptor.getPadInterval();
+ // pre-condition: padInterval can be a power of two between 2 and 256.
+ padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
+ padInterval, false);
+ padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
+ // post-condition: padInterval can be a value between 0 and 7.
+ return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
+ }
+
+ Value setPadAmount(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0, ArrayRef<Value> consts) const {
+ bool pad_enable = op.getPadAmount() != nullptr;
+ if (!pad_enable)
+ return sgpr0;
+
+ Value padAmount = adaptor.getPadAmount();
+ // pre-condition: padAmount is a value between 1-128.
+ padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
+ // post-condition: padAmount is a value between 0-127.
+ return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
+ }
+
+ Value setAtomicBarrierAddress(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter,
+ Location loc, Value sgpr1,
+ ArrayRef<Value> consts) const {
+ bool atomic_barrier_enable = adaptor.getAtomicBarrierAddress() != nullptr;
+ if (!atomic_barrier_enable)
+ return sgpr1;
+
+ Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
+ auto barrierAddressTy =
+ cast<MemRefType>(op.getAtomicBarrierAddress().getType());
+ ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
+ atomicBarrierAddress =
+ getStridedElementPtr(rewriter, loc, barrierAddressTy,
+ atomicBarrierAddress, atomicBarrierIndices);
+ IntegerType i32 = rewriter.getI32Type();
+ // pre-condition: atomicBarrierAddress is aligned to 8 bytes which implies
+ // that the 3 LSBs are zero.
+ atomicBarrierAddress =
+ LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
+ atomicBarrierAddress =
+ LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
+ Value mask = createI32Constant(rewriter, loc, 0xFFFF);
+ atomicBarrierAddress =
+ LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
+ return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
+ }
+
+ std::pair<Value, Value> setTensorDim0(MakeDmaDescriptorOp op,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter,
+ Location loc, Value sgpr1, Value sgpr2,
+ ArrayRef<Value> consts) const {
+ SmallVector<OpFoldResult> mixedGlobalSizes = op.getMixedGlobalSizes();
+ OpFoldResult tensorDim0OpFoldResult = mixedGlobalSizes.back();
+ Value tensorDim0;
+ if (auto attr = dyn_cast<Attribute>(tensorDim0OpFoldResult))
+ tensorDim0 =
+ createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
+ else
+ tensorDim0 = cast<Value>(tensorDim0OpFoldResult);
+
+ Value c16 = createI32Constant(rewriter, loc, 16);
+ Value tensorDim0High = LLVM::LShrOp::create(rewriter, loc, tensorDim0, c16);
+ sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDim0, 48);
+ sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDim0High, 48 + 16);
+ return {sgpr1, sgpr2};
+ }
+
+ std::pair<Value, Value> setTensorDim1(MakeDmaDescriptorOp op,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter,
+ Location loc, Value sgpr2, Value sgpr3,
+ ArrayRef<Value> consts) const {
+ // TODO: Generalize to setTensorDimX.
+ SmallVector<OpFoldResult> mixedGlobalSizes = op.getMixedGlobalSizes();
+ OpFoldResult tensorDim1OpFoldResult = *(mixedGlobalSizes.rbegin() + 1);
+ Value tensorDim1;
+ if (auto attr = dyn_cast<Attribute>(tensorDim1OpFoldResult))
+ tensorDim1 =
+ createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
+ else
+ tensorDim1 = cast<Value>(tensorDim1OpFoldResult);
+
+ Value c16 = createI32Constant(rewriter, loc, 16);
+ Value tensorDim1High = LLVM::LShrOp::create(rewriter, loc, tensorDim1, c16);
+ sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDim1, 80);
+ sgpr3 = setValueAtOffset(rewriter, loc, sgpr3, tensorDim1High, 80 + 16);
+ return {sgpr2, sgpr3};
+ }
+
+ Value setTileDimX(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr, ArrayRef<Value> consts, size_t dimX,
+ int64_t offset) const {
+ SmallVector<OpFoldResult> mixedSharedSizes = op.getMixedSharedSizes();
+
+ if (mixedSharedSizes.size() <= dimX)
+ return sgpr;
+
+ OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
+ Value tileDimX;
+ if (auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult))
+ tileDimX =
+ createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
+ else
+ tileDimX = cast<Value>(tileDimXOpFoldResult);
+
+ return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
+ }
+
+ Value setTileDim0(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr3, ArrayRef<Value> consts) const {
+ return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
+ }
+
+ Value setTileDim1(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr4, ArrayRef<Value> consts) const {
+ return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
+ }
+
+ Value setTileDim2(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr4, ArrayRef<Value> consts) const {
+ return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
+ }
+
+ std::pair<Value, Value>
+ setTensorDimXStride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgprY, Value sgprZ, ArrayRef<Value> consts,
+ size_t dimX, int64_t offset) const {
+ SmallVector<OpFoldResult> mixedGlobalStrides = op.getMixedGlobalStrides();
+
+ if (mixedGlobalStrides.size() <= dimX)
+ return {sgprY, sgprZ};
+
+ OpFoldResult tensorDimXStrideOpFoldResult =
+ *(mixedGlobalStrides.rbegin() + dimX);
+ Value tensorDimXStride;
+ if (auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
+ tensorDimXStride =
+ createI64Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
+ else
+ tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
+
+ constexpr int64_t first48bits = (1ll << 48) - 1;
+ Value mask = createI64Constant(rewriter, loc, first48bits);
+ tensorDimXStride =
+ LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
+ IntegerType i32 = rewriter.getI32Type();
+ Value tensorDimXStrideLow =
+ LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
+
+ int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
+ Value shiftVal = createI64Constant(rewriter, loc, shift);
+ Value tensorDimXStrideHigh =
+ LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
+ tensorDimXStrideHigh =
+ LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
+
+ sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
+ sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
+ offset + shift);
+ return {sgprY, sgprZ};
+ }
+
+ std::pair<Value, Value>
+ setTensorDim0Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
+ return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
+ 0, 160);
+ }
+
+ std::pair<Value, Value>
+ setTensorDim1Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
+ return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
+ 1, 208);
+ }
+
+ Value getDGroup1(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ ArrayRef<Value> consts) const {
+ Value sgprs[8];
+ for (int64_t i = 0; i < 8; i++) {
+ sgprs[i] = consts[0];
+ }
+
+ sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]);
+ sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
+ sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
+ sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
+ sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
+ sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts);
+ sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
+ sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
+
+ sgprs[1] =
+ setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
+ std::tie(sgprs[1], sgprs[2]) =
+ setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
+ std::tie(sgprs[2], sgprs[3]) =
+ setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
+
+ sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
+ sgprs[4] = setTileDim1(op, adaptor, rewriter, loc, sgprs[4], consts);
+ sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
+ std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
+ op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
+ std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
+ op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
+
+ IntegerType i32 = rewriter.getI32Type();
+ Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
+ assert(v8i32 && "expected type conversion to succeed");
+ Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
+
+ for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
+ dgroup1 =
+ LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
+ }
+
+ return dgroup1;
+ }
+
+ LogicalResult
+ matchAndRewrite(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (chipset < kGfx1250)
+ return op->emitOpError(
+ "make_dma_descriptor is only supported on gfx1250");
+
+ if (op.getRank() > 2)
+ return op->emitOpError("unimplemented");
+
+ Location loc = op.getLoc();
+
+ IntegerType i32 = rewriter.getI32Type();
+ [[maybe_unused]] Type v4i32 =
+ this->typeConverter->convertType(VectorType::get(4, i32));
+ assert(v4i32 && "expected type conversion to succeed");
+
+ SmallVector<Value> consts;
+ for (int64_t i = 0; i < 8; i++)
+ consts.push_back(createI32Constant(rewriter, loc, i));
+
+ Value dgroup0 = this->getDGroup0(adaptor);
+ Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
+
+ SmallVector<Value> results = {dgroup0, dgroup1};
+ rewriter.replaceOpWithMultiple(op, {results});
+ return success();
+ }
+};
+
struct ConvertAMDGPUToROCDLPass
: public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
using Base::Base;
@@ -2087,6 +2724,11 @@ struct ConvertAMDGPUToROCDLPass
RewritePatternSet patterns(ctx);
LLVMTypeConverter converter(ctx);
+ converter.addConversion([&](TDMBaseType type) -> Type {
+ Type i32 = IntegerType::get(type.getContext(), 32);
+ return converter.convertType(VectorType::get(4, i32));
+ });
+
populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
LLVMConversionTarget target(getContext());
target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
@@ -2122,25 +2764,27 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns,
Chipset chipset) {
populateAMDGPUMemorySpaceAttributeConversions(converter);
- patterns
- .add<FatRawBufferCastLowering,
- RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
- RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
- RawBufferOpLowering<RawBufferAtomicFaddOp,
- ROCDL::RawPtrBufferAtomicFaddOp>,
- RawBufferOpLowering<RawBufferAtomicFmaxOp,
- ROCDL::RawPtrBufferAtomicFmaxOp>,
- RawBufferOpLowering<RawBufferAtomicSmaxOp,
- ROCDL::RawPtrBufferAtomicSmaxOp>,
- RawBufferOpLowering<RawBufferAtomicUminOp,
- ROCDL::RawPtrBufferAtomicUminOp>,
- RawBufferOpLowering<RawBufferAtomicCmpswapOp,
- ROCDL::RawPtrBufferAtomicCmpSwap>,
- AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
- SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
- WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
- PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
- PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
- TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
+ patterns.add<
+ FatRawBufferCastLowering,
+ RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
+ RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
+ RawBufferOpLowering<RawBufferAtomicFaddOp,
+ ROCDL::RawPtrBufferAtomicFaddOp>,
+ RawBufferOpLowering<RawBufferAtomicFmaxOp,
+ ROCDL::RawPtrBufferAtomicFmaxOp>,
+ RawBufferOpLowering<RawBufferAtomicSmaxOp,
+ ROCDL::RawPtrBufferAtomicSmaxOp>,
+ RawBufferOpLowering<RawBufferAtomicUminOp,
+ ROCDL::RawPtrBufferAtomicUminOp>,
+ RawBufferOpLowering<RawBufferAtomicCmpswapOp,
+ ROCDL::RawPtrBufferAtomicCmpSwap>,
+ AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
+ SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
+ WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
+ ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
+ PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
+ GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering,
+ AMDGPUMakeDmaBaseLowering, AMDGPUMakeDmaDescriptorLowering>(converter,
+ chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
new file mode 100644
index 0000000..79816fc
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -0,0 +1,665 @@
+//===- ArithToAPFloat.cpp - Arithmetic to APFloat Conversion --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::func;
+
+static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
+ StringRef name, FunctionType funcT, bool setPrivate,
+ SymbolTableCollection *symbolTables = nullptr) {
+ OpBuilder::InsertionGuard g(b);
+ assert(!symTable->getRegion(0).empty() && "expected non-empty region");
+ b.setInsertionPointToStart(&symTable->getRegion(0).front());
+ FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT);
+ if (setPrivate)
+ funcOp.setPrivate();
+ if (symbolTables) {
+ SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
+ symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
+ }
+ return funcOp;
+}
+
+/// Helper function to look up or create the symbol for a runtime library
+/// function with the given parameter types. Returns an int64_t, unless a
+/// different result type is specified.
+static FailureOr<FuncOp>
+lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable,
+ StringRef name, TypeRange paramTypes,
+ SymbolTableCollection *symbolTables = nullptr,
+ Type resultType = {}) {
+ if (!resultType)
+ resultType = IntegerType::get(symTable->getContext(), 64);
+ std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
+ auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
+ FailureOr<FuncOp> func =
+ lookupFnDecl(symTable, funcName, funcT, symbolTables);
+ // Failed due to type mismatch.
+ if (failed(func))
+ return func;
+ // Successfully matched existing decl.
+ if (*func)
+ return *func;
+
+ return createFnDecl(b, symTable, funcName, funcT,
+ /*setPrivate=*/true, symbolTables);
+}
+
+/// Helper function to look up or create the symbol for a runtime library
+/// function for a binary arithmetic operation.
+///
+/// Parameter 1: APFloat semantics
+/// Parameter 2: Left-hand side operand
+/// Parameter 3: Right-hand side operand
+///
+/// This function will return a failure if the function is found but has an
+/// unexpected signature.
+///
+static FailureOr<FuncOp>
+lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
+ SymbolTableCollection *symbolTables = nullptr) {
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ return lookupOrCreateApFloatFn(b, symTable, name, {i32Type, i64Type, i64Type},
+ symbolTables);
+}
+
+static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy) {
+ int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+ return arith::ConstantOp::create(b, loc, b.getI32Type(),
+ b.getIntegerAttr(b.getI32Type(), sem));
+}
+
+/// Given two operands of vector type and vector result type (with the same
+/// shape), call the given function for each pair of scalar operands and
+/// package the result into a vector. If the given operands and result type are
+/// not vectors, call the function directly. The second operand is optional.
+template <typename Fn, typename... Values>
+static Value forEachScalarValue(RewriterBase &rewriter, Location loc,
+ Value operand1, Value operand2, Type resultType,
+ Fn fn) {
+ auto vecTy1 = dyn_cast<VectorType>(operand1.getType());
+ if (operand2) {
+ // Sanity check: Operand types must match.
+ assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) &&
+ "expected same vector types");
+ }
+ if (!vecTy1) {
+ // Not a vector. Call the function directly.
+ return fn(operand1, operand2, resultType);
+ }
+
+ // Prepare scalar operands.
+ ResultRange sclars1 =
+ vector::ToElementsOp::create(rewriter, loc, operand1)->getResults();
+ SmallVector<Value> scalars2;
+ if (!operand2) {
+ // No second operand. Create a vector of empty values.
+ scalars2.assign(vecTy1.getNumElements(), Value());
+ } else {
+ llvm::append_range(
+ scalars2,
+ vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
+ }
+
+ // Call the function for each pair of scalar operands.
+ auto resultVecType = cast<VectorType>(resultType);
+ SmallVector<Value> results;
+ for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) {
+ Value result = fn(scalar1, scalar2, resultVecType.getElementType());
+ results.push_back(result);
+ }
+
+ // Package the results into a vector.
+ return vector::FromElementsOp::create(
+ rewriter, loc,
+ vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
+ results);
+}
+
+/// Check preconditions for the conversion:
+/// 1. All operands / results must be integers or floats (or vectors thereof).
+/// 2. The bitwidth of the operands / results must be <= 64.
+static LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op) {
+ for (Value value : llvm::concat<Value>(op->getOperands(), op->getResults())) {
+ Type type = value.getType();
+ if (auto vecTy = dyn_cast<VectorType>(type)) {
+ type = vecTy.getElementType();
+ }
+ if (!type.isIntOrFloat()) {
+ return rewriter.notifyMatchFailure(
+ op, "only integers and floats (or vectors thereof) are supported");
+ }
+ if (type.getIntOrFloatBitWidth() > 64)
+ return rewriter.notifyMatchFailure(op,
+ "bitwidth > 64 bits is not supported");
+ }
+ return success();
+}
+
+/// Rewrite a binary arithmetic operation to an APFloat function call.
+template <typename OpTy>
+struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
+ BinaryArithOpToAPFloatConversion(MLIRContext *context,
+ const char *APFloatName,
+ SymbolOpInterface symTable,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
+ APFloatName(APFloatName) {};
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
+
+ // Get APFloat function from runtime library.
+ FailureOr<FuncOp> fn =
+ lookupOrCreateBinaryFn(rewriter, symTable, APFloatName);
+ if (failed(fn))
+ return fn;
+
+ // Scalarize and convert to APFloat runtime calls.
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getLhs(), op.getRhs(), op.getType(),
+ [&](Value lhs, Value rhs, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto floatTy = cast<FloatType>(resultType);
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ auto int64Type = rewriter.getI64Type();
+ Value lhsBits = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, lhs));
+ Value rhsBits = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, rhs));
+
+ // Call APFloat function.
+ Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, lhsBits, rhsBits};
+ auto resultOp = func::CallOp::create(rewriter, loc,
+ TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
+ resultOp->getResult(0));
+ return arith::BitcastOp::create(rewriter, loc, floatTy,
+ truncatedBits);
+ });
+ rewriter.replaceOp(op, repl);
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+ const char *APFloatName;
+};
+
+template <typename OpTy>
+struct FpToFpConversion final : OpRewritePattern<OpTy> {
+ FpToFpConversion(MLIRContext *context, SymbolOpInterface symTable,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<OpTy>(context, benefit), symTable(symTable) {}
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
+
+ // Get APFloat function from runtime library.
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ FailureOr<FuncOp> fn = lookupOrCreateApFloatFn(
+ rewriter, symTable, "convert", {i32Type, i32Type, i64Type});
+ if (failed(fn))
+ return fn;
+
+ // Scalarize and convert to APFloat runtime calls.
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand1, Value operand2, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto inFloatTy = cast<FloatType>(operand1.getType());
+ auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
+
+ // Call APFloat function.
+ Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+ auto outFloatTy = cast<FloatType>(resultType);
+ Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+ std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
+ auto resultOp = func::CallOp::create(rewriter, loc,
+ TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
+ Value truncatedBits = arith::TruncIOp::create(
+ rewriter, loc, outIntWType, resultOp->getResult(0));
+ return arith::BitcastOp::create(rewriter, loc, outFloatTy,
+ truncatedBits);
+ });
+ rewriter.replaceOp(op, repl);
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+};
+
+template <typename OpTy>
+struct FpToIntConversion final : OpRewritePattern<OpTy> {
+ FpToIntConversion(MLIRContext *context, SymbolOpInterface symTable,
+ bool isUnsigned, PatternBenefit benefit = 1)
+ : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
+ isUnsigned(isUnsigned) {}
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
+
+ // Get APFloat function from runtime library.
+ auto i1Type = IntegerType::get(symTable->getContext(), 1);
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ FailureOr<FuncOp> fn =
+ lookupOrCreateApFloatFn(rewriter, symTable, "convert_to_int",
+ {i32Type, i32Type, i1Type, i64Type});
+ if (failed(fn))
+ return fn;
+
+ // Scalarize and convert to APFloat runtime calls.
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand1, Value operand2, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto inFloatTy = cast<FloatType>(operand1.getType());
+ auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
+
+ // Call APFloat function.
+ Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+ auto outIntTy = cast<IntegerType>(resultType);
+ Value outWidthValue = arith::ConstantOp::create(
+ rewriter, loc, i32Type,
+ rewriter.getIntegerAttr(i32Type, outIntTy.getWidth()));
+ Value isUnsignedValue = arith::ConstantOp::create(
+ rewriter, loc, i1Type,
+ rewriter.getIntegerAttr(i1Type, isUnsigned));
+ SmallVector<Value> params = {inSemValue, outWidthValue,
+ isUnsignedValue, operandBits};
+ auto resultOp = func::CallOp::create(rewriter, loc,
+ TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ return arith::TruncIOp::create(rewriter, loc, outIntTy,
+ resultOp->getResult(0));
+ });
+ rewriter.replaceOp(op, repl);
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+ bool isUnsigned;
+};
+
+template <typename OpTy>
+struct IntToFpConversion final : OpRewritePattern<OpTy> {
+ IntToFpConversion(MLIRContext *context, SymbolOpInterface symTable,
+ bool isUnsigned, PatternBenefit benefit = 1)
+ : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
+ isUnsigned(isUnsigned) {}
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
+
+ // Get APFloat function from runtime library.
+ auto i1Type = IntegerType::get(symTable->getContext(), 1);
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ FailureOr<FuncOp> fn =
+ lookupOrCreateApFloatFn(rewriter, symTable, "convert_from_int",
+ {i32Type, i32Type, i1Type, i64Type});
+ if (failed(fn))
+ return fn;
+
+ // Scalarize and convert to APFloat runtime calls.
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand1, Value operand2, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto inIntTy = cast<IntegerType>(operand1.getType());
+ Value operandBits = operand1;
+ if (operandBits.getType().getIntOrFloatBitWidth() < 64) {
+ if (isUnsigned) {
+ operandBits =
+ arith::ExtUIOp::create(rewriter, loc, i64Type, operandBits);
+ } else {
+ operandBits =
+ arith::ExtSIOp::create(rewriter, loc, i64Type, operandBits);
+ }
+ }
+
+ // Call APFloat function.
+ auto outFloatTy = cast<FloatType>(resultType);
+ Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+ Value inWidthValue = arith::ConstantOp::create(
+ rewriter, loc, i32Type,
+ rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
+ Value isUnsignedValue = arith::ConstantOp::create(
+ rewriter, loc, i1Type,
+ rewriter.getIntegerAttr(i1Type, isUnsigned));
+ SmallVector<Value> params = {outSemValue, inWidthValue,
+ isUnsignedValue, operandBits};
+ auto resultOp = func::CallOp::create(rewriter, loc,
+ TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
+ Value truncatedBits = arith::TruncIOp::create(
+ rewriter, loc, outIntWType, resultOp->getResult(0));
+ return arith::BitcastOp::create(rewriter, loc, outFloatTy,
+ truncatedBits);
+ });
+ rewriter.replaceOp(op, repl);
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+ bool isUnsigned;
+};
+
+struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
+ CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<arith::CmpFOp>(context, benefit), symTable(symTable) {}
+
+ LogicalResult matchAndRewrite(arith::CmpFOp op,
+ PatternRewriter &rewriter) const override {
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
+
+ // Get APFloat function from runtime library.
+ auto i1Type = IntegerType::get(symTable->getContext(), 1);
+ auto i8Type = IntegerType::get(symTable->getContext(), 8);
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ FailureOr<FuncOp> fn =
+ lookupOrCreateApFloatFn(rewriter, symTable, "compare",
+ {i32Type, i64Type, i64Type}, nullptr, i8Type);
+ if (failed(fn))
+ return fn;
+
+ // Scalarize and convert to APFloat runtime calls.
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getLhs(), op.getRhs(), op.getType(),
+ [&](Value lhs, Value rhs, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto floatTy = cast<FloatType>(lhs.getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value lhsBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, lhs));
+ Value rhsBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, rhs));
+
+ // Call APFloat function.
+ Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, lhsBits, rhsBits};
+ Value comparisonResult =
+ func::CallOp::create(rewriter, loc, TypeRange(i8Type),
+ SymbolRefAttr::get(*fn), params)
+ ->getResult(0);
+
+ // Generate an i1 SSA value that is "true" if the comparison result
+ // matches the given `val`.
+ auto checkResult = [&](llvm::APFloat::cmpResult val) {
+ return arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::eq, comparisonResult,
+ arith::ConstantOp::create(
+ rewriter, loc, i8Type,
+ rewriter.getIntegerAttr(i8Type, static_cast<int8_t>(val)))
+ .getResult());
+ };
+ // Generate an i1 SSA value that is "true" if the comparison result
+ // matches any of the given `vals`.
+ std::function<Value(ArrayRef<llvm::APFloat::cmpResult>)>
+ checkResults = [&](ArrayRef<llvm::APFloat::cmpResult> vals) {
+ Value first = checkResult(vals.front());
+ if (vals.size() == 1)
+ return first;
+ Value rest = checkResults(vals.drop_front());
+ return arith::OrIOp::create(rewriter, loc, first, rest)
+ .getResult();
+ };
+
+ // This switch-case statement was taken from arith::applyCmpPredicate.
+ Value result;
+ switch (op.getPredicate()) {
+ case arith::CmpFPredicate::AlwaysFalse:
+ result =
+ arith::ConstantOp::create(rewriter, loc, i1Type,
+ rewriter.getIntegerAttr(i1Type, 0))
+ .getResult();
+ break;
+ case arith::CmpFPredicate::OEQ:
+ result = checkResult(llvm::APFloat::cmpEqual);
+ break;
+ case arith::CmpFPredicate::OGT:
+ result = checkResult(llvm::APFloat::cmpGreaterThan);
+ break;
+ case arith::CmpFPredicate::OGE:
+ result = checkResults(
+ {llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
+ break;
+ case arith::CmpFPredicate::OLT:
+ result = checkResult(llvm::APFloat::cmpLessThan);
+ break;
+ case arith::CmpFPredicate::OLE:
+ result = checkResults(
+ {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual});
+ break;
+ case arith::CmpFPredicate::ONE:
+ // Not cmpUnordered and not cmpUnordered.
+ result = checkResults(
+ {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan});
+ break;
+ case arith::CmpFPredicate::ORD:
+ // Not cmpUnordered.
+ result = checkResults({llvm::APFloat::cmpLessThan,
+ llvm::APFloat::cmpGreaterThan,
+ llvm::APFloat::cmpEqual});
+ break;
+ case arith::CmpFPredicate::UEQ:
+ result = checkResults(
+ {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual});
+ break;
+ case arith::CmpFPredicate::UGT:
+ result = checkResults(
+ {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan});
+ break;
+ case arith::CmpFPredicate::UGE:
+ result = checkResults({llvm::APFloat::cmpUnordered,
+ llvm::APFloat::cmpGreaterThan,
+ llvm::APFloat::cmpEqual});
+ break;
+ case arith::CmpFPredicate::ULT:
+ result = checkResults(
+ {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan});
+ break;
+ case arith::CmpFPredicate::ULE:
+ result = checkResults({llvm::APFloat::cmpUnordered,
+ llvm::APFloat::cmpLessThan,
+ llvm::APFloat::cmpEqual});
+ break;
+ case arith::CmpFPredicate::UNE:
+ // Not cmpEqual.
+ result = checkResults({llvm::APFloat::cmpLessThan,
+ llvm::APFloat::cmpGreaterThan,
+ llvm::APFloat::cmpUnordered});
+ break;
+ case arith::CmpFPredicate::UNO:
+ result = checkResult(llvm::APFloat::cmpUnordered);
+ break;
+ case arith::CmpFPredicate::AlwaysTrue:
+ result =
+ arith::ConstantOp::create(rewriter, loc, i1Type,
+ rewriter.getIntegerAttr(i1Type, 1))
+ .getResult();
+ break;
+ }
+ return result;
+ });
+ rewriter.replaceOp(op, repl);
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+};
+
+struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
+ NegFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<arith::NegFOp>(context, benefit), symTable(symTable) {}
+
+ LogicalResult matchAndRewrite(arith::NegFOp op,
+ PatternRewriter &rewriter) const override {
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
+
+ // Get APFloat function from runtime library.
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ FailureOr<FuncOp> fn =
+ lookupOrCreateApFloatFn(rewriter, symTable, "neg", {i32Type, i64Type});
+ if (failed(fn))
+ return fn;
+
+ // Scalarize and convert to APFloat runtime calls.
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand1, Value operand2, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto floatTy = cast<FloatType>(operand1.getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, operand1));
+
+ // Call APFloat function.
+ Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, operandBits};
+ Value negatedBits =
+ func::CallOp::create(rewriter, loc, TypeRange(i64Type),
+ SymbolRefAttr::get(*fn), params)
+ ->getResult(0);
+
+ // Truncate result to the original width.
+ Value truncatedBits =
+ arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
+ return arith::BitcastOp::create(rewriter, loc, floatTy,
+ truncatedBits);
+ });
+ rewriter.replaceOp(op, repl);
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+};
+
+namespace {
+struct ArithToAPFloatConversionPass final
+ : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
+ using Base::Base;
+
+ void runOnOperation() override;
+};
+
+void ArithToAPFloatConversionPass::runOnOperation() {
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+ patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp>>(context, "add",
+ getOperation());
+ patterns.add<BinaryArithOpToAPFloatConversion<arith::SubFOp>>(
+ context, "subtract", getOperation());
+ patterns.add<BinaryArithOpToAPFloatConversion<arith::MulFOp>>(
+ context, "multiply", getOperation());
+ patterns.add<BinaryArithOpToAPFloatConversion<arith::DivFOp>>(
+ context, "divide", getOperation());
+ patterns.add<BinaryArithOpToAPFloatConversion<arith::RemFOp>>(
+ context, "remainder", getOperation());
+ patterns.add<BinaryArithOpToAPFloatConversion<arith::MinNumFOp>>(
+ context, "minnum", getOperation());
+ patterns.add<BinaryArithOpToAPFloatConversion<arith::MaxNumFOp>>(
+ context, "maxnum", getOperation());
+ patterns.add<BinaryArithOpToAPFloatConversion<arith::MinimumFOp>>(
+ context, "minimum", getOperation());
+ patterns.add<BinaryArithOpToAPFloatConversion<arith::MaximumFOp>>(
+ context, "maximum", getOperation());
+ patterns
+ .add<FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>,
+ CmpFOpToAPFloatConversion, NegFOpToAPFloatConversion>(
+ context, getOperation());
+ patterns.add<FpToIntConversion<arith::FPToSIOp>>(context, getOperation(),
+ /*isUnsigned=*/false);
+ patterns.add<FpToIntConversion<arith::FPToUIOp>>(context, getOperation(),
+ /*isUnsigned=*/true);
+ patterns.add<IntToFpConversion<arith::SIToFPOp>>(context, getOperation(),
+ /*isUnsigned=*/false);
+ patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
+ /*isUnsigned=*/true);
+ LogicalResult result = success();
+ ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
+ if (diag.getSeverity() == DiagnosticSeverity::Error) {
+ result = failure();
+ }
+ // NB: if you don't return failure, no other diag handlers will fire (see
+ // mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit).
+ return failure();
+ });
+ walkAndApplyPatterns(getOperation(), std::move(patterns));
+ if (failed(result))
+ return signalPassFailure();
+}
+} // namespace
diff --git a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
new file mode 100644
index 0000000..31fce7a
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRArithToAPFloat
+ ArithToAPFloat.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRArithDialect
+ MLIRArithTransforms
+ MLIRFuncDialect
+ MLIRFuncUtils
+ MLIRVectorDialect
+ )
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index b609990..220826d 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -14,6 +14,7 @@
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/TypeUtilities.h"
@@ -280,6 +281,7 @@ ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
adaptor.getOperands(), op->getAttrs(),
+ /*propAttr=*/Attribute{},
*getTypeConverter(), rewriter);
}
@@ -481,6 +483,10 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
LogicalResult
CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
+ if (LLVM::detail::isUnsupportedFloatingPointType(*this->getTypeConverter(),
+ op.getLhs().getType()))
+ return rewriter.notifyMatchFailure(op, "unsupported floating point type");
+
Type operandType = adaptor.getLhs().getType();
Type resultType = op.getResult().getType();
LLVM::FastmathFlags fmf =
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index bebf1b8..613dc6d 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard)
add_subdirectory(AMDGPUToROCDL)
add_subdirectory(ArithCommon)
add_subdirectory(ArithToAMDGPU)
+add_subdirectory(ArithToAPFloat)
add_subdirectory(ArithToArmSME)
add_subdirectory(ArithToEmitC)
add_subdirectory(ArithToLLVM)
diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
index 86d02e6..6a0c211 100644
--- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
+++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
@@ -96,7 +96,8 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
ConversionPatternRewriter &rewriter) const override {
return LLVM::detail::oneToOneRewrite(
op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
- op->getAttrs(), *getTypeConverter(), rewriter);
+ op->getAttrs(), /*propAttr=*/Attribute{}, *getTypeConverter(),
+ rewriter);
}
};
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 93fe2ed..2220f61 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -374,9 +374,12 @@ FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp(
// Create a memory effect attribute corresponding to readnone.
if (funcOp->hasAttr(readnoneAttrName)) {
auto memoryAttr = LLVM::MemoryEffectsAttr::get(
- rewriter.getContext(),
- {LLVM::ModRefInfo::NoModRef, LLVM::ModRefInfo::NoModRef,
- LLVM::ModRefInfo::NoModRef});
+ rewriter.getContext(), {/*other=*/LLVM::ModRefInfo::NoModRef,
+ /*argMem=*/LLVM::ModRefInfo::NoModRef,
+ /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
+ /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
+ /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
+ /*targetMem1=*/LLVM::ModRefInfo::NoModRef});
newFuncOp.setMemoryEffectsAttr(memoryAttr);
}
diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
index 425594b..f143a9e 100644
--- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
+++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
@@ -66,7 +66,10 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
constexpr auto noModRef = mlir::LLVM::ModRefInfo::NoModRef;
auto memAttr = b.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/noModRef,
- /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
+ /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef,
+ /*errnoMem=*/noModRef,
+ /*targetMem0=*/noModRef,
+ /*targetMem1=*/noModRef);
func.setMemoryEffectsAttr(memAttr);
}
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index d64c4d6..5848489 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -419,7 +419,10 @@ struct LowerGpuOpsToNVVMOpsPass final
if (this->hasRedux)
populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns);
configureGpuToNVVMConversionLegality(target);
- if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
+ ConversionConfig config;
+ config.allowPatternRollback = allowPatternRollback;
+ if (failed(
+ applyPartialConversion(m, target, std::move(llvmPatterns), config)))
signalPassFailure();
}
};
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 99c059c..6254de8 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Types.h"
using namespace mlir;
@@ -57,7 +58,8 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
if (type.getElementType().isF32())
return type.getOperand() == "COp" ? NVVM::MMATypes::f32
: NVVM::MMATypes::tf32;
-
+ if (type.getElementType().isF64())
+ return NVVM::MMATypes::f64;
if (type.getElementType().isSignedInteger(8))
return NVVM::MMATypes::s8;
if (type.getElementType().isUnsignedInteger(8))
@@ -212,8 +214,13 @@ struct WmmaMmaOpToNVVMLowering
// then passed on to the intrinsic call. Emit llvm ops to extract individual
// values form lowered memrefs.
SmallVector<Value> unpackedOps;
-
auto unpackOp = [&](Value operand) {
+ // f64 a and b fragments are not structs but scalars.
+ if (!isa<LLVM::LLVMStructType>(operand.getType())) {
+ unpackedOps.push_back(operand);
+ return;
+ }
+ // every other type is lowered to an LLVM struct, extract the values.
auto structType = cast<LLVM::LLVMStructType>(operand.getType());
for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
Value toUse = LLVM::ExtractValueOp::create(rewriter, loc, operand, i);
@@ -276,10 +283,16 @@ struct WmmaConstantOpToNVVMLowering
return failure();
Location loc = subgroupMmaConstantOp.getLoc();
Value cst = adaptor.getOperands()[0];
- LLVM::LLVMStructType type = convertMMAToLLVMType(
+ Type type = convertMMAToLLVMType(
cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
+ // If the element is not a struct, it means it's a scalar f64.
+ auto structType = dyn_cast<LLVM::LLVMStructType>(type);
+ if (!structType) {
+ rewriter.replaceOp(subgroupMmaConstantOp, cst);
+ return success();
+ }
// If the element type is a vector create a vector from the operand.
- if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) {
+ if (auto vecType = dyn_cast<VectorType>(structType.getBody()[0])) {
Value vecCst = LLVM::PoisonOp::create(rewriter, loc, vecType);
for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
Value idx = LLVM::ConstantOp::create(rewriter, loc,
@@ -289,8 +302,8 @@ struct WmmaConstantOpToNVVMLowering
}
cst = vecCst;
}
- Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, type);
- for (size_t i : llvm::seq(size_t(0), type.getBody().size())) {
+ Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structType);
+ for (size_t i : llvm::seq(size_t(0), structType.getBody().size())) {
matrixStruct =
LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, cst, i);
}
@@ -354,10 +367,24 @@ struct WmmaElementwiseOpToNVVMLowering
return failure();
Location loc = subgroupMmaElementwiseOp.getLoc();
size_t numOperands = adaptor.getOperands().size();
- LLVM::LLVMStructType destType = convertMMAToLLVMType(
+ Type destType = convertMMAToLLVMType(
cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
- Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, destType);
- for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
+
+ // If the element is not a struct, it means it's a scalar f64.
+ LLVM::LLVMStructType structDestTy =
+ dyn_cast<LLVM::LLVMStructType>(destType);
+ if (!structDestTy) {
+ SmallVector<Value> operands;
+ for (auto operand : adaptor.getOperands()) {
+ operands.push_back(operand);
+ }
+ Value element = createScalarOp(
+ rewriter, loc, subgroupMmaElementwiseOp.getOpType(), operands);
+ rewriter.replaceOp(subgroupMmaElementwiseOp, element);
+ return success();
+ }
+ Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structDestTy);
+ for (size_t i = 0, e = structDestTy.getBody().size(); i < e; ++i) {
SmallVector<Value> extractedOperands;
for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
extractedOperands.push_back(LLVM::ExtractValueOp::create(
@@ -377,13 +404,18 @@ struct WmmaElementwiseOpToNVVMLowering
} // namespace
/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
-LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
+Type mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
NVVM::MMAFrag frag = convertOperand(type.getOperand());
NVVM::MMATypes eltType = getElementType(type);
auto nRow = type.getShape()[0];
auto nCol = type.getShape()[1];
std::pair<Type, unsigned> typeInfo =
NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext());
+ // Special handling for f64 a and b fragments
+ Type f64Ty = Float64Type::get(type.getContext());
+ if (typeInfo.first == f64Ty && typeInfo.second == 1) {
+ return f64Ty;
+ }
return LLVM::LLVMStructType::getLiteral(
type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
}
diff --git a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
index bc2f2f2..d4b4c46 100644
--- a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
+++ b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
@@ -107,16 +107,16 @@ struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> {
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value n = adaptor.getLhs();
- Type n_type = n.getType();
+ Type nType = n.getType();
Value m = adaptor.getRhs();
// Define the constants
- Value zero = spirv::ConstantOp::create(rewriter, loc, n_type,
- IntegerAttr::get(n_type, 0));
- Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type,
- IntegerAttr::get(n_type, 1));
- Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type,
- IntegerAttr::get(n_type, -1));
+ Value zero = spirv::ConstantOp::create(rewriter, loc, nType,
+ IntegerAttr::get(nType, 0));
+ Value posOne = spirv::ConstantOp::create(rewriter, loc, nType,
+ IntegerAttr::get(nType, 1));
+ Value negOne = spirv::ConstantOp::create(rewriter, loc, nType,
+ IntegerAttr::get(nType, -1));
// Compute `x`.
Value mPos = spirv::SGreaterThanOp::create(rewriter, loc, m, zero);
@@ -157,14 +157,14 @@ struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> {
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value n = adaptor.getLhs();
- Type n_type = n.getType();
+ Type nType = n.getType();
Value m = adaptor.getRhs();
// Define the constants
- Value zero = spirv::ConstantOp::create(rewriter, loc, n_type,
- IntegerAttr::get(n_type, 0));
- Value one = spirv::ConstantOp::create(rewriter, loc, n_type,
- IntegerAttr::get(n_type, 1));
+ Value zero = spirv::ConstantOp::create(rewriter, loc, nType,
+ IntegerAttr::get(nType, 0));
+ Value one = spirv::ConstantOp::create(rewriter, loc, nType,
+ IntegerAttr::get(nType, 1));
// Compute the non-zero result.
Value minusOne = spirv::ISubOp::create(rewriter, loc, n, one);
@@ -193,16 +193,16 @@ struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> {
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value n = adaptor.getLhs();
- Type n_type = n.getType();
+ Type nType = n.getType();
Value m = adaptor.getRhs();
// Define the constants
- Value zero = spirv::ConstantOp::create(rewriter, loc, n_type,
- IntegerAttr::get(n_type, 0));
- Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type,
- IntegerAttr::get(n_type, 1));
- Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type,
- IntegerAttr::get(n_type, -1));
+ Value zero = spirv::ConstantOp::create(rewriter, loc, nType,
+ IntegerAttr::get(nType, 0));
+ Value posOne = spirv::ConstantOp::create(rewriter, loc, nType,
+ IntegerAttr::get(nType, 1));
+ Value negOne = spirv::ConstantOp::create(rewriter, loc, nType,
+ IntegerAttr::get(nType, -1));
// Compute `x`.
Value mNeg = spirv::SLessThanOp::create(rewriter, loc, m, zero);
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 48a0319..f28a6cc 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -296,19 +296,13 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
// Detail methods
//===----------------------------------------------------------------------===//
-void LLVM::detail::setNativeProperties(Operation *op,
- IntegerOverflowFlags overflowFlags) {
- if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
- iface.setOverflowFlags(overflowFlags);
-}
-
/// Replaces the given operation "op" with a new operation of type "targetOp"
/// and given operands.
LogicalResult LLVM::detail::oneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
- ArrayRef<NamedAttribute> targetAttrs,
- const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
- IntegerOverflowFlags overflowFlags) {
+ ArrayRef<NamedAttribute> targetAttrs, Attribute propertiesAttr,
+ const LLVMTypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter) {
unsigned numResults = op->getNumResults();
SmallVector<Type> resultTypes;
@@ -320,11 +314,10 @@ LogicalResult LLVM::detail::oneToOneRewrite(
}
// Create the operation through state since we don't know its C++ type.
- Operation *newOp =
- rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
- resultTypes, targetAttrs);
-
- setNativeProperties(newOp, overflowFlags);
+ OperationState state(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
+ resultTypes, targetAttrs);
+ state.propertiesAttr = propertiesAttr;
+ Operation *newOp = rewriter.create(state);
// If the operation produced 0 or 1 result, return them immediately.
if (numResults == 0)
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index e7dd0b5..e5969c2 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -105,9 +105,9 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
LogicalResult LLVM::detail::vectorOneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
- ArrayRef<NamedAttribute> targetAttrs,
- const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
- IntegerOverflowFlags overflowFlags) {
+ ArrayRef<NamedAttribute> targetAttrs, Attribute propertiesAttr,
+ const LLVMTypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter) {
assert(!operands.empty());
// Cannot convert ops if their operands are not of LLVM type.
@@ -116,18 +116,38 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
auto llvmNDVectorTy = operands[0].getType();
if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy))
- return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter,
- rewriter, overflowFlags);
-
- auto callback = [op, targetOp, targetAttrs, overflowFlags,
+ return oneToOneRewrite(op, targetOp, operands, targetAttrs, propertiesAttr,
+ typeConverter, rewriter);
+ auto callback = [op, targetOp, targetAttrs, propertiesAttr,
&rewriter](Type llvm1DVectorTy, ValueRange operands) {
- Operation *newOp =
- rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp),
- operands, llvm1DVectorTy, targetAttrs);
- LLVM::detail::setNativeProperties(newOp, overflowFlags);
+ OperationState state(op->getLoc(), rewriter.getStringAttr(targetOp),
+ operands, llvm1DVectorTy, targetAttrs);
+ state.propertiesAttr = propertiesAttr;
+ Operation *newOp = rewriter.create(state);
return newOp->getResult(0);
};
return handleMultidimensionalVectors(op, operands, typeConverter, callback,
rewriter);
}
+
+/// Return the given type if it's a floating point type. If the given type is
+/// a vector type, return its element type if it's a floating point type.
+static FloatType getFloatingPointType(Type type) {
+ if (auto floatType = dyn_cast<FloatType>(type))
+ return floatType;
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return dyn_cast<FloatType>(vecType.getElementType());
+ return nullptr;
+}
+
+bool LLVM::detail::isUnsupportedFloatingPointType(
+ const TypeConverter &typeConverter, Type type) {
+ FloatType floatType = getFloatingPointType(type);
+ if (!floatType)
+ return false;
+ Type convertedType = typeConverter.convertType(floatType);
+ if (!convertedType)
+ return true;
+ return !isa<FloatType>(convertedType);
+}
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 16ef11a..59a16df 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -93,13 +93,13 @@ public:
/// Different MPI implementations have different communicator types.
/// Using i64 as a portable, intermediate type.
/// Appropriate cast needs to take place before calling MPI functions.
- virtual Value getCommWorld(const Location loc,
+ virtual Value getCommWorld(Location loc,
ConversionPatternRewriter &rewriter) = 0;
/// Type converter provides i64 type for communicator type.
/// Converts to native type, which might be ptr or int or whatever.
- virtual Value castComm(const Location loc,
- ConversionPatternRewriter &rewriter, Value comm) = 0;
+ virtual Value castComm(Location loc, ConversionPatternRewriter &rewriter,
+ Value comm) = 0;
/// Get the MPI_STATUS_IGNORE value (typically a pointer type).
virtual intptr_t getStatusIgnore() = 0;
@@ -109,13 +109,12 @@ public:
/// Gets or creates an MPI datatype as a value which corresponds to the given
/// type.
- virtual Value getDataType(const Location loc,
- ConversionPatternRewriter &rewriter, Type type) = 0;
+ virtual Value getDataType(Location loc, ConversionPatternRewriter &rewriter,
+ Type type) = 0;
/// Gets or creates an MPI_Op value which corresponds to the given
/// enum value.
- virtual Value getMPIOp(const Location loc,
- ConversionPatternRewriter &rewriter,
+ virtual Value getMPIOp(Location loc, ConversionPatternRewriter &rewriter,
mpi::MPI_ReductionOpEnum opAttr) = 0;
};
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 11f866c..0a382d8 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -122,7 +122,7 @@ static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
return totalSizeBytes.getResult();
}
-static emitc::ApplyOp
+static emitc::AddressOfOp
createPointerFromEmitcArray(Location loc, OpBuilder &builder,
TypedValue<emitc::ArrayType> arrayValue) {
@@ -133,9 +133,9 @@ createPointerFromEmitcArray(Location loc, OpBuilder &builder,
llvm::SmallVector<mlir::Value> indices(arrayType.getRank(), zeroIndex);
emitc::SubscriptOp subPtr =
emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices));
- emitc::ApplyOp ptr = emitc::ApplyOp::create(
+ emitc::AddressOfOp ptr = emitc::AddressOfOp::create(
builder, loc, emitc::PointerType::get(arrayType.getElementType()),
- builder.getStringAttr("&"), subPtr);
+ subPtr);
return ptr;
}
@@ -225,12 +225,12 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
auto srcArrayValue =
cast<TypedValue<emitc::ArrayType>>(operands.getSource());
- emitc::ApplyOp srcPtr =
+ emitc::AddressOfOp srcPtr =
createPointerFromEmitcArray(loc, rewriter, srcArrayValue);
auto targetArrayValue =
cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
- emitc::ApplyOp targetPtr =
+ emitc::AddressOfOp targetPtr =
createPointerFromEmitcArray(loc, rewriter, targetArrayValue);
emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create(
@@ -319,8 +319,8 @@ struct ConvertGetGlobal final
emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create(
rewriter, op.getLoc(), lvalueType, operands.getNameAttr());
emitc::PointerType pointerType = emitc::PointerType::get(resultTy);
- rewriter.replaceOpWithNewOp<emitc::ApplyOp>(
- op, pointerType, rewriter.getStringAttr("&"), globalLValue);
+ rewriter.replaceOpWithNewOp<emitc::AddressOfOp>(op, pointerType,
+ globalLValue);
return success();
}
rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 9348d3c1..64a7f56 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -922,15 +922,12 @@ struct NVGPUMBarrierArriveExpectTxLowering
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Value txcount = truncToI32(b, adaptor.getTxcount());
-
- if (isMbarrierShared(op.getBarriers().getType())) {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
- op, barrier, txcount, adaptor.getPredicate());
- return success();
- }
-
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
- op, barrier, txcount, adaptor.getPredicate());
+ op, Type{}, // return-value is optional and is void by default
+ barrier, txcount, // barrier and txcount
+ NVVM::MemScopeKind::CTA, // default scope is CTA
+ false, // relaxed-semantics is false
+ adaptor.getPredicate());
return success();
}
};
@@ -949,13 +946,6 @@ struct NVGPUMBarrierTryWaitParityLowering
Value ticks = truncToI32(b, adaptor.getTicks());
Value phase =
LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity());
-
- if (isMbarrierShared(op.getBarriers().getType())) {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
- op, barrier, phase, ticks);
- return success();
- }
-
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
phase, ticks);
return success();
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 021e31a..7fdc23a 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -66,6 +66,9 @@ struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
for (NamedAttribute attr : op->getAttrs()) {
if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
Type convertedType = converter->convertType(typeAttr.getValue());
+ if (!convertedType)
+ return rewriter.notifyMatchFailure(
+ op, "failed to convert type in attribute");
convertedAttrs.emplace_back(attr.getName(),
TypeAttr::get(convertedType));
} else {
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 37cfc9f..03842cc 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -36,6 +36,7 @@ namespace {
struct SCFToControlFlowPass
: public impl::SCFToControlFlowPassBase<SCFToControlFlowPass> {
+ using Base::Base;
void runOnOperation() override;
};
@@ -736,7 +737,9 @@ void SCFToControlFlowPass::runOnOperation() {
target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp,
scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
- if (failed(
- applyPartialConversion(getOperation(), target, std::move(patterns))))
+ ConversionConfig config;
+ config.allowPatternRollback = allowPatternRollback;
+ if (failed(applyPartialConversion(getOperation(), target, std::move(patterns),
+ config)))
signalPassFailure();
}
diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
index 76a822b..309121f 100644
--- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
+++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
@@ -453,10 +453,24 @@ static LogicalResult processParallelLoop(
1, 2,
rewriter.getAffineDimExpr(0) * rewriter.getAffineSymbolExpr(0) +
rewriter.getAffineSymbolExpr(1));
+ // Map through cloningMap first so we use values valid at the launch
+ // scope, then ensure they are launch-independent (or cloned constants).
+ Value mappedStep = cloningMap.lookupOrDefault(step);
+ Value mappedLowerBound = cloningMap.lookupOrDefault(lowerBound);
+
+ mappedStep = ensureLaunchIndependent(mappedStep);
+ mappedLowerBound = ensureLaunchIndependent(mappedLowerBound);
+
+ // If either cannot be made available above the launch, fail gracefully.
+ if (!mappedStep || !mappedLowerBound) {
+ return rewriter.notifyMatchFailure(
+ parallelOp, "lower bound / step must be constant or defined above "
+ "the gpu.launch");
+ }
+
newIndex = AffineApplyOp::create(
rewriter, loc, annotation.getMap().compose(lowerAndStep),
- ValueRange{operand, ensureLaunchIndependent(step),
- ensureLaunchIndependent(lowerBound)});
+ ValueRange{operand, mappedStep, mappedLowerBound});
// If there was also a bound, insert that, too.
// TODO: Check that we do not assign bounds twice.
if (annotation.getBound()) {
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 460595b..6423d49 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -188,7 +188,8 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
OpBuilder::InsertionGuard guard(builder);
Type type = reduce.getOperands()[reductionIndex].getType();
auto decl = omp::DeclareReductionOp::create(builder, reduce.getLoc(),
- "__scf_reduction", type);
+ "__scf_reduction", type,
+ /*byref_element_type=*/{});
symbolTable.insert(decl);
builder.createBlock(&decl.getInitializerRegion(),
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 50fca56..02b61bd 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1520,20 +1520,12 @@ public:
if (!dstType)
return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
- Location loc = tanOp.getLoc();
- Value sin = LLVM::SinOp::create(rewriter, loc, dstType, tanOp.getOperand());
- Value cos = LLVM::CosOp::create(rewriter, loc, dstType, tanOp.getOperand());
- rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
+ rewriter.replaceOpWithNewOp<LLVM::TanOp>(tanOp, dstType,
+ adaptor.getOperands());
return success();
}
};
-/// Convert `spirv.Tanh` to
-///
-/// exp(2x) - 1
-/// -----------
-/// exp(2x) + 1
-///
class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
public:
using SPIRVToLLVMConversion<spirv::GLTanhOp>::SPIRVToLLVMConversion;
@@ -1546,18 +1538,8 @@ public:
if (!dstType)
return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
- Location loc = tanhOp.getLoc();
- Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
- Value multiplied =
- LLVM::FMulOp::create(rewriter, loc, dstType, two, tanhOp.getOperand());
- Value exponential = LLVM::ExpOp::create(rewriter, loc, dstType, multiplied);
- Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
- Value numerator =
- LLVM::FSubOp::create(rewriter, loc, dstType, exponential, one);
- Value denominator =
- LLVM::FAddOp::create(rewriter, loc, dstType, exponential, one);
- rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
- denominator);
+ rewriter.replaceOpWithNewOp<LLVM::TanhOp>(tanhOp, dstType,
+ adaptor.getOperands());
return success();
}
};
diff --git a/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp b/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp
index 9921a06..feb0489 100644
--- a/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp
+++ b/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp
@@ -23,8 +23,11 @@ namespace mlir {
using namespace mlir;
-namespace {
+//===----------------------------------------------------------------------===//
+// PoisonOpLowering
+//===----------------------------------------------------------------------===//
+namespace {
struct PoisonOpLowering : public ConvertOpToLLVMPattern<ub::PoisonOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
@@ -32,13 +35,8 @@ struct PoisonOpLowering : public ConvertOpToLLVMPattern<ub::PoisonOp> {
matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
-
} // namespace
-//===----------------------------------------------------------------------===//
-// PoisonOpLowering
-//===----------------------------------------------------------------------===//
-
LogicalResult
PoisonOpLowering::matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -61,6 +59,29 @@ PoisonOpLowering::matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor,
}
//===----------------------------------------------------------------------===//
+// UnreachableOpLowering
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct UnreachableOpLowering
+ : public ConvertOpToLLVMPattern<ub::UnreachableOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(ub::UnreachableOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+LogicalResult
+
+UnreachableOpLowering::matchAndRewrite(
+ ub::UnreachableOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ rewriter.replaceOpWithNewOp<LLVM::UnreachableOp>(op);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
@@ -93,7 +114,7 @@ struct UBToLLVMConversionPass
void mlir::ub::populateUBToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
- patterns.add<PoisonOpLowering>(converter);
+ patterns.add<PoisonOpLowering, UnreachableOpLowering>(converter);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
index 244d214..3831387 100644
--- a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
+++ b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
@@ -40,6 +40,17 @@ struct PoisonOpLowering final : OpConversionPattern<ub::PoisonOp> {
}
};
+struct UnreachableOpLowering final : OpConversionPattern<ub::UnreachableOp> {
+ using Base::Base;
+
+ LogicalResult
+ matchAndRewrite(ub::UnreachableOp op, OpAdaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<spirv::UnreachableOp>(op);
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -75,5 +86,6 @@ struct UBToSPIRVConversionPass final
void mlir::ub::populateUBToSPIRVConversionPatterns(
const SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
- patterns.add<PoisonOpLowering>(converter, patterns.getContext());
+ patterns.add<PoisonOpLowering, UnreachableOpLowering>(converter,
+ patterns.getContext());
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 69a317ec..05d541f 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -345,7 +345,8 @@ public:
matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = scatter->getLoc();
- MemRefType memRefType = scatter.getMemRefType();
+ auto memRefType = dyn_cast<MemRefType>(scatter.getBaseType());
+ assert(memRefType && "The base should be bufferized");
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return rewriter.notifyMatchFailure(scatter, "memref type not supported");
@@ -1654,6 +1655,20 @@ private:
return failure();
}
}
+ } else if (auto floatTy = dyn_cast<FloatType>(printType)) {
+ // Print other floating-point types using the APFloat runtime library.
+ int32_t sem =
+ llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+ Value semValue = LLVM::ConstantOp::create(
+ rewriter, loc, rewriter.getI32Type(),
+ rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
+ Value floatBits =
+ LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), value);
+ printer =
+ LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables);
+ emitCall(rewriter, loc, printer.value(),
+ ValueRange({semValue, floatBits}));
+ return success();
} else {
return failure();
}
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 1b4d1a4..079e1e2 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -519,8 +519,13 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
return lowerToScatteredLoadOp(readOp, rewriter);
}
- // Perform common data transfer checks.
VectorType vecTy = readOp.getVectorType();
+
+ // Lower using load.gather in 1D case
+ if (vecTy.getRank() == 1 && !readOp.hasOutOfBoundsDim())
+ return lowerToScatteredLoadOp(readOp, rewriter);
+
+ // Perform common data transfer checks.
if (failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
return failure();
@@ -562,7 +567,8 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
/*packed=*/nullptr, transposeAttr,
/*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ /*l2_hint=*/hint, /*l3_hint=*/hint,
+ /*layout=*/nullptr);
rewriter.replaceOp(readOp, loadOp);
return success();
@@ -616,7 +622,8 @@ struct TransferWriteLowering
auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
ndDesc, indices,
/*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ /*l2_hint=*/hint, /*l3_hint=*/hint,
+ /*layout=*/nullptr);
rewriter.replaceOp(writeOp, storeOp);
return success();
@@ -720,7 +727,8 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
/*packed=*/nullptr, /*transpose=*/nullptr,
/*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ /*l2_hint=*/hint, /*l3_hint=*/hint,
+ /*layout=*/nullptr);
rewriter.replaceOp(loadOp, loadNdOp);
return success();
@@ -758,7 +766,8 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
auto storeNdOp =
xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices,
/*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ /*l2_hint=*/hint, /*l3_hint=*/hint,
+ /*layout=*/nullptr);
rewriter.replaceOp(storeOp, storeNdOp);
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index de552ce..0ecb50e 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -50,11 +50,10 @@ static constexpr int32_t executionSize{16};
// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
enum class NdTdescOffset : uint32_t {
- BasePtr = 0, // Base pointer (i64)
- BaseShapeW = 2, // Base shape width (i32)
- BaseShapeH = 3, // Base shape height (i32)
- TensorOffsetW = 4, // Tensor offset W (i32)
- TensorOffsetH = 5 // Tensor offset H (i32)
+ BasePtr = 0, // Base pointer (i64)
+ BaseShapeW = 2, // Base shape width (i32)
+ BaseShapeH = 3, // Base shape height (i32)
+ BasePitch = 4, // Base pitch (i32)
};
static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
@@ -151,6 +150,14 @@ translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
}
}
+//
+// Note:
+// Block operations for tile of sub byte element types are handled by
+// emulating with larger element types.
+// Tensor descriptor are keep intact and only ops consuming them are
+// emulated
+//
+
class CreateNdDescToXeVMPattern
: public OpConversionPattern<xegpu::CreateNdDescOp> {
using OpConversionPattern::OpConversionPattern;
@@ -179,16 +186,12 @@ class CreateNdDescToXeVMPattern
Value baseAddr;
Value baseShapeW;
Value baseShapeH;
- Value offsetW;
- Value offsetH;
// Source can be a memref or a pointer (ui64, ui32, i64 or i32).
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
+ SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
// Descriptor shape is expected to be 2D.
int64_t rank = mixedSizes.size();
- if (rank != 2)
- return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
-
auto sourceTy = source.getType();
auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
// If source is a memref, we need to extract the aligned pointer as index.
@@ -197,10 +200,20 @@ class CreateNdDescToXeVMPattern
if (!sourceMemrefTy.hasRank()) {
return rewriter.notifyMatchFailure(op, "Expected ranked Memref.");
}
- baseAddr =
- memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
+ // Access adaptor after failure check to avoid rolling back generated code
+ // for materialization cast.
+ baseAddr = adaptor.getSource();
} else {
baseAddr = adaptor.getSource();
+ if (baseAddr.getType() != i64Ty) {
+ // Pointer type may be i32. Cast to i64 if needed.
+ baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
+ }
+ }
+ // 1D tensor descriptor is just the base address.
+ if (rank == 1) {
+ rewriter.replaceOp(op, baseAddr);
+ return success();
}
// Utility for creating offset values from op fold result.
auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
@@ -209,19 +222,11 @@ class CreateNdDescToXeVMPattern
val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
return val;
};
- // Offsets are not supported (0 is used).
- offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
- offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
// Get shape values from op fold results.
baseShapeW = createOffset(mixedSizes, 1);
baseShapeH = createOffset(mixedSizes, 0);
- if (sourceMemrefTy) {
- // Cast index to i64.
- baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
- } else if (baseAddr.getType() != i64Ty) {
- // Pointer type may be i32. Cast to i64 if needed.
- baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
- }
+ // Get pitch value from op fold results.
+ Value basePitch = createOffset(mixedStrides, 0);
// Populate payload.
Value payLoadAsI64 =
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
@@ -235,12 +240,9 @@ class CreateNdDescToXeVMPattern
payload =
vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
static_cast<int>(NdTdescOffset::BaseShapeH));
- payload = vector::InsertOp::create(
- rewriter, loc, offsetW, payload,
- static_cast<int>(NdTdescOffset::TensorOffsetW));
- payload = vector::InsertOp::create(
- rewriter, loc, offsetH, payload,
- static_cast<int>(NdTdescOffset::TensorOffsetH));
+ payload =
+ vector::InsertOp::create(rewriter, loc, basePitch, payload,
+ static_cast<int>(NdTdescOffset::BasePitch));
rewriter.replaceOp(op, payload);
return success();
}
@@ -257,108 +259,240 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
ConversionPatternRewriter &rewriter) const override {
auto mixedOffsets = op.getMixedOffsets();
int64_t opOffsetsSize = mixedOffsets.size();
- if (opOffsetsSize != 2)
- return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
auto tdesc = adaptor.getTensorDesc();
auto tdescTy = op.getTensorDescType();
- if (tdescTy.getRank() != 2)
- return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
+ auto tileRank = tdescTy.getRank();
+ if (opOffsetsSize != tileRank)
+ return rewriter.notifyMatchFailure(
+ op, "Expected offset rank to match descriptor rank.");
auto elemType = tdescTy.getElementType();
auto elemBitSize = elemType.getIntOrFloatBitWidth();
- if (elemBitSize % 8 != 0)
+ bool isSubByte = elemBitSize < 8;
+ uint64_t wScaleFactor = 1;
+
+ if (!isSubByte && (elemBitSize % 8 != 0))
return rewriter.notifyMatchFailure(
op, "Expected element type bit width to be multiple of 8.");
+ auto tileW = tdescTy.getDimSize(tileRank - 1);
+ // For sub byte types, only 4bits are currently supported.
+ if (isSubByte) {
+ if (elemBitSize != 4)
+ return rewriter.notifyMatchFailure(
+ op, "Only sub byte types of 4bits are supported.");
+ if (tileRank != 2)
+ return rewriter.notifyMatchFailure(
+ op, "Sub byte types are only supported for 2D tensor descriptors.");
+ auto subByteFactor = 8 / elemBitSize;
+ auto tileH = tdescTy.getDimSize(0);
+ // Handle special case for packed load.
+ if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
+ if (op.getPacked().value_or(false)) {
+ // packed load is implemented as packed loads of 8bit elements.
+ if (tileH == systolicDepth * 4 &&
+ tileW == executionSize * subByteFactor) {
+ // Usage case for loading as Matrix B with pack request.
+ // source is assumed to pre-packed into 8bit elements
+ // Emulate with 8bit loads with pack request.
+ // scaled_tileW = executionSize
+ elemType = rewriter.getIntegerType(8);
+ tileW = executionSize;
+ wScaleFactor = subByteFactor;
+ }
+ }
+ }
+ // If not handled by packed load case above, handle other cases.
+ if (wScaleFactor == 1) {
+ auto sub16BitFactor = subByteFactor * 2;
+ if (tileW == executionSize * sub16BitFactor) {
+ // Usage case for loading as Matrix A operand
+ // Emulate with 16bit loads/stores.
+ // scaled_tileW = executionSize
+ elemType = rewriter.getIntegerType(16);
+ tileW = executionSize;
+ wScaleFactor = sub16BitFactor;
+ } else {
+ return rewriter.notifyMatchFailure(
+ op, "Unsupported tile shape for sub byte types.");
+ }
+ }
+ // recompute element bit size for emulation.
+ elemBitSize = elemType.getIntOrFloatBitWidth();
+ }
- VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
- Value payLoadAsI64 =
- vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
- Value basePtr = vector::ExtractOp::create(
- rewriter, loc, payLoadAsI64, static_cast<int>(NdTdescOffset::BasePtr));
- Value baseShapeW = vector::ExtractOp::create(
- rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
- Value baseShapeH = vector::ExtractOp::create(
- rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
- // Offsets are provided by the op.
- // convert them to i32.
- Value offsetW =
- getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
- offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
- rewriter.getI32Type(), offsetW);
- Value offsetH =
- getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
- offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
- rewriter.getI32Type(), offsetH);
// Get address space from tensor descriptor memory space.
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
- // Convert base pointer (i64) to LLVM pointer type.
- Value basePtrLLVM =
- LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
- // Compute element byte size and surface width in bytes.
- Value elemByteSize = arith::ConstantIntOp::create(
- rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
- Value surfaceW =
- arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
-
- // Get tile sizes and vblocks from the tensor descriptor type.
- auto tileW = tdescTy.getDimSize(1);
- auto tileH = tdescTy.getDimSize(0);
- int32_t vblocks = tdescTy.getArrayLength();
- if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
- Value src = adaptor.getValue();
- // If store value is a scalar, get value from op instead of adaptor.
- // Adaptor might have optimized away single element vector
- if (src.getType().isIntOrFloat()) {
- src = op.getValue();
+ if (tileRank == 2) {
+ // Compute element byte size.
+ Value elemByteSize = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
+ VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
+ Value payLoadAsI64 =
+ vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
+ Value basePtr =
+ vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
+ static_cast<int>(NdTdescOffset::BasePtr));
+ Value baseShapeW = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
+ Value baseShapeH = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
+ Value basePitch = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BasePitch));
+ // Offsets are provided by the op.
+ // convert them to i32.
+ Value offsetW =
+ getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
+ offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI32Type(), offsetW);
+ Value offsetH =
+ getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
+ offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI32Type(), offsetH);
+ // Convert base pointer (i64) to LLVM pointer type.
+ Value basePtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
+ // FIXME: width or pitch is not the same as baseShapeW it should be the
+ // stride of the second to last dimension in row major layout.
+ // Compute width in bytes.
+ Value baseShapeWInBytes =
+ arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
+ // Compute pitch in bytes.
+ Value basePitchBytes =
+ arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
+
+ if (wScaleFactor > 1) {
+ // Scale offsetW, baseShapeWInBytes for sub byte emulation.
+ // Note: tileW is already scaled above.
+ Value wScaleFactorValLog2 = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor));
+ baseShapeWInBytes = arith::ShRSIOp::create(
+ rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2);
+ basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes,
+ wScaleFactorValLog2);
+ offsetW =
+ arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2);
}
- VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
- if (!srcVecTy)
- return rewriter.notifyMatchFailure(
- op, "Expected store value to be a vector type.");
- // Get flat vector type of integer type with matching element bit size.
- VectorType newSrcVecTy =
- encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
- if (srcVecTy != newSrcVecTy)
- src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
- auto storeCacheControl =
- translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
- xevm::BlockStore2dOp::create(
- rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
- offsetH, elemBitSize, tileW, tileH, src,
- xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
- rewriter.eraseOp(op);
- } else {
- auto loadCacheControl =
- translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
- if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
- xevm::BlockPrefetch2dOp::create(
- rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
- offsetH, elemBitSize, tileW, tileH, vblocks,
- xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+ // Get tile height from the tensor descriptor type.
+ auto tileH = tdescTy.getDimSize(0);
+ // Get vblocks from the tensor descriptor type.
+ int32_t vblocks = tdescTy.getArrayLength();
+ if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
+ Value src = adaptor.getValue();
+ // If store value is a scalar, get value from op instead of adaptor.
+ // Adaptor might have optimized away single element vector
+ if (src.getType().isIntOrFloat()) {
+ src = op.getValue();
+ }
+ VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
+ if (!srcVecTy)
+ return rewriter.notifyMatchFailure(
+ op, "Expected store value to be a vector type.");
+ // Get flat vector type of integer type with matching element bit size.
+ VectorType newSrcVecTy =
+ encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
+ if (srcVecTy != newSrcVecTy)
+ src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
+ auto storeCacheControl =
+ translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ xevm::BlockStore2dOp::create(
+ rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
+ basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src,
+ xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
rewriter.eraseOp(op);
} else {
- VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
- const bool vnni = op.getPacked().value_or(false);
- auto transposeValue = op.getTranspose();
- bool transpose =
- transposeValue.has_value() && transposeValue.value()[0] == 1;
- VectorType loadedTy = encodeVectorTypeTo(
- dstVecTy, vnni ? rewriter.getI32Type()
- : rewriter.getIntegerType(elemBitSize));
-
- Value resultFlatVec = xevm::BlockLoad2dOp::create(
- rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
- surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
- transpose, vnni,
+ auto loadCacheControl =
+ translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
+ xevm::BlockPrefetch2dOp::create(
+ rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
+ basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH,
+ vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+ rewriter.eraseOp(op);
+ } else {
+ VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
+ const bool vnni = op.getPacked().value_or(false);
+ auto transposeValue = op.getTranspose();
+ bool transpose =
+ transposeValue.has_value() && transposeValue.value()[0] == 1;
+ VectorType loadedTy = encodeVectorTypeTo(
+ dstVecTy, vnni ? rewriter.getI32Type()
+ : rewriter.getIntegerType(elemBitSize));
+
+ Value resultFlatVec = xevm::BlockLoad2dOp::create(
+ rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes,
+ baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW,
+ tileH, vblocks, transpose, vnni,
+ xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+ resultFlatVec = vector::BitCastOp::create(
+ rewriter, loc,
+ encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
+ resultFlatVec);
+ rewriter.replaceOp(op, resultFlatVec);
+ }
+ }
+ } else {
+ // 1D tensor descriptor.
+ // `tdesc` represents base address as i64
+ // Offset in number of elements, need to multiply by element byte size.
+ // Compute byte offset.
+ // byteOffset = offset * elementByteSize
+ Value offset =
+ getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
+ offset = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI64Type(), offset);
+ // Compute element byte size.
+ Value elemByteSize = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI64Type(), elemBitSize / 8);
+ Value byteOffset =
+ rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize);
+ // Final address = basePtr + byteOffset
+ Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
+ loc, tdesc,
+ getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(),
+ byteOffset));
+ // Convert base pointer (i64) to LLVM pointer type.
+ Value finalPtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
+ if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
+ Value src = adaptor.getValue();
+ // If store value is a scalar, get value from op instead of adaptor.
+ // Adaptor might have optimized away single element vector
+ if (src.getType().isIntOrFloat()) {
+ src = op.getValue();
+ }
+ VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
+ if (!srcVecTy)
+ return rewriter.notifyMatchFailure(
+ op, "Expected store value to be a vector type.");
+ // Get flat vector type of integer type with matching element bit size.
+ VectorType newSrcVecTy =
+ encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
+ if (srcVecTy != newSrcVecTy)
+ src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
+ auto storeCacheControl =
+ translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
+ op, finalPtrLLVM, src,
+ xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
+ } else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
+ auto loadCacheControl =
+ translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ VectorType resTy = cast<VectorType>(op.getValue().getType());
+ VectorType loadedTy =
+ encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
+ Value load = xevm::BlockLoadOp::create(
+ rewriter, loc, loadedTy, finalPtrLLVM,
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
- resultFlatVec = vector::BitCastOp::create(
- rewriter, loc,
- encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
- resultFlatVec);
- rewriter.replaceOp(op, resultFlatVec);
+ if (loadedTy != resTy)
+ load = vector::BitCastOp::create(rewriter, loc, resTy, load);
+ rewriter.replaceOp(op, load);
+ } else {
+ return rewriter.notifyMatchFailure(
+ op, "Unsupported operation: xegpu.prefetch_nd with tensor "
+ "descriptor rank == 1");
}
}
return success();
@@ -511,9 +645,6 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
}
};
-// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions
-// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than
-// 32 bits will be converted to 32 bits.
class CreateMemDescOpPattern final
: public OpConversionPattern<xegpu::CreateMemDescOp> {
public:
@@ -522,16 +653,7 @@ public:
matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto resTy = op.getMemDesc();
-
- // Create the result MemRefType with the same shape, element type, and
- // memory space
- auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
-
- Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
- auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
- op.getSource(), zero, ValueRange());
- rewriter.replaceOp(op, viewOp);
+ rewriter.replaceOp(op, adaptor.getSource());
return success();
}
};
@@ -551,19 +673,27 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
- Value basePtrStruct = adaptor.getMemDesc();
+ Value baseAddr32 = adaptor.getMemDesc();
Value mdescVal = op.getMemDesc();
// Load result or Store value Type can be vector or scalar.
- Value data;
- if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>)
- data = op.getResult();
- else
- data = adaptor.getData();
- VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
+ Type dataTy;
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+ Type resType = op.getResult().getType();
+ // Some transforms may leave unit dimension in the 2D vector, adaptors do
+ // not catch it for results.
+ if (auto vecType = dyn_cast<VectorType>(resType)) {
+ assert(llvm::count_if(vecType.getShape(),
+ [](int64_t d) { return d != 1; }) <= 1 &&
+ "Expected either 1D vector or nD with unit dimensions");
+ resType = VectorType::get({vecType.getNumElements()},
+ vecType.getElementType());
+ }
+ dataTy = resType;
+ } else
+ dataTy = adaptor.getData().getType();
+ VectorType valOrResVecTy = dyn_cast<VectorType>(dataTy);
if (!valOrResVecTy)
- valOrResVecTy = VectorType::get(1, data.getType());
- if (valOrResVecTy.getShape().size() != 1)
- return rewriter.notifyMatchFailure(op, "Expected 1D data vector.");
+ valOrResVecTy = VectorType::get(1, dataTy);
int64_t elemBitWidth =
valOrResVecTy.getElementType().getIntOrFloatBitWidth();
@@ -579,21 +709,14 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
- Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
- rewriter, loc, basePtrStruct);
-
- // Convert base pointer (ptr) to i32
- Value basePtrI32 = arith::IndexCastUIOp::create(
- rewriter, loc, rewriter.getI32Type(), basePtrLLVM);
-
Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
linearOffset = arith::IndexCastUIOp::create(
rewriter, loc, rewriter.getI32Type(), linearOffset);
- basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset,
- elemByteSize);
+ Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
+ linearOffset, elemByteSize);
// convert base pointer (i32) to LLVM pointer type
- basePtrLLVM =
+ Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
if (op.getSubgroupBlockIoAttr()) {
@@ -929,20 +1052,22 @@ struct ConvertXeGPUToXeVMPass
return VectorType::get(sum, elemType);
});
typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
+ // Scattered descriptors are not supported in XeVM lowering.
if (type.isScattered())
+ return {};
+ if (type.getRank() == 1)
return IntegerType::get(&getContext(), 64);
auto i32Type = IntegerType::get(&getContext(), 32);
return VectorType::get(8, i32Type);
});
- // Convert MemDescType into flattened MemRefType for SLM
+ // Convert MemDescType into i32 for SLM
typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
- Type elemTy = type.getElementType();
- int numElems = type.getNumElements();
- return MemRefType::get(numElems, elemTy, AffineMap(), 3);
+ return IntegerType::get(&getContext(), 32);
});
typeConverter.addConversion([&](MemRefType type) -> Type {
- // Convert MemRefType to i64 type.
+ if (type.getMemorySpaceAsInt() == 3)
+ return IntegerType::get(&getContext(), 32);
return IntegerType::get(&getContext(), 64);
});
@@ -1059,6 +1184,7 @@ struct ConvertXeGPUToXeVMPass
};
typeConverter.addSourceMaterialization(
singleElementVectorMaterializationCast);
+ typeConverter.addSourceMaterialization(vectorMaterializationCast);
typeConverter.addTargetMaterialization(memrefMaterializationCast);
typeConverter.addTargetMaterialization(ui32MaterializationCast);
typeConverter.addTargetMaterialization(ui64MaterializationCast);
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index f276984..20a420d 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -290,7 +290,7 @@ static LLVM::CallOp createDeviceFunctionCall(
ArrayRef<Type> argTypes, ArrayRef<Value> args,
mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) {
- auto moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
+ auto *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
assert(moduleOp && "Expecting module");
Location loc = op->getLoc();
@@ -401,7 +401,10 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
/*argMem=*/LLVM::ModRefInfo::NoModRef,
- /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
+ /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
+ /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
+ /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
+ /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
auto funcAttrs = convergentNoUnwindWillReturnAttrs;
funcAttrs.memEffectsAttr = memAttr;
Value result =
@@ -450,7 +453,10 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
/*argMem=*/LLVM::ModRefInfo::Ref,
- /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
+ /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
+ /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
+ /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
+ /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
funcAttr.memEffectsAttr = memAttr;
LLVM::CallOp call = createDeviceFunctionCall(
@@ -556,7 +562,10 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
/*argMem=*/LLVM::ModRefInfo::Ref,
- /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
+ /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
+ /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
+ /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
+ /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
funcAttr = noUnwindAttrs;
funcAttr.memEffectsAttr = memAttr;
} else {
@@ -798,7 +807,10 @@ class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> {
constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/noModRef,
- /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
+ /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef,
+ /*errnoMem=*/noModRef,
+ /*targetMem0=*/noModRef,
+ /*targetMem1=*/noModRef);
call.setMemoryEffectsAttr(memAttr);
rewriter.replaceOp(op, call);
return success();
@@ -836,7 +848,10 @@ class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> {
constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/noModRef,
- /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
+ /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef,
+ /*errnoMem=*/noModRef,
+ /*targetMem0=*/noModRef,
+ /*targetMem1=*/noModRef);
call.setMemoryEffectsAttr(memAttr);
rewriter.replaceOp(op, call);
return success();
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index df955fc..b7a665b 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -55,6 +55,10 @@ void AMDGPUDialect::initialize() {
#define GET_OP_LIST
#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
>();
+ addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc"
+ >();
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
@@ -339,19 +343,45 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
}
//===----------------------------------------------------------------------===//
-// ScaledExtPacked816Op
+// ScaledExtPackedMatrixOp
//===----------------------------------------------------------------------===//
-LogicalResult ScaledExtPacked816Op::verify() {
+LogicalResult ScaledExtPackedMatrixOp::verify() {
int blockSize = getBlockSize();
- assert((blockSize == 16 || blockSize == 32) && "invalid block size");
+ assert(llvm::is_contained({16, 32}, blockSize) && "invalid block size");
+
int firstScaleByte = getFirstScaleByte();
- if (blockSize == 16 && !llvm::is_contained({0, 1}, firstScaleByte)) {
- return emitOpError(
- "blockSize of 16 can only have firstScaleByte be 0 or 1.");
- }
- if (blockSize == 32 && !llvm::is_contained({0, 2}, firstScaleByte)) {
- return emitOpError(
- "blockSize of 32 can only have firstScaleByte be 0 or 2.");
+ int firstScaleLane = getFirstScaleLane();
+ auto sourceType = cast<VectorType>(getSource().getType());
+ Type elementType = sourceType.getElementType();
+ auto floatType = cast<FloatType>(elementType);
+ unsigned bitWidth = floatType.getWidth();
+
+ assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
+
+ const bool is_fp8 = bitWidth == 8;
+ const bool is_block_16 = blockSize == 16;
+
+ if (!is_fp8) {
+ if (is_block_16) {
+ if (!llvm::is_contained({0, 1}, firstScaleByte)) {
+ return emitOpError("blockSize of 16 can only have firstScaleByte be 0 "
+ "or 1 for f4 and f6.");
+ }
+ } else {
+ if (!llvm::is_contained({0, 2}, firstScaleByte)) {
+ return emitOpError("blockSize of 32 can only have firstScaleByte be 0 "
+ "or 2 for f4 and f6.");
+ }
+ }
+ } else {
+ if (is_block_16) {
+ bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
+ ((firstScaleLane == 16) && (firstScaleByte == 2));
+ if (!is_valid) {
+ return emitOpError("blockSize of 16 can only have (firstScaleLane, "
+ "firstScaleByte) be (0, 0) or (16, 2) for f8.");
+ }
+ }
}
return success();
@@ -567,6 +597,53 @@ LogicalResult PermlaneSwapOp::verify() {
}
//===----------------------------------------------------------------------===//
+// MemoryCounterWaitOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Fuse adjacent memory counter wait ops, taking the minimum value of the
+/// counters.
+struct FuseMemoryCounterWaitOp final : OpRewritePattern<MemoryCounterWaitOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(MemoryCounterWaitOp op,
+ PatternRewriter &rewriter) const override {
+ auto next = dyn_cast<MemoryCounterWaitOp>(op->getNextNode());
+ if (!next)
+ return failure();
+
+ auto setters = {&MemoryCounterWaitOp::setLoad,
+ &MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs,
+ &MemoryCounterWaitOp::setExp,
+ &MemoryCounterWaitOp::setTensor};
+ auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp(),
+ op.getTensor()};
+ auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(),
+ next.getExp(), next.getTensor()};
+ rewriter.modifyOpInPlace(op, [&] {
+ for (auto [setter, lhs, rhs] :
+ llvm::zip_equal(setters, lhsVals, rhsVals)) {
+ if (lhs && rhs) {
+ (op.*setter)(std::min(*lhs, *rhs));
+ } else if (lhs) {
+ (op.*setter)(*lhs);
+ } else if (rhs) {
+ (op.*setter)(*rhs);
+ }
+ }
+ });
+ rewriter.eraseOp(next);
+ return success();
+ }
+};
+} // namespace
+
+void MemoryCounterWaitOp::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<FuseMemoryCounterWaitOp>(context);
+}
+
+//===----------------------------------------------------------------------===//
// GatherToLDSOp
//===----------------------------------------------------------------------===//
@@ -662,19 +739,123 @@ LogicalResult TransposeLoadOp::verify() {
};
auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
- if (validNumElems == kValidLoadSizeMap.end()) {
+ if (validNumElems == kValidLoadSizeMap.end())
return emitOpError("Unsupported element type size for transpose load: ")
<< elementTypeSize << " bits";
- }
- if (numElements != validNumElems->second) {
+
+ if (numElements != validNumElems->second)
return emitOpError(
"Transferring type size mismatch: expected num of elements: ")
<< validNumElems->second;
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// MakeDmaBaseOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult MakeDmaBaseOp::verify() {
+
+ auto ldsType = cast<MemRefType>(getLds().getType());
+ auto globalType = cast<MemRefType>(getGlobal().getType());
+ if (!hasWorkgroupMemorySpace(ldsType.getMemorySpace()))
+ return emitOpError(
+ "lds memref must have workgroup address space attribute.");
+ if (!hasGlobalMemorySpace(globalType.getMemorySpace()))
+ return emitOpError(
+ "global memref must have global address space attribute.");
+
+ Type elementType = ldsType.getElementType();
+ unsigned width = elementType.getIntOrFloatBitWidth();
+
+ if (!llvm::is_contained<unsigned>({8, 16, 32, 64}, width))
+ return emitOpError(
+ "element type must be 1, 2, 4, or 8 bytes long but type was ")
+ << width << " bits long.";
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// MakeDmaDescriptorOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult MakeDmaDescriptorOp::verify() {
+ ArrayRef<int64_t> globalStaticStrides = getGlobalStaticStrides();
+
+ if (globalStaticStrides.empty())
+ return emitOpError("strides must not be empty.");
+ if (globalStaticStrides.back() != 1)
+ return emitOpError("strides for the innermost dimension must be 1.");
+
+ ArrayRef<int64_t> globalStaticSizes = getGlobalStaticSizes();
+ size_t rank = globalStaticSizes.size();
+ if (rank > 5)
+ return emitOpError("tensor and tile must be at most of rank 5.");
+ if (rank != globalStaticStrides.size())
+ return emitOpError("strides and sizes must have same rank.");
+
+ ArrayRef<int64_t> sharedStaticSizes = getSharedStaticSizes();
+ if (rank != sharedStaticSizes.size())
+ return emitOpError("tensor must have same rank as tile.");
+
+ unsigned elementTypeWidth = getElementTypeWidth();
+ if (!llvm::is_contained<unsigned>({8, 16, 32, 64}, elementTypeWidth))
+ return emitOpError(
+ "element type width must be 1, 2, 4 or 8 bytes, but was ")
+ << elementTypeWidth << " bits long";
+
+ if (Value atomicBarrierAddress = getAtomicBarrierAddress()) {
+ auto atomicBarrierAddressType =
+ cast<MemRefType>(atomicBarrierAddress.getType());
+ bool barrierInLDS =
+ hasWorkgroupMemorySpace(atomicBarrierAddressType.getMemorySpace());
+ if (!barrierInLDS)
+ return emitOpError("atomic barrier address must be in LDS.");
}
+ if (getEarlyTimeout() && !getWorkgroupMask())
+ return emitOpError(
+ "early timeout does not apply when workgroup_mask is not set.");
return success();
}
+OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) {
+ SmallVector<OpFoldResult> mixedGlobalSizes(getMixedGlobalSizes());
+ SmallVector<OpFoldResult> mixedGlobalStrides(getMixedGlobalStrides());
+ SmallVector<OpFoldResult> mixedSharedSizes(getMixedSharedSizes());
+
+ if (failed(foldDynamicIndexList(mixedGlobalSizes, /*onlyNonNegative=*/true,
+ /*onlyNonZero=*/true)) &&
+ failed(foldDynamicIndexList(mixedGlobalStrides, /*onlyNonNegative=*/true,
+ /*onlyNonZero=*/true)) &&
+ failed(foldDynamicIndexList(mixedSharedSizes, /*onlyNonNegative=*/true,
+ /*onlyNonZero=*/true)))
+ return nullptr;
+
+ SmallVector<Value> dynamicGlobalSizes, dynamicGlobalStrides,
+ dynamicSharedSizes;
+ SmallVector<int64_t> staticGlobalSizes, staticGlobalStrides,
+ staticSharedSizes;
+
+ dispatchIndexOpFoldResults(mixedGlobalSizes, dynamicGlobalSizes,
+ staticGlobalSizes);
+ setGlobalStaticSizes(staticGlobalSizes);
+ getGlobalDynamicSizesMutable().assign(dynamicGlobalSizes);
+
+ dispatchIndexOpFoldResults(mixedGlobalStrides, dynamicGlobalStrides,
+ staticGlobalStrides);
+ setGlobalStaticStrides(staticGlobalStrides);
+ getGlobalDynamicStridesMutable().assign(dynamicGlobalStrides);
+
+ dispatchIndexOpFoldResults(mixedSharedSizes, dynamicSharedSizes,
+ staticSharedSizes);
+ setSharedStaticSizes(staticSharedSizes);
+ getSharedDynamicSizesMutable().assign(dynamicSharedSizes);
+ return getResult();
+}
+
//===----------------------------------------------------------------------===//
// ScaledMFMAOp
//===----------------------------------------------------------------------===//
@@ -813,5 +994,8 @@ void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc"
+
#define GET_OP_CLASSES
#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
index f15c63c..89ef51f 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
@@ -33,19 +33,18 @@ using namespace mlir::amdgpu;
/// This pattern supports lowering of: `vector.maskedload` to `vector.load`
/// and `arith.select` if the memref is in buffer address space.
-static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
- vector::MaskedLoadOp maskedOp) {
- auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
+static LogicalResult hasBufferAddressSpace(Type type) {
+ auto memRefType = dyn_cast<MemRefType>(type);
if (!memRefType)
- return rewriter.notifyMatchFailure(maskedOp, "not a memref source");
+ return failure();
Attribute addrSpace = memRefType.getMemorySpace();
if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
- return rewriter.notifyMatchFailure(maskedOp, "no address space");
+ return failure();
if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
amdgpu::AddressSpace::FatRawBuffer)
- return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space");
+ return failure();
return success();
}
@@ -83,10 +82,11 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
PatternRewriter &rewriter) const override {
if (maskedOp->hasAttr(kMaskedloadNeedsMask))
- return failure();
+ return rewriter.notifyMatchFailure(maskedOp, "already rewritten");
- if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) {
- return failure();
+ if (failed(hasBufferAddressSpace(maskedOp.getBase().getType()))) {
+ return rewriter.notifyMatchFailure(
+ maskedOp, "isn't a load from a fat buffer resource");
}
// Check if this is either a full inbounds load or an empty, oob load. If
@@ -176,9 +176,14 @@ struct FullMaskedLoadToConditionalLoad
LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp,
PatternRewriter &rewriter) const override {
+ if (succeeded(hasBufferAddressSpace(loadOp.getBase().getType())))
+ return rewriter.notifyMatchFailure(
+ loadOp, "buffer loads are handled by a more specialized pattern");
+
FailureOr<Value> maybeCond = matchFullMask(rewriter, loadOp.getMask());
if (failed(maybeCond)) {
- return failure();
+ return rewriter.notifyMatchFailure(loadOp,
+ "isn't loading a broadcasted scalar");
}
Value cond = maybeCond.value();
@@ -203,6 +208,15 @@ struct FullMaskedStoreToConditionalStore
LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp,
PatternRewriter &rewriter) const override {
+ // A condition-free implementation of fully masked stores requires
+ // 1) an accessor for the num_records field on buffer resources/fat pointers
+ // 2) knowledge that said field will always be set accurately - that is,
+ // that writes to x < num_records of offset wouldn't trap, which is
+ // something a pattern user would need to assert or we'd need to prove.
+ //
+ // Therefore, conditional stores to buffers still go down this path at
+ // present.
+
FailureOr<Value> maybeCond = matchFullMask(rewriter, storeOp.getMask());
if (failed(maybeCond)) {
return failure();
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 0c35921..c6addfb 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -5421,7 +5421,7 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
return rewriter.notifyMatchFailure(op,
"no unit basis entries to replace");
- if (newIndices.size() == 0) {
+ if (newIndices.empty()) {
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
return success();
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
index c942c02..b04e2d6 100644
--- a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
@@ -82,7 +82,7 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
ArrayRef<int64_t> oldShape = oldMemRefType.getShape();
SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank());
newShape[0] = 2;
- std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1);
+ llvm::copy(oldShape, newShape.begin() + 1);
return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({});
};
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 4743941..8f1249e 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -1711,6 +1711,12 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef<AffineForOp> loops) {
outermost.getBody()->getOperations().splice(
Block::iterator(secondOutermostLoop.getOperation()),
innermost.getBody()->getOperations());
+ for (auto [iter, init] :
+ llvm::zip_equal(secondOutermostLoop.getRegionIterArgs(),
+ secondOutermostLoop.getInits())) {
+ iter.replaceAllUsesWith(init);
+ iter.dropAllUses();
+ }
secondOutermostLoop.erase();
return success();
}
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index de3efc9f..e256915 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -389,8 +389,8 @@ def TruncIExtUIToExtUI :
// trunci(shrsi(x, c)) -> trunci(shrui(x, c))
def TruncIShrSIToTrunciShrUI :
Pat<(Arith_TruncIOp:$tr
- (Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0)), $overflow),
- (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0))), $overflow),
+ (Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0), $exact), $overflow),
+ (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0)), $exact), $overflow),
[(TruncationMatchesShiftAmount $x, $tr, $c0)]>;
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index adeb50b..c4e81e5 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -35,7 +35,7 @@ static Value createConst(Location loc, Type type, int value,
}
/// Create a float constant.
-static Value createFloatConst(Location loc, Type type, APFloat value,
+static Value createFloatConst(Location loc, Type type, const APFloat &value,
PatternRewriter &rewriter) {
auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
index 39e398b..cb7c3d7 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
@@ -150,7 +150,7 @@ public:
rhsMask = packInputs(op1.getRhsMask(), op2.getRhsMask());
}
- auto extOp = op.getLhs().getDefiningOp();
+ auto *extOp = op.getLhs().getDefiningOp();
arm_sme::CombiningKind kind = op.getKind();
if (kind == arm_sme::CombiningKind::Add) {
@@ -311,8 +311,8 @@ public:
rhsMask = packInputs(rhs0Mask, rhs1Mask);
}
- auto lhsExtOp = op.getLhs().getDefiningOp();
- auto rhsExtOp = op.getRhs().getDefiningOp();
+ auto *lhsExtOp = op.getLhs().getDefiningOp();
+ auto *rhsExtOp = op.getRhs().getDefiningOp();
arm_sme::CombiningKind kind = op.getKind();
if (kind == arm_sme::CombiningKind::Add) {
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index e0cf353..9b11270 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -680,16 +680,6 @@ bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
return false;
}
-// bufferization.to_buffer is not allowed to change the rank.
-static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
-#ifndef NDEBUG
- auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(tensor.getType());
- assert((!rankedTensorType || llvm::cast<MemRefType>(memrefType).getRank() ==
- rankedTensorType.getRank()) &&
- "to_buffer would be invalid: mismatching ranks");
-#endif
-}
-
FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options,
const BufferizationState &state) {
@@ -708,7 +698,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
FailureOr<BufferLikeType> bufferType = getBufferType(value, options, state);
if (failed(bufferType))
return failure();
- ensureToBufferOpIsValid(value, *bufferType);
+
return bufferization::ToBufferOp::create(rewriter, value.getLoc(),
*bufferType, value)
.getResult();
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index d6c3cd6..bd177ba 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -54,9 +54,6 @@ struct BuiltinTensorExternalModel
mlir::LogicalResult verifyCompatibleBufferType(
mlir::Type tensor, BufferLikeType bufferType,
llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
- assert(isa<TensorType>(tensor) && "expected tensor type");
- assert(isa<BaseMemRefType>(bufferType) && "expected memref type");
-
auto tensorType = cast<ShapedType>(tensor);
auto memrefType = cast<ShapedType>(bufferType);
diff --git a/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp b/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp
index 51feec7..f8eb45c 100644
--- a/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp
+++ b/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp
@@ -17,6 +17,10 @@
// Pipeline implementation.
//===----------------------------------------------------------------------===//
+void mlir::bufferization::buildBufferDeallocationPipeline(OpPassManager &pm) {
+ buildBufferDeallocationPipeline(pm, BufferDeallocationPipelineOptions());
+}
+
void mlir::bufferization::buildBufferDeallocationPipeline(
OpPassManager &pm, const BufferDeallocationPipelineOptions &options) {
memref::ExpandReallocPassOptions expandAllocPassOptions{
@@ -44,5 +48,7 @@ void mlir::bufferization::registerBufferizationPipelines() {
"The default pipeline for automatically inserting deallocation "
"operations after one-shot bufferization. Deallocation operations "
"(except `memref.realloc`) may not be present already.",
- buildBufferDeallocationPipeline);
+ [](OpPassManager &pm, const BufferDeallocationPipelineOptions &options) {
+ buildBufferDeallocationPipeline(pm, options);
+ });
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 1784964..677c0ba 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
+#include "mlir/Transforms/RegionUtils.h"
namespace mlir {
namespace bufferization {
@@ -105,8 +106,13 @@ Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter,
// this replacement.
Operation *insertionPoint =
findValidInsertionPoint(emptyTensorOp, user, neededValues);
- if (!insertionPoint)
- return {};
+ if (!insertionPoint) {
+ // If no already suitable insertion point was found, attempt to move all
+ // needed values before the user.
+ if (failed(moveValueDefinitions(rewriter, neededValues, user)))
+ return {};
+ insertionPoint = user;
+ }
rewriter.setInsertionPoint(insertionPoint);
Value replacement =
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 9ccbfd3..5dfe3e6 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -497,7 +497,7 @@ static bool matchesInsertDestination(const AnalysisState &state,
// terminates. All of them must be equivalent subsets.
SetVector<Value> backwardSlice =
state.findValueInReverseUseDefChain(opOperand, matchingSubset);
- return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset));
+ return llvm::all_of(backwardSlice, matchingSubset);
}
/// Return "true" if the given "read" and potentially conflicting "write" are
diff --git a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt
index 58551bb..05a787f 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt
@@ -12,4 +12,5 @@ add_mlir_dialect_library(MLIRControlFlowDialect
MLIRControlFlowInterfaces
MLIRIR
MLIRSideEffectInterfaces
+ MLIRUBDialect
)
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index f1da1a1..d2078d8 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
@@ -445,6 +446,37 @@ struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
return success(replaced);
}
};
+
+/// If the destination block of a conditional branch contains only
+/// ub.unreachable, unconditionally branch to the other destination.
+struct DropUnreachableCondBranch : public OpRewritePattern<CondBranchOp> {
+ using OpRewritePattern<CondBranchOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(CondBranchOp condbr,
+ PatternRewriter &rewriter) const override {
+ // If the "true" destination is unreachable, branch to the "false"
+ // destination.
+ Block *trueDest = condbr.getTrueDest();
+ Block *falseDest = condbr.getFalseDest();
+ if (llvm::hasSingleElement(*trueDest) &&
+ isa<ub::UnreachableOp>(trueDest->getTerminator())) {
+ rewriter.replaceOpWithNewOp<BranchOp>(condbr, falseDest,
+ condbr.getFalseOperands());
+ return success();
+ }
+
+ // If the "false" destination is unreachable, branch to the "true"
+ // destination.
+ if (llvm::hasSingleElement(*falseDest) &&
+ isa<ub::UnreachableOp>(falseDest->getTerminator())) {
+ rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest,
+ condbr.getTrueOperands());
+ return success();
+ }
+
+ return failure();
+ }
+};
} // namespace
void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -452,7 +484,7 @@ void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
SimplifyCondBranchIdenticalSuccessors,
SimplifyCondBranchFromCondBranchOnSameCondition,
- CondBranchTruthPropagation>(context);
+ CondBranchTruthPropagation, DropUnreachableCondBranch>(context);
}
SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index d478220..b0566dd 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -226,6 +226,21 @@ FailureOr<SmallVector<ReplacementItem>> parseFormatString(
}
//===----------------------------------------------------------------------===//
+// AddressOfOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult AddressOfOp::verify() {
+ emitc::LValueType referenceType = getReference().getType();
+ emitc::PointerType resultType = getResult().getType();
+
+ if (referenceType.getValueType() != resultType.getPointee())
+ return emitOpError("requires result to be a pointer to the type "
+ "referenced by operand");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//
@@ -380,6 +395,20 @@ LogicalResult emitc::ConstantOp::verify() {
OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
//===----------------------------------------------------------------------===//
+// DereferenceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult DereferenceOp::verify() {
+ emitc::PointerType pointerType = getPointer().getType();
+
+ if (pointerType.getPointee() != getResult().getType().getValueType())
+ return emitOpError("requires result to be an lvalue of the type "
+ "pointed to by operand");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// ExpressionOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp
index b4cb093..d6dfd02 100644
--- a/mlir/lib/Dialect/Func/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp
@@ -254,3 +254,28 @@ func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp,
return std::make_pair(*newFuncOpOrFailure, newCallOp);
}
+
+FailureOr<func::FuncOp>
+func::lookupFnDecl(SymbolOpInterface symTable, StringRef name,
+ FunctionType funcT, SymbolTableCollection *symbolTables) {
+ FuncOp func;
+ if (symbolTables) {
+ func = symbolTables->lookupSymbolIn<FuncOp>(
+ symTable, StringAttr::get(symTable->getContext(), name));
+ } else {
+ func = llvm::dyn_cast_or_null<FuncOp>(
+ SymbolTable::lookupSymbolIn(symTable, name));
+ }
+
+ if (!func)
+ return func;
+
+ mlir::FunctionType foundFuncT = func.getFunctionType();
+ // Assert the signature of the found function is same as expected
+ if (funcT != foundFuncT) {
+ return func.emitError("matched function '")
+ << name << "' but with different type: " << foundFuncT
+ << " (expected " << funcT << ")";
+ }
+ return func;
+}
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 6c6d8d2..61a630a 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -208,7 +208,7 @@ Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
bool MMAMatrixType::isValidElementType(Type elementType) {
- return elementType.isF16() || elementType.isF32() ||
+ return elementType.isF16() || elementType.isF32() || elementType.isF64() ||
elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
elementType.isInteger(32);
}
@@ -225,7 +225,7 @@ MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
if (!MMAMatrixType::isValidElementType(elementType))
return emitError()
- << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
+ << "MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64";
return success();
}
diff --git a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
index ec68acf..85b7b1ce 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
@@ -21,6 +21,7 @@ add_mlir_dialect_library(MLIRGPUPipelines
MLIRNVVMToLLVM
MLIRReconcileUnrealizedCasts
MLIRSCFToControlFlow
+ MLIRVectorToLLVMPass
MLIRVectorToSCF
MLIRXeGPUTransforms
MLIRXeGPUToXeVM
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
index 2c3e466..5462cdd 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
@@ -72,6 +72,7 @@ void buildGpuPassPipeline(OpPassManager &pm,
ConvertGpuOpsToNVVMOpsOptions opt;
opt.useBarePtrCallConv = options.kernelUseBarePtrCallConv;
opt.indexBitwidth = options.indexBitWidth;
+ opt.allowPatternRollback = options.allowPatternRollback;
pm.addNestedPass<gpu::GPUModuleOp>(createConvertGpuOpsToNVVMOps(opt));
pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
index 1a1485b..38313dc 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
@@ -63,13 +63,20 @@ void buildGPUPassPipeline(OpPassManager &pm,
if (options.xegpuOpLevel == "workgroup") {
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUWgToSgDistribute());
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+ xegpu::XeGPUPropagateLayoutOptions layoutOptions;
+ layoutOptions.layoutKind = "inst";
+ pm.addNestedPass<gpu::GPUModuleOp>(
+ xegpu::createXeGPUPropagateLayout(layoutOptions));
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUBlocking());
pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
}
if (options.xegpuOpLevel == "subgroup" ||
options.xegpuOpLevel == "workgroup") {
- pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUPropagateLayout());
+ xegpu::XeGPUPropagateLayoutOptions layoutOptions;
+ layoutOptions.layoutKind = "lane";
+ pm.addNestedPass<gpu::GPUModuleOp>(
+ xegpu::createXeGPUPropagateLayout(layoutOptions));
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUSubgroupDistribute());
pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
@@ -104,8 +111,11 @@ void buildPostGPUCommonPassPipeline(
pm.addPass(createGpuToLLVMConversionPass(gpuToLLVMOptions));
}
pm.addPass(createLowerAffinePass());
+ pm.addPass(createConvertVectorToLLVMPass());
pm.addPass(createConvertToLLVMPass());
pm.addPass(createReconcileUnrealizedCastsPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
// gpu-module-to-binary
{
GpuModuleToBinaryPassOptions gpuToModuleBinOptions;
diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
index cd13840..70d2e11 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
@@ -143,8 +143,8 @@ private:
};
/// Erases `executeOp` and returns a clone with additional `results`.
-async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp,
- ValueRange results) {
+static async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp,
+ ValueRange results) {
// Add values to async.yield op.
Operation *yieldOp = executeOp.getBody()->getTerminator();
yieldOp->insertOperands(yieldOp->getNumOperands(), results);
diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
index 212ccc9..8d10aac 100644
--- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
+++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
@@ -169,7 +169,7 @@ LogicalResult getSegmentSizes(Operation *op, StringRef elemName,
LogicalResult getOperandSegmentSizes(Operation *op,
ArrayRef<Variadicity> variadicities,
SmallVectorImpl<int> &segmentSizes) {
- return getSegmentSizes(op, "operand", "operand_segment_sizes",
+ return getSegmentSizes(op, "operand", "operandSegmentSizes",
op->getNumOperands(), variadicities, segmentSizes);
}
@@ -180,7 +180,7 @@ LogicalResult getOperandSegmentSizes(Operation *op,
LogicalResult getResultSegmentSizes(Operation *op,
ArrayRef<Variadicity> variadicities,
SmallVectorImpl<int> &segmentSizes) {
- return getSegmentSizes(op, "result", "result_segment_sizes",
+ return getSegmentSizes(op, "result", "resultSegmentSizes",
op->getNumResults(), variadicities, segmentSizes);
}
diff --git a/mlir/lib/Dialect/Index/IR/IndexDialect.cpp b/mlir/lib/Dialect/Index/IR/IndexDialect.cpp
index 183d0e3..887e8e1 100644
--- a/mlir/lib/Dialect/Index/IR/IndexDialect.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexDialect.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Transforms/InliningUtils.h"
using namespace mlir;
using namespace mlir::index;
@@ -15,10 +16,23 @@ using namespace mlir::index;
//===----------------------------------------------------------------------===//
// IndexDialect
//===----------------------------------------------------------------------===//
+namespace {
+/// This class defines the interface for handling inlining for index
+/// dialect operations.
+struct IndexInlinerInterface : public DialectInlinerInterface {
+ using DialectInlinerInterface::DialectInlinerInterface;
+
+ /// All index dialect ops can be inlined.
+ bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
+ return true;
+ }
+};
+} // namespace
void IndexDialect::initialize() {
registerAttributes();
registerOperations();
+ addInterfaces<IndexInlinerInterface>();
declarePromisedInterface<ConvertToLLVMPatternInterface, IndexDialect>();
}
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index cc66fac..a73f0c1 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -31,6 +31,7 @@ add_mlir_dialect_library(MLIRLLVMDialect
MLIRControlFlowInterfaces
MLIRDataLayoutInterfaces
MLIRFunctionInterfaces
+ MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRIR
MLIRMemorySlotInterfaces
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index feaffa3..160b6ae 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -30,6 +30,7 @@ static constexpr llvm::StringRef kPrintF16 = "printF16";
static constexpr llvm::StringRef kPrintBF16 = "printBF16";
static constexpr llvm::StringRef kPrintF32 = "printF32";
static constexpr llvm::StringRef kPrintF64 = "printF64";
+static constexpr llvm::StringRef kPrintApFloat = "printApFloat";
static constexpr llvm::StringRef kPrintString = "printString";
static constexpr llvm::StringRef kPrintOpen = "printOpen";
static constexpr llvm::StringRef kPrintClose = "printClose";
@@ -160,6 +161,16 @@ mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
+ return lookupOrCreateReservedFn(
+ b, moduleOp, kPrintApFloat,
+ {IntegerType::get(moduleOp->getContext(), 32),
+ IntegerType::get(moduleOp->getContext(), 64)},
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
+}
+
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
return LLVM::LLVMPointerType::get(context);
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index b8331e0..9f87e50 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -219,11 +219,16 @@ bool TBAANodeAttr::classof(Attribute attr) {
MemoryEffectsAttr MemoryEffectsAttr::get(MLIRContext *context,
ArrayRef<ModRefInfo> memInfoArgs) {
if (memInfoArgs.empty())
- return MemoryEffectsAttr::get(context, ModRefInfo::ModRef,
- ModRefInfo::ModRef, ModRefInfo::ModRef);
- if (memInfoArgs.size() == 3)
+ return MemoryEffectsAttr::get(context, /*other=*/ModRefInfo::ModRef,
+ /*argMem=*/ModRefInfo::ModRef,
+ /*inaccessibleMem=*/ModRefInfo::ModRef,
+ /*errnoMem=*/ModRefInfo::ModRef,
+ /*targetMem0=*/ModRefInfo::ModRef,
+ /*targetMem1=*/ModRefInfo::ModRef);
+ if (memInfoArgs.size() == 6)
return MemoryEffectsAttr::get(context, memInfoArgs[0], memInfoArgs[1],
- memInfoArgs[2]);
+ memInfoArgs[2], memInfoArgs[3],
+ memInfoArgs[4], memInfoArgs[5]);
return {};
}
@@ -234,6 +239,12 @@ bool MemoryEffectsAttr::isReadWrite() {
return false;
if (this->getOther() != ModRefInfo::ModRef)
return false;
+ if (this->getErrnoMem() != ModRefInfo::ModRef)
+ return false;
+ if (this->getTargetMem0() != ModRefInfo::ModRef)
+ return false;
+ if (this->getTargetMem1() != ModRefInfo::ModRef)
+ return false;
return true;
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 1bf4a1c..5b81948 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -4224,6 +4224,34 @@ LogicalResult InlineAsmOp::verify() {
}
//===----------------------------------------------------------------------===//
+// UDivOp
+//===----------------------------------------------------------------------===//
+Speculation::Speculatability UDivOp::getSpeculatability() {
+ // X / 0 => UB
+ Value divisor = getRhs();
+ if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
+ return Speculation::Speculatable;
+
+ return Speculation::NotSpeculatable;
+}
+
+//===----------------------------------------------------------------------===//
+// SDivOp
+//===----------------------------------------------------------------------===//
+Speculation::Speculatability SDivOp::getSpeculatability() {
+ // This function conservatively assumes that all signed division by -1 are
+ // not speculatable.
+ // X / 0 => UB
+ // INT_MIN / -1 => UB
+ Value divisor = getRhs();
+ if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
+ matchPattern(divisor, m_IntRangeWithoutNegOneS()))
+ return Speculation::Speculatable;
+
+ return Speculation::NotSpeculatable;
+}
+
+//===----------------------------------------------------------------------===//
// LLVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index ce93d18..5dc4fa2 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -667,6 +667,7 @@ LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries,
static constexpr llvm::StringRef kSpirvPrefix = "spirv.";
static constexpr llvm::StringRef kArmSVCount = "aarch64.svcount";
+static constexpr llvm::StringRef kAMDGCNNamedBarrier = "amdgcn.named.barrier";
bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const {
// See llvm/lib/IR/Type.cpp for reference.
@@ -676,6 +677,9 @@ bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const {
properties |=
(LLVMTargetExtType::HasZeroInit | LLVM::LLVMTargetExtType::CanBeGlobal);
+ if (getExtTypeName() == kAMDGCNNamedBarrier)
+ properties |= LLVMTargetExtType::CanBeGlobal;
+
return (properties & prop) == prop;
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index d43f881..5ce56e6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -31,6 +31,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/NVVMIntrinsicUtils.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/NVPTXAddrSpace.h"
@@ -48,6 +49,47 @@ using namespace NVVM;
static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
//===----------------------------------------------------------------------===//
+// Helper/Utility methods
+//===----------------------------------------------------------------------===//
+
+static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
+ auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
+ return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
+}
+
+static bool isPtrInGenericSpace(mlir::Value ptr) {
+ return isPtrInAddrSpace(ptr, NVVMMemorySpace::Generic);
+}
+
+static bool isPtrInSharedCTASpace(mlir::Value ptr) {
+ return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
+}
+
+static bool isPtrInSharedClusterSpace(mlir::Value ptr) {
+ return isPtrInAddrSpace(ptr, NVVMMemorySpace::SharedCluster);
+}
+
+static llvm::Value *castPtrToAddrSpace(llvm::IRBuilderBase &builder,
+ llvm::Value *ptr,
+ NVVMMemorySpace targetAS) {
+ unsigned AS = static_cast<unsigned>(targetAS);
+ return builder.CreateAddrSpaceCast(
+ ptr, llvm::PointerType::get(builder.getContext(), AS));
+}
+
+// Helper method to convert CtaGroupKind in NVVM Dialect to CtaGroupKind in LLVM
+static llvm::nvvm::CTAGroupKind
+getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup) {
+ switch (ctaGroup) {
+ case NVVM::CTAGroupKind::CTA_1:
+ return llvm::nvvm::CTAGroupKind::CG_1;
+ case NVVM::CTAGroupKind::CTA_2:
+ return llvm::nvvm::CTAGroupKind::CG_2;
+ }
+ llvm_unreachable("unsupported cta_group value");
+}
+
+//===----------------------------------------------------------------------===//
// Verifier methods
//===----------------------------------------------------------------------===//
@@ -199,6 +241,83 @@ LogicalResult CpAsyncBulkTensorReduceOp::verify() {
return success();
}
+LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
+ bool isSharedCTA = isPtrInSharedCTASpace(getDstMem());
+ if (isSharedCTA && getMulticastMask())
+ return emitError("Multicast is not supported with shared::cta mode.");
+
+ return success();
+}
+
+static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr,
+ NVVM::MemScopeKind scope,
+ Value retVal = nullptr) {
+ if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER)
+ return op->emitError("mbarrier scope must be either CTA or Cluster");
+
+ bool isSharedCluster = isPtrInSharedClusterSpace(addr);
+ bool hasRetValue = static_cast<bool>(retVal);
+ if (isSharedCluster && hasRetValue)
+ return op->emitError(
+ "mbarrier in shared_cluster space cannot return any value");
+
+ return success();
+}
+
+LogicalResult MBarrierArriveOp::verify() {
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
+ getRes());
+}
+
+LogicalResult MBarrierArriveDropOp::verify() {
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
+ getRes());
+}
+
+LogicalResult MBarrierArriveExpectTxOp::verify() {
+ // The inline-ptx version of this Op does not support all features.
+ // With predicate, this Op lowers to inline-ptx. So, verify and
+ // error-out if there are unsupported features.
+ if (getPredicate()) {
+ if (getScope() != NVVM::MemScopeKind::CTA)
+ return emitError("mbarrier scope must be CTA when using predicate");
+
+ if (isPtrInSharedClusterSpace(getAddr()))
+ return emitError("mbarrier in shared_cluster space is not supported when "
+ "using predicate");
+
+ if (getRes())
+ return emitError("return-value is not supported when using predicate");
+
+ if (getRelaxed() == true)
+ return emitError("mbarrier with relaxed semantics is not supported when "
+ "using predicate");
+ }
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
+ getRes());
+}
+
+LogicalResult MBarrierArriveDropExpectTxOp::verify() {
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(),
+ getRes());
+}
+
+LogicalResult MBarrierExpectTxOp::verify() {
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
+}
+
+LogicalResult MBarrierCompleteTxOp::verify() {
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
+}
+
+LogicalResult MBarrierTestWaitOp::verify() {
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
+}
+
+LogicalResult MBarrierTryWaitOp::verify() {
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
+}
+
LogicalResult ConvertFloatToTF32Op::verify() {
using RndMode = NVVM::FPRoundingMode;
switch (getRnd()) {
@@ -365,22 +484,71 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() {
return success();
}
+LogicalResult PermuteOp::verify() {
+ using Mode = NVVM::PermuteMode;
+ bool hasHi = static_cast<bool>(getHi());
+
+ switch (getMode()) {
+ case Mode::DEFAULT:
+ case Mode::F4E:
+ case Mode::B4E:
+ if (!hasHi)
+ return emitError("mode '")
+ << stringifyPermuteMode(getMode()) << "' requires 'hi' operand.";
+ break;
+ case Mode::RC8:
+ case Mode::ECL:
+ case Mode::ECR:
+ case Mode::RC16:
+ if (hasHi)
+ return emitError("mode '") << stringifyPermuteMode(getMode())
+ << "' does not accept 'hi' operand.";
+ break;
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Stochastic Rounding Conversion Ops
//===----------------------------------------------------------------------===//
-LogicalResult ConvertF32x2ToF16x2Op::verify() {
- if (getRnd() != FPRoundingMode::RS)
- return emitOpError("Only RS rounding mode is supported for "
- "conversions from f32x2 to f16x2.");
+static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType,
+ FPRoundingMode rnd,
+ bool hasRandomBits,
+ Operation *op) {
+ static constexpr FPRoundingMode validRndModes[] = {
+ FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS};
+
+ if (!llvm::is_contained(validRndModes, rnd)) {
+ return op->emitOpError(
+ "Only RN, RZ, and RS rounding modes are supported for "
+ "conversions from f32x2 to ")
+ << dstType << ".";
+ }
+
+ if (rnd == FPRoundingMode::RS) {
+ if (!hasRandomBits) {
+ return op->emitOpError("random_bits is required for RS rounding mode.");
+ }
+ } else {
+ if (hasRandomBits) {
+ return op->emitOpError(
+ "random_bits not supported for RN and RZ rounding modes.");
+ }
+ }
+
return success();
}
+LogicalResult ConvertF32x2ToF16x2Op::verify() {
+ return verifyConvertF32x2ToFP16x2Op("f16x2", getRnd(),
+ getRandomBits() ? true : false, *this);
+}
+
LogicalResult ConvertF32x2ToBF16x2Op::verify() {
- if (getRnd() != FPRoundingMode::RS)
- return emitOpError("Only RS rounding mode is supported for "
- "conversions from f32x2 to bf16x2.");
- return success();
+ return verifyConvertF32x2ToFP16x2Op("bf16x2", getRnd(),
+ getRandomBits() ? true : false, *this);
}
LogicalResult ConvertF32x4ToF8x4Op::verify() {
@@ -919,6 +1087,482 @@ LogicalResult MmaOp::verify() {
return success();
}
+MMATypes MmaSpOp::accumPtxType() {
+ std::optional<mlir::NVVM::MMATypes> val = MmaOp::inferOperandMMAType(
+ getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
+ assert(val.has_value() && "accumulator PTX type should always be inferrable");
+ return val.value();
+}
+
+MMATypes MmaSpOp::resultPtxType() {
+ std::optional<mlir::NVVM::MMATypes> val =
+ MmaOp::inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
+ assert(val.has_value() && "result PTX type should always be inferrable");
+ return val.value();
+}
+
+mlir::NVVM::IDArgPair
+MmaSpOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MmaSpOp>(op);
+
+ // Get operands
+ llvm::SmallVector<llvm::Value *> args;
+ for (mlir::Value v : thisOp.getOperands())
+ args.push_back(mt.lookupValue(v));
+
+ // Get intrinsic ID using the existing getIntrinsicID method
+ auto intId = MmaSpOp::getIntrinsicID(
+ thisOp.getShape().getM(), thisOp.getShape().getN(),
+ thisOp.getShape().getK(), thisOp.getIntOverflowBehavior(),
+ thisOp.getOrderedMetadata(), thisOp.getKind(),
+ *thisOp.getMultiplicandAPtxType(), *thisOp.getMultiplicandBPtxType(),
+ thisOp.accumPtxType(), thisOp.resultPtxType());
+
+ return {intId, args};
+}
+
+void MmaSpOp::print(OpAsmPrinter &p) {
+ SmallVector<Type, 4> regTypes;
+ struct OperandFragment {
+ StringRef operandName;
+ StringRef ptxTypeAttr;
+ SmallVector<Value, 4> regs;
+ explicit OperandFragment(StringRef name, StringRef ptxTypeName)
+ : operandName(name), ptxTypeAttr(ptxTypeName) {}
+ };
+
+ std::array<OperandFragment, 5> frags{
+ OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
+ OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
+ OperandFragment("C", ""), OperandFragment("sparseMetadata", ""),
+ OperandFragment("selector", "")};
+ SmallVector<StringRef, 4> ignoreAttrNames{
+ mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()};
+
+ // Handle variadic operands A, B, C
+ for (unsigned fragIdx = 0; fragIdx < 3; fragIdx++) {
+ auto &frag = frags[fragIdx];
+ auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
+ for (auto operandIdx = varOperandSpec.first;
+ operandIdx < varOperandSpec.first + varOperandSpec.second;
+ operandIdx++) {
+ frag.regs.push_back(this->getOperand(operandIdx));
+ if (operandIdx == varOperandSpec.first) {
+ regTypes.push_back(this->getOperand(operandIdx).getType());
+ }
+ }
+ std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
+ regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
+ if (inferredType)
+ ignoreAttrNames.push_back(frag.ptxTypeAttr);
+ }
+
+ // Handle sparse metadata and selector (single operands)
+ frags[3].regs.push_back(getSparseMetadata());
+ frags[4].regs.push_back(getSparsitySelector());
+
+ auto printMmaSpOperand = [&](const OperandFragment &frag) -> void {
+ p << " " << frag.operandName;
+ p << "[";
+ p.printOperands(frag.regs);
+ p << "]";
+ };
+
+ for (const auto &frag : frags)
+ printMmaSpOperand(frag);
+
+ p.printOptionalAttrDict((*this)->getAttrs(), ignoreAttrNames);
+ p << " : ";
+ p << "(";
+ for (int i = 0; i < 3; ++i) {
+ p << regTypes[i];
+ if (i < 2)
+ p << ", ";
+ }
+ p << ") -> " << getResult().getType();
+}
+
+void MmaSpOp::build(
+ OpBuilder &builder, OperationState &result, Type resultType,
+ ValueRange operandA, ValueRange operandB, ValueRange operandC,
+ Value sparseMetadata, Value sparsitySelector, ArrayRef<int64_t> shape,
+ std::optional<MMAIntOverflow> intOverflow,
+ std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
+
+ assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
+ MLIRContext *ctx = builder.getContext();
+ result.addAttribute(
+ "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
+
+ result.addOperands(operandA);
+ result.addOperands(operandB);
+ result.addOperands(operandC);
+ result.addOperands(sparseMetadata);
+ result.addOperands(sparsitySelector);
+
+ if (multiplicandPtxTypes) {
+ result.addAttribute("multiplicandAPtxType",
+ MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
+ result.addAttribute("multiplicandBPtxType",
+ MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
+ } else {
+ if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
+ result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
+ if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
+ result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
+ }
+
+ if (intOverflow.has_value())
+ result.addAttribute("intOverflowBehavior",
+ MMAIntOverflowAttr::get(ctx, *intOverflow));
+
+ result.addTypes(resultType);
+ result.addAttribute(
+ MmaSpOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
+ static_cast<int32_t>(operandB.size()),
+ static_cast<int32_t>(operandC.size()), 1,
+ 1})); // sparseMetadata and sparsitySelector
+}
+
+ParseResult MmaSpOp::parse(OpAsmParser &parser, OperationState &result) {
+ struct OperandFragment {
+ std::optional<MMATypes> elemtype;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
+ SmallVector<Type> regTypes;
+ };
+
+ Builder &builder = parser.getBuilder();
+ std::array<OperandFragment, 6> frags; // A, B, C, sparseMetadata, selector
+
+ NamedAttrList namedAttributes;
+
+ // A helper to parse the operand segments.
+ auto parseMmaSpOperand = [&](StringRef operandName,
+ OperandFragment &frag) -> LogicalResult {
+ if (parser.parseKeyword(operandName).failed())
+ return failure();
+ if (parser
+ .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
+ .failed())
+ return failure();
+ return success();
+ };
+
+ // Parse the operand segments.
+ if (parseMmaSpOperand("A", frags[0]).failed())
+ return failure();
+ if (parseMmaSpOperand("B", frags[1]).failed())
+ return failure();
+ if (parseMmaSpOperand("C", frags[2]).failed())
+ return failure();
+ if (parseMmaSpOperand("sparseMetadata", frags[3]).failed())
+ return failure();
+ if (parseMmaSpOperand("selector", frags[4]).failed())
+ return failure();
+
+ if (parser.parseOptionalAttrDict(namedAttributes).failed())
+ return failure();
+
+ // Parse the type specification and resolve operands.
+ SmallVector<Type, 3> operandTypes;
+ if (failed(parser.parseColon()))
+ return failure();
+ if (failed(parser.parseLParen()))
+ return failure();
+ if (failed(parser.parseTypeList(operandTypes)))
+ return failure();
+ if (failed(parser.parseRParen()))
+ return failure();
+ if (operandTypes.size() != 3)
+ return parser.emitError(
+ parser.getNameLoc(),
+ "expected one type for each operand segment but got " +
+ Twine(operandTypes.size()) + " types");
+ for (const auto &iter : llvm::enumerate(operandTypes)) {
+ auto &frag = frags[iter.index()];
+ frag.regTypes.resize(frag.regs.size(), iter.value());
+ if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
+ parser.getNameLoc(), result.operands)))
+ return failure();
+ frag.elemtype =
+ MmaOp::inferOperandMMAType(frag.regTypes[0],
+ /*isAccumulator*/ iter.index() >= 2);
+ }
+
+ Type resultType;
+ if (parser.parseArrow() || parser.parseType(resultType))
+ return failure();
+ frags[5].elemtype =
+ MmaOp::inferOperandMMAType(resultType, /*isAccumulator*/ true);
+
+ // Resolve sparse metadata and selector (assume i32 type)
+ Type i32Type = builder.getIntegerType(32);
+ if (parser
+ .resolveOperands(frags[3].regs, i32Type, parser.getCurrentLocation(),
+ result.operands)
+ .failed())
+ return failure();
+ if (parser
+ .resolveOperands(frags[4].regs, i32Type, parser.getCurrentLocation(),
+ result.operands)
+ .failed())
+ return failure();
+
+ std::array<StringRef, 2> names{"multiplicandAPtxType",
+ "multiplicandBPtxType"};
+ for (unsigned idx = 0; idx < names.size(); idx++) {
+ const auto &frag = frags[idx];
+ std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
+ if (!frag.elemtype.has_value() && !attr.has_value()) {
+ return parser.emitError(
+ parser.getNameLoc(),
+ "attribute " + names[idx] +
+ " is not provided explicitly and cannot be inferred");
+ }
+ if (!attr.has_value())
+ result.addAttribute(
+ names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
+ }
+
+ result.addTypes(resultType);
+ if (!namedAttributes.empty())
+ result.addAttributes(namedAttributes);
+ result.addAttribute(MmaSpOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({
+ static_cast<int32_t>(frags[0].regs.size()),
+ static_cast<int32_t>(frags[1].regs.size()),
+ static_cast<int32_t>(frags[2].regs.size()),
+ 1, // sparseMetadata
+ 1 // sparsitySelector
+ }));
+ return success();
+}
+
+LogicalResult MmaSpOp::verify() {
+ MLIRContext *context = getContext();
+ auto f16Ty = Float16Type::get(context);
+ auto i32Ty = IntegerType::get(context, 32);
+ auto f16x2Ty = VectorType::get(2, f16Ty);
+ auto f32Ty = Float32Type::get(context);
+ auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
+ context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
+
+ auto s32x4StructTy =
+ LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
+ auto f32x8StructTy =
+ LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
+ auto f16x2x2StructTy =
+ LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
+ auto f32x4StructTy =
+ LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
+ auto s32x2StructTy =
+ LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
+
+ std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
+ getShapeAttr().getK()};
+
+ // These variables define the set of allowed data types for matrices A, B, C,
+ // and result.
+ using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
+ using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
+ AllowedShapes allowedShapes;
+ AllowedTypes expectedA;
+ AllowedTypes expectedB;
+ AllowedTypes expectedC;
+ SmallVector<Type> expectedResult;
+
+ // When M = 16, we just need to calculate the number of 8xk tiles, where
+ // k is a factor that depends on the data type.
+ if (mmaShape[0] == 16) {
+ int64_t kFactor;
+ Type multiplicandFragType;
+ switch (*getMultiplicandAPtxType()) {
+ case MMATypes::tf32:
+ kFactor = 4;
+ multiplicandFragType = i32Ty;
+ expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
+ context, {f32Ty, f32Ty, f32Ty, f32Ty}));
+ // Sparse MMA supports m16n8k8 and m16n8k16 for tf32
+ allowedShapes.push_back({16, 8, 8});
+ allowedShapes.push_back({16, 8, 16});
+ break;
+ case MMATypes::bf16:
+ kFactor = 8;
+ multiplicandFragType = i32Ty;
+ expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
+ context, {f32Ty, f32Ty, f32Ty, f32Ty}));
+ // Sparse MMA supports m16n8k16 and m16n8k32 for bf16
+ allowedShapes.push_back({16, 8, 16});
+ allowedShapes.push_back({16, 8, 32});
+ break;
+ case MMATypes::f16:
+ kFactor = 8;
+ multiplicandFragType = f16x2Ty;
+ expectedResult.push_back(f16x2x2StructTy);
+ expectedResult.push_back(f32x4StructTy);
+ // Sparse MMA supports m16n8k16 and m16n8k32 for f16
+ allowedShapes.push_back({16, 8, 16});
+ allowedShapes.push_back({16, 8, 32});
+ break;
+ case MMATypes::s4:
+ case MMATypes::u4:
+ kFactor = 32;
+ // Sparse MMA supports m16n8k64 and m16n8k128 for s4/u4
+ allowedShapes.push_back({16, 8, 64});
+ allowedShapes.push_back({16, 8, 128});
+ break;
+ case MMATypes::s8:
+ case MMATypes::u8:
+ kFactor = 16;
+ // Sparse MMA supports m16n8k32 and m16n8k64 for s8/u8
+ allowedShapes.push_back({16, 8, 32});
+ allowedShapes.push_back({16, 8, 64});
+ break;
+ case MMATypes::e4m3:
+ case MMATypes::e5m2:
+ case MMATypes::e3m2:
+ case MMATypes::e2m3:
+ case MMATypes::e2m1:
+ kFactor = 32;
+ multiplicandFragType = i32Ty;
+ expectedResult.push_back(f16x2x2StructTy);
+ expectedResult.push_back(f32x4StructTy);
+ // Sparse MMA supports m16n8k64 for FP8 types
+ allowedShapes.push_back({16, 8, 64});
+ break;
+ default:
+ return emitError("invalid shape or multiplicand type: " +
+ stringifyEnum(getMultiplicandAPtxType().value()));
+ }
+
+ if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
+ expectedResult.push_back(s32x4StructTy);
+ expectedC.emplace_back(4, i32Ty);
+ multiplicandFragType = i32Ty;
+ } else if (*getMultiplicandAPtxType() >= MMATypes::e4m3 &&
+ *getMultiplicandAPtxType() <= MMATypes::e2m1) {
+ // FP8 types
+ expectedC.emplace_back(2, f16x2Ty);
+ expectedC.emplace_back(4, f32Ty);
+ } else {
+ expectedC.emplace_back(2, f16x2Ty);
+ expectedC.emplace_back(4, f32Ty);
+ }
+
+ // For sparse MMA, A operand is compressed (2:4 sparsity means half the
+ // elements)
+ int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor) / 2;
+ int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
+ expectedA.emplace_back(unitA, multiplicandFragType);
+ expectedB.emplace_back(unitB, multiplicandFragType);
+
+ if (resultPtxType() != accumPtxType())
+ return emitOpError("ctype does not match dtype");
+ }
+
+ // In the M=8 case, there is only 1 possible case per data type.
+ if (mmaShape[0] == 8) {
+ if (*getMultiplicandAPtxType() == MMATypes::f16) {
+ expectedA.emplace_back(2, f16x2Ty);
+ expectedB.emplace_back(2, f16x2Ty);
+ expectedResult.push_back(f16x2x4StructTy);
+ expectedResult.push_back(f32x8StructTy);
+ expectedC.emplace_back(4, f16x2Ty);
+ expectedC.emplace_back(8, f32Ty);
+ allowedShapes.push_back({8, 8, 4});
+ }
+ if (*getMultiplicandAPtxType() == MMATypes::f64) {
+ Type f64Ty = Float64Type::get(context);
+ expectedA.emplace_back(1, f64Ty);
+ expectedB.emplace_back(1, f64Ty);
+ expectedC.emplace_back(2, f64Ty);
+ expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
+ context, SmallVector<Type>(2, f64Ty)));
+ allowedShapes.push_back({8, 8, 4});
+ }
+ if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
+ expectedA.push_back({i32Ty});
+ expectedB.push_back({i32Ty});
+ expectedC.push_back({i32Ty, i32Ty});
+ expectedResult.push_back(s32x2StructTy);
+ if (isInt4PtxType(getMultiplicandAPtxType().value()))
+ allowedShapes.push_back({8, 8, 32});
+ if (isInt8PtxType(getMultiplicandAPtxType().value()))
+ allowedShapes.push_back({8, 8, 16});
+ }
+ }
+
+ std::string errorMessage;
+ llvm::raw_string_ostream errorStream(errorMessage);
+
+ // Check that we matched an existing shape/dtype combination.
+ if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
+ !llvm::is_contained(allowedShapes, mmaShape)) {
+ errorStream << "unimplemented variant for MMA shape <";
+ llvm::interleaveComma(mmaShape, errorStream);
+ errorStream << ">";
+ return emitOpError(errorMessage);
+ }
+
+ // Verify the operand types for segments of A, B, and C operands.
+ std::array<StringRef, 3> operandNames{"A", "B", "C"};
+ for (const auto &iter : llvm::enumerate(
+ SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
+ auto spec = this->getODSOperandIndexAndLength(iter.index());
+ SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
+ operand_type_begin() + spec.first +
+ spec.second);
+ bool match = llvm::is_contained(iter.value(), operandTySeg);
+
+ if (!match) {
+ errorStream << "Could not match types for the "
+ << operandNames[iter.index()]
+ << " operands; expected one of ";
+ for (const auto &x : iter.value()) {
+ errorStream << x.size() << "x" << x[0] << " ";
+ }
+ errorStream << "but got ";
+ llvm::interleaveComma(operandTySeg, errorStream);
+ return emitOpError(errorMessage);
+ }
+ }
+
+ // Check the result type
+ if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
+ return expectedResultType == getResult().getType();
+ })) {
+ errorStream
+ << "Could not match allowed types for the result; expected one of ";
+ llvm::interleaveComma(expectedResult, errorStream);
+ errorStream << " but got " << getResult().getType();
+ return emitOpError(errorMessage);
+ }
+
+ // Ensure int4/int8 MMA variants specify the accum overflow behavior
+ // attribute.
+ if (isInt4PtxType(*getMultiplicandAPtxType()) ||
+ isInt8PtxType(*getMultiplicandAPtxType())) {
+ if (!getIntOverflowBehavior())
+ return emitOpError("op requires " +
+ getIntOverflowBehaviorAttrName().strref() +
+ " attribute");
+ }
+
+ // Validate sparse metadata type (should be i32)
+ if (!getSparseMetadata().getType().isInteger(32)) {
+ return emitOpError() << "sparse metadata must be i32 type";
+ }
+
+ // Validate sparsity selector type (should be i32)
+ if (!getSparsitySelector().getType().isInteger(32)) {
+ return emitOpError() << "sparsity selector must be i32 type";
+ }
+
+ return success();
+}
+
LogicalResult ShflOp::verify() {
auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
@@ -1454,6 +2098,13 @@ bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
return true; // Has manual mapping
}
+LogicalResult NVVM::FenceSyncRestrictOp::verify() {
+ if (getOrder() != NVVM::MemOrderKind::ACQUIRE &&
+ getOrder() != NVVM::MemOrderKind::RELEASE)
+ return emitOpError("only acquire and release semantics are supported");
+ return success();
+}
+
LogicalResult NVVM::FenceProxyOp::verify() {
if (getKind() == NVVM::ProxyKind::TENSORMAP)
return emitOpError() << "tensormap proxy is not a supported proxy kind";
@@ -1476,7 +2127,6 @@ LogicalResult NVVM::FenceProxyAcquireOp::verify() {
if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
return emitOpError("uni-directional proxies only support tensormap "
"for to_proxy attribute");
-
return success();
}
@@ -1488,7 +2138,19 @@ LogicalResult NVVM::FenceProxyReleaseOp::verify() {
if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
return emitOpError("uni-directional proxies only support tensormap "
"for to_proxy attribute");
+ return success();
+}
+
+LogicalResult NVVM::FenceProxySyncRestrictOp::verify() {
+ if (getOrder() != NVVM::MemOrderKind::ACQUIRE &&
+ getOrder() != NVVM::MemOrderKind::RELEASE)
+ return emitOpError("only acquire and release semantics are supported");
+
+ if (getFromProxy() != NVVM::ProxyKind::GENERIC)
+ return emitOpError("only generic is support for from_proxy attribute");
+ if (getToProxy() != NVVM::ProxyKind::async)
+ return emitOpError("only async is supported for to_proxy attribute");
return success();
}
@@ -1504,6 +2166,15 @@ LogicalResult NVVM::BarrierOp::verify() {
if (getNumberOfThreads() && !getBarrierId())
return emitOpError(
"barrier id is missing, it should be set between 0 to 15");
+
+ if (getBarrierId() && (getReductionOp() || getReductionPredicate()))
+ return emitOpError("reduction are only available when id is 0");
+
+ if ((getReductionOp() && !getReductionPredicate()) ||
+ (!getReductionOp() && getReductionPredicate()))
+ return emitOpError("reduction predicate and reduction operation must be "
+ "specified together");
+
return success();
}
@@ -1741,24 +2412,68 @@ void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
//===----------------------------------------------------------------------===//
std::string NVVM::MBarrierInitOp::getPtx() {
- unsigned addressSpace =
- llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
- return (addressSpace == NVVMMemorySpace::Shared)
- ? std::string("mbarrier.init.shared.b64 [%0], %1;")
- : std::string("mbarrier.init.b64 [%0], %1;");
+ bool isShared = isPtrInSharedCTASpace(getAddr());
+ return isShared ? std::string("mbarrier.init.shared.b64 [%0], %1;")
+ : std::string("mbarrier.init.b64 [%0], %1;");
+}
+
+std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
+ bool isShared = isPtrInSharedCTASpace(getAddr());
+ return isShared
+ ? std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
+ : std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
+}
+
+std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
+ bool isShared = isPtrInSharedCTASpace(getAddr());
+ llvm::StringRef space = isShared ? ".shared" : "";
+
+ return llvm::formatv("{\n\t"
+ ".reg .pred P1; \n\t"
+ "LAB_WAIT: \n\t"
+ "mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t"
+ "@P1 bra.uni DONE; \n\t"
+ "bra.uni LAB_WAIT; \n\t"
+ "DONE: \n\t"
+ "}",
+ space);
}
//===----------------------------------------------------------------------===//
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//
-static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
- auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
- return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
-}
+mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::BarrierOp>(op);
+ llvm::Value *barrierId = thisOp.getBarrierId()
+ ? mt.lookupValue(thisOp.getBarrierId())
+ : builder.getInt32(0);
+ llvm::Intrinsic::ID id;
+ llvm::SmallVector<llvm::Value *> args;
+ if (thisOp.getNumberOfThreads()) {
+ id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count;
+ args.push_back(barrierId);
+ args.push_back(mt.lookupValue(thisOp.getNumberOfThreads()));
+ } else if (thisOp.getReductionOp()) {
+ switch (*thisOp.getReductionOp()) {
+ case NVVM::BarrierReduction::AND:
+ id = llvm::Intrinsic::nvvm_barrier0_and;
+ break;
+ case NVVM::BarrierReduction::OR:
+ id = llvm::Intrinsic::nvvm_barrier0_or;
+ break;
+ case NVVM::BarrierReduction::POPC:
+ id = llvm::Intrinsic::nvvm_barrier0_popc;
+ break;
+ }
+ args.push_back(mt.lookupValue(thisOp.getReductionPredicate()));
+ } else {
+ id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
+ args.push_back(barrierId);
+ }
-static bool isPtrInSharedCTASpace(mlir::Value ptr) {
- return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
+ return {id, std::move(args)};
}
mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
@@ -1787,15 +2502,213 @@ mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs(
return {id, {mt.lookupValue(thisOp.getAddr())}};
}
+mlir::NVVM::IDArgPair MBarrierExpectTxOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierExpectTxOp>(op);
+
+ bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ // bit-0: Space
+ // bit-1: Scope
+ size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
+
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cluster};
+
+ // Fill the Intrinsic Args
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getAddr()));
+ args.push_back(mt.lookupValue(thisOp.getTxcount()));
+
+ return {IDs[index], std::move(args)};
+}
+
+mlir::NVVM::IDArgPair MBarrierCompleteTxOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierCompleteTxOp>(op);
+
+ bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ // bit-0: Space
+ // bit-1: Scope
+ size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
+
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cluster};
+
+ // Fill the Intrinsic Args
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getAddr()));
+ args.push_back(mt.lookupValue(thisOp.getTxcount()));
+
+ return {IDs[index], std::move(args)};
+}
+
mlir::NVVM::IDArgPair MBarrierArriveOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MBarrierArriveOp>(op);
- bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
- llvm::Intrinsic::ID id = isShared
- ? llvm::Intrinsic::nvvm_mbarrier_arrive_shared
- : llvm::Intrinsic::nvvm_mbarrier_arrive;
- return {id, {mt.lookupValue(thisOp.getAddr())}};
+ bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ // bit-0: Space
+ // bit-1: Scope
+ size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
+
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cluster};
+ static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::
+ nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cluster};
+ auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
+
+ // Tidy-up the Intrinsic Args
+ bool needCast = isPtrInGenericSpace(thisOp.getAddr());
+ llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
+ if (needCast)
+ mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
+
+ // When count is not explicitly specified, the default is 1.
+ llvm::LLVMContext &ctx = mt.getLLVMContext();
+ bool hasCount = static_cast<bool>(thisOp.getCount());
+ llvm::Value *count =
+ hasCount ? mt.lookupValue(thisOp.getCount())
+ : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
+
+ return {id, {mbar, count}};
+}
+
+mlir::NVVM::IDArgPair MBarrierArriveDropOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierArriveDropOp>(op);
+
+ bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ // bit-0: Space
+ // bit-1: Scope
+ size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
+
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cluster};
+ static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::
+ nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cluster,
+ llvm::Intrinsic::
+ nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::
+ nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cluster};
+ auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
+
+ // Tidy-up the Intrinsic Args
+ bool needCast = isPtrInGenericSpace(thisOp.getAddr());
+ llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
+ if (needCast)
+ mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
+
+ // When count is not explicitly specified, the default is 1.
+ llvm::LLVMContext &ctx = mt.getLLVMContext();
+ bool hasCount = static_cast<bool>(thisOp.getCount());
+ llvm::Value *count =
+ hasCount ? mt.lookupValue(thisOp.getCount())
+ : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
+
+ return {id, {mbar, count}};
+}
+
+bool MBarrierArriveExpectTxOp::getAsmValues(
+ RewriterBase &rewriter,
+ llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
+ &asmValues) {
+ // Add all the operands but not the attrs to the asmValues list.
+ // The attrs here are used to generate the right variants for
+ // intrinsics-lowering. So, we ignore them while generating inline-PTX.
+ for (auto val : getOperands())
+ asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
+
+ return false;
+}
+
+mlir::NVVM::IDArgPair MBarrierArriveExpectTxOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierArriveExpectTxOp>(op);
+
+ bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ // bit-0: Space
+ // bit-1: Scope
+ size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
+
+ // clang-format off
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cluster};
+ static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cluster};
+ // clang-format on
+ auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
+
+ // Tidy-up the Intrinsic Args
+ llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount());
+ llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
+ bool needCast = isPtrInGenericSpace(thisOp.getAddr());
+ if (needCast)
+ mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
+
+ return {id, {mbar, txcount}};
+}
+
+mlir::NVVM::IDArgPair MBarrierArriveDropExpectTxOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierArriveDropExpectTxOp>(op);
+
+ bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr());
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ // bit-0: Space
+ // bit-1: Scope
+ size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
+
+ // clang-format off
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cluster};
+ static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cluster,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cluster};
+ // clang-format on
+ auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
+
+ // Tidy-up the Intrinsic Args
+ llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount());
+ llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
+ bool needCast = isPtrInGenericSpace(thisOp.getAddr());
+ if (needCast)
+ mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
+
+ return {id, {mbar, txcount}};
}
mlir::NVVM::IDArgPair MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs(
@@ -1813,17 +2726,100 @@ mlir::NVVM::IDArgPair MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs(
return {id, std::move(args)};
}
-mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs(
+mlir::NVVM::IDArgPair MBarrierArriveDropNocompleteOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
- auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
+ auto thisOp = cast<NVVM::MBarrierArriveDropNocompleteOp>(op);
bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
- llvm::Intrinsic::ID id = isShared
- ? llvm::Intrinsic::nvvm_mbarrier_test_wait_shared
- : llvm::Intrinsic::nvvm_mbarrier_test_wait;
+ llvm::Intrinsic::ID id =
+ isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete_shared
+ : llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete;
// Fill the Intrinsic Args
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(thisOp.getAddr()));
- args.push_back(mt.lookupValue(thisOp.getState()));
+ args.push_back(mt.lookupValue(thisOp.getCount()));
+
+ return {id, std::move(args)};
+}
+
+mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
+ bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ // bit-0: isPhaseParity
+ // bit-1: Scope
+ size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0);
+
+ // clang-format off
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta};
+ static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta};
+ // clang-format on
+ auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
+
+ // Tidy-up the Intrinsic Args
+ llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
+ llvm::Value *input = mt.lookupValue(thisOp.getStateOrPhase());
+ bool needCast = isPtrInGenericSpace(thisOp.getAddr());
+ if (needCast)
+ mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
+
+ return {id, {mbar, input}};
+}
+
+mlir::NVVM::IDArgPair MBarrierTryWaitOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierTryWaitOp>(op);
+ bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ bool hasTicks = static_cast<bool>(thisOp.getTicks());
+ // bit-0: isPhaseParity
+ // bit-1: Scope
+ // bit-2: hasTicks
+ size_t index = ((hasTicks ? 1 : 0) << 2) | ((isClusterScope ? 1 : 0) << 1) |
+ (isPhaseParity ? 1 : 0);
+
+ // clang-format off
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cluster_space_cta};
+ static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cluster_space_cta};
+ // clang-format on
+ auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
+
+ // Tidy-up the mbarrier pointer
+ llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
+ bool needCast = isPtrInGenericSpace(thisOp.getAddr());
+ if (needCast)
+ mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
+
+ // Fill the Intrinsic Args
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mbar);
+ args.push_back(mt.lookupValue(thisOp.getStateOrPhase()));
+ if (hasTicks)
+ args.push_back(mt.lookupValue(thisOp.getTicks()));
return {id, std::move(args)};
}
@@ -1914,11 +2910,15 @@ mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
args.push_back(mt.lookupValue(thisOp.getSrcMem()));
args.push_back(mt.lookupValue(thisOp.getSize()));
- // Multicast mask, if available.
+ // Multicast mask for shared::cluster only, if available.
mlir::Value multicastMask = thisOp.getMulticastMask();
const bool hasMulticastMask = static_cast<bool>(multicastMask);
- llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
- args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused);
+ const bool isSharedCTA = isPtrInSharedCTASpace(thisOp.getDstMem());
+ if (!isSharedCTA) {
+ llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
+ args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask)
+ : i16Unused);
+ }
// Cache hint, if available.
mlir::Value cacheHint = thisOp.getL2CacheHint();
@@ -1927,11 +2927,14 @@ mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
// Flag arguments for multicast and cachehint.
- args.push_back(builder.getInt1(hasMulticastMask));
+ if (!isSharedCTA)
+ args.push_back(builder.getInt1(hasMulticastMask));
args.push_back(builder.getInt1(hasCacheHint));
llvm::Intrinsic::ID id =
- llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
+ isSharedCTA
+ ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta
+ : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
return {id, std::move(args)};
}
@@ -2646,30 +3649,100 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
}()
-llvm::Intrinsic::ID ConvertF32x2ToF16x2Op::getIntrinsicID() {
- bool hasRelu = getRelu();
- bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
+NVVM::IDArgPair
+ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ static constexpr llvm::Intrinsic::ID rndRNIds[] = {
+ llvm::Intrinsic::nvvm_ff2f16x2_rn,
+ llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
+ llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
+ llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRZIds[] = {
+ llvm::Intrinsic::nvvm_ff2f16x2_rz,
+ llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
+ llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
+ llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRSIds[] = {
+ llvm::Intrinsic::nvvm_ff2f16x2_rs,
+ llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
+ llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
+ llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
+ };
+
+ unsigned hasRelu = op.getRelu() ? 1 : 0;
+ unsigned hasSatFinite =
+ (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
+ // idx: bit-0 - relu
+ // bit-1 - satfinite
+ unsigned idx = (hasSatFinite << 1) | hasRelu;
- if (hasRelu && hasSatFinite)
- return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite;
- if (hasRelu)
- return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu;
- if (hasSatFinite)
- return llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite;
- return llvm::Intrinsic::nvvm_ff2f16x2_rs;
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getSrcHi()));
+ args.push_back(mt.lookupValue(op.getSrcLo()));
+ if (op.getRandomBits())
+ args.push_back(mt.lookupValue(op.getRandomBits()));
+
+ switch (op.getRnd()) {
+ case FPRoundingMode::RN:
+ return {rndRNIds[idx], std::move(args)};
+ case FPRoundingMode::RZ:
+ return {rndRZIds[idx], std::move(args)};
+ case FPRoundingMode::RS:
+ return {rndRSIds[idx], std::move(args)};
+ default:
+ llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op");
+ }
}
-llvm::Intrinsic::ID ConvertF32x2ToBF16x2Op::getIntrinsicID() {
- bool hasRelu = getRelu();
- bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
+NVVM::IDArgPair
+ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ static constexpr llvm::Intrinsic::ID rndRNIds[] = {
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRZIds[] = {
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRSIds[] = {
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
+ };
- if (hasRelu && hasSatFinite)
- return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite;
- if (hasRelu)
- return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu;
- if (hasSatFinite)
- return llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite;
- return llvm::Intrinsic::nvvm_ff2bf16x2_rs;
+ unsigned hasRelu = op.getRelu() ? 1 : 0;
+ unsigned hasSatFinite =
+ (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
+ // idx: bit-0 - relu
+ // bit-1 - satfinite
+ unsigned idx = (hasSatFinite << 1) | hasRelu;
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getSrcHi()));
+ args.push_back(mt.lookupValue(op.getSrcLo()));
+ if (op.getRandomBits())
+ args.push_back(mt.lookupValue(op.getRandomBits()));
+
+ switch (op.getRnd()) {
+ case FPRoundingMode::RN:
+ return {rndRNIds[idx], std::move(args)};
+ case FPRoundingMode::RZ:
+ return {rndRZIds[idx], std::move(args)};
+ case FPRoundingMode::RS:
+ return {rndRSIds[idx], std::move(args)};
+ default:
+ llvm_unreachable("Invalid rounding mode for ConvertF32x2ToBF16x2Op");
+ }
}
llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
@@ -3010,6 +4083,630 @@ NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
return {intrinsicID, args};
}
+mlir::NVVM::IDArgPair
+PermuteOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::PermuteOp>(op);
+ NVVM::PermuteMode mode = thisOp.getMode();
+
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_prmt, llvm::Intrinsic::nvvm_prmt_f4e,
+ llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8,
+ llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr,
+ llvm::Intrinsic::nvvm_prmt_rc16};
+
+ unsigned modeIndex = static_cast<unsigned>(mode);
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getLo()));
+
+ // Only first 3 modes (Default, f4e, b4e) need the hi operand.
+ if (modeIndex < 3)
+ args.push_back(mt.lookupValue(thisOp.getHi()));
+
+ args.push_back(mt.lookupValue(thisOp.getSelector()));
+
+ return {IDs[modeIndex], args};
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair
+Tcgen05MMAOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+
+ auto thisOp = cast<NVVM::Tcgen05MMAOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixD()));
+
+ llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
+ const bool isATensor = isa<llvm::PointerType>(A->getType());
+ args.push_back(A);
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixB()));
+ args.push_back(mt.lookupValue(thisOp.getIdesc()));
+ args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+
+ using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
+ using CtaGroupArray = std::array<EnableAShiftArray, 2>;
+ using IsATensorArray = std::array<CtaGroupArray, 2>;
+ using HasScaleInputDArray = std::array<IsATensorArray, 2>;
+ using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
+
+ // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift]
+ static constexpr HasDisableOutputLaneArray tcgen05MMAIDs = {
+ { // without diable output lane
+ {{// without scale input D
+ {{
+ // shared
+ {{// cg1
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic}}},
+ {{// tensor
+ {
+ // cg1
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
+ },
+ {
+ // cg2
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
+ }}},
+ }},
+ // with scale input D
+ {{ // shared
+ {{// cg1
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic}}},
+ {{// tensor
+ {
+ // cg1
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
+ },
+ {
+ // cg2
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
+ }}}}}}},
+ // with disable output lane
+ {{ // without scale input D
+ {{ // shared
+ {{// cg1
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
+ notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2,
+ notIntrinsic}}},
+ {{// cg1
+ {
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
+ },
+ // cg2
+ {
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_disable_output_lane_cg2,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift,
+ }}}}},
+ // with scale input D
+ {{ // shared
+ {{// cg1
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
+ notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2,
+ notIntrinsic}}},
+ // tensor
+ {{// cg1
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift},
+ // cg2
+ {
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift,
+ }}}}}}}}};
+
+ llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
+ bool hasScaleInputD = ScaleInputD != nullptr;
+
+ llvm::Value *DisableOutputLane =
+ mt.lookupValue(thisOp.getDisableOutputLane());
+ bool hasDisableOutputLane = DisableOutputLane != nullptr;
+
+ const unsigned ctaGroup =
+ static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
+
+ llvm::Intrinsic::ID ID =
+ tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
+ [ctaGroup - 1][thisOp.getAShift()];
+
+ assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMAOp.");
+
+ if (hasScaleInputD)
+ args.push_back(ScaleInputD);
+
+ if (hasDisableOutputLane)
+ args.push_back(DisableOutputLane);
+
+ args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
+
+ if (!hasDisableOutputLane)
+ args.push_back(builder.getInt32(ctaGroup));
+
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
+
+ return {ID, args};
+}
+
+static LogicalResult
+verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane,
+ NVVM::CTAGroupKind ctaGroup, bool hasAShift,
+ NVVM::Tcgen05MMACollectorOp collectorOp, Location loc) {
+
+ if (disableOutputLane) {
+ mlir::VectorType disableOutputLaneType =
+ cast<mlir::VectorType>(disableOutputLane.getType());
+ if ((ctaGroup == NVVM::CTAGroupKind::CTA_1 &&
+ disableOutputLaneType.getNumElements() != 4) ||
+ (ctaGroup == NVVM::CTAGroupKind::CTA_2 &&
+ disableOutputLaneType.getNumElements() != 8))
+ return emitError(loc) << "Disable Output Lane of length "
+ << disableOutputLaneType.getNumElements()
+ << " is incompatible with CtaGroupAttr";
+ }
+
+ if (hasAShift && !isATensor)
+ return emitError(
+ loc, "A-shift can be applied only when matrix A is in tensor memory");
+
+ if (hasAShift == true && (collectorOp == Tcgen05MMACollectorOp::FILL ||
+ collectorOp == Tcgen05MMACollectorOp::USE))
+ return emitError(
+ loc, "Cannot use collector buffer operation fill or use with ashift");
+
+ return success();
+}
+
+LogicalResult Tcgen05MMAOp::verify() {
+ return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
+ getDisableOutputLane(), getCtaGroup(), getAShift(),
+ getCollectorOp(), getLoc());
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma.sp functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair Tcgen05MMASparseOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+
+ auto thisOp = cast<NVVM::Tcgen05MMASparseOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixD()));
+
+ llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
+ bool isATensor = isa<llvm::PointerType>(A->getType());
+ args.push_back(A);
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixB()));
+ args.push_back(mt.lookupValue(thisOp.getIdesc()));
+ args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+ args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
+
+ using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
+ using CtaGroupArray = std::array<EnableAShiftArray, 2>;
+ using IsATensorArray = std::array<CtaGroupArray, 2>;
+ using HasScaleInputDArray = std::array<IsATensorArray, 2>;
+ using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
+
+ // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift]
+ static constexpr HasDisableOutputLaneArray tcgen05MMASparseIDs = {
+ { // without diable output lane
+ {{// without scale input D
+ {{
+ // shared
+ {{// cg1
+ {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic}}},
+ {{// tensor
+ {
+ // cg1
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
+ },
+ {
+ // cg2
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
+ }}},
+ }},
+ // with scale input D
+ {{ // shared
+ {{// cg1
+ {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
+ notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
+ notIntrinsic}}},
+ {{// tensor
+ {
+ // cg1
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
+ },
+ {
+ // cg2
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
+ }}}}}}},
+ // with disable output lane
+ {{ // without scale input D
+ {{ // shared
+ {{// cg1
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1,
+ notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2,
+ notIntrinsic}}},
+ {{// cg1
+ {
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift,
+ },
+ // cg2
+ {
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift,
+ }}}}},
+ // with scale input D
+ {{ // shared
+ {{// cg1
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1,
+ notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2,
+ notIntrinsic}}},
+ // tensor
+ {{// cg1
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift},
+ // cg2
+ {
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift,
+ }}}}}}}}};
+
+ llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
+ bool hasScaleInputD = ScaleInputD != nullptr;
+
+ llvm::Value *DisableOutputLane =
+ mt.lookupValue(thisOp.getDisableOutputLane());
+ bool hasDisableOutputLane = DisableOutputLane != nullptr;
+
+ unsigned ctaGroup =
+ static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
+
+ llvm::Intrinsic::ID ID =
+ tcgen05MMASparseIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
+ [ctaGroup - 1][thisOp.getAShift()];
+
+ assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMASparseOp.");
+
+ if (hasScaleInputD)
+ args.push_back(ScaleInputD);
+
+ if (hasDisableOutputLane)
+ args.push_back(DisableOutputLane);
+
+ args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
+
+ if (!hasDisableOutputLane)
+ args.push_back(builder.getInt32(ctaGroup));
+
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
+
+ return {ID, args};
+}
+
+LogicalResult Tcgen05MMASparseOp::verify() {
+ return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
+ getDisableOutputLane(), getCtaGroup(), getAShift(),
+ getCollectorOp(), getLoc());
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma.block_scale functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+
+ auto thisOp = cast<NVVM::Tcgen05MMABlockScaleOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixD()));
+
+ llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
+ bool isATensor = isa<llvm::PointerType>(A->getType());
+ args.push_back(A);
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixB()));
+ args.push_back(mt.lookupValue(thisOp.getIdesc()));
+ args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+ args.push_back(mt.lookupValue(thisOp.getScaleA()));
+ args.push_back(mt.lookupValue(thisOp.getScaleB()));
+ args.push_back(builder.getInt32(
+ static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()))));
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
+
+ auto kind = thisOp.getKind();
+ auto blockScale = thisOp.getBlockScale();
+ llvm::Intrinsic::ID ID = [&]() {
+ if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF8F6F4) {
+ if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
+ return isATensor ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale;
+ } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
+ return isATensor
+ ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale_block32
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32;
+ }
+ } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4) {
+ if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
+ return isATensor
+ ? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale
+ : llvm::Intrinsic::nvvm_tcgen05_mma_shared_mxf4_block_scale;
+ } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
+ return isATensor ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_mxf4_block_scale_block32
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_mxf4_block_scale_block32;
+ }
+ } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4NVF4) {
+ if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
+ return isATensor
+ ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block32
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block32;
+
+ } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
+ return isATensor
+ ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block16
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block16;
+ }
+ }
+ llvm_unreachable("Invalid tcgen05.mma.block_scale attributes");
+ }();
+
+ return {ID, args};
+}
+
+static LogicalResult
+verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp,
+ NVVM::Tcgen05MMABlockScaleKind kind,
+ NVVM::Tcgen05MMABlockScale blockScale,
+ Location loc) {
+
+ if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT &&
+ kind == Tcgen05MMABlockScaleKind::MXF4NVF4)
+ return emitError(loc, "mxf4nvf4 requires block scale attribute");
+
+ if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 &&
+ kind != Tcgen05MMABlockScaleKind::MXF4NVF4)
+ return emitError(loc,
+ llvm::formatv("{} kind does not support block16 attribute",
+ stringifyEnum(kind)));
+
+ return success();
+}
+
+LogicalResult Tcgen05MMABlockScaleOp::verify() {
+ return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(),
+ getBlockScale(), getLoc());
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma.sp.block_scale functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+
+ auto thisOp = cast<NVVM::Tcgen05MMASparseBlockScaleOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixD()));
+
+ llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
+ bool isATensor = isa<llvm::PointerType>(A->getType());
+ args.push_back(A);
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixB()));
+ args.push_back(mt.lookupValue(thisOp.getIdesc()));
+ args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+ args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
+ args.push_back(mt.lookupValue(thisOp.getScaleA()));
+ args.push_back(mt.lookupValue(thisOp.getScaleB()));
+ args.push_back(builder.getInt32(
+ static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()))));
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
+
+ auto kind = thisOp.getKind();
+ auto blockScale = thisOp.getBlockScale();
+ llvm::Intrinsic::ID ID = [&]() {
+ if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF8F6F4) {
+ if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
+ return isATensor ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale;
+ } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
+ return isATensor
+ ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale_block32
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32;
+ }
+ } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4) {
+ if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
+ return isATensor ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_mxf4_block_scale;
+ } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
+ return isATensor
+ ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale_block32
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32;
+ }
+ } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4NVF4) {
+ if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
+ return isATensor
+ ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block32
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block32;
+
+ } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
+ return isATensor
+ ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block16
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block16;
+ }
+ }
+ llvm_unreachable("Invalid tcgen05.mma.sp.block_scale attributes");
+ }();
+
+ return {ID, args};
+}
+
+LogicalResult Tcgen05MMASparseBlockScaleOp::verify() {
+ return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(),
+ getBlockScale(), getLoc());
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma.ws functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair Tcgen05MMAWsOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+
+ auto thisOp = cast<NVVM::Tcgen05MMAWsOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixD()));
+
+ llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
+ bool isATensor = isa<llvm::PointerType>(A->getType());
+ args.push_back(A);
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixB()));
+ args.push_back(mt.lookupValue(thisOp.getIdesc()));
+ args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+
+ mlir::Value ZeroColMask = thisOp.getZeroColMask();
+ llvm::Intrinsic::ID ID = notIntrinsic;
+ if (ZeroColMask) {
+ args.push_back(mt.lookupValue(ZeroColMask));
+ ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor_zero_col_mask
+ : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared_zero_col_mask;
+ } else
+ ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor
+ : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared;
+
+ args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer())));
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
+
+ return {ID, args};
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma.ws.sp functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair Tcgen05MMAWsSparseOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+
+ auto thisOp = cast<NVVM::Tcgen05MMAWsSparseOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixD()));
+
+ llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
+ bool isATensor = isa<llvm::PointerType>(A->getType());
+ args.push_back(A);
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixB()));
+ args.push_back(mt.lookupValue(thisOp.getIdesc()));
+ args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+ args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
+
+ mlir::Value ZeroColMask = thisOp.getZeroColMask();
+ llvm::Intrinsic::ID ID = notIntrinsic;
+ if (ZeroColMask) {
+ args.push_back(mt.lookupValue(ZeroColMask));
+ ID = isATensor
+ ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor_zero_col_mask
+ : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared_zero_col_mask;
+ } else
+ ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor
+ : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared;
+
+ args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer())));
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
+
+ return {ID, args};
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
@@ -3213,16 +4910,20 @@ LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
"Minimum NVVM target SM version is sm_20");
}
- gpuModuleOp->walk([&](Operation *op) {
- if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
- const NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion();
- if (!requirement.isCompatibleWith(targetSMVersion)) {
- op->emitOpError() << "is not supported on " << getChip();
- return WalkResult::interrupt();
- }
- }
- return WalkResult::advance();
- });
+ if (gpuModuleOp
+ ->walk([&](Operation *op) {
+ if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
+ const NVVMCheckSMVersion requirement =
+ reqOp.getRequiredMinSMVersion();
+ if (!requirement.isCompatibleWith(targetSMVersion)) {
+ op->emitOpError() << "is not supported on " << getChip();
+ return WalkResult::interrupt();
+ }
+ }
+ return WalkResult::advance();
+ })
+ .wasInterrupted())
+ return failure();
return success();
}
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp
index 67573c4..12dd225 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp
@@ -109,8 +109,12 @@ static Location getNestedLoc(Operation *op, LLVM::DIScopeAttr scopeAttr,
return FusedLoc::get(context, {loc}, lexicalBlockFileAttr);
}
+/// Adds DILexicalBlockFileAttr for operations with CallSiteLoc and operations
+/// from different files than their containing function.
static void setLexicalBlockFileAttr(Operation *op) {
- if (auto callSiteLoc = dyn_cast<CallSiteLoc>(op->getLoc())) {
+ Location opLoc = op->getLoc();
+
+ if (auto callSiteLoc = dyn_cast<CallSiteLoc>(opLoc)) {
auto callerLoc = callSiteLoc.getCaller();
auto calleeLoc = callSiteLoc.getCallee();
LLVM::DIScopeAttr scopeAttr;
@@ -122,6 +126,45 @@ static void setLexicalBlockFileAttr(Operation *op) {
op->setLoc(
CallSiteLoc::get(getNestedLoc(op, scopeAttr, calleeLoc), callerLoc));
}
+
+ return;
+ }
+
+ auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
+ if (!funcOp)
+ return;
+
+ FileLineColLoc opFileLoc = extractFileLoc(opLoc);
+ if (!opFileLoc)
+ return;
+
+ FileLineColLoc funcFileLoc = extractFileLoc(funcOp.getLoc());
+ if (!funcFileLoc)
+ return;
+
+ StringRef opFile = opFileLoc.getFilename().getValue();
+ StringRef funcFile = funcFileLoc.getFilename().getValue();
+
+ // Handle cross-file operations: add DILexicalBlockFileAttr when the
+ // operation's source file differs from its containing function.
+ if (opFile != funcFile) {
+ auto funcOpLoc = llvm::dyn_cast_if_present<FusedLoc>(funcOp.getLoc());
+ if (!funcOpLoc)
+ return;
+ auto scopeAttr = dyn_cast<LLVM::DISubprogramAttr>(funcOpLoc.getMetadata());
+ if (!scopeAttr)
+ return;
+
+ auto *context = op->getContext();
+ LLVM::DIFileAttr opFileAttr =
+ LLVM::DIFileAttr::get(context, llvm::sys::path::filename(opFile),
+ llvm::sys::path::parent_path(opFile));
+
+ LLVM::DILexicalBlockFileAttr lexicalBlockFileAttr =
+ LLVM::DILexicalBlockFileAttr::get(context, scopeAttr, opFileAttr, 0);
+
+ Location newLoc = FusedLoc::get(context, {opLoc}, lexicalBlockFileAttr);
+ op->setLoc(newLoc);
}
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index dcc1ef9..b4b1347 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -1057,12 +1057,15 @@ LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
// FillOpInterface implementation
//===----------------------------------------------------------------------===//
+namespace {
enum class MatchFillResult {
Success = 0,
NotLinalgOp,
WrongNumOperands,
- NotScalarInput
+ NotScalarInput,
+ TypeMismatch
};
+} // namespace
static MatchFillResult isFillInterfaceImpl(Operation *op) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
@@ -1075,17 +1078,33 @@ static MatchFillResult isFillInterfaceImpl(Operation *op) {
if (!linalgOp.isScalar(value))
return MatchFillResult::NotScalarInput;
+ // Check that the scalar input type matches the output element type.
+ OpOperand *output = linalgOp.getDpsInitOperand(0);
+ Type scalarType = value->get().getType();
+ Type outputElementType = getElementTypeOrSelf(output->get().getType());
+ if (scalarType != outputElementType)
+ return MatchFillResult::TypeMismatch;
+
return MatchFillResult::Success;
}
LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
- auto res = isFillInterfaceImpl(op);
+ MatchFillResult res = isFillInterfaceImpl(op);
if (res == MatchFillResult::NotLinalgOp)
return op->emitError("expected a LinalgOp");
if (res == MatchFillResult::WrongNumOperands)
return op->emitError("expected op with 1 input and 1 output");
if (res == MatchFillResult::NotScalarInput)
return op->emitError("expected op with scalar input");
+ if (res == MatchFillResult::TypeMismatch) {
+ auto linalgOp = cast<linalg::LinalgOp>(op);
+ Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType();
+ Type outputElementType =
+ getElementTypeOrSelf(linalgOp.getDpsInitOperand(0)->get().getType());
+ return op->emitOpError("expected fill value type (")
+ << scalarType << ") to match output element type ("
+ << outputElementType << ")";
+ }
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3dc45ed..33ec79b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1338,8 +1338,6 @@ Speculation::Speculatability GenericOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
-LogicalResult GenericOp::verify() { return success(); }
-
namespace {
/// Remove linalg operations that are just copying the values from inputs to
@@ -2091,7 +2089,7 @@ LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
return failure();
// Single dimension transpose.
- if (getPermutation().size() == 0) {
+ if (getPermutation().empty()) {
result.push_back(getInput());
return success();
}
@@ -4885,13 +4883,6 @@ void ElementwiseOp::print(OpAsmPrinter &p) {
elidedAttrs);
}
-LogicalResult ElementwiseOp::verify() {
- // All necessary checks are done either by
- // - EnumAttr (e.g. unknown operation kind)
- // - verifyStructuredOpInterface (incorrect map, sizes).
- return success();
-}
-
/// Implements the block region builder for the ElementwiseOp. This is called by
/// 'fillStructuredOpRegion'.
void ElementwiseOp::regionBuilder(
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index aa82063..b8c1bad 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -176,7 +176,8 @@ static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(
if (auto attr = dyn_cast<Attribute>(paramOrHandle)) {
reified.push_back(cast<IntegerAttr>(attr).getInt());
continue;
- } else if (isa<ParamType>(cast<Value>(paramOrHandle).getType())) {
+ }
+ if (isa<ParamType>(cast<Value>(paramOrHandle).getType())) {
ArrayRef<Attribute> params = state.getParams(cast<Value>(paramOrHandle));
if (params.size() != 1)
return transformOp.emitSilenceableError() << "expected a single param";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 22690da..9e6c1e6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -747,8 +747,7 @@ struct RankReducedExtractSliceOp
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
auto rankReducedType = cast<RankedTensorType>(
tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
- reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
- strides));
+ reassociation->size(), sliceOp.getSourceType(), sizes));
Location loc = sliceOp.getLoc();
Value newSlice = tensor::ExtractSliceOp::create(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 05fc7cb..421ab5e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1038,6 +1038,62 @@ private:
ControlFusionFn controlFoldingReshapes;
};
+/// Carries information about a padded dimension.
+struct PadDimInfo {
+ // The resulting shape after padding each dimension.
+ SmallVector<int64_t> paddedShape;
+
+ // Low and high padding amounts for each dimension.
+ SmallVector<OpFoldResult> lowPad;
+ SmallVector<OpFoldResult> highPad;
+};
+
+/// Computes the expanded padding information for the given pad operation based
+/// on the provided expanded shape and reassociation indices. Returns a list of
+/// PadDimInfo containing the low and high padding amounts and the padded
+/// size for each dimension, or failure if the expansion is not possible.
+static FailureOr<PadDimInfo>
+computeExpandedPadding(tensor::PadOp padOp, ArrayRef<int64_t> expandedShape,
+ ArrayRef<ReassociationIndices> reassociations,
+ PatternRewriter &rewriter) {
+ // If the padding value depends on the index values of the pad operation,
+ // then it may not be valid to expand the dimensions, since it will change
+ // the index values on which the padding value depends. This is not currently
+ // supported by the pad expansion patterns, but it could be implemented
+ // similarly to the expansion of linalg.generic ops with linalg.index ops in
+ // the body, as is done in `updateExpandedGenericOpRegion`.
+ if (!padOp.getConstantPaddingValue())
+ return failure();
+
+ // Expanded dimensions cannot have padding because the resulting padding may
+ // not be representable by a tensor.pad op. There are some special cases where
+ // it is possible (like expanding unit dims), but supporting these cases is
+ // NYI, so disallow it for now.
+ ArrayRef<int64_t> low = padOp.getStaticLow();
+ ArrayRef<int64_t> high = padOp.getStaticHigh();
+ for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
+ if (reInd.size() != 1 && (l != 0 || h != 0))
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
+ SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
+ ArrayRef<int64_t> paddedShape = padOp.getResultType().getShape();
+ PadDimInfo padDimInfo;
+ padDimInfo.paddedShape.assign(expandedShape);
+ padDimInfo.lowPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
+ padDimInfo.highPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ if (reInd.size() == 1) {
+ padDimInfo.paddedShape[reInd[0]] = paddedShape[idx];
+ padDimInfo.lowPad[reInd[0]] = mixedLowPad[idx];
+ padDimInfo.highPad[reInd[0]] = mixedHighPad[idx];
+ }
+ }
+
+ return padDimInfo;
+}
+
class FoldPadWithProducerReshapeOpByExpansion
: public OpRewritePattern<tensor::PadOp> {
public:
@@ -1053,46 +1109,96 @@ public:
padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
if (!reshapeOp)
return failure();
- if (!reshapeOp->hasOneUse())
- return failure();
if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
return rewriter.notifyMatchFailure(padOp,
"fusion blocked by control function");
}
- ArrayRef<int64_t> low = padOp.getStaticLow();
- ArrayRef<int64_t> high = padOp.getStaticHigh();
+ RankedTensorType expandedType = reshapeOp.getSrcType();
SmallVector<ReassociationIndices> reassociations =
reshapeOp.getReassociationIndices();
+ FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
+ padOp, expandedType.getShape(), reassociations, rewriter);
+ if (failed(maybeExpandedPadding))
+ return failure();
+ PadDimInfo &expandedPadding = maybeExpandedPadding.value();
- for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
- if (reInd.size() != 1 && (l != 0 || h != 0))
- return failure();
+ Location loc = padOp->getLoc();
+ RankedTensorType expandedPaddedType =
+ padOp.getResultType().clone(expandedPadding.paddedShape);
+
+ auto newPadOp = tensor::PadOp::create(
+ rewriter, loc, expandedPaddedType, reshapeOp.getSrc(),
+ expandedPadding.lowPad, expandedPadding.highPad,
+ padOp.getConstantPaddingValue(), padOp.getNofold());
+
+ rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+ padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
+
+ return success();
+ }
+
+private:
+ ControlFusionFn controlFoldingReshapes;
+};
+
+class FoldReshapeWithProducerPadOpByExpansion
+ : public OpRewritePattern<tensor::ExpandShapeOp> {
+public:
+ FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+ PatternRewriter &rewriter) const override {
+ tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+ if (!padOp)
+ return failure();
+
+ if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+ return rewriter.notifyMatchFailure(expandOp,
+ "fusion blocked by control function");
}
- SmallVector<OpFoldResult> newLow, newHigh;
- RankedTensorType expandedType = reshapeOp.getSrcType();
- RankedTensorType paddedType = padOp.getResultType();
- SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
+ RankedTensorType expandedType = expandOp.getResultType();
+ SmallVector<ReassociationIndices> reassociations =
+ expandOp.getReassociationIndices();
+ FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
+ padOp, expandedType.getShape(), reassociations, rewriter);
+ if (failed(maybeExpandedPadding))
+ return failure();
+ PadDimInfo &expandedPadding = maybeExpandedPadding.value();
+
+ Location loc = expandOp->getLoc();
+ SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape();
+ SmallVector<int64_t> newExpandedShape(expandedType.getShape());
+ rewriter.setInsertionPointAfterValue(padOp.getSource());
+ SmallVector<OpFoldResult> padSrcSizes =
+ tensor::getMixedSizes(rewriter, loc, padOp.getSource());
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ // We know that any reassociation with multiple dims is not padded because
+ // of the requirements of computeExpandedPadding.
if (reInd.size() == 1) {
- expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
- }
- for (size_t i = 0; i < reInd.size(); ++i) {
- newLow.push_back(padOp.getMixedLowPad()[idx]);
- newHigh.push_back(padOp.getMixedHighPad()[idx]);
+ newExpandedShape[reInd[0]] = padOp.getSourceType().getDimSize(idx);
+ newExpandedSizes[reInd[0]] = padSrcSizes[idx];
}
}
-
- Location loc = padOp->getLoc();
- RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
+ RankedTensorType newExpandedType = expandedType.clone(newExpandedShape);
+ auto newExpandOp = tensor::ExpandShapeOp::create(
+ rewriter, loc, newExpandedType, padOp.getSource(), reassociations,
+ newExpandedSizes);
+ RankedTensorType expandedPaddedType =
+ padOp.getResultType().clone(expandedPadding.paddedShape);
+ rewriter.setInsertionPoint(expandOp);
auto newPadOp = tensor::PadOp::create(
- rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
+ rewriter, loc, expandedPaddedType, newExpandOp.getResult(),
+ expandedPadding.lowPad, expandedPadding.highPad,
padOp.getConstantPaddingValue(), padOp.getNofold());
- rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
- padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
+ rewriter.replaceOp(expandOp, newPadOp.getResult());
return success();
}
@@ -1921,6 +2027,62 @@ private:
ControlFusionFn controlFoldingReshapes;
};
+/// Computes the collapsed padding information for the given pad operation based
+/// on the provided collapsed shape and reassociation indices. Returns a
+/// PadDimInfo containing the low and high padding amounts and the collapsed
+/// shape for each dimension, or failure if the collapse is not possible.
+static FailureOr<PadDimInfo>
+computeCollapsedPadding(tensor::PadOp padOp,
+ ArrayRef<ReassociationIndices> reassociations,
+ PatternRewriter &rewriter) {
+ // If the padding value depends on the index values of the pad operation,
+ // then it may not be valid to collapse the dimensions, since it will change
+ // the index values on which the padding value depends. This is not currently
+ // supported by the pad collapsing patterns, but it could be implemented
+ // similarly to the collapsing of linalg.generic ops with linalg.index ops in
+ // the body, as is done in `generateCollapsedIndexingRegion`.
+ if (!padOp.getConstantPaddingValue())
+ return failure();
+
+ // Collapsed dimensions cannot have padding because this can produce strided
+ // padding that isn't representable by a tensor.pad op. There are some special
+ // cases where it is possible (like collapsing unit dims), but supporting
+ // these cases is NYI, so disallow it for now.
+ ArrayRef<int64_t> low = padOp.getStaticLow();
+ ArrayRef<int64_t> high = padOp.getStaticHigh();
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ for (int64_t dim : reInd) {
+ if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1)
+ return failure();
+ }
+ }
+
+ // Initialize padding values for collapsed tensors with zeros
+ ArrayRef<int64_t> expandedPaddedShape = padOp.getType().getShape();
+ PadDimInfo padDimInfo;
+ padDimInfo.lowPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
+ padDimInfo.highPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
+
+ // Update padding for dimensions that are not being collapsed, and compute
+ // the collapsed padded shape.
+ SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
+ SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ if (reInd.size() == 1) {
+ padDimInfo.lowPad[idx] = mixedLowPad[reInd[0]];
+ padDimInfo.highPad[idx] = mixedHighPad[reInd[0]];
+ }
+ SaturatedInteger collapsedSize = SaturatedInteger::wrap(1);
+ for (int64_t dim : reInd) {
+ collapsedSize =
+ collapsedSize * SaturatedInteger::wrap(expandedPaddedShape[dim]);
+ }
+ padDimInfo.paddedShape.push_back(collapsedSize.asInteger());
+ }
+
+ return padDimInfo;
+}
+
class FoldPadWithProducerReshapeOpByCollapsing
: public OpRewritePattern<tensor::PadOp> {
public:
@@ -1936,57 +2098,40 @@ public:
padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
if (!reshapeOp)
return failure();
- if (!reshapeOp->hasOneUse())
- return failure();
if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
return rewriter.notifyMatchFailure(padOp,
"fusion blocked by control function");
}
- ArrayRef<int64_t> low = padOp.getStaticLow();
- ArrayRef<int64_t> high = padOp.getStaticHigh();
SmallVector<ReassociationIndices> reassociations =
reshapeOp.getReassociationIndices();
+ FailureOr<PadDimInfo> maybeCollapsedPadding =
+ computeCollapsedPadding(padOp, reassociations, rewriter);
+ if (failed(maybeCollapsedPadding))
+ return failure();
+ PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();
- for (auto reInd : reassociations) {
- if (reInd.size() == 1)
- continue;
- if (llvm::any_of(reInd, [&](int64_t ind) {
- return low[ind] != 0 || high[ind] != 0;
- })) {
- return failure();
- }
- }
-
- SmallVector<OpFoldResult> newLow, newHigh;
- RankedTensorType collapsedType = reshapeOp.getSrcType();
- RankedTensorType paddedType = padOp.getResultType();
- SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
- SmallVector<OpFoldResult> expandedPaddedSizes(
- getMixedValues(reshapeOp.getStaticOutputShape(),
- reshapeOp.getOutputShape(), rewriter));
+ SmallVector<OpFoldResult> expandedPaddedSizes =
+ reshapeOp.getMixedOutputShape();
AffineExpr d0, d1, d2;
bindDims(rewriter.getContext(), d0, d1, d2);
auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
Location loc = reshapeOp->getLoc();
- for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
- OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
- OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
+ for (auto [reInd, l, h] :
+ llvm::zip_equal(reassociations, collapsedPadding.lowPad,
+ collapsedPadding.highPad)) {
if (reInd.size() == 1) {
- collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
- OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
+ expandedPaddedSizes[reInd[0]] = affine::makeComposedFoldedAffineApply(
rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
- expandedPaddedSizes[reInd[0]] = paddedSize;
}
- newLow.push_back(l);
- newHigh.push_back(h);
}
RankedTensorType collapsedPaddedType =
- paddedType.clone(collapsedPaddedShape);
+ padOp.getType().clone(collapsedPadding.paddedShape);
auto newPadOp = tensor::PadOp::create(
- rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
+ rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(),
+ collapsedPadding.lowPad, collapsedPadding.highPad,
padOp.getConstantPaddingValue(), padOp.getNofold());
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
@@ -2000,6 +2145,52 @@ private:
ControlFusionFn controlFoldingReshapes;
};
+class FoldReshapeWithProducerPadOpByCollapsing
+ : public OpRewritePattern<tensor::CollapseShapeOp> {
+public:
+ FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
+ PatternRewriter &rewriter) const override {
+ tensor::PadOp padOp = reshapeOp.getSrc().getDefiningOp<tensor::PadOp>();
+ if (!padOp)
+ return failure();
+
+ if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
+ return rewriter.notifyMatchFailure(padOp,
+ "fusion blocked by control function");
+ }
+
+ SmallVector<ReassociationIndices> reassociations =
+ reshapeOp.getReassociationIndices();
+ RankedTensorType collapsedPaddedType = reshapeOp.getResultType();
+ FailureOr<PadDimInfo> maybeCollapsedPadding =
+ computeCollapsedPadding(padOp, reassociations, rewriter);
+ if (failed(maybeCollapsedPadding))
+ return failure();
+ PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();
+
+ Location loc = reshapeOp->getLoc();
+ auto newCollapseOp = tensor::CollapseShapeOp::create(
+ rewriter, loc, padOp.getSource(), reassociations);
+
+ auto newPadOp = tensor::PadOp::create(
+ rewriter, loc, collapsedPaddedType, newCollapseOp.getResult(),
+ collapsedPadding.lowPad, collapsedPadding.highPad,
+ padOp.getConstantPaddingValue(), padOp.getNofold());
+
+ rewriter.replaceOp(reshapeOp, newPadOp.getResult());
+ return success();
+ }
+
+private:
+ ControlFusionFn controlFoldingReshapes;
+};
+
/// Pattern to collapse dimensions.
template <typename LinalgType>
class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
@@ -2239,6 +2430,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
+ patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
+ controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
}
@@ -2250,6 +2443,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
patterns.getContext(), controlFoldingReshapes);
+ patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
+ patterns.getContext(), controlFoldingReshapes);
patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
controlFoldingReshapes);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
index 9974ccd..cbd6357 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
@@ -200,10 +200,10 @@ static void populateOpPayload(
SmallVector<OpOperand *> newInputOperands = newOp.getDpsInputOperands();
updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
- SmallVector<OpOperand *> origOutputOperands = llvm::to_vector(llvm::map_range(
- genericOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
- SmallVector<OpOperand *> newOutputOperands = llvm::to_vector(llvm::map_range(
- newOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
+ SmallVector<OpOperand *> origOutputOperands =
+ llvm::to_vector(llvm::make_pointer_range(genericOp.getDpsInitsMutable()));
+ SmallVector<OpOperand *> newOutputOperands =
+ llvm::to_vector(llvm::make_pointer_range(newOp.getDpsInitsMutable()));
updateReplacements(origOutputOperands, newOutputOperands,
origOutsToNewOutsPos);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 9436f1c..161d978 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -913,8 +913,7 @@ static Value replaceByPackingResult(RewriterBase &rewriter,
llvm_unreachable("loop independence prerequisite not met");
// offsets = [maybe_leading_ivs = originalLoopIvs, 0 .. 0].
- std::copy(loopIterationCounts.begin(), loopIterationCounts.end(),
- offsets.begin());
+ llvm::copy(loopIterationCounts, offsets.begin());
hoistedPackedTensor =
scf::getForInductionVarOwner(packingResult.clonedLoopIvs.front())
->getResult(0);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 40fc0d6..c2485a0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -237,6 +237,69 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}
+/// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy`
+/// with `dilations` and `strides`.
+template <typename ConvOpTy>
+static FailureOr<LinalgOp>
+specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
+ ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
+ SmallVector<Value> inputs = genericOp.getDpsInputs();
+ ValueRange outputs = genericOp.getDpsInits();
+ SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
+ ? TypeRange(ValueRange(outputs))
+ : TypeRange{};
+ LinalgOp namedOp;
+ // Ops with no dilations and no strides.
+ if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
+ std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
+ std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
+ namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes,
+ inputs, outputs);
+ } else {
+ Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
+ Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
+ namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
+ genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
+ }
+ return namedOp;
+}
+
+/// Converts linalg.generic to named linalg.*conv/pooling* where possible.
+static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ SmallVector<int64_t> dilations, strides;
+#define CONV_OP_SPECIALIZER(ConvOpTy) \
+ if (isaConvolutionOpOfType<ConvOpTy>(genericOp, &dilations, &strides)) \
+ return specializeToConvOp<ConvOpTy>(rewriter, genericOp, dilations, \
+ strides); \
+ // -----------------------------
+ // Convolution ops.
+ // -----------------------------
+ CONV_OP_SPECIALIZER(linalg::Conv1DOp);
+ CONV_OP_SPECIALIZER(linalg::Conv1DNwcWcfOp);
+ CONV_OP_SPECIALIZER(linalg::Conv1DNcwFcwOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DOp);
+ CONV_OP_SPECIALIZER(linalg::Conv3DOp);
+ // -----------------------------
+ // Depthwise Convolution ops.
+ // -----------------------------
+ CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNcwCwOp);
+ CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp);
+ CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcmOp);
+ CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp);
+ CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp);
+ // -----------------------------
+ // Pooling ops.
+ // -----------------------------
+ CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxOp);
+ CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinOp);
+ CONV_OP_SPECIALIZER(linalg::PoolingNhwcSumOp);
+ CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxUnsignedOp);
+ CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinUnsignedOp);
+#undef CONV_OP_SPECIALIZER
+ return failure();
+}
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -316,6 +379,11 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
if (isaContractionOpInterface(genericOp)) {
return specializeLinalgContractions(rewriter, genericOp);
}
+
+ // Convolution - e.g. *conv/pooling*
+ if (isaConvolutionOpInterface(genericOp)) {
+ return specializeLinalgConvolutions(rewriter, genericOp);
+ }
return failure();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 705d6f2..8e14ef4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -452,8 +452,7 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
SmallVector<OpFoldResult> allShapeSizes =
op.createFlatListOfOperandDims(b, op.getLoc());
AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap();
- if (!shapeSizesToLoopsMap)
- return failure();
+ assert(shapeSizesToLoopsMap && "invalid linalgOp with null ShapesToLoopsMap");
auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges(
b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 8a0440b..50a84ac 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -167,7 +167,7 @@ struct LinalgOpTilingInterface
llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) {
auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr);
if (!dimExpr)
- continue;
+ return failure();
unsigned position = dimExpr.getPosition();
auto it = mappedOffsets.find(position);
if (it != mappedOffsets.end()) {
@@ -357,6 +357,32 @@ struct LinalgOpTilingInterface
/// Inline the op payload and store the result.
return inlinePayload(builder, linalgOp, ivs, indexedValues);
}
+
+ bool isOpFusableWithConsumerSlice(Operation *op, unsigned resultNumber,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) const {
+ // The verifier gives all the necessary requirements for consumer fusion.
+ return true;
+ }
+
+ bool isOpFusableWithProducerSlices(
+ Operation *op, ArrayRef<unsigned> operandNumbers,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
+
+ auto linalgOp = cast<LinalgOp>(op);
+ SmallVector<AffineMap> indexingMaps =
+ llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) {
+ OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
+ return linalgOp.getMatchingIndexingMap(&opOperand);
+ });
+ // Check that offsets/sizes are consistent across all operands.
+ OpBuilder b(op);
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+ return succeeded(getMappedOffsetAndSize(linalgOp, b, indexingMaps,
+ allOffsets, allSizes, mappedOffsets,
+ mappedSizes));
+ }
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 027268c..67e2b9f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1167,12 +1167,9 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
"this is not supported ATM!");
}
- Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
- Attribute oneIdxAttr = rewriter.getIndexAttr(1);
Location loc = packOp.getLoc();
int64_t srcRank = packOp.getSourceRank();
- int64_t destRank = packOp.getDestRank();
// 1. Get the input that is going to be packed. If the input requires padding,
// add a padding operation and return that as the input.
@@ -1262,14 +1259,8 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
writeSizes.push_back(tileSizeOfr);
}
- // TODO: Add a constructor for tensor.insert_slice that doesn't require
- // strides nor offsets.
- SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
- SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
-
auto insert = tensor::InsertSliceOp::create(
- rewriter, loc, transposedOp.getResult()[0], packOp.getDest(),
- writeOffsets, writeSizes, writeStrides);
+ rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeSizes);
// 4. Replace tensor.packOp with tensor.insert_slice created above
rewriter.replaceOp(packOp, insert.getResult());
@@ -1279,7 +1270,6 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const {
- int64_t srcRank = unpackOp.getSourceRank();
int64_t destRank = unpackOp.getDestRank();
ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
@@ -1296,7 +1286,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
Value source = unpackOp.getSource();
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
unpackOp.getDimAndTileMapping();
- Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
// The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
@@ -1307,9 +1296,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
// outer-tiled-dims being all 1), this will be
// [ outer-untiled-dims, tile-sizes ]
SmallVector<OpFoldResult> extractSliceSizes;
- // The offset and strides attributes for ExtractSliceOp.
- SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr);
- SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr);
// Shape for EmptyOp that's used as the init value for TransposeOp below.
// This should be:
@@ -1364,8 +1350,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
Type elemType = unpackOp.getSourceType().getElementType();
auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
Value innerTile = tensor::ExtractSliceOp::create(
- rewriter, loc, readType, unpackOp.getSource(), extractSliceOffsets,
- extractSliceSizes, extractSliceStrides);
+ rewriter, loc, readType, unpackOp.getSource(), extractSliceSizes);
// 2. Transpose the tile to match the outer corresponding tile order.
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
@@ -1381,9 +1366,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
// 3. Handle in-complete tiles if needed. It truncates trailing data from the
// transposed tile.
- int numLoops = shapeForEmptyOp.size();
- SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
- SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
SmallVector<OpFoldResult> tileSizes;
ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
for (auto i : llvm::seq<unsigned>(0, destRank)) {
@@ -1393,13 +1375,11 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
}
auto partialTile =
- tensor::ExtractSliceOp::create(rewriter, loc, transposedOp.getResult()[0],
- tileOffsets, tileSizes, tileStrides);
+ tensor::ExtractSliceOp::create(rewriter, loc, RankedTensorType(),
+ transposedOp.getResult()[0], tileSizes);
// 4. Insert the result to the destination tensor.
SmallVector<OpFoldResult> writeSizes;
- SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
- SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
for (int i = 0, idx = 0; i < destRank; ++i) {
if (dimAndTileMapping.count(i) || destShape[i] != 1)
writeSizes.push_back(tileSizes[idx++]);
@@ -1407,8 +1387,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
writeSizes.push_back(oneIdxAttr);
}
auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile,
- unpackOp.getDest(), writeOffsets,
- writeSizes, writeStrides);
+ unpackOp.getDest(), writeSizes);
rewriter.replaceOp(unpackOp, insert.getResult());
return success();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 19d2d85..bb3bccd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -746,12 +746,12 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
auto vectorType = state.getCanonicalVecType(
getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
+ SmallVector<Value> indices(linalgOp.getRank(outputOperand),
+ arith::ConstantIndexOp::create(rewriter, loc, 0));
+
Operation *write;
if (vectorType.getRank() > 0) {
AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
- SmallVector<Value> indices(
- linalgOp.getRank(outputOperand),
- arith::ConstantIndexOp::create(rewriter, loc, 0));
value = broadcastIfNeeded(rewriter, value, vectorType);
assert(value.getType() == vectorType && "Incorrect type");
write = vector::TransferWriteOp::create(
@@ -762,7 +762,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
assert(value.getType() == vectorType && "Incorrect type");
write = vector::TransferWriteOp::create(rewriter, loc, value,
- outputOperand->get(), ValueRange{});
+ outputOperand->get(), indices);
}
write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
@@ -1890,9 +1890,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
// Create masked TransferReadOp.
auto maskedRead = vector::createReadOrMaskedRead(
- rewriter, loc, packOp.getSource(), readVecType.getShape(), padValue,
- useInBoundsInsteadOfMasking,
- /*inputScalableVecSizes=*/{});
+ rewriter, loc, packOp.getSource(), readVecType, padValue,
+ useInBoundsInsteadOfMasking);
// Create ShapeCastOp.
auto shapeCastOp = vector::ShapeCastOp::create(
@@ -1977,9 +1976,12 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
}
// -- Generate the read operation --
+ VectorType readVecType =
+ VectorType::get(readVectorSizes, unpackTensorType.getElementType(),
+ readScalableVectorFlags);
Value readResult = vector::createReadOrMaskedRead(
- rewriter, loc, unpackOp.getSource(), readVectorSizes, std::nullopt,
- useInBoundsInsteadOfMasking, readScalableVectorFlags);
+ rewriter, loc, unpackOp.getSource(), readVecType, std::nullopt,
+ useInBoundsInsteadOfMasking);
// -- Generate the transpose operation --
PackingMetadata packMetadata;
@@ -2025,9 +2027,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
.reifyResultShapes(rewriter, reifiedReturnShapes);
(void)status; // prevent unused variable warning on non-assert builds
assert(succeeded(status) && "failed to reify result shapes");
+ auto readType = VectorType::get(inputVectorSizes, padValue.getType());
auto maskedRead = vector::createReadOrMaskedRead(
- rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
- /*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{});
+ rewriter, loc, padOp.getSource(), readType, padValue,
+ /*useInBoundsInsteadOfMasking=*/false);
// Create Xfer write Op
Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
@@ -2222,9 +2225,9 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
Value read = mlir::vector::createReadOrMaskedRead(
- rewriter, loc, opOperand.get(), readType.getShape(),
+ rewriter, loc, opOperand.get(), readType,
/*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
- /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
+ /*useInBoundsInsteadOfMasking=*/false);
vecOperands.push_back(read);
}
@@ -3165,9 +3168,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
SmallVector<Value> readIndices(
vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
Value read = mlir::vector::createReadOrMaskedRead(
- rewriter, loc, source, vecType.getShape(), padValue,
- /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(),
- /*inputScalableVecSizes=*/{});
+ rewriter, loc, source, vecType, padValue,
+ /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
// Create write
auto writeIndices =
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 6eeb206..01e6e1e 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -235,6 +235,731 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
return iteratorType == utils::IteratorType::reduction;
}
+//===----------------------------------------------------------------------===//
+// Convolution matcher utilities
+//===----------------------------------------------------------------------===//
+
+/// Returns the BlockArgument that leads to `val`, if any. Traverses optional
+/// ext* ops.
+static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) {
+ BlockArgument blockArg = dyn_cast<BlockArgument>(val);
+ if ((blockArg))
+ return blockArg;
+
+ Operation *defOp = val.getDefiningOp();
+ if (!dyn_cast_if_present<arith::ExtFOp>(defOp) &&
+ !dyn_cast_if_present<arith::ExtSIOp>(defOp) &&
+ !dyn_cast_if_present<arith::ExtUIOp>(defOp)) {
+ return nullptr;
+ }
+ return dyn_cast<BlockArgument>(defOp->getOperand(0));
+}
+
+/// Utility to match block body for convolution ops.
+/// The body is thus expected to yield :-
+/// %out + (%lhs * %rhs)
+/// where: %lhs, %rhs and %out are block arguments and
+/// %lhs and %rhs can have optional upcast operation.
+static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) {
+ Operation *addOp = yieldVal.getDefiningOp();
+ if (!isa_and_present<arith::AddIOp, arith::AddFOp>(addOp))
+ return false;
+
+ Operation *mulOp = addOp->getOperand(1).getDefiningOp();
+ if (!isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp))
+ return false;
+
+ BlockArgument lhsBlockArg =
+ getBlockArgumentWithOptionalExtOps(mulOp->getOperand(0));
+ BlockArgument rhsBlockArg =
+ getBlockArgumentWithOptionalExtOps(mulOp->getOperand(1));
+ BlockArgument outBlockArg =
+ getBlockArgumentWithOptionalExtOps(addOp->getOperand(0));
+ if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||
+ lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body ||
+ outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 ||
+ rhsBlockArg.getArgNumber() != 1 || outBlockArg.getArgNumber() != 2)
+ return false;
+ return true;
+}
+
+/// Utility to match block body for linalg.pool* ops.
+template <typename... OpTypes>
+static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
+ Operation *defOp = yieldVal.getDefiningOp();
+ if (!(isa_and_present<OpTypes>(defOp) || ...))
+ return false;
+
+ BlockArgument lhsArg =
+ getBlockArgumentWithOptionalExtOps(defOp->getOperand(0));
+ BlockArgument rhsArg =
+ getBlockArgumentWithOptionalExtOps(defOp->getOperand(1));
+ if (!lhsArg || !rhsArg || lhsArg.getOwner() != body ||
+ rhsArg.getOwner() != body || lhsArg.getArgNumber() != 2 ||
+ rhsArg.getArgNumber() != 0)
+ return false;
+ return true;
+}
+
+static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal,
+ body);
+}
+
+// max_unsigned ops should not allow float data type.
+// TODO(#164800): Retire OPDSL logic.
+static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal,
+ body);
+}
+
+static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal,
+ body);
+}
+
+// min_unsigned ops should not allow float data type.
+// TODO(#164800): Retire OPDSL logic.
+static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal,
+ body);
+}
+
+static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
+}
+
+static AffineExpr getAffineMapDim(ArrayAttr indexingMaps, uint32_t mapIndex,
+ uint32_t dimIndex) {
+ auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
+ if (dimIndex < affineMap.getNumResults())
+ return affineMap.getResult(dimIndex);
+ return nullptr;
+}
+
+/// Check if `expr` is either:
+/// - a dimension expr alone (implying multiplication by 1), or
+/// - a multiplication of dimension expr by any positive constant != 1
+/// In both cases we will capture the dimension expression into `dim` and
+/// return the constant multiplier. Returns -1 in case of a match failure.
+static int64_t isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim) {
+ if ((dim = dyn_cast<AffineDimExpr>(expr)))
+ return 1;
+
+ auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
+ if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
+ return -1;
+
+ AffineExpr lhs = mulExpr.getLHS();
+ AffineExpr rhs = mulExpr.getRHS();
+
+ AffineConstantExpr cst = nullptr;
+ if (((dim = dyn_cast<AffineDimExpr>(lhs)) &&
+ (cst = dyn_cast<AffineConstantExpr>(rhs))) ||
+ ((dim = dyn_cast<AffineDimExpr>(rhs)) &&
+ (cst = dyn_cast<AffineConstantExpr>(lhs))))
+ return cst.getValue();
+ return -1;
+}
+
+/// Given an array of AffineMaps `indexingMaps` verify the following
+/// commutatively:-
+/// indexingMaps[0].getResult(iDim) ==
+/// indexingMaps[1].getResult(fDim) * <c0> +
+/// indexingMaps[n-1].getResult(oDim) * <c1>
+/// where,
+/// - c0 and c1 can be any constant,
+/// - n is the size of the indexingMaps' array,
+/// - 0, 1 and n-1 are input, filter and output map indices respectively,
+/// - iDim, fDim and oDim are the input, filter and output dimension
+/// indices in their respective indexing maps
+/// Example:
+/// #inputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6)
+/// -> (d0, d1 * 2 + d4 * 3, d2 + d5, d6)>
+/// #filterMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+/// #outputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+///
+/// Here,
+/// #inputMap[1] = #outputMap[1] * 2 + #filterMap[0] * 3
+/// Therefore,
+/// matchConvDimAddExprPattern(indexingMaps, 1, 0, 1, dilation, stride)
+/// would return true and update dilation = 3 and stride = 2
+static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
+ unsigned fDim, unsigned oDim,
+ int64_t &dilation, int64_t &stride) {
+ unsigned inputMapIdx = 0, filterMapIdx = 1,
+ outputMapIdx = indexingMaps.size() - 1;
+ AffineExpr inpExpr = getAffineMapDim(indexingMaps, inputMapIdx, iDim);
+ auto addExpr = dyn_cast_or_null<AffineBinaryOpExpr>(inpExpr);
+ if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
+ return false;
+
+ AffineExpr dim0, dim1;
+ int64_t c0 = isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0);
+ int64_t c1 = isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1);
+
+ if (c0 == -1 || c1 == -1)
+ return false;
+ // Pattern matched with dims and constants extracted.
+ AffineExpr fExpr = getAffineMapDim(indexingMaps, filterMapIdx, fDim);
+ AffineExpr oExpr = getAffineMapDim(indexingMaps, outputMapIdx, oDim);
+ if (dim0 == fExpr && dim1 == oExpr) {
+ dilation = c0;
+ stride = c1;
+ return true;
+ }
+ if (dim1 == fExpr && dim0 == oExpr) {
+ dilation = c1;
+ stride = c0;
+ return true;
+ }
+ return false;
+}
+
+/// Returns true if the given indexing maps matches with the expected indexing
+/// maps.
+static bool convLayoutMatches(ArrayRef<ArrayRef<AffineExpr>> mapListExpected,
+ ArrayAttr indexingMaps, MLIRContext *context) {
+ SmallVector<AffineMap, 4> expectedIndexingMaps =
+ AffineMap::inferFromExprList(mapListExpected, context);
+ return indexingMaps ==
+ ArrayAttr::get(
+ context, llvm::to_vector<4>(llvm::map_range(
+ expectedIndexingMaps, [&](AffineMap m) -> Attribute {
+ return AffineMapAttr::get(m);
+ })));
+}
+
+/// Enum representing pooling operation types used by ConvMatcherBuilder.
+enum class PoolingType {
+ None,
+ MaxSigned,
+ MaxUnsigned,
+ MinSigned,
+ MinUnsigned,
+ Sum
+};
+
+/// Helper class for building convolution op matchers with minimal boilerplate.
+/// Reduces repetitive code across Conv1D/2D/3D and Depthwise variants as well
+/// as Pooling ops.
+///
+/// Usage: Create an instance with the op, spatial rank, and output pointers for
+/// extracted dilations/strides. Then chain matchStride() calls for each spatial
+/// dimension, followed by matchMaps() to verify indexing maps, and finally
+/// matchBody() to verify the operation body pattern.
+///
+/// The `matched` flag starts as `true` and is set to `false` if any match step
+/// fails. This allows chaining multiple match calls; once any match fails, all
+/// subsequent calls become no-ops and the final result is `false`.
+///
+/// The `dilations` and `strides` pointers are output parameters that get
+/// populated with the extracted dilation and stride values from the operation's
+/// indexing maps during matchStride() calls. These values are initially set to
+/// 1 for each spatial dimension and updated as patterns are matched.
+class ConvMatcherBuilder {
+ LinalgOp op;
+ MLIRContext *ctx;
+ SmallVector<int64_t> *dilations, *strides;
+ ArrayAttr indexingMaps;
+ PoolingType poolingType;
+ bool matched = true;
+
+public:
+ ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector<int64_t> *d,
+ SmallVector<int64_t> *s,
+ PoolingType poolingType = PoolingType::None)
+ : op(op), ctx(op->getContext()), dilations(d), strides(s),
+ indexingMaps(op.getIndexingMaps()), poolingType(poolingType) {
+ *dilations = SmallVector<int64_t>(spatialRank, 1);
+ *strides = SmallVector<int64_t>(spatialRank, 1);
+ }
+
+ /// Get affine dimension expression for dimension `i`.
+ AffineExpr dim(unsigned i) { return getAffineDimExpr(i, ctx); }
+
+ /// Build strided expression: base * stride[idx] + kernel * dilation[idx].
+ AffineExpr strided(AffineExpr base, AffineExpr kernel, unsigned idx) {
+ return base * (*strides)[idx] + kernel * (*dilations)[idx];
+ }
+
+ /// Match stride/dilation pattern for a spatial dimension.
+ /// Returns *this for method chaining.
+ ConvMatcherBuilder &matchStride(unsigned iDim, unsigned fDim, unsigned oDim,
+ unsigned idx) {
+ if (matched) {
+ matched &= matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim,
+ (*dilations)[idx], (*strides)[idx]);
+ }
+ return *this;
+ }
+
+ /// Match expected indexing maps layout. Returns *this for method chaining.
+ ConvMatcherBuilder &matchMaps(ArrayRef<ArrayRef<AffineExpr>> maps) {
+ if (matched)
+ matched &= convLayoutMatches(maps, indexingMaps, ctx);
+ return *this;
+ }
+
+ /// Match body pattern. This should be called last.
+ bool matchBody() {
+ if (!matched)
+ return false;
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ switch (poolingType) {
+ case PoolingType::None:
+ return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body);
+ case PoolingType::MaxSigned:
+ return bodyMatcherForMaxSignedPoolOps(yieldOp.getOperand(0), body);
+ case PoolingType::MaxUnsigned:
+ return bodyMatcherForMaxUnsignedPoolOps(yieldOp.getOperand(0), body);
+ case PoolingType::MinSigned:
+ return bodyMatcherForMinSignedPoolOps(yieldOp.getOperand(0), body);
+ case PoolingType::MinUnsigned:
+ return bodyMatcherForMinUnsignedPoolOps(yieldOp.getOperand(0), body);
+ case PoolingType::Sum:
+ return bodyMatcherForSumPoolOps(yieldOp.getOperand(0), body);
+ }
+ return false;
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Matchers for specific convolution operation.
+//===----------------------------------------------------------------------===//
+
+// #inputMap = affine_map<(W, w) -> (W + w)>
+// #filterMap = affine_map<(W, w) -> (w)>
+// #outputMap = affine_map<(W, w) -> (W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv1DOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+ AffineExpr W = m.dim(0);
+ AffineExpr w = m.dim(1);
+
+ return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{m.strided(W, w, 0)},
+ /*filterMap=*/{w},
+ /*outputMap=*/{W}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(N, W, F, w, c) -> (N, W + w, c)>
+// #filterMap = affine_map<(N, W, F, w, c) -> (w, c, F)>
+// #outputMap = affine_map<(N, W, F, w, c) -> (N, W, F)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv1DNwcWcfOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+ AffineExpr N = m.dim(0);
+ AffineExpr W = m.dim(1);
+ AffineExpr F = m.dim(2);
+ AffineExpr w = m.dim(3);
+ AffineExpr c = m.dim(4);
+
+ return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), c},
+ /*filterMap=*/{w, c, F},
+ /*outputMap=*/{N, W, F}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(N, F, W, c, w) -> (N, c, W + w)>
+// #filterMap = affine_map<(N, F, W, c, w) -> (F, c, w)>
+// #outputMap = affine_map<(N, F, W, c, w) -> (N, F, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv1DNcwFcwOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+ AffineExpr N = m.dim(0);
+ AffineExpr F = m.dim(1);
+ AffineExpr W = m.dim(2);
+ AffineExpr c = m.dim(3);
+ AffineExpr w = m.dim(4);
+
+ return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)},
+ /*filterMap=*/{F, c, w},
+ /*outputMap=*/{N, F, W}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(H, W, h, w) -> (H + h, W + w)>
+// #filterMap = affine_map<(H, W, h, w) -> (h, w)>
+// #outputMap = affine_map<(H, W, h, w) -> (H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv2DOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+ AffineExpr H = m.dim(0);
+ AffineExpr W = m.dim(1);
+ AffineExpr h = m.dim(2);
+ AffineExpr w = m.dim(3);
+
+ return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
+ .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
+ .matchMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{H, W}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)>
+// #filterMap = affine_map<(D, H, W, d, h, w) -> (d, h, w)>
+// #outputMap = affine_map<(D, H, W, d, h, w) -> (D, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv3DOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
+ AffineExpr D = m.dim(0);
+ AffineExpr H = m.dim(1);
+ AffineExpr W = m.dim(2);
+ AffineExpr d = m.dim(3);
+ AffineExpr h = m.dim(4);
+ AffineExpr w = m.dim(5);
+
+ return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
+ .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
+ .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/2)
+ .matchMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1),
+ m.strided(W, w, 2)},
+ /*filterMap=*/{d, h, w},
+ /*outputMap=*/{D, H, W}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(N, W, C, w) -> (N, C, W + w)>
+// #filterMap = affine_map<(N, W, C, w) -> (C, w)>
+// #outputMap = affine_map<(N, W, C, w) -> (N, C, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv1DNcwCwOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+ AffineExpr N = m.dim(0);
+ AffineExpr W = m.dim(1);
+ AffineExpr C = m.dim(2);
+ AffineExpr w = m.dim(3);
+
+ return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
+ /*filterMap=*/{C, w},
+ /*outputMap=*/{N, C, W}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)>
+// #filterMap = affine_map<(N, W, C, w) -> (w, C)>
+// #outputMap = affine_map<(N, W, C, w) -> (N, W, C)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+ AffineExpr N = m.dim(0);
+ AffineExpr W = m.dim(1);
+ AffineExpr C = m.dim(2);
+ AffineExpr w = m.dim(3);
+
+ return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+ /*filterMap=*/{w, C},
+ /*outputMap=*/{N, W, C}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(N, W, C, CM, w) -> (N, W + w, C)>
+// #filterMap = affine_map<(N, W, C, CM, w) -> (w, C, CM)>
+// #outputMap = affine_map<(N, W, C, CM, w) -> (N, W, C, CM)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+ AffineExpr N = m.dim(0);
+ AffineExpr W = m.dim(1);
+ AffineExpr C = m.dim(2);
+ AffineExpr CM = m.dim(3);
+ AffineExpr w = m.dim(4);
+
+ return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+ /*filterMap=*/{w, C, CM},
+ /*outputMap=*/{N, W, C, CM}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)>
+// #filterMap = affine_map<(N, H, W, C, h, w) -> (C, h, w)>
+// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+ AffineExpr N = m.dim(0);
+ AffineExpr H = m.dim(1);
+ AffineExpr W = m.dim(2);
+ AffineExpr C = m.dim(3);
+ AffineExpr h = m.dim(4);
+ AffineExpr w = m.dim(5);
+
+ return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
+ .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1)
+ .matchMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
+ /*filterMap=*/{C, h, w},
+ /*outputMap=*/{N, C, H, W}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(N, D, H, W, CM, d, h, w, C)
+// -> (N, D + d, H + h, W + w, C)>
+// #filterMap = affine_map<(N, D, H, W, CM, d, h, w, C)
+// -> (d, h, w, C, CM)>
+// #outputMap = affine_map<(N, D, H, W, CM, d, h, w, C)
+// -> (N, D, H, W, C, CM)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
+ AffineExpr N = m.dim(0);
+ AffineExpr D = m.dim(1);
+ AffineExpr H = m.dim(2);
+ AffineExpr W = m.dim(3);
+ AffineExpr CM = m.dim(4);
+ AffineExpr d = m.dim(5);
+ AffineExpr h = m.dim(6);
+ AffineExpr w = m.dim(7);
+ AffineExpr C = m.dim(8);
+
+ return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
+ .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
+ m.strided(W, w, 2), C},
+ /*filterMap=*/{d, h, w, C, CM},
+ /*outputMap=*/{N, D, H, W, C, CM}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)>
+// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcMaxOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
+ PoolingType::MaxSigned);
+ AffineExpr N = m.dim(0);
+ AffineExpr H = m.dim(1);
+ AffineExpr W = m.dim(2);
+ AffineExpr C = m.dim(3);
+ AffineExpr h = m.dim(4);
+ AffineExpr w = m.dim(5);
+
+ return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)>
+// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcMinOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
+ PoolingType::MinSigned);
+ AffineExpr N = m.dim(0);
+ AffineExpr H = m.dim(1);
+ AffineExpr W = m.dim(2);
+ AffineExpr C = m.dim(3);
+ AffineExpr h = m.dim(4);
+ AffineExpr w = m.dim(5);
+
+ return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)>
+// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcSumOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
+ PoolingType::Sum);
+ AffineExpr N = m.dim(0);
+ AffineExpr H = m.dim(1);
+ AffineExpr W = m.dim(2);
+ AffineExpr C = m.dim(3);
+ AffineExpr h = m.dim(4);
+ AffineExpr w = m.dim(5);
+
+ return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)>
+// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
+ PoolingType::MaxUnsigned);
+ AffineExpr N = m.dim(0);
+ AffineExpr H = m.dim(1);
+ AffineExpr W = m.dim(2);
+ AffineExpr C = m.dim(3);
+ AffineExpr h = m.dim(4);
+ AffineExpr w = m.dim(5);
+
+ return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
+ .matchBody();
+}
+
+// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)>
+// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcMinUnsignedOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
+ PoolingType::MinUnsigned);
+ AffineExpr N = m.dim(0);
+ AffineExpr H = m.dim(1);
+ AffineExpr W = m.dim(2);
+ AffineExpr C = m.dim(3);
+ AffineExpr h = m.dim(4);
+ AffineExpr w = m.dim(5);
+
+ return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
+ .matchBody();
+}
+
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
Value source, Value pad, bool nofold,
ValueRange typeDynDims) {
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index 1382c7ac..d358362 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
MLIRMemorySlotInterfaces
MLIRShapedOpInterfaces
MLIRSideEffectInterfaces
+ MLIRUBDialect
MLIRValueBoundsOpInterface
MLIRViewLikeInterface
)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp
index 6ff63df..a1e3f10 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp
@@ -10,6 +10,7 @@
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index dfa2e4e..5404238 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
@@ -61,15 +62,8 @@ static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
// Interfaces for AllocaOp
//===----------------------------------------------------------------------===//
-static bool isSupportedElementType(Type type) {
- return llvm::isa<MemRefType>(type) ||
- OpBuilder(type.getContext()).getZeroAttr(type);
-}
-
SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
MemRefType type = getType();
- if (!isSupportedElementType(type.getElementType()))
- return {};
if (!type.hasStaticShape())
return {};
// Make sure the memref contains only a single element.
@@ -81,16 +75,7 @@ SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
OpBuilder &builder) {
- assert(isSupportedElementType(slot.elemType));
- // TODO: support more types.
- return TypeSwitch<Type, Value>(slot.elemType)
- .Case([&](MemRefType t) {
- return memref::AllocaOp::create(builder, getLoc(), t);
- })
- .Default([&](Type t) {
- return arith::ConstantOp::create(builder, getLoc(), t,
- builder.getZeroAttr(t));
- });
+ return ub::PoisonOp::create(builder, getLoc(), slot.elemType);
}
std::optional<PromotableAllocationOpInterface>
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1c21a2f..1035d7c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1074,13 +1074,6 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
return subview.getDynamicSize(sourceIndex);
}
- if (auto sizeInterface =
- dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
- assert(sizeInterface.isDynamicSize(unsignedIndex) &&
- "Expected dynamic subview size");
- return sizeInterface.getDynamicSize(unsignedIndex);
- }
-
// dim(memrefcast) -> dim
if (succeeded(foldMemRefCast(*this)))
return getResult();
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index bd02516..c9352e8 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -959,7 +959,11 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
PatternRewriter &rewriter) const override {
auto viewLikeOp =
extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
- if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest())
+ // ViewLikeOpInterface by itself doesn't guarantee to preserve the base
+ // pointer in general and `memref.view` is one such example, so just check
+ // for a few specific cases.
+ if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest() ||
+ !isa<memref::SubViewOp, memref::ReinterpretCastOp>(viewLikeOp))
return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source");
rewriter.modifyOpInPlace(extractOp, [&]() {
extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 214410f..3667fdb 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -347,28 +347,55 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
return failure();
- llvm::TypeSwitch<Operation *, void>(loadOp)
+
+ return llvm::TypeSwitch<Operation *, LogicalResult>(loadOp)
.Case([&](affine::AffineLoadOp op) {
rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
loadOp, expandShapeOp.getViewSource(), sourceIndices);
+ return success();
})
.Case([&](memref::LoadOp op) {
rewriter.replaceOpWithNewOp<memref::LoadOp>(
loadOp, expandShapeOp.getViewSource(), sourceIndices,
op.getNontemporal());
+ return success();
})
.Case([&](vector::LoadOp op) {
rewriter.replaceOpWithNewOp<vector::LoadOp>(
op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
op.getNontemporal());
+ return success();
})
.Case([&](vector::MaskedLoadOp op) {
rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
op.getMask(), op.getPassThru());
+ return success();
+ })
+ .Case([&](vector::TransferReadOp op) {
+ // We only support minor identity maps in the permutation attribute.
+ if (!op.getPermutationMap().isMinorIdentity())
+ return failure();
+
+ // We only support the case where the source of the expand shape has
+ // rank greater than or equal to the vector rank.
+ const int64_t sourceRank = sourceIndices.size();
+ const int64_t vectorRank = op.getVectorType().getRank();
+ if (sourceRank < vectorRank)
+ return failure();
+
+ // We need to construct a new minor identity map since we will have lost
+ // some dimensions in folding away the expand shape.
+ auto minorIdMap = AffineMap::getMinorIdentityMap(sourceRank, vectorRank,
+ op.getContext());
+
+ rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
+ op, op.getVectorType(), expandShapeOp.getViewSource(),
+ sourceIndices, minorIdMap, op.getPadding(), op.getMask(),
+ op.getInBounds());
+ return success();
})
.DefaultUnreachable("unexpected operation");
- return success();
}
template <typename OpTy>
@@ -659,6 +686,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
+ LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>,
StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 6a81a15..c498c8a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -90,17 +90,16 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
if (!dimIndex)
return failure();
- ReifiedRankedShapedTypeDims reifiedResultShapes;
- if (failed(reifyResultShapes(rewriter, dimValue.getOwner(),
- reifiedResultShapes)))
+ FailureOr<OpFoldResult> replacement = reifyDimOfResult(
+ rewriter, dimValue.getOwner(), dimValue.getResultNumber(), *dimIndex);
+ if (failed(replacement))
return failure();
- unsigned resultNumber = dimValue.getResultNumber();
- // Do not apply pattern if the IR is invalid (dim out of bounds).
- if ((size_t)(*dimIndex) >= reifiedResultShapes[resultNumber].size())
- return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds");
- Value replacement = getValueOrCreateConstantIndexOp(
- rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]);
- rewriter.replaceOp(dimOp, replacement);
+ // Check if the OpFoldResult is empty (unreifiable dimension).
+ if (!replacement.value())
+ return failure();
+ Value replacementVal = getValueOrCreateConstantIndexOp(
+ rewriter, dimOp.getLoc(), replacement.value());
+ rewriter.replaceOp(dimOp, replacementVal);
return success();
}
};
@@ -166,12 +165,14 @@ namespace {
struct ResolveRankedShapeTypeResultDimsPass final
: public memref::impl::ResolveRankedShapeTypeResultDimsPassBase<
ResolveRankedShapeTypeResultDimsPass> {
+ using Base::Base;
void runOnOperation() override;
};
struct ResolveShapedTypeResultDimsPass final
: public memref::impl::ResolveShapedTypeResultDimsPassBase<
ResolveShapedTypeResultDimsPass> {
+ using Base::Base;
void runOnOperation() override;
};
@@ -195,14 +196,22 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ auto result = applyPatternsGreedily(getOperation(), std::move(patterns));
+ if (errorOnPatternIterationLimit && failed(result)) {
+ getOperation()->emitOpError(
+ "dim operation resolution hit pattern iteration limit");
return signalPassFailure();
+ }
}
void ResolveShapedTypeResultDimsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ auto result = applyPatternsGreedily(getOperation(), std::move(patterns));
+ if (errorOnPatternIterationLimit && failed(result)) {
+ getOperation()->emitOpError(
+ "dim operation resolution hit pattern iteration limit");
return signalPassFailure();
+ }
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 14152c5..e5cc41e 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -268,61 +268,82 @@ struct SubViewOpInterface
MemRefType sourceType = subView.getSource().getType();
// For each dimension, assert that:
- // 0 <= offset < dim_size
- // 0 <= offset + (size - 1) * stride < dim_size
+ // For empty slices (size == 0) : 0 <= offset <= dim_size
+ // For non-empty slices (size > 0): 0 <= offset < dim_size
+ // 0 <= offset + (size - 1) * stride
+ // dim_size
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
Value one = arith::ConstantIndexOp::create(builder, loc, 1);
+
auto metadataOp =
ExtractStridedMetadataOp::create(builder, loc, subView.getSource());
+
for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
- // Reset insertion point to before the operation for each dimension
+ // Reset insertion point to before the operation for each dimension.
builder.setInsertionPoint(subView);
+
Value offset = getValueOrCreateConstantIndexOp(
builder, loc, subView.getMixedOffsets()[i]);
Value size = getValueOrCreateConstantIndexOp(builder, loc,
subView.getMixedSizes()[i]);
Value stride = getValueOrCreateConstantIndexOp(
builder, loc, subView.getMixedStrides()[i]);
-
- // Verify that offset is in-bounds.
Value dimSize = metadataOp.getSizes()[i];
- Value offsetInBounds =
- generateInBoundsCheck(builder, loc, offset, zero, dimSize);
- cf::AssertOp::create(builder, loc, offsetInBounds,
+
+ // Verify that offset is in-bounds (conditional on slice size).
+ Value sizeIsZero = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::eq, size, zero);
+ auto offsetCheckIf = scf::IfOp::create(
+ builder, loc, sizeIsZero,
+ [&](OpBuilder &b, Location loc) {
+ // For empty slices, offset can be at the boundary: 0 <= offset <=
+ // dimSize.
+ Value offsetGEZero = arith::CmpIOp::create(
+ b, loc, arith::CmpIPredicate::sge, offset, zero);
+ Value offsetLEDimSize = arith::CmpIOp::create(
+ b, loc, arith::CmpIPredicate::sle, offset, dimSize);
+ Value emptyOffsetValid =
+ arith::AndIOp::create(b, loc, offsetGEZero, offsetLEDimSize);
+ scf::YieldOp::create(b, loc, emptyOffsetValid);
+ },
+ [&](OpBuilder &b, Location loc) {
+ // For non-empty slices, offset must be a valid index: 0 <= offset
+ // dimSize.
+ Value offsetInBounds =
+ generateInBoundsCheck(b, loc, offset, zero, dimSize);
+ scf::YieldOp::create(b, loc, offsetInBounds);
+ });
+
+ Value offsetCondition = offsetCheckIf.getResult(0);
+ cf::AssertOp::create(builder, loc, offsetCondition,
generateErrorMessage(op, "offset " +
std::to_string(i) +
" is out-of-bounds"));
- // Only verify if size > 0
+ // Verify that the slice endpoint is in-bounds (only for non-empty
+ // slices).
Value sizeIsNonZero = arith::CmpIOp::create(
builder, loc, arith::CmpIPredicate::sgt, size, zero);
+ auto ifOp = scf::IfOp::create(
+ builder, loc, sizeIsNonZero,
+ [&](OpBuilder &b, Location loc) {
+ // Verify that slice does not run out-of-bounds.
+ Value sizeMinusOne = arith::SubIOp::create(b, loc, size, one);
+ Value sizeMinusOneTimesStride =
+ arith::MulIOp::create(b, loc, sizeMinusOne, stride);
+ Value lastPos =
+ arith::AddIOp::create(b, loc, offset, sizeMinusOneTimesStride);
+ Value lastPosInBounds =
+ generateInBoundsCheck(b, loc, lastPos, zero, dimSize);
+ scf::YieldOp::create(b, loc, lastPosInBounds);
+ },
+ [&](OpBuilder &b, Location loc) {
+ Value trueVal =
+ arith::ConstantOp::create(b, loc, b.getBoolAttr(true));
+ scf::YieldOp::create(b, loc, trueVal);
+ });
- auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(),
- sizeIsNonZero, /*withElseRegion=*/true);
-
- // Populate the "then" region (for size > 0).
- builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-
- // Verify that slice does not run out-of-bounds.
- Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
- Value sizeMinusOneTimesStride =
- arith::MulIOp::create(builder, loc, sizeMinusOne, stride);
- Value lastPos =
- arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
- Value lastPosInBounds =
- generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
-
- scf::YieldOp::create(builder, loc, lastPosInBounds);
-
- // Populate the "else" region (for size == 0).
- builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- Value trueVal =
- arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true));
- scf::YieldOp::create(builder, loc, trueVal);
-
- builder.setInsertionPointAfter(ifOp);
Value finalCondition = ifOp.getResult(0);
-
cf::AssertOp::create(
builder, loc, finalCondition,
generateErrorMessage(op,
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index 6200366..e548698 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -133,17 +133,20 @@ getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits,
}
/// Returns true if all the uses of op are not read/load.
-/// There can be SubviewOp users as long as all its users are also
+/// There can be view-like-op users as long as all its users are also
/// StoreOp/transfer_write. If return true it also fills out the uses, if it
/// returns false uses is unchanged.
static bool resultIsNotRead(Operation *op, std::vector<Operation *> &uses) {
std::vector<Operation *> opUses;
for (OpOperand &use : op->getUses()) {
Operation *useOp = use.getOwner();
+ // Use escaped the scope
+ if (useOp->mightHaveTrait<OpTrait::IsTerminator>())
+ return false;
if (isa<memref::DeallocOp>(useOp) ||
(useOp->getNumResults() == 0 && useOp->getNumRegions() == 0 &&
!mlir::hasEffect<MemoryEffects::Read>(useOp)) ||
- (isa<memref::SubViewOp>(useOp) && resultIsNotRead(useOp, opUses))) {
+ (isa<ViewLikeOpInterface>(useOp) && resultIsNotRead(useOp, opUses))) {
opUses.push_back(useOp);
continue;
}
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index 2a857ed..0d05313 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -675,7 +675,7 @@ MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc,
Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
OpBuilder &b, Location loc, OpFoldResult laneId, Value memref,
IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
- auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn));
+ auto loads = buildMemRefLoads(b, loc, laneId, memref, indexFn);
Type elementType = getElementTypeOrSelf(memref.getType());
auto vt = VectorType::get(vectorShape, elementType);
@@ -727,7 +727,7 @@ SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand(
[&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
toStore.push_back(v);
});
- return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn));
+ return buildMemRefStores(b, loc, toStore, laneId, memref, indexFn);
}
static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
@@ -792,7 +792,7 @@ FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
if (failed(maybeInfo))
return failure();
- MmaSyncInfo info = *maybeInfo;
+ const MmaSyncInfo &info = *maybeInfo;
auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
diff --git a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp
index 40e769e..1d775fb 100644
--- a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp
+++ b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp
@@ -41,5 +41,12 @@ InFlightDiagnostic OpenACCSupport::emitNYI(Location loc, const Twine &message) {
return mlir::emitError(loc, "not yet implemented: " + message);
}
+bool OpenACCSupport::isValidSymbolUse(Operation *user, SymbolRefAttr symbol,
+ Operation **definingOpPtr) {
+ if (impl)
+ return impl->isValidSymbolUse(user, symbol, definingOpPtr);
+ return acc::isValidSymbolUse(user, symbol, definingOpPtr);
+}
+
} // namespace acc
} // namespace mlir
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 8c9c137..47f1222 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -15,6 +15,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
@@ -203,12 +204,91 @@ struct MemRefPointerLikeModel
return false;
}
+
+ mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
+ TypedValue<PointerLikeType> srcPtr,
+ Type valueType) const {
+ // Load from a memref - only valid for scalar memrefs (rank 0).
+ // This is because the address computation for memrefs is part of the load
+ // (and not computed separately), but the API does not have arguments for
+ // indexing.
+ auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr);
+ if (!memrefValue)
+ return {};
+
+ auto memrefTy = memrefValue.getType();
+
+ // Only load from scalar memrefs (rank 0)
+ if (memrefTy.getRank() != 0)
+ return {};
+
+ return memref::LoadOp::create(builder, loc, memrefValue);
+ }
+
+ bool genStore(Type pointer, OpBuilder &builder, Location loc,
+ Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
+ // Store to a memref - only valid for scalar memrefs (rank 0)
+ // This is because the address computation for memrefs is part of the store
+ // (and not computed separately), but the API does not have arguments for
+ // indexing.
+ auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr);
+ if (!memrefValue)
+ return false;
+
+ auto memrefTy = memrefValue.getType();
+
+ // Only store to scalar memrefs (rank 0)
+ if (memrefTy.getRank() != 0)
+ return false;
+
+ memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
+ return true;
+ }
};
struct LLVMPointerPointerLikeModel
: public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
LLVM::LLVMPointerType> {
Type getElementType(Type pointer) const { return Type(); }
+
+ mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
+ TypedValue<PointerLikeType> srcPtr,
+ Type valueType) const {
+ // For LLVM pointers, we need the valueType to determine what to load
+ if (!valueType)
+ return {};
+
+ return LLVM::LoadOp::create(builder, loc, valueType, srcPtr);
+ }
+
+ bool genStore(Type pointer, OpBuilder &builder, Location loc,
+ Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
+ LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
+ return true;
+ }
+};
+
+struct MemrefAddressOfGlobalModel
+ : public AddressOfGlobalOpInterface::ExternalModel<
+ MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
+ SymbolRefAttr getSymbol(Operation *op) const {
+ auto getGlobalOp = cast<memref::GetGlobalOp>(op);
+ return getGlobalOp.getNameAttr();
+ }
+};
+
+struct MemrefGlobalVariableModel
+ : public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
+ memref::GlobalOp> {
+ bool isConstant(Operation *op) const {
+ auto globalOp = cast<memref::GlobalOp>(op);
+ return globalOp.getConstant();
+ }
+
+ Region *getInitRegion(Operation *op) const {
+ // GlobalOp uses attributes for initialization, not regions
+ return nullptr;
+ }
};
/// Helper function for any of the times we need to modify an ArrayAttr based on
@@ -302,6 +382,11 @@ void OpenACCDialect::initialize() {
MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext());
LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
*getContext());
+
+ // Attach operation interfaces
+ memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
+ *getContext());
+ memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext());
}
//===----------------------------------------------------------------------===//
@@ -467,6 +552,28 @@ checkValidModifier(Op op, acc::DataClauseModifier validModifiers) {
return success();
}
+template <typename OpT, typename RecipeOpT>
+static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName) {
+ // Mappable types do not need a recipe because it is possible to generate one
+ // from its API. Reject reductions though because no API is available for them
+ // at this time.
+ if (mlir::acc::isMappableType(op.getVar().getType()) &&
+ !std::is_same_v<OpT, acc::ReductionOp>)
+ return success();
+
+ mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr();
+ if (!operandRecipe)
+ return op->emitOpError() << "recipe expected for " << operandName;
+
+ auto decl =
+ SymbolTable::lookupNearestSymbolFrom<RecipeOpT>(op, operandRecipe);
+ if (!decl)
+ return op->emitOpError()
+ << "expected symbol reference " << operandRecipe << " to point to a "
+ << operandName << " declaration";
+ return success();
+}
+
static ParseResult parseVar(mlir::OpAsmParser &parser,
OpAsmParser::UnresolvedOperand &var) {
// Either `var` or `varPtr` keyword is required.
@@ -573,6 +680,18 @@ static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op,
}
}
+static ParseResult parseRecipeSym(mlir::OpAsmParser &parser,
+ mlir::SymbolRefAttr &recipeAttr) {
+ if (failed(parser.parseAttribute(recipeAttr)))
+ return failure();
+ return success();
+}
+
+static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op,
+ mlir::SymbolRefAttr recipeAttr) {
+ p << recipeAttr;
+}
+
//===----------------------------------------------------------------------===//
// DataBoundsOp
//===----------------------------------------------------------------------===//
@@ -595,6 +714,9 @@ LogicalResult acc::PrivateOp::verify() {
return failure();
if (failed(checkNoModifier(*this)))
return failure();
+ if (failed(
+ checkRecipe<acc::PrivateOp, acc::PrivateRecipeOp>(*this, "private")))
+ return failure();
return success();
}
@@ -609,6 +731,9 @@ LogicalResult acc::FirstprivateOp::verify() {
return failure();
if (failed(checkNoModifier(*this)))
return failure();
+ if (failed(checkRecipe<acc::FirstprivateOp, acc::FirstprivateRecipeOp>(
+ *this, "firstprivate")))
+ return failure();
return success();
}
@@ -637,6 +762,9 @@ LogicalResult acc::ReductionOp::verify() {
return failure();
if (failed(checkNoModifier(*this)))
return failure();
+ if (failed(checkRecipe<acc::ReductionOp, acc::ReductionRecipeOp>(
+ *this, "reduction")))
+ return failure();
return success();
}
@@ -1322,6 +1450,28 @@ PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
return recipe;
}
+std::optional<PrivateRecipeOp>
+PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
+ StringRef recipeName,
+ FirstprivateRecipeOp firstprivRecipe) {
+ // Create the private.recipe op with the same type as the firstprivate.recipe.
+ OpBuilder::InsertionGuard guard(builder);
+ auto varType = firstprivRecipe.getType();
+ auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
+
+ // Clone the init region
+ IRMapping mapping;
+ firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping);
+
+ // Clone destroy region if the firstprivate.recipe has one.
+ if (!firstprivRecipe.getDestroyRegion().empty()) {
+ IRMapping mapping;
+ firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(),
+ mapping);
+ }
+ return recipe;
+}
+
//===----------------------------------------------------------------------===//
// FirstprivateRecipeOp
//===----------------------------------------------------------------------===//
@@ -1432,40 +1582,6 @@ LogicalResult acc::ReductionRecipeOp::verifyRegions() {
}
//===----------------------------------------------------------------------===//
-// Custom parser and printer verifier for private clause
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseSymOperandList(
- mlir::OpAsmParser &parser,
- llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
- llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
- llvm::SmallVector<SymbolRefAttr> attributes;
- if (failed(parser.parseCommaSeparatedList([&]() {
- if (parser.parseAttribute(attributes.emplace_back()) ||
- parser.parseArrow() ||
- parser.parseOperand(operands.emplace_back()) ||
- parser.parseColonType(types.emplace_back()))
- return failure();
- return success();
- })))
- return failure();
- llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
- attributes.end());
- symbols = ArrayAttr::get(parser.getContext(), arrayAttr);
- return success();
-}
-
-static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op,
- mlir::OperandRange operands,
- mlir::TypeRange types,
- std::optional<mlir::ArrayAttr> attributes) {
- llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) {
- p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
- << std::get<1>(it).getType();
- });
-}
-
-//===----------------------------------------------------------------------===//
// ParallelOp
//===----------------------------------------------------------------------===//
@@ -1484,45 +1600,19 @@ static LogicalResult checkDataOperands(Op op,
return success();
}
-template <typename Op>
-static LogicalResult
-checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
- mlir::OperandRange operands, llvm::StringRef operandName,
- llvm::StringRef symbolName, bool checkOperandType = true) {
- if (!operands.empty()) {
- if (!attributes || attributes->size() != operands.size())
- return op->emitOpError()
- << "expected as many " << symbolName << " symbol reference as "
- << operandName << " operands";
- } else {
- if (attributes)
- return op->emitOpError()
- << "unexpected " << symbolName << " symbol reference";
- return success();
- }
-
+template <typename OpT, typename RecipeOpT>
+static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp,
+ const mlir::ValueRange &operands,
+ llvm::StringRef operandName) {
llvm::DenseSet<Value> set;
- for (auto args : llvm::zip(operands, *attributes)) {
- mlir::Value operand = std::get<0>(args);
-
+ for (mlir::Value operand : operands) {
+ if (!mlir::isa<OpT>(operand.getDefiningOp()))
+ return accConstructOp->emitOpError()
+ << "expected " << operandName << " as defining op";
if (!set.insert(operand).second)
- return op->emitOpError()
+ return accConstructOp->emitOpError()
<< operandName << " operand appears more than once";
-
- mlir::Type varType = operand.getType();
- auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
- auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
- if (!decl)
- return op->emitOpError()
- << "expected symbol reference " << symbolRef << " to point to a "
- << operandName << " declaration";
-
- if (checkOperandType && decl.getType() && decl.getType() != varType)
- return op->emitOpError() << "expected " << operandName << " (" << varType
- << ") to be the same type as " << operandName
- << " declaration (" << decl.getType() << ")";
}
-
return success();
}
@@ -1579,17 +1669,17 @@ static LogicalResult verifyDeviceTypeAndSegmentCountMatch(
}
LogicalResult acc::ParallelOp::verify() {
- if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
- *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
- "privatizations", /*checkOperandType=*/false)))
+ if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
+ mlir::acc::PrivateRecipeOp>(
+ *this, getPrivateOperands(), "private")))
return failure();
- if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
- *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
- "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
+ if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
+ mlir::acc::FirstprivateRecipeOp>(
+ *this, getFirstprivateOperands(), "firstprivate")))
return failure();
- if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
- *this, getReductionRecipes(), getReductionOperands(), "reduction",
- "reductions", false)))
+ if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
+ mlir::acc::ReductionRecipeOp>(
+ *this, getReductionOperands(), "reduction")))
return failure();
if (failed(verifyDeviceTypeAndSegmentCountMatch(
@@ -1720,7 +1810,6 @@ void ParallelOp::build(mlir::OpBuilder &odsBuilder,
mlir::ValueRange gangPrivateOperands,
mlir::ValueRange gangFirstPrivateOperands,
mlir::ValueRange dataClauseOperands) {
-
ParallelOp::build(
odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
/*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
@@ -1729,9 +1818,8 @@ void ParallelOp::build(mlir::OpBuilder &odsBuilder,
/*numGangsDeviceType=*/nullptr, numWorkers,
/*numWorkersDeviceType=*/nullptr, vectorLength,
/*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
- /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr,
- gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands,
- /*firstprivatizations=*/nullptr, dataClauseOperands,
+ /*selfAttr=*/nullptr, reductionOperands, gangPrivateOperands,
+ gangFirstPrivateOperands, dataClauseOperands,
/*defaultAttr=*/nullptr, /*combined=*/nullptr);
}
@@ -1808,46 +1896,22 @@ void acc::ParallelOp::addWaitOperands(
void acc::ParallelOp::addPrivatization(MLIRContext *context,
mlir::acc::PrivateOp op,
mlir::acc::PrivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getPrivateOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getPrivatizationRecipesAttr())
- llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
void acc::ParallelOp::addFirstPrivatization(
MLIRContext *context, mlir::acc::FirstprivateOp op,
mlir::acc::FirstprivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getFirstprivateOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getFirstprivatizationRecipesAttr())
- llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
void acc::ParallelOp::addReduction(MLIRContext *context,
mlir::acc::ReductionOp op,
mlir::acc::ReductionRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getReductionOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getReductionRecipesAttr())
- llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
static ParseResult parseNumGangs(
@@ -2415,17 +2479,17 @@ mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
}
LogicalResult acc::SerialOp::verify() {
- if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
- *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
- "privatizations", /*checkOperandType=*/false)))
+ if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
+ mlir::acc::PrivateRecipeOp>(
+ *this, getPrivateOperands(), "private")))
return failure();
- if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
- *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
- "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
+ if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
+ mlir::acc::FirstprivateRecipeOp>(
+ *this, getFirstprivateOperands(), "firstprivate")))
return failure();
- if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
- *this, getReductionRecipes(), getReductionOperands(), "reduction",
- "reductions", false)))
+ if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
+ mlir::acc::ReductionRecipeOp>(
+ *this, getReductionOperands(), "reduction")))
return failure();
if (failed(verifyDeviceTypeAndSegmentCountMatch(
@@ -2489,46 +2553,22 @@ void acc::SerialOp::addWaitOperands(
void acc::SerialOp::addPrivatization(MLIRContext *context,
mlir::acc::PrivateOp op,
mlir::acc::PrivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getPrivateOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getPrivatizationRecipesAttr())
- llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
void acc::SerialOp::addFirstPrivatization(
MLIRContext *context, mlir::acc::FirstprivateOp op,
mlir::acc::FirstprivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getFirstprivateOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getFirstprivatizationRecipesAttr())
- llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
void acc::SerialOp::addReduction(MLIRContext *context,
mlir::acc::ReductionOp op,
mlir::acc::ReductionRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getReductionOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getReductionRecipesAttr())
- llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
//===----------------------------------------------------------------------===//
@@ -2658,6 +2698,27 @@ LogicalResult acc::KernelsOp::verify() {
return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
}
+void acc::KernelsOp::addPrivatization(MLIRContext *context,
+ mlir::acc::PrivateOp op,
+ mlir::acc::PrivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
+ getPrivateOperandsMutable().append(op.getResult());
+}
+
+void acc::KernelsOp::addFirstPrivatization(
+ MLIRContext *context, mlir::acc::FirstprivateOp op,
+ mlir::acc::FirstprivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
+ getFirstprivateOperandsMutable().append(op.getResult());
+}
+
+void acc::KernelsOp::addReduction(MLIRContext *context,
+ mlir::acc::ReductionOp op,
+ mlir::acc::ReductionRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
+ getReductionOperandsMutable().append(op.getResult());
+}
+
void acc::KernelsOp::addNumWorkersOperand(
MLIRContext *context, mlir::Value newValue,
llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
@@ -2967,19 +3028,21 @@ bool hasDuplicateDeviceTypes(
}
/// Check for duplicates in the DeviceType array attribute.
-LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
+/// Returns std::nullopt if no duplicates, or the duplicate DeviceType if found.
+static std::optional<mlir::acc::DeviceType>
+checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
if (!deviceTypes)
- return success();
+ return std::nullopt;
for (auto attr : deviceTypes) {
auto deviceTypeAttr =
mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
if (!deviceTypeAttr)
- return failure();
+ return mlir::acc::DeviceType::None;
if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
- return failure();
+ return deviceTypeAttr.getValue();
}
- return success();
+ return std::nullopt;
}
LogicalResult acc::LoopOp::verify() {
@@ -3006,9 +3069,10 @@ LogicalResult acc::LoopOp::verify() {
getCollapseDeviceTypeAttr().getValue().size())
return emitOpError() << "collapse attribute count must match collapse"
<< " device_type count";
- if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr())))
- return emitOpError()
- << "duplicate device_type found in collapseDeviceType attribute";
+ if (auto duplicateDeviceType = checkDeviceTypes(getCollapseDeviceTypeAttr()))
+ return emitOpError() << "duplicate device_type `"
+ << acc::stringifyDeviceType(*duplicateDeviceType)
+ << "` found in collapseDeviceType attribute";
// Check gang
if (!getGangOperands().empty()) {
@@ -3021,8 +3085,12 @@ LogicalResult acc::LoopOp::verify() {
return emitOpError() << "gangOperandsArgType attribute count must match"
<< " gangOperands count";
}
- if (getGangAttr() && failed(checkDeviceTypes(getGangAttr())))
- return emitOpError() << "duplicate device_type found in gang attribute";
+ if (getGangAttr()) {
+ if (auto duplicateDeviceType = checkDeviceTypes(getGangAttr()))
+ return emitOpError() << "duplicate device_type `"
+ << acc::stringifyDeviceType(*duplicateDeviceType)
+ << "` found in gang attribute";
+ }
if (failed(verifyDeviceTypeAndSegmentCountMatch(
*this, getGangOperands(), getGangOperandsSegmentsAttr(),
@@ -3030,22 +3098,30 @@ LogicalResult acc::LoopOp::verify() {
return failure();
// Check worker
- if (failed(checkDeviceTypes(getWorkerAttr())))
- return emitOpError() << "duplicate device_type found in worker attribute";
- if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())))
- return emitOpError() << "duplicate device_type found in "
- "workerNumOperandsDeviceType attribute";
+ if (auto duplicateDeviceType = checkDeviceTypes(getWorkerAttr()))
+ return emitOpError() << "duplicate device_type `"
+ << acc::stringifyDeviceType(*duplicateDeviceType)
+ << "` found in worker attribute";
+ if (auto duplicateDeviceType =
+ checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr()))
+ return emitOpError() << "duplicate device_type `"
+ << acc::stringifyDeviceType(*duplicateDeviceType)
+ << "` found in workerNumOperandsDeviceType attribute";
if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
getWorkerNumOperandsDeviceTypeAttr(),
"worker")))
return failure();
// Check vector
- if (failed(checkDeviceTypes(getVectorAttr())))
- return emitOpError() << "duplicate device_type found in vector attribute";
- if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr())))
- return emitOpError() << "duplicate device_type found in "
- "vectorOperandsDeviceType attribute";
+ if (auto duplicateDeviceType = checkDeviceTypes(getVectorAttr()))
+ return emitOpError() << "duplicate device_type `"
+ << acc::stringifyDeviceType(*duplicateDeviceType)
+ << "` found in vector attribute";
+ if (auto duplicateDeviceType =
+ checkDeviceTypes(getVectorOperandsDeviceTypeAttr()))
+ return emitOpError() << "duplicate device_type `"
+ << acc::stringifyDeviceType(*duplicateDeviceType)
+ << "` found in vectorOperandsDeviceType attribute";
if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
getVectorOperandsDeviceTypeAttr(),
"vector")))
@@ -3110,19 +3186,19 @@ LogicalResult acc::LoopOp::verify() {
}
}
- if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
- *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
- "privatizations", false)))
+ if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
+ mlir::acc::PrivateRecipeOp>(
+ *this, getPrivateOperands(), "private")))
return failure();
- if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
- *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
- "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
+ if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
+ mlir::acc::FirstprivateRecipeOp>(
+ *this, getFirstprivateOperands(), "firstprivate")))
return failure();
- if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
- *this, getReductionRecipes(), getReductionOperands(), "reduction",
- "reductions", false)))
+ if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
+ mlir::acc::ReductionRecipeOp>(
+ *this, getReductionOperands(), "reduction")))
return failure();
if (getCombined().has_value() &&
@@ -3556,45 +3632,21 @@ void acc::LoopOp::addGangOperands(
void acc::LoopOp::addPrivatization(MLIRContext *context,
mlir::acc::PrivateOp op,
mlir::acc::PrivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getPrivateOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getPrivatizationRecipesAttr())
- llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
void acc::LoopOp::addFirstPrivatization(
MLIRContext *context, mlir::acc::FirstprivateOp op,
mlir::acc::FirstprivateRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getFirstprivateOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getFirstprivatizationRecipesAttr())
- llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op,
mlir::acc::ReductionRecipeOp recipe) {
+ op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
getReductionOperandsMutable().append(op.getResult());
-
- llvm::SmallVector<mlir::Attribute> recipes;
-
- if (getReductionRecipesAttr())
- llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
-
- recipes.push_back(
- mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
- setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
//===----------------------------------------------------------------------===//
@@ -4059,7 +4111,8 @@ LogicalResult acc::RoutineOp::verify() {
if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
- "be present at the same time";
+ "be present at the same time for device_type `"
+ << acc::stringifyDeviceType(dtype) << "`";
}
return success();
@@ -4356,6 +4409,100 @@ RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
return std::nullopt;
}
+void RoutineOp::addSeq(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+ setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
+ effectiveDeviceTypes));
+}
+
+void RoutineOp::addVector(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+ setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
+ effectiveDeviceTypes));
+}
+
+void RoutineOp::addWorker(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+ setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
+ effectiveDeviceTypes));
+}
+
+void RoutineOp::addGang(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
+ setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
+ effectiveDeviceTypes));
+}
+
+void RoutineOp::addGang(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
+ uint64_t val) {
+ llvm::SmallVector<mlir::Attribute> dimValues;
+ llvm::SmallVector<mlir::Attribute> deviceTypes;
+
+ if (getGangDimAttr())
+ llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
+ if (getGangDimDeviceTypeAttr())
+ llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
+
+ assert(dimValues.size() == deviceTypes.size());
+
+ if (effectiveDeviceTypes.empty()) {
+ dimValues.push_back(
+ mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
+ deviceTypes.push_back(
+ acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
+ } else {
+ for (DeviceType dt : effectiveDeviceTypes) {
+ dimValues.push_back(
+ mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
+ deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
+ }
+ }
+ assert(dimValues.size() == deviceTypes.size());
+
+ setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
+ setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
+}
+
+void RoutineOp::addBindStrName(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
+ mlir::StringAttr val) {
+ unsigned before = getBindStrNameDeviceTypeAttr()
+ ? getBindStrNameDeviceTypeAttr().size()
+ : 0;
+
+ setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
+ context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes));
+ unsigned after = getBindStrNameDeviceTypeAttr().size();
+
+ llvm::SmallVector<mlir::Attribute> vals;
+ if (getBindStrNameAttr())
+ llvm::copy(getBindStrNameAttr(), std::back_inserter(vals));
+ for (unsigned i = 0; i < after - before; ++i)
+ vals.push_back(val);
+
+ setBindStrNameAttr(mlir::ArrayAttr::get(context, vals));
+}
+
+void RoutineOp::addBindIDName(MLIRContext *context,
+ llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
+ mlir::SymbolRefAttr val) {
+ unsigned before =
+ getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0;
+
+ setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
+ context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes));
+ unsigned after = getBindIdNameDeviceTypeAttr().size();
+
+ llvm::SmallVector<mlir::Attribute> vals;
+ if (getBindIdNameAttr())
+ llvm::copy(getBindIdNameAttr(), std::back_inserter(vals));
+ for (unsigned i = 0; i < after - before; ++i)
+ vals.push_back(val);
+
+ setBindIdNameAttr(mlir::ArrayAttr::get(context, vals));
+}
+
//===----------------------------------------------------------------------===//
// InitOp
//===----------------------------------------------------------------------===//
@@ -4739,3 +4886,12 @@ mlir::acc::getMutableDataOperands(mlir::Operation *accOp) {
.Default([&](mlir::Operation *) { return nullptr; })};
return dataOperands;
}
+
+mlir::SymbolRefAttr mlir::acc::getRecipe(mlir::Operation *accOp) {
+ auto recipe{
+ llvm::TypeSwitch<mlir::Operation *, mlir::SymbolRefAttr>(accOp)
+ .Case<ACC_DATA_ENTRY_OPS>(
+ [&](auto entry) { return entry.getRecipeAttr(); })
+ .Default([&](mlir::Operation *) { return mlir::SymbolRefAttr{}; })};
+ return recipe;
+}
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
index 91262bd..67cdf10 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
@@ -237,11 +237,6 @@ public:
void runOnOperation() override;
private:
- /// Collects all data clauses that dominate the compute construct.
- /// Needed to determine if a variable is already covered by an existing data
- /// clause.
- SmallVector<Value> getDominatingDataClauses(Operation *computeConstructOp);
-
/// Looks through the `dominatingDataClauses` to find the original data clause
/// op for an alias. Returns nullptr if no original data clause op is found.
template <typename OpT>
@@ -277,8 +272,7 @@ private:
/// Generates recipes for a list of variables.
void generateRecipes(ModuleOp &module, OpBuilder &builder,
Operation *computeConstructOp,
- const SmallVector<Value> &newOperands,
- SmallVector<Attribute> &newRecipeSyms);
+ const SmallVector<Value> &newOperands);
};
/// Determines if a variable is a candidate for implicit data mapping.
@@ -301,62 +295,6 @@ static bool isCandidateForImplicitData(Value val, Region &accRegion) {
return true;
}
-SmallVector<Value>
-ACCImplicitData::getDominatingDataClauses(Operation *computeConstructOp) {
- llvm::SmallSetVector<Value, 8> dominatingDataClauses;
-
- llvm::TypeSwitch<Operation *>(computeConstructOp)
- .Case<acc::ParallelOp, acc::KernelsOp, acc::SerialOp>([&](auto op) {
- for (auto dataClause : op.getDataClauseOperands()) {
- dominatingDataClauses.insert(dataClause);
- }
- })
- .Default([](Operation *) {});
-
- // Collect the data clauses from enclosing data constructs.
- Operation *currParentOp = computeConstructOp->getParentOp();
- while (currParentOp) {
- if (isa<acc::DataOp>(currParentOp)) {
- for (auto dataClause :
- dyn_cast<acc::DataOp>(currParentOp).getDataClauseOperands()) {
- dominatingDataClauses.insert(dataClause);
- }
- }
- currParentOp = currParentOp->getParentOp();
- }
-
- // Find the enclosing function/subroutine
- auto funcOp = computeConstructOp->getParentOfType<FunctionOpInterface>();
- if (!funcOp)
- return dominatingDataClauses.takeVector();
-
- // Walk the function to find `acc.declare_enter`/`acc.declare_exit` pairs that
- // dominate and post-dominate the compute construct and add their data
- // clauses to the list.
- auto &domInfo = this->getAnalysis<DominanceInfo>();
- auto &postDomInfo = this->getAnalysis<PostDominanceInfo>();
- funcOp->walk([&](acc::DeclareEnterOp declareEnterOp) {
- if (domInfo.dominates(declareEnterOp.getOperation(), computeConstructOp)) {
- // Collect all `acc.declare_exit` ops for this token.
- SmallVector<acc::DeclareExitOp> exits;
- for (auto *user : declareEnterOp.getToken().getUsers())
- if (auto declareExit = dyn_cast<acc::DeclareExitOp>(user))
- exits.push_back(declareExit);
-
- // Only add clauses if every `acc.declare_exit` op post-dominates the
- // compute construct.
- if (!exits.empty() && llvm::all_of(exits, [&](acc::DeclareExitOp exitOp) {
- return postDomInfo.postDominates(exitOp, computeConstructOp);
- })) {
- for (auto dataClause : declareEnterOp.getDataClauseOperands())
- dominatingDataClauses.insert(dataClause);
- }
- }
- });
-
- return dominatingDataClauses.takeVector();
-}
-
template <typename OpT>
Operation *ACCImplicitData::getOriginalDataClauseOpForAlias(
Value var, OpBuilder &builder, OpT computeConstructOp,
@@ -453,23 +391,23 @@ ACCImplicitData::generateFirstprivateRecipe(ModuleOp &module, Value var,
void ACCImplicitData::generateRecipes(ModuleOp &module, OpBuilder &builder,
Operation *computeConstructOp,
- const SmallVector<Value> &newOperands,
- SmallVector<Attribute> &newRecipeSyms) {
+ const SmallVector<Value> &newOperands) {
auto &accSupport = this->getAnalysis<acc::OpenACCSupport>();
for (auto var : newOperands) {
auto loc{var.getLoc()};
- if (isa<acc::PrivateOp>(var.getDefiningOp())) {
+ if (auto privateOp = dyn_cast<acc::PrivateOp>(var.getDefiningOp())) {
auto recipe = generatePrivateRecipe(
module, acc::getVar(var.getDefiningOp()), loc, builder, accSupport);
if (recipe)
- newRecipeSyms.push_back(SymbolRefAttr::get(module->getContext(),
- recipe.getSymName().str()));
- } else if (isa<acc::FirstprivateOp>(var.getDefiningOp())) {
+ privateOp.setRecipeAttr(
+ SymbolRefAttr::get(module->getContext(), recipe.getSymName()));
+ } else if (auto firstprivateOp =
+ dyn_cast<acc::FirstprivateOp>(var.getDefiningOp())) {
auto recipe = generateFirstprivateRecipe(
module, acc::getVar(var.getDefiningOp()), loc, builder, accSupport);
if (recipe)
- newRecipeSyms.push_back(SymbolRefAttr::get(module->getContext(),
- recipe.getSymName().str()));
+ firstprivateOp.setRecipeAttr(SymbolRefAttr::get(
+ module->getContext(), recipe.getSymName().str()));
} else {
accSupport.emitNYI(var.getLoc(), "implicit reduction");
}
@@ -570,6 +508,8 @@ Operation *ACCImplicitData::generateDataClauseOpForCandidate(
newDataOp = acc::PresentOp::create(builder, loc, var,
/*structured=*/true, /*implicit=*/true,
accSupport.getVariableName(var));
+ newDataOp->setAttr(acc::getFromDefaultClauseAttrName(),
+ builder.getUnitAttr());
} else {
auto copyinOp =
acc::CopyinOp::create(builder, loc, var,
@@ -611,56 +551,22 @@ static void legalizeValuesInRegion(Region &accRegion,
}
}
-// Adds the private operands and private recipes to the data construct
-// operation in a valid way (ensures that the index in the privatizationRecipes
-// array matches the position of the private operand).
+// Adds the private operands to the compute construct operation.
template <typename OpT>
-static void
-addNewPrivateOperands(OpT &accOp, const SmallVector<Value> &privateOperands,
- const SmallVector<Attribute> &privateRecipeSyms) {
- assert(privateOperands.size() == privateRecipeSyms.size());
+static void addNewPrivateOperands(OpT &accOp,
+ const SmallVector<Value> &privateOperands) {
if (privateOperands.empty())
return;
- SmallVector<Attribute> completePrivateRecipesSyms;
- SmallVector<Attribute> completeFirstprivateRecipesSyms;
- SmallVector<Value> newPrivateOperands;
- SmallVector<Value> newFirstprivateOperands;
-
- // Collect all of the existing recipes since they are held in an attribute.
- // To add to it, we need to create a brand new one.
- if (accOp.getPrivatizationRecipes().has_value())
- for (auto privatization : accOp.getPrivatizationRecipesAttr())
- completePrivateRecipesSyms.push_back(privatization);
- if (accOp.getFirstprivatizationRecipes().has_value())
- for (auto privatization : accOp.getFirstprivatizationRecipesAttr())
- completeFirstprivateRecipesSyms.push_back(privatization);
-
- // Now separate between private and firstprivate operands.
- for (auto [priv, privateRecipeSym] :
- llvm::zip(privateOperands, privateRecipeSyms)) {
+ for (auto priv : privateOperands) {
if (isa<acc::PrivateOp>(priv.getDefiningOp())) {
- newPrivateOperands.push_back(priv);
- completePrivateRecipesSyms.push_back(privateRecipeSym);
+ accOp.getPrivateOperandsMutable().append(priv);
} else if (isa<acc::FirstprivateOp>(priv.getDefiningOp())) {
- newFirstprivateOperands.push_back(priv);
- completeFirstprivateRecipesSyms.push_back(privateRecipeSym);
+ accOp.getFirstprivateOperandsMutable().append(priv);
} else {
- llvm_unreachable("unhandled private operand");
+ llvm_unreachable("unhandled reduction operand");
}
}
-
- // Append all of the new private operands to their appropriate list.
- accOp.getPrivateOperandsMutable().append(newPrivateOperands);
- accOp.getFirstprivateOperandsMutable().append(newFirstprivateOperands);
-
- // Update the privatizationRecipes attributes to hold all of the new recipes.
- if (!completePrivateRecipesSyms.empty())
- accOp.setPrivatizationRecipesAttr(
- ArrayAttr::get(accOp.getContext(), completePrivateRecipesSyms));
- if (!completeFirstprivateRecipesSyms.empty())
- accOp.setFirstprivatizationRecipesAttr(
- ArrayAttr::get(accOp.getContext(), completeFirstprivateRecipesSyms));
}
static Operation *findDataExitOp(Operation *dataEntryOp) {
@@ -808,7 +714,10 @@ void ACCImplicitData::generateImplicitDataOps(
LLVM_DEBUG(llvm::dbgs() << "== Generating clauses for ==\n"
<< computeConstructOp << "\n");
}
- auto dominatingDataClauses = getDominatingDataClauses(computeConstructOp);
+ auto &domInfo = this->getAnalysis<DominanceInfo>();
+ auto &postDomInfo = this->getAnalysis<PostDominanceInfo>();
+ auto dominatingDataClauses =
+ acc::getDominatingDataClauses(computeConstructOp, domInfo, postDomInfo);
for (auto var : candidateVars) {
auto newDataClauseOp = generateDataClauseOpForCandidate(
var, module, builder, computeConstructOp, dominatingDataClauses,
@@ -829,13 +738,11 @@ void ACCImplicitData::generateImplicitDataOps(
// of the data clause ops)
legalizeValuesInRegion(accRegion, newPrivateOperands, newDataClauseOperands);
- SmallVector<Attribute> newPrivateRecipeSyms;
// 5) Generate private recipes which are required for properly attaching
// private operands.
if constexpr (!std::is_same_v<OpT, acc::KernelsOp> &&
!std::is_same_v<OpT, acc::KernelEnvironmentOp>)
- generateRecipes(module, builder, computeConstructOp, newPrivateOperands,
- newPrivateRecipeSyms);
+ generateRecipes(module, builder, computeConstructOp, newPrivateOperands);
// 6) Figure out insertion order for the new data clause operands.
SmallVector<Value> sortedDataClauseOperands(
@@ -846,15 +753,10 @@ void ACCImplicitData::generateImplicitDataOps(
// 7) Generate the data exit operations.
generateDataExitOperations(builder, computeConstructOp, newDataClauseOperands,
sortedDataClauseOperands);
-
// 8) Add all of the new operands to the compute construct op.
- assert(newPrivateOperands.size() == newPrivateRecipeSyms.size() &&
- "sizes must match");
if constexpr (!std::is_same_v<OpT, acc::KernelsOp> &&
!std::is_same_v<OpT, acc::KernelEnvironmentOp>)
- addNewPrivateOperands(computeConstructOp, newPrivateOperands,
- newPrivateRecipeSyms);
-
+ addNewPrivateOperands(computeConstructOp, newPrivateOperands);
computeConstructOp.getDataClauseOperandsMutable().assign(
sortedDataClauseOperands);
}
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp
new file mode 100644
index 0000000..8cab223
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp
@@ -0,0 +1,431 @@
+//===- ACCImplicitDeclare.cpp ---------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass applies implicit `acc declare` actions to global variables
+// referenced in OpenACC compute regions and routine functions.
+//
+// Overview:
+// ---------
+// Global references in an acc regions (for globals not marked with `acc
+// declare` by the user) can be handled in one of two ways:
+// - Mapped through data clauses
+// - Implicitly marked as `acc declare` (this pass)
+//
+// Thus, the OpenACC specification focuses solely on implicit data mapping rules
+// whose implementation is captured in `ACCImplicitData` pass.
+//
+// However, it is both advantageous and required for certain cases to
+// use implicit `acc declare` instead:
+// - Any functions that are implicitly marked as `acc routine` through
+// `ACCImplicitRoutine` may reference globals. Since data mapping
+// is only possible for compute regions, such globals can only be
+// made available on device through `acc declare`.
+// - Compiler can generate and use globals for cases needed in IR
+// representation such as type descriptors or various names needed for
+// runtime calls and error reporting - such cases often are introduced
+// after a frontend semantic checking is done since it is related to
+// implementation detail. Thus, such compiler generated globals would
+// not have been visible for a user to mark with `acc declare`.
+// - Constant globals such as filename strings or data initialization values
+// are values that do not get mutated but are still needed for appropriate
+// runtime execution. If a kernel is launched 1000 times, it is not a
+// good idea to map such a global 1000 times. Therefore, such globals
+// benefit from being marked with `acc declare`.
+//
+// This pass automatically
+// marks global variables with the `acc.declare` attribute when they are
+// referenced in OpenACC compute constructs or routine functions and meet
+// the criteria noted above, ensuring
+// they are properly handled for device execution.
+//
+// The pass performs two main optimizations:
+//
+// 1. Hoisting: For non-constant globals referenced in compute regions, the
+// pass hoists the address-of operation out of the region when possible,
+// allowing them to be implicitly mapped through normal data clause
+// mechanisms rather than requiring declare marking.
+//
+// 2. Declaration: For globals that must be available on the device (constants,
+// globals in routines, globals in recipe operations), the pass adds the
+// `acc.declare` attribute with the copyin data clause.
+//
+// Requirements:
+// -------------
+// To use this pass in a pipeline, the following requirements must be met:
+//
+// 1. Operation Interface Implementation: Operations that compute addresses
+// of global variables must implement the `acc::AddressOfGlobalOpInterface`
+// and those that represent globals must implement the
+// `acc::GlobalOpInterface`. Additionally, any operations that indirectly
+// access globals must implement the `acc::IndirectGlobalAccessOpInterface`.
+//
+// 2. Analysis Registration (Optional): If custom behavior is needed for
+// determining if a symbol use is valid within GPU regions, the dialect
+// should pre-register the `acc::OpenACCSupport` analysis.
+//
+// Examples:
+// ---------
+//
+// Example 1: Non-constant global in compute region (hoisted)
+//
+// Before:
+// memref.global @g_scalar : memref<f32> = dense<0.0>
+// func.func @test() {
+// acc.serial {
+// %addr = memref.get_global @g_scalar : memref<f32>
+// %val = memref.load %addr[] : memref<f32>
+// acc.yield
+// }
+// }
+//
+// After:
+// memref.global @g_scalar : memref<f32> = dense<0.0>
+// func.func @test() {
+// %addr = memref.get_global @g_scalar : memref<f32>
+// acc.serial {
+// %val = memref.load %addr[] : memref<f32>
+// acc.yield
+// }
+// }
+//
+// Example 2: Constant global in compute region (declared)
+//
+// Before:
+// memref.global constant @g_const : memref<f32> = dense<1.0>
+// func.func @test() {
+// acc.serial {
+// %addr = memref.get_global @g_const : memref<f32>
+// %val = memref.load %addr[] : memref<f32>
+// acc.yield
+// }
+// }
+//
+// After:
+// memref.global constant @g_const : memref<f32> = dense<1.0>
+// {acc.declare = #acc.declare<dataClause = acc_copyin>}
+// func.func @test() {
+// acc.serial {
+// %addr = memref.get_global @g_const : memref<f32>
+// %val = memref.load %addr[] : memref<f32>
+// acc.yield
+// }
+// }
+//
+// Example 3: Global in acc routine (declared)
+//
+// Before:
+// memref.global @g_data : memref<f32> = dense<0.0>
+// acc.routine @routine_0 func(@device_func)
+// func.func @device_func() attributes {acc.routine_info = ...} {
+// %addr = memref.get_global @g_data : memref<f32>
+// %val = memref.load %addr[] : memref<f32>
+// }
+//
+// After:
+// memref.global @g_data : memref<f32> = dense<0.0>
+// {acc.declare = #acc.declare<dataClause = acc_copyin>}
+// acc.routine @routine_0 func(@device_func)
+// func.func @device_func() attributes {acc.routine_info = ...} {
+// %addr = memref.get_global @g_data : memref<f32>
+// %val = memref.load %addr[] : memref<f32>
+// }
+//
+// Example 4: Global in private recipe (declared if recipe is used)
+//
+// Before:
+// memref.global @g_init : memref<f32> = dense<0.0>
+// acc.private.recipe @priv_recipe : memref<f32> init {
+// ^bb0(%arg0: memref<f32>):
+// %alloc = memref.alloc() : memref<f32>
+// %global = memref.get_global @g_init : memref<f32>
+// %val = memref.load %global[] : memref<f32>
+// memref.store %val, %alloc[] : memref<f32>
+// acc.yield %alloc : memref<f32>
+// } destroy { ... }
+// func.func @test() {
+// %var = memref.alloc() : memref<f32>
+// %priv = acc.private varPtr(%var : memref<f32>)
+// recipe(@priv_recipe) -> memref<f32>
+// acc.parallel private(%priv : memref<f32>) { ... }
+// }
+//
+// After:
+// memref.global @g_init : memref<f32> = dense<0.0>
+// {acc.declare = #acc.declare<dataClause = acc_copyin>}
+// acc.private.recipe @priv_recipe : memref<f32> init {
+// ^bb0(%arg0: memref<f32>):
+// %alloc = memref.alloc() : memref<f32>
+// %global = memref.get_global @g_init : memref<f32>
+// %val = memref.load %global[] : memref<f32>
+// memref.store %val, %alloc[] : memref<f32>
+// acc.yield %alloc : memref<f32>
+// } destroy { ... }
+// func.func @test() {
+// %var = memref.alloc() : memref<f32>
+// %priv = acc.private varPtr(%var : memref<f32>)
+// recipe(@priv_recipe) -> memref<f32>
+// acc.parallel private(%priv : memref<f32>) { ... }
+// }
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+
+#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+namespace acc {
+#define GEN_PASS_DEF_ACCIMPLICITDECLARE
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+} // namespace acc
+} // namespace mlir
+
+#define DEBUG_TYPE "acc-implicit-declare"
+
+using namespace mlir;
+
+namespace {
+
+using GlobalOpSetT = llvm::SmallSetVector<Operation *, 16>;
+
+/// Checks whether a use of the requested `globalOp` should be considered
+/// for hoisting out of acc region due to avoid `acc declare`ing something
+/// that instead should be implicitly mapped.
+static bool isGlobalUseCandidateForHoisting(Operation *globalOp,
+ Operation *user,
+ SymbolRefAttr symbol,
+ acc::OpenACCSupport &accSupport) {
+ // This symbol is valid in GPU region. This means semantics
+ // would change if moved to host - therefore it is not a candidate.
+ if (accSupport.isValidSymbolUse(user, symbol))
+ return false;
+
+ bool isConstant = false;
+ bool isFunction = false;
+
+ if (auto globalVarOp = dyn_cast<acc::GlobalVariableOpInterface>(globalOp))
+ isConstant = globalVarOp.isConstant();
+
+ if (isa<FunctionOpInterface>(globalOp))
+ isFunction = true;
+
+ // Constants should be kept in device code to ensure they are duplicated.
+ // Function references should be kept in device code to ensure their device
+ // addresses are computed. Everything else should be hoisted since we already
+ // proved they are not valid symbols in GPU region.
+ return !isConstant && !isFunction;
+}
+
+/// Checks whether it is valid to use acc.declare marking on the global.
+bool isValidForAccDeclare(Operation *globalOp) {
+ // For functions - we use acc.routine marking instead.
+ return !isa<FunctionOpInterface>(globalOp);
+}
+
+/// Checks whether a recipe operation has meaningful use of its symbol that
+/// justifies processing its regions for global references. Returns false if:
+/// 1. The recipe has no symbol uses at all, or
+/// 2. The only symbol use is the recipe's own symbol definition
+template <typename RecipeOpT>
+static bool hasRelevantRecipeUse(RecipeOpT &recipeOp, ModuleOp &mod) {
+ std::optional<SymbolTable::UseRange> symbolUses = recipeOp.getSymbolUses(mod);
+
+ // No recipe symbol uses.
+ if (!symbolUses.has_value() || symbolUses->empty())
+ return false;
+
+ // If more than one use, assume it's used.
+ auto begin = symbolUses->begin();
+ auto end = symbolUses->end();
+ if (begin != end && std::next(begin) != end)
+ return true;
+
+ // If single use, check if the use is the recipe itself.
+ const SymbolTable::SymbolUse &use = *symbolUses->begin();
+ return use.getUser() != recipeOp.getOperation();
+}
+
+// Hoists addr_of operations for non-constant globals out of OpenACC regions.
+// This way - they are implicitly mapped instead of being considered for
+// implicit declare.
+template <typename AccConstructT>
+static void hoistNonConstantDirectUses(AccConstructT accOp,
+ acc::OpenACCSupport &accSupport) {
+ accOp.walk([&](acc::AddressOfGlobalOpInterface addrOfOp) {
+ SymbolRefAttr symRef = addrOfOp.getSymbol();
+ if (symRef) {
+ Operation *globalOp =
+ SymbolTable::lookupNearestSymbolFrom(addrOfOp, symRef);
+ if (isGlobalUseCandidateForHoisting(globalOp, addrOfOp, symRef,
+ accSupport)) {
+ addrOfOp->moveBefore(accOp);
+ LLVM_DEBUG(
+ llvm::dbgs() << "Hoisted:\n\t" << addrOfOp << "\n\tfrom:\n\t";
+ accOp->print(llvm::dbgs(),
+ OpPrintingFlags{}.skipRegions().enableDebugInfo());
+ llvm::dbgs() << "\n");
+ }
+ }
+ });
+}
+
+// Collects the globals referenced in a device region
+static void collectGlobalsFromDeviceRegion(Region &region,
+ GlobalOpSetT &globals,
+ acc::OpenACCSupport &accSupport,
+ SymbolTable &symTab) {
+ region.walk([&](Operation *op) {
+ // 1) Only consider relevant operations which use symbols
+ auto addrOfOp = dyn_cast<acc::AddressOfGlobalOpInterface>(op);
+ if (addrOfOp) {
+ SymbolRefAttr symRef = addrOfOp.getSymbol();
+ // 2) Found an operation which uses the symbol. Next determine if it
+ // is a candidate for `acc declare`. Some of the criteria considered
+ // is whether this symbol is not already a device one (either because
+ // acc declare is already used or this is a CUF global).
+ Operation *globalOp = nullptr;
+ bool isCandidate = !accSupport.isValidSymbolUse(op, symRef, &globalOp);
+ // 3) Add the candidate to the set of globals to be `acc declare`d.
+ if (isCandidate && globalOp && isValidForAccDeclare(globalOp))
+ globals.insert(globalOp);
+ } else if (auto indirectAccessOp =
+ dyn_cast<acc::IndirectGlobalAccessOpInterface>(op)) {
+ // Process operations that indirectly access globals
+ llvm::SmallVector<SymbolRefAttr> symbols;
+ indirectAccessOp.getReferencedSymbols(symbols, &symTab);
+ for (SymbolRefAttr symRef : symbols)
+ if (Operation *globalOp = symTab.lookup(symRef.getLeafReference()))
+ if (isValidForAccDeclare(globalOp))
+ globals.insert(globalOp);
+ }
+ });
+}
+
+// Adds the declare attribute to the operation `op`.
+static void addDeclareAttr(MLIRContext *context, Operation *op,
+ acc::DataClause clause) {
+ op->setAttr(acc::getDeclareAttrName(),
+ acc::DeclareAttr::get(context,
+ acc::DataClauseAttr::get(context, clause)));
+}
+
+// This pass applies implicit declare actions for globals referenced in
+// OpenACC compute and routine regions.
+class ACCImplicitDeclare
+ : public acc::impl::ACCImplicitDeclareBase<ACCImplicitDeclare> {
+public:
+ using ACCImplicitDeclareBase<ACCImplicitDeclare>::ACCImplicitDeclareBase;
+
+ void runOnOperation() override {
+ ModuleOp mod = getOperation();
+ MLIRContext *context = &getContext();
+ acc::OpenACCSupport &accSupport = getAnalysis<acc::OpenACCSupport>();
+
+ // 1) Start off by hoisting any AddressOf operations out of acc region
+ // for any cases we do not want to `acc declare`. This is because we can
+ // rely on implicit data mapping in majority of cases without uselessly
+ // polluting the device globals.
+ mod.walk([&](Operation *op) {
+ TypeSwitch<Operation *, void>(op)
+ .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>(
+ [&](auto accOp) {
+ hoistNonConstantDirectUses(accOp, accSupport);
+ });
+ });
+
+ // 2) Collect global symbols which need to be `acc declare`d. Do it for
+ // compute regions, acc routine, and existing globals with the declare
+ // attribute.
+ SymbolTable symTab(mod);
+ GlobalOpSetT globalsToAccDeclare;
+ mod.walk([&](Operation *op) {
+ TypeSwitch<Operation *, void>(op)
+ .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>(
+ [&](auto accOp) {
+ collectGlobalsFromDeviceRegion(
+ accOp.getRegion(), globalsToAccDeclare, accSupport, symTab);
+ })
+ .Case<FunctionOpInterface>([&](auto func) {
+ if ((acc::isAccRoutine(func) ||
+ acc::isSpecializedAccRoutine(func)) &&
+ !func.isExternal())
+ collectGlobalsFromDeviceRegion(func.getFunctionBody(),
+ globalsToAccDeclare, accSupport,
+ symTab);
+ })
+ .Case<acc::GlobalVariableOpInterface>([&](auto globalVarOp) {
+ if (globalVarOp->getAttr(acc::getDeclareAttrName()))
+ if (Region *initRegion = globalVarOp.getInitRegion())
+ collectGlobalsFromDeviceRegion(*initRegion, globalsToAccDeclare,
+ accSupport, symTab);
+ })
+ .Case<acc::PrivateRecipeOp>([&](auto privateRecipe) {
+ if (hasRelevantRecipeUse(privateRecipe, mod)) {
+ collectGlobalsFromDeviceRegion(privateRecipe.getInitRegion(),
+ globalsToAccDeclare, accSupport,
+ symTab);
+ collectGlobalsFromDeviceRegion(privateRecipe.getDestroyRegion(),
+ globalsToAccDeclare, accSupport,
+ symTab);
+ }
+ })
+ .Case<acc::FirstprivateRecipeOp>([&](auto firstprivateRecipe) {
+ if (hasRelevantRecipeUse(firstprivateRecipe, mod)) {
+ collectGlobalsFromDeviceRegion(firstprivateRecipe.getInitRegion(),
+ globalsToAccDeclare, accSupport,
+ symTab);
+ collectGlobalsFromDeviceRegion(
+ firstprivateRecipe.getDestroyRegion(), globalsToAccDeclare,
+ accSupport, symTab);
+ collectGlobalsFromDeviceRegion(firstprivateRecipe.getCopyRegion(),
+ globalsToAccDeclare, accSupport,
+ symTab);
+ }
+ })
+ .Case<acc::ReductionRecipeOp>([&](auto reductionRecipe) {
+ if (hasRelevantRecipeUse(reductionRecipe, mod)) {
+ collectGlobalsFromDeviceRegion(reductionRecipe.getInitRegion(),
+ globalsToAccDeclare, accSupport,
+ symTab);
+ collectGlobalsFromDeviceRegion(
+ reductionRecipe.getCombinerRegion(), globalsToAccDeclare,
+ accSupport, symTab);
+ }
+ });
+ });
+
+ // 3) Finally, generate the appropriate declare actions needed to ensure
+ // this is considered for device global.
+ for (Operation *globalOp : globalsToAccDeclare) {
+ LLVM_DEBUG(
+ llvm::dbgs() << "Global is being `acc declare copyin`d: ";
+ globalOp->print(llvm::dbgs(),
+ OpPrintingFlags{}.skipRegions().enableDebugInfo());
+ llvm::dbgs() << "\n");
+
+ // Mark it as declare copyin.
+ addDeclareAttr(context, globalOp, acc::DataClause::acc_copyin);
+
+ // TODO: May need to create the global constructor which does the mapping
+ // action. It is not yet clear if this is needed yet (since the globals
+ // might just end up in the GPU image without requiring mapping via
+ // runtime).
+ }
+ }
+};
+
+} // namespace
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp
new file mode 100644
index 0000000..12efaf4
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp
@@ -0,0 +1,237 @@
+//===- ACCImplicitRoutine.cpp - OpenACC Implicit Routine Transform -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass implements the implicit rules described in OpenACC specification
+// for `Routine Directive` (OpenACC 3.4 spec, section 2.15.1).
+//
+// "If no explicit routine directive applies to a procedure whose definition
+// appears in the program unit being compiled, then the implementation applies
+// an implicit routine directive to that procedure if any of the following
+// conditions holds:
+// - The procedure is called or its address is accessed in a compute region."
+//
+// The specification further states:
+// "When the implementation applies an implicit routine directive to a
+// procedure, it must recursively apply implicit routine directives to other
+// procedures for which the above rules specify relevant dependencies. Such
+// dependencies can form a cycle, so the implementation must take care to avoid
+// infinite recursion."
+//
+// This pass implements these requirements by:
+// 1. Walking through all OpenACC compute constructs and functions already
+// marked with `acc routine` in the module and identifying function calls
+// within these regions.
+// 2. Creating implicit `acc.routine` operations for functions that don't
+// already have routine declarations.
+// 3. Recursively walking through all existing `acc routine` and creating
+// implicit routine operations for function calls within these routines,
+// while avoiding infinite recursion through proper tracking.
+//
+// Requirements:
+// -------------
+// To use this pass in a pipeline, the following requirements must be met:
+//
+// 1. Operation Interface Implementation: Operations that define functions
+// or call functions should implement `mlir::FunctionOpInterface` and
+// `mlir::CallOpInterface` respectively.
+//
+// 2. Analysis Registration (Optional): If custom behavior is needed for
+// determining if a symbol use is valid within GPU regions, the dialect
+// should pre-register the `acc::OpenACCSupport` analysis.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+
+#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include <queue>
+
+#define DEBUG_TYPE "acc-implicit-routine"
+
+namespace mlir {
+namespace acc {
+#define GEN_PASS_DEF_ACCIMPLICITROUTINE
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+} // namespace acc
+} // namespace mlir
+
+namespace {
+
+using namespace mlir;
+
+class ACCImplicitRoutine
+ : public acc::impl::ACCImplicitRoutineBase<ACCImplicitRoutine> {
+private:
+ unsigned routineCounter = 0;
+ static constexpr llvm::StringRef accRoutinePrefix = "acc_routine_";
+
+ // Count existing routine operations and update counter
+ void initRoutineCounter(ModuleOp module) {
+ module.walk([&](acc::RoutineOp routineOp) { routineCounter++; });
+ }
+
+ // Check if routine has a default bind clause or a device-type specific bind
+ // clause. Returns true if `acc routine` has a default bind clause or
+ // a device-type specific bind clause.
+ bool isACCRoutineBindDefaultOrDeviceType(acc::RoutineOp op,
+ acc::DeviceType deviceType) {
+ // Fast check to avoid device-type specific lookups.
+ if (!op.getBindIdName() && !op.getBindStrName())
+ return false;
+ return op.getBindNameValue().has_value() ||
+ op.getBindNameValue(deviceType).has_value();
+ }
+
+ // Generate a unique name for the routine and create the routine operation
+ acc::RoutineOp createRoutineOp(OpBuilder &builder, Location loc,
+ FunctionOpInterface &callee) {
+ std::string routineName =
+ (accRoutinePrefix + std::to_string(routineCounter++)).str();
+ auto routineOp = acc::RoutineOp::create(
+ builder, loc,
+ /* sym_name=*/builder.getStringAttr(routineName),
+ /* func_name=*/
+ mlir::SymbolRefAttr::get(builder.getContext(),
+ builder.getStringAttr(callee.getName())),
+ /* bindIdName=*/nullptr,
+ /* bindStrName=*/nullptr,
+ /* bindIdNameDeviceType=*/nullptr,
+ /* bindStrNameDeviceType=*/nullptr,
+ /* worker=*/nullptr,
+ /* vector=*/nullptr,
+ /* seq=*/nullptr,
+ /* nohost=*/nullptr,
+ /* implicit=*/builder.getUnitAttr(),
+ /* gang=*/nullptr,
+ /* gangDim=*/nullptr,
+ /* gangDimDeviceType=*/nullptr);
+
+ // Assert that the callee does not already have routine info attribute
+ assert(!callee->hasAttr(acc::getRoutineInfoAttrName()) &&
+ "function is already associated with a routine");
+
+ callee->setAttr(
+ acc::getRoutineInfoAttrName(),
+ mlir::acc::RoutineInfoAttr::get(
+ builder.getContext(),
+ {mlir::SymbolRefAttr::get(builder.getContext(),
+ builder.getStringAttr(routineName))}));
+ return routineOp;
+ }
+
+ // Used to walk through a compute region looking for function calls.
+ void
+ implicitRoutineForCallsInComputeRegions(Operation *op, SymbolTable &symTab,
+ mlir::OpBuilder &builder,
+ acc::OpenACCSupport &accSupport) {
+ op->walk([&](CallOpInterface callOp) {
+ if (!callOp.getCallableForCallee())
+ return;
+
+ auto calleeSymbolRef =
+ dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee());
+ // When call is done through ssa value, the callee is not a symbol.
+ // Skip it because we don't know the call target.
+ if (!calleeSymbolRef)
+ return;
+
+ auto callee = symTab.lookup<FunctionOpInterface>(
+ calleeSymbolRef.getLeafReference().str());
+ // If the callee does not exist or is already a valid symbol for GPU
+ // regions, skip it
+
+ assert(callee && "callee function must be found in symbol table");
+ if (accSupport.isValidSymbolUse(callOp.getOperation(), calleeSymbolRef))
+ return;
+ builder.setInsertionPoint(callee);
+ createRoutineOp(builder, callee.getLoc(), callee);
+ });
+ }
+
+ // Recursively handle calls within a routine operation
+ void implicitRoutineForCallsInRoutine(acc::RoutineOp routineOp,
+ mlir::OpBuilder &builder,
+ acc::OpenACCSupport &accSupport,
+ acc::DeviceType targetDeviceType) {
+ // When bind clause is used, it means that the target is different than the
+ // function to which the `acc routine` is used with. Skip this case to
+ // avoid implicitly recursively marking calls that would not end up on
+ // device.
+ if (isACCRoutineBindDefaultOrDeviceType(routineOp, targetDeviceType))
+ return;
+
+ SymbolTable symTab(routineOp->getParentOfType<ModuleOp>());
+ std::queue<acc::RoutineOp> routineQueue;
+ routineQueue.push(routineOp);
+ while (!routineQueue.empty()) {
+ auto currentRoutine = routineQueue.front();
+ routineQueue.pop();
+ auto func = symTab.lookup<FunctionOpInterface>(
+ currentRoutine.getFuncName().getLeafReference());
+ func.walk([&](CallOpInterface callOp) {
+ if (!callOp.getCallableForCallee())
+ return;
+
+ auto calleeSymbolRef =
+ dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee());
+ // When call is done through ssa value, the callee is not a symbol.
+ // Skip it because we don't know the call target.
+ if (!calleeSymbolRef)
+ return;
+
+ auto callee = symTab.lookup<FunctionOpInterface>(
+ calleeSymbolRef.getLeafReference().str());
+ // If the callee does not exist or is already a valid symbol for GPU
+ // regions, skip it
+ assert(callee && "callee function must be found in symbol table");
+ if (accSupport.isValidSymbolUse(callOp.getOperation(), calleeSymbolRef))
+ return;
+ builder.setInsertionPoint(callee);
+ auto newRoutineOp = createRoutineOp(builder, callee.getLoc(), callee);
+ routineQueue.push(newRoutineOp);
+ });
+ }
+ }
+
+public:
+ using ACCImplicitRoutineBase<ACCImplicitRoutine>::ACCImplicitRoutineBase;
+
+ void runOnOperation() override {
+ auto module = getOperation();
+ mlir::OpBuilder builder(module.getContext());
+ SymbolTable symTab(module);
+ initRoutineCounter(module);
+
+ acc::OpenACCSupport &accSupport = getAnalysis<acc::OpenACCSupport>();
+
+ // Handle compute regions
+ module.walk([&](Operation *op) {
+ if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(op))
+ implicitRoutineForCallsInComputeRegions(op, symTab, builder,
+ accSupport);
+ });
+
+ // Use the device type option from the pass options.
+ acc::DeviceType targetDeviceType = deviceType;
+
+ // Handle existing routines
+ module.walk([&](acc::RoutineOp routineOp) {
+ implicitRoutineForCallsInRoutine(routineOp, builder, accSupport,
+ targetDeviceType);
+ });
+ }
+};
+
+} // namespace
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp
new file mode 100644
index 0000000..f41ce276
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp
@@ -0,0 +1,117 @@
+//===- ACCLegalizeSerial.cpp - Legalize ACC Serial region -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass converts acc.serial into acc.parallel with num_gangs(1)
+// num_workers(1) vector_length(1).
+//
+// This transformation simplifies processing of acc regions by unifying the
+// handling of serial and parallel constructs. Since an OpenACC serial region
+// executes sequentially (like a parallel region with a single gang, worker, and
+// vector), this conversion is semantically equivalent while enabling code reuse
+// in later compilation stages.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Region.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+namespace acc {
+#define GEN_PASS_DEF_ACCLEGALIZESERIAL
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+} // namespace acc
+} // namespace mlir
+
+#define DEBUG_TYPE "acc-legalize-serial"
+
+namespace {
+using namespace mlir;
+
+struct ACCSerialOpConversion : public OpRewritePattern<acc::SerialOp> {
+ using OpRewritePattern<acc::SerialOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(acc::SerialOp serialOp,
+ PatternRewriter &rewriter) const override {
+
+ const Location loc = serialOp.getLoc();
+
+ // Create a container holding the constant value of 1 for use as the
+ // num_gangs, num_workers, and vector_length attributes.
+ llvm::SmallVector<mlir::Value> numValues;
+ auto value = arith::ConstantIntOp::create(rewriter, loc, 1, 32);
+ numValues.push_back(value);
+
+ // Since num_gangs is specified as both attributes and values, create a
+ // segment attribute.
+ llvm::SmallVector<int32_t> numGangsSegments;
+ numGangsSegments.push_back(numValues.size());
+ auto gangSegmentsAttr = rewriter.getDenseI32ArrayAttr(numGangsSegments);
+
+ // Create a device_type attribute set to `none` which ensures that
+ // the parallel dimensions specification applies to the default clauses.
+ llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
+ auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+ rewriter.getContext(), mlir::acc::DeviceType::None);
+ crtDeviceTypes.push_back(crtDeviceTypeAttr);
+ auto devTypeAttr =
+ mlir::ArrayAttr::get(rewriter.getContext(), crtDeviceTypes);
+
+ LLVM_DEBUG(llvm::dbgs() << "acc.serial OP: " << serialOp << "\n");
+
+ // Create a new acc.parallel op with the same operands - except include the
+ // num_gangs, num_workers, and vector_length attributes.
+ acc::ParallelOp parOp = acc::ParallelOp::create(
+ rewriter, loc, serialOp.getAsyncOperands(),
+ serialOp.getAsyncOperandsDeviceTypeAttr(), serialOp.getAsyncOnlyAttr(),
+ serialOp.getWaitOperands(), serialOp.getWaitOperandsSegmentsAttr(),
+ serialOp.getWaitOperandsDeviceTypeAttr(),
+ serialOp.getHasWaitDevnumAttr(), serialOp.getWaitOnlyAttr(), numValues,
+ gangSegmentsAttr, devTypeAttr, numValues, devTypeAttr, numValues,
+ devTypeAttr, serialOp.getIfCond(), serialOp.getSelfCond(),
+ serialOp.getSelfAttrAttr(), serialOp.getReductionOperands(),
+ serialOp.getPrivateOperands(), serialOp.getFirstprivateOperands(),
+ serialOp.getDataClauseOperands(), serialOp.getDefaultAttrAttr(),
+ serialOp.getCombinedAttr());
+
+ parOp.getRegion().takeBody(serialOp.getRegion());
+
+ LLVM_DEBUG(llvm::dbgs() << "acc.parallel OP: " << parOp << "\n");
+ rewriter.replaceOp(serialOp, parOp);
+
+ return success();
+ }
+};
+
+class ACCLegalizeSerial
+ : public mlir::acc::impl::ACCLegalizeSerialBase<ACCLegalizeSerial> {
+public:
+ using ACCLegalizeSerialBase<ACCLegalizeSerial>::ACCLegalizeSerialBase;
+ void runOnOperation() override {
+ func::FuncOp funcOp = getOperation();
+ MLIRContext *context = funcOp.getContext();
+ RewritePatternSet patterns(context);
+ patterns.insert<ACCSerialOpConversion>(context);
+ (void)applyPatternsGreedily(funcOp, std::move(patterns));
+ }
+};
+
+} // namespace
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
index f8fff59..10a1796 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
@@ -1,5 +1,8 @@
add_mlir_dialect_library(MLIROpenACCTransforms
ACCImplicitData.cpp
+ ACCImplicitDeclare.cpp
+ ACCImplicitRoutine.cpp
+ ACCLegalizeSerial.cpp
LegalizeDataValues.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
index fbac28e..7f27b44 100644
--- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
+++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
@@ -9,8 +9,13 @@
#include "mlir/Dialect/OpenACC/OpenACCUtils.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/Intrinsics.h"
#include "llvm/Support/Casting.h"
mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region &region) {
@@ -155,3 +160,109 @@ mlir::Value mlir::acc::getBaseEntity(mlir::Value val) {
return val;
}
+
+bool mlir::acc::isValidSymbolUse(mlir::Operation *user,
+ mlir::SymbolRefAttr symbol,
+ mlir::Operation **definingOpPtr) {
+ mlir::Operation *definingOp =
+ mlir::SymbolTable::lookupNearestSymbolFrom(user, symbol);
+
+ // If there are no defining ops, we have no way to ensure validity because
+ // we cannot check for any attributes.
+ if (!definingOp)
+ return false;
+
+ if (definingOpPtr)
+ *definingOpPtr = definingOp;
+
+ // Check if the defining op is a recipe (private, reduction, firstprivate).
+ // Recipes are valid as they get materialized before being offloaded to
+ // device. They are only instructions for how to materialize.
+ if (mlir::isa<mlir::acc::PrivateRecipeOp, mlir::acc::ReductionRecipeOp,
+ mlir::acc::FirstprivateRecipeOp>(definingOp))
+ return true;
+
+ // Check if the defining op is a function
+ if (auto func =
+ mlir::dyn_cast_if_present<mlir::FunctionOpInterface>(definingOp)) {
+ // If this symbol is actually an acc routine - then it is expected for it
+ // to be offloaded - therefore it is valid.
+ if (func->hasAttr(mlir::acc::getRoutineInfoAttrName()))
+ return true;
+
+ // If this symbol is a call to an LLVM intrinsic, then it is likely valid.
+ // Check the following:
+ // 1. The function is private
+ // 2. The function has no body
+ // 3. Name starts with "llvm."
+ // 4. The function's name is a valid LLVM intrinsic name
+ if (func.getVisibility() == mlir::SymbolTable::Visibility::Private &&
+ func.getFunctionBody().empty() && func.getName().starts_with("llvm.") &&
+ llvm::Intrinsic::lookupIntrinsicID(func.getName()) !=
+ llvm::Intrinsic::not_intrinsic)
+ return true;
+ }
+
+ // A declare attribute is needed for symbol references.
+ bool hasDeclare = definingOp->hasAttr(mlir::acc::getDeclareAttrName());
+ return hasDeclare;
+}
+
+llvm::SmallVector<mlir::Value>
+mlir::acc::getDominatingDataClauses(mlir::Operation *computeConstructOp,
+ mlir::DominanceInfo &domInfo,
+ mlir::PostDominanceInfo &postDomInfo) {
+ llvm::SmallSetVector<mlir::Value, 8> dominatingDataClauses;
+
+ llvm::TypeSwitch<mlir::Operation *>(computeConstructOp)
+ .Case<mlir::acc::ParallelOp, mlir::acc::KernelsOp, mlir::acc::SerialOp>(
+ [&](auto op) {
+ for (auto dataClause : op.getDataClauseOperands()) {
+ dominatingDataClauses.insert(dataClause);
+ }
+ })
+ .Default([](mlir::Operation *) {});
+
+ // Collect the data clauses from enclosing data constructs.
+ mlir::Operation *currParentOp = computeConstructOp->getParentOp();
+ while (currParentOp) {
+ if (mlir::isa<mlir::acc::DataOp>(currParentOp)) {
+ for (auto dataClause : mlir::dyn_cast<mlir::acc::DataOp>(currParentOp)
+ .getDataClauseOperands()) {
+ dominatingDataClauses.insert(dataClause);
+ }
+ }
+ currParentOp = currParentOp->getParentOp();
+ }
+
+ // Find the enclosing function/subroutine
+ auto funcOp =
+ computeConstructOp->getParentOfType<mlir::FunctionOpInterface>();
+ if (!funcOp)
+ return dominatingDataClauses.takeVector();
+
+ // Walk the function to find `acc.declare_enter`/`acc.declare_exit` pairs that
+ // dominate and post-dominate the compute construct and add their data
+ // clauses to the list.
+ funcOp->walk([&](mlir::acc::DeclareEnterOp declareEnterOp) {
+ if (domInfo.dominates(declareEnterOp.getOperation(), computeConstructOp)) {
+ // Collect all `acc.declare_exit` ops for this token.
+ llvm::SmallVector<mlir::acc::DeclareExitOp> exits;
+ for (auto *user : declareEnterOp.getToken().getUsers())
+ if (auto declareExit = mlir::dyn_cast<mlir::acc::DeclareExitOp>(user))
+ exits.push_back(declareExit);
+
+ // Only add clauses if every `acc.declare_exit` op post-dominates the
+ // compute construct.
+ if (!exits.empty() &&
+ llvm::all_of(exits, [&](mlir::acc::DeclareExitOp exitOp) {
+ return postDomInfo.postDominates(exitOp, computeConstructOp);
+ })) {
+ for (auto dataClause : declareEnterOp.getDataClauseOperands())
+ dominatingDataClauses.insert(dataClause);
+ }
+ }
+ });
+
+ return dominatingDataClauses.takeVector();
+}
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 1b069c6..103295d 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -617,6 +617,7 @@ parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
break;
case ClauseScheduleKind::Auto:
case ClauseScheduleKind::Runtime:
+ case ClauseScheduleKind::Distribute:
chunkSize = std::nullopt;
}
@@ -1817,6 +1818,9 @@ static ParseResult parseMapClause(OpAsmParser &parser,
if (mapTypeMod == "ref_ptr_ptee")
mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;
+ if (mapTypeMod == "is_device_ptr")
+ mapTypeBits |= ClauseMapFlags::is_device_ptr;
+
return success();
};
@@ -1886,6 +1890,8 @@ static void printMapClause(OpAsmPrinter &p, Operation *op,
mapTypeStrs.push_back("ref_ptee");
if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr_ptee))
mapTypeStrs.push_back("ref_ptr_ptee");
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::is_device_ptr))
+ mapTypeStrs.push_back("is_device_ptr");
if (mapFlags == ClauseMapFlags::none)
mapTypeStrs.push_back("none");
@@ -2824,6 +2830,7 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
ArrayRef<NamedAttribute> attributes) {
build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
/*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
+ /*linear_var_types*/ nullptr,
/*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
/*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
/*private_needs_barrier=*/false,
@@ -2842,8 +2849,8 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
WsloopOp::build(
builder, state,
/*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
- clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
- clauses.ordered, clauses.privateVars,
+ clauses.linearStepVars, clauses.linearVarTypes, clauses.nowait,
+ clauses.order, clauses.orderMod, clauses.ordered, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
clauses.reductionMod, clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
@@ -2888,17 +2895,16 @@ LogicalResult WsloopOp::verifyRegions() {
void SimdOp::build(OpBuilder &builder, OperationState &state,
const SimdOperands &clauses) {
MLIRContext *ctx = builder.getContext();
- // TODO Store clauses in op: linearVars, linearStepVars
- SimdOp::build(builder, state, clauses.alignedVars,
- makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
- /*linear_vars=*/{}, /*linear_step_vars=*/{},
- clauses.nontemporalVars, clauses.order, clauses.orderMod,
- clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
- clauses.privateNeedsBarrier, clauses.reductionMod,
- clauses.reductionVars,
- makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
- makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
- clauses.simdlen);
+ SimdOp::build(
+ builder, state, clauses.alignedVars,
+ makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
+ clauses.linearVars, clauses.linearStepVars, clauses.linearVarTypes,
+ clauses.nontemporalVars, clauses.order, clauses.orderMod,
+ clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
+ clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+ makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
+ clauses.simdlen);
}
LogicalResult SimdOp::verify() {
diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index 423e1c3..b111117 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -19,5 +19,5 @@ add_mlir_dialect_library(MLIRSCFDialect
MLIRSideEffectInterfaces
MLIRTensorDialect
MLIRValueBoundsOpInterface
+ MLIRTransformUtils
)
-
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 881e256..c4bd31f 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -26,6 +26,7 @@
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
+#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
@@ -3687,6 +3688,133 @@ LogicalResult scf::WhileOp::verify() {
}
namespace {
+/// Move a scf.if op that is directly before the scf.condition op in the while
+/// before region, and whose condition matches the condition of the
+/// scf.condition op, down into the while after region.
+///
+/// scf.while (..) : (...) -> ... {
+/// %additional_used_values = ...
+/// %cond = ...
+/// ...
+/// %res = scf.if %cond -> (...) {
+/// use(%additional_used_values)
+/// ... // then block
+/// scf.yield %then_value
+/// } else {
+/// scf.yield %else_value
+/// }
+/// scf.condition(%cond) %res, ...
+/// } do {
+/// ^bb0(%res_arg, ...):
+/// use(%res_arg)
+/// ...
+///
+/// becomes
+/// scf.while (..) : (...) -> ... {
+/// %additional_used_values = ...
+/// %cond = ...
+/// ...
+/// scf.condition(%cond) %else_value, ..., %additional_used_values
+/// } do {
+/// ^bb0(%res_arg ..., %additional_args): :
+/// use(%additional_args)
+/// ... // if then block
+/// use(%then_value)
+/// ...
+struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
+ using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(scf::WhileOp op,
+ PatternRewriter &rewriter) const override {
+ auto conditionOp = op.getConditionOp();
+
+ // Only support ifOp right before the condition at the moment. Relaxing this
+ // would require to:
+ // - check that the body does not have side-effects conflicting with
+ // operations between the if and the condition.
+ // - check that results of the if operation are only used as arguments to
+ // the condition.
+ auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
+
+ // Check that the ifOp is directly before the conditionOp and that it
+ // matches the condition of the conditionOp. Also ensure that the ifOp has
+ // no else block with content, as that would complicate the transformation.
+ // TODO: support else blocks with content.
+ if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
+ (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
+ return failure();
+
+ assert((ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
+ *ifOp->user_begin() == conditionOp)) &&
+ "ifOp has unexpected uses");
+
+ Location loc = op.getLoc();
+
+ // Replace uses of ifOp results in the conditionOp with the yielded values
+ // from the ifOp branches.
+ for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
+ auto it = llvm::find(ifOp->getResults(), arg);
+ if (it != ifOp->getResults().end()) {
+ size_t ifOpIdx = it.getIndex();
+ Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
+ Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
+
+ rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
+ rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
+ }
+ }
+
+ // Collect additional used values from before region.
+ SetVector<Value> additionalUsedValuesSet;
+ visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) {
+ if (&op.getBefore() == operand->get().getParentRegion())
+ additionalUsedValuesSet.insert(operand->get());
+ });
+
+ // Create new whileOp with additional used values as results.
+ auto additionalUsedValues = additionalUsedValuesSet.getArrayRef();
+ auto additionalValueTypes = llvm::map_to_vector(
+ additionalUsedValues, [](Value val) { return val.getType(); });
+ size_t additionalValueSize = additionalUsedValues.size();
+ SmallVector<Type> newResultTypes(op.getResultTypes());
+ newResultTypes.append(additionalValueTypes);
+
+ auto newWhileOp =
+ scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
+
+ rewriter.modifyOpInPlace(newWhileOp, [&] {
+ newWhileOp.getBefore().takeBody(op.getBefore());
+ newWhileOp.getAfter().takeBody(op.getAfter());
+ newWhileOp.getAfter().addArguments(
+ additionalValueTypes,
+ SmallVector<Location>(additionalValueSize, loc));
+ });
+
+ rewriter.modifyOpInPlace(conditionOp, [&] {
+ conditionOp.getArgsMutable().append(additionalUsedValues);
+ });
+
+ // Replace uses of additional used values inside the ifOp then region with
+ // the whileOp after region arguments.
+ rewriter.replaceUsesWithIf(
+ additionalUsedValues,
+ newWhileOp.getAfterArguments().take_back(additionalValueSize),
+ [&](OpOperand &use) {
+ return ifOp.getThenRegion().isAncestor(
+ use.getOwner()->getParentRegion());
+ });
+
+ // Inline ifOp then region into new whileOp after region.
+ rewriter.eraseOp(ifOp.thenYield());
+ rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
+ newWhileOp.getAfterBody()->begin());
+ rewriter.eraseOp(ifOp);
+ rewriter.replaceOp(op,
+ newWhileOp->getResults().drop_back(additionalValueSize));
+ return success();
+ }
+};
+
/// Replace uses of the condition within the do block with true, since otherwise
/// the block would not be evaluated.
///
@@ -4343,7 +4471,7 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
LogicalResult matchAndRewrite(WhileOp loop,
PatternRewriter &rewriter) const override {
- auto oldBefore = loop.getBeforeBody();
+ auto *oldBefore = loop.getBeforeBody();
ConditionOp oldTerm = loop.getConditionOp();
ValueRange beforeArgs = oldBefore->getArguments();
ValueRange termArgs = oldTerm.getArgs();
@@ -4364,7 +4492,7 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
beforeArgs);
}
- auto oldAfter = loop.getAfterBody();
+ auto *oldAfter = loop.getAfterBody();
SmallVector<Type> newResultTypes(beforeArgs.size());
for (auto &&[i, j] : llvm::enumerate(*mapping))
@@ -4373,8 +4501,8 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
auto newLoop = WhileOp::create(
rewriter, loop.getLoc(), newResultTypes, loop.getInits(),
/*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
- auto newBefore = newLoop.getBeforeBody();
- auto newAfter = newLoop.getAfterBody();
+ auto *newBefore = newLoop.getBeforeBody();
+ auto *newAfter = newLoop.getAfterBody();
SmallVector<Value> newResults(beforeArgs.size());
SmallVector<Value> newAfterArgs(beforeArgs.size());
@@ -4399,7 +4527,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<RemoveLoopInvariantArgsFromBeforeBlock,
RemoveLoopInvariantValueYielded, WhileConditionTruth,
WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
- WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
+ WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 29b770f..009c2c3 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1092,7 +1092,7 @@ static LogicalResult addInitOperandsToLoopNest(
for (auto [outerLoop, innerLoop] :
llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
// Again assume that all the outer loops are scf.for operations.
- auto outerForLoop = cast<scf::ForOp>(outerLoop);
+ auto outerForLoop = cast<scf::ForOp>(outerLoop.getOperation());
auto outerLoopYield =
cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
SmallVector<Value> newYields =
@@ -2184,61 +2184,24 @@ cloneAsInsertSlices(RewriterBase &rewriter,
return clonedSlices;
}
-/// Implementation of fusing consumer of a single slice by computing the
-/// slice of the consumer in-place for scf loop.
-FailureOr<scf::SCFFuseConsumerOfSliceResult>
-mlir::scf::tileAndFuseConsumerOfSlices(
- RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
- MutableArrayRef<LoopLikeOpInterface> loops) {
- if (candidateSlices.empty()) {
- return rewriter.notifyMatchFailure(
- rewriter.getUnknownLoc(),
- "no candidate slices provided for consumer fusion");
- }
- // Return if `loops` is empty, return an error for now. Caller is expected
- // to handle this case.
- if (loops.empty()) {
- return rewriter.notifyMatchFailure(
- candidateSlices.front(),
- "cannot call tile and fuse consumer with an empty loop nest");
- }
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSlicesImpl(RewriterBase &rewriter, Operation *consumerOp,
+ ArrayRef<OpOperand *> consumerOpOperands,
+ ArrayRef<Operation *> candidateSlices,
+ MutableArrayRef<LoopLikeOpInterface> loops) {
+ assert(!loops.empty() && "expected loops to be not empty");
- if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
- llvm::all_of(candidateSlices,
- llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
+ // 1. Check assumption for loop with `reorderOperations` disabled.
+ if (failed(checkAssumptionForLoop(loops.front(), consumerOp, false))) {
return rewriter.notifyMatchFailure(
- candidateSlices.front(),
- "candidates slices need to be all `tensor.extract_slice`s or "
- "`tensor.parallel_insert_slice`s");
- }
-
- // 1. Get the consumer of scf.for for the result yielded by
- // tensor.insert_slice/parallel_insert_slice.
- SmallVector<OpOperand *> consumerOpOperands;
- Operation *consumerOp;
- {
- FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
- getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
- if (failed(maybeConsumerOpOperand)) {
- return rewriter.notifyMatchFailure(candidateSlices.front(),
- "could not fetch consumer to fuse");
- }
- std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
- consumerOp = consumerOpOperands.front()->getOwner();
+ loops.front(), "the first user of loop should not dominate any define "
+ "of consumer operand(s)");
}
LoopLikeOpInterface outerMostLoop = loops.front();
LoopLikeOpInterface innerMostLoop = loops.back();
- // Check assumption for loop with `reorderOperations` disabled.
- if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
- return rewriter.notifyMatchFailure(
- outerMostLoop, "the first user of loop should not dominate any define "
- "of consumer operand(s)");
- }
-
OpBuilder::InsertionGuard g(rewriter);
-
// 2. Check consumer is not using scf loop's output as init.
auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
if (!dstOp)
@@ -2428,11 +2391,166 @@ mlir::scf::tileAndFuseConsumerOfSlices(
llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
});
+ auto consumerOpOperandsVec = llvm::to_vector(consumerOpOperands);
return scf::SCFFuseConsumerOfSliceResult{
- std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands),
+ std::move(consumerOpOperandsVec), std::move(tiledAndFusedOpOperands),
std::move(tileAndFuseResult->tiledOps)};
}
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf loop.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumerOfSlices(
+ RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
+ MutableArrayRef<LoopLikeOpInterface> loops) {
+ if (candidateSlices.empty()) {
+ return rewriter.notifyMatchFailure(
+ rewriter.getUnknownLoc(),
+ "no candidate slices provided for consumer fusion");
+ }
+ // Return if `loops` is empty, return an error for now. Caller is expected
+ // to handle this case.
+ if (loops.empty()) {
+ return rewriter.notifyMatchFailure(
+ candidateSlices.front(),
+ "cannot call tile and fuse consumer with an empty loop nest");
+ }
+
+ if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
+ llvm::all_of(candidateSlices,
+ llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
+ return rewriter.notifyMatchFailure(
+ candidateSlices.front(),
+ "candidates slices need to be all `tensor.extract_slice`s or "
+ "`tensor.parallel_insert_slice`s");
+ }
+
+ // Get the consumer of scf.for for the result yielded by
+ // tensor.insert_slice/parallel_insert_slice.
+ FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperands =
+ getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
+ if (failed(maybeConsumerOpOperands)) {
+ return rewriter.notifyMatchFailure(candidateSlices.front(),
+ "could not fetch consumer to fuse");
+ }
+ Operation *consumerOp = maybeConsumerOpOperands->front()->getOwner();
+
+ return tileAndFuseConsumerOfSlicesImpl(rewriter, consumerOp,
+ maybeConsumerOpOperands.value(),
+ candidateSlices, loops);
+}
+
+/// For a given `result` of a `forallOp` return the
+/// `tensor.parallel_insert_slice` op (or combining op) that is used to
+/// construct this result.
+static std::optional<Operation *>
+getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result) {
+ if (result.getOwner() != forallOp)
+ return std::nullopt;
+ BlockArgument bbArg = forallOp.getTiedBlockArgument(result);
+ SmallVector<Operation *> combiningOps = forallOp.getCombiningOps(bbArg);
+ // If the number of combining ops is not 1, then this is unexpected. Return
+ // nullopt.
+ if (combiningOps.size() != 1)
+ return std::nullopt;
+ return combiningOps[0];
+}
+
+/// For a given result of the loop nest that is a tiled loop nest, return the
+/// insert slice-like op that is used for consumer fusion
+static std::optional<Operation *>
+getProducingInsertSliceLikeOp(OpResult result,
+ ArrayRef<LoopLikeOpInterface> loops) {
+ assert(!loops.empty() && "Expected loops to be not empty");
+ LoopLikeOpInterface outerMostLoop = loops.front();
+ if (auto forallOp = dyn_cast<scf::ForallOp>(outerMostLoop.getOperation())) {
+ assert(loops.size() == 1 &&
+ "expected only a single loop when tiling using scf.forall");
+ return getProducingParallelInsertSlice(forallOp, result);
+ }
+ // Assume that the loop nest is a nested `scf.for` that is created through
+ // tiling and retrieve the `tensor.insert_slice` operation used to construct
+ // the result.
+ while (loops.size() != 1) {
+ LoopLikeOpInterface loop = loops.front();
+ if (result.getOwner() != loop)
+ return std::nullopt;
+ auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
+ if (!forOp)
+ return std::nullopt;
+ auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+ auto innerForResult =
+ dyn_cast<OpResult>(yieldOp.getOperand(result.getResultNumber()));
+ if (!innerForResult)
+ return std::nullopt;
+ result = innerForResult;
+ loops = loops.drop_front();
+ }
+ LoopLikeOpInterface loop = loops.front();
+ if (result.getOwner() != loop)
+ return std::nullopt;
+ auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
+ if (!forOp)
+ return std::nullopt;
+ auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+ auto insertSliceOp = yieldOp.getOperand(result.getResultNumber())
+ .getDefiningOp<tensor::InsertSliceOp>();
+ if (!insertSliceOp)
+ return std::nullopt;
+ return insertSliceOp;
+}
+
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer,
+ MutableArrayRef<LoopLikeOpInterface> loops) {
+ if (!isa<TilingInterface>(consumer)) {
+ return rewriter.notifyMatchFailure(
+ consumer, "unhandled consumer that does not implement TilingInterface");
+ }
+
+ // Return if `loops` is empty, return an error for now. Caller is expected
+ // to handle this case.
+ if (loops.empty()) {
+ return rewriter.notifyMatchFailure(
+ consumer, "cannot call tile and fuse consumer with an empty loop nest");
+ }
+
+ LoopLikeOpInterface outermostLoop = loops.front();
+
+ // Collect the operands of the consumer that come from the outermost loop of
+ // the loop nest.
+ SmallVector<OpOperand *> consumerFusableOperands;
+ for (OpOperand &opOperand : consumer->getOpOperands()) {
+ if (opOperand.get().getDefiningOp() == outermostLoop) {
+ consumerFusableOperands.push_back(&opOperand);
+ }
+ }
+
+ // Nothing to fuse. Just return an empty set.
+ if (consumerFusableOperands.empty()) {
+ return mlir::scf::SCFFuseConsumerOfSliceResult{consumerFusableOperands,
+ SmallVector<OpOperand *>{},
+ SmallVector<Operation *>{}};
+ }
+
+ // Collect the relevant tensor.insert_slice/tensor.parallel_insert_slices
+ // for fusion.
+ SmallVector<Operation *> candidateSlices;
+ candidateSlices.reserve(consumerFusableOperands.size());
+ for (OpOperand *opOperand : consumerFusableOperands) {
+ std::optional<Operation *> slice =
+ getProducingInsertSliceLikeOp(cast<OpResult>(opOperand->get()), loops);
+ if (!slice) {
+ return rewriter.notifyMatchFailure(
+ consumer,
+ "couldnt find producing insert-slice like operation for operand");
+ }
+ candidateSlices.push_back(slice.value());
+ }
+ return tileAndFuseConsumerOfSlicesImpl(
+ rewriter, consumer, consumerFusableOperands, candidateSlices, loops);
+}
+
//===----------------------------------------------------------------------===//
// lowerToLoopsUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index f0b46e6..a846d7e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -220,6 +220,89 @@ MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
}
//===----------------------------------------------------------------------===//
+// spirv.Switch
+//===----------------------------------------------------------------------===//
+
+void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector,
+ Block *defaultTarget, ValueRange defaultOperands,
+ DenseIntElementsAttr literals, BlockRange targets,
+ ArrayRef<ValueRange> targetOperands) {
+ build(builder, result, selector, defaultOperands, targetOperands, literals,
+ defaultTarget, targets);
+}
+
+void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector,
+ Block *defaultTarget, ValueRange defaultOperands,
+ ArrayRef<APInt> literals, BlockRange targets,
+ ArrayRef<ValueRange> targetOperands) {
+ DenseIntElementsAttr literalsAttr;
+ if (!literals.empty()) {
+ ShapedType literalType = VectorType::get(
+ static_cast<int64_t>(literals.size()), selector.getType());
+ literalsAttr = DenseIntElementsAttr::get(literalType, literals);
+ }
+ build(builder, result, selector, defaultTarget, defaultOperands, literalsAttr,
+ targets, targetOperands);
+}
+
+void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector,
+ Block *defaultTarget, ValueRange defaultOperands,
+ ArrayRef<int32_t> literals, BlockRange targets,
+ ArrayRef<ValueRange> targetOperands) {
+ DenseIntElementsAttr literalsAttr;
+ if (!literals.empty()) {
+ ShapedType literalType = VectorType::get(
+ static_cast<int64_t>(literals.size()), selector.getType());
+ literalsAttr = DenseIntElementsAttr::get(literalType, literals);
+ }
+ build(builder, result, selector, defaultTarget, defaultOperands, literalsAttr,
+ targets, targetOperands);
+}
+
+LogicalResult SwitchOp::verify() {
+ std::optional<DenseIntElementsAttr> literals = getLiterals();
+ BlockRange targets = getTargets();
+
+ if (!literals && targets.empty())
+ return success();
+
+ Type selectorType = getSelector().getType();
+ Type literalType = literals->getType().getElementType();
+ if (literalType != selectorType)
+ return emitOpError() << "'selector' type (" << selectorType
+ << ") should match literals type (" << literalType
+ << ")";
+
+ if (literals && literals->size() != static_cast<int64_t>(targets.size()))
+ return emitOpError() << "number of literals (" << literals->size()
+ << ") should match number of targets ("
+ << targets.size() << ")";
+ return success();
+}
+
+SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
+ assert(index < getNumSuccessors() && "invalid successor index");
+ return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
+ : getTargetOperandsMutable(index - 1));
+}
+
+Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
+ std::optional<DenseIntElementsAttr> literals = getLiterals();
+
+ if (!literals)
+ return getDefaultTarget();
+
+ SuccessorRange targets = getTargets();
+ if (auto value = dyn_cast_or_null<IntegerAttr>(operands.front())) {
+ for (auto [index, literal] : llvm::enumerate(literals->getValues<APInt>()))
+ if (literal == value.getValue())
+ return targets[index];
+ return getDefaultTarget();
+ }
+ return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
// spirv.mlir.loop
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
index 2f3a28f..8575487 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
@@ -81,6 +81,83 @@ static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
}
}
+/// Adapted from the cf.switch implementation.
+/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
+/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
+static ParseResult parseSwitchOpCases(
+ OpAsmParser &parser, Type &selectorType, Block *&defaultTarget,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands,
+ SmallVectorImpl<Type> &defaultOperandTypes, DenseIntElementsAttr &literals,
+ SmallVectorImpl<Block *> &targets,
+ SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>>
+ &targetOperands,
+ SmallVectorImpl<SmallVector<Type>> &targetOperandTypes) {
+ if (parser.parseKeyword("default") || parser.parseColon() ||
+ parser.parseSuccessor(defaultTarget))
+ return failure();
+ if (succeeded(parser.parseOptionalLParen())) {
+ if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
+ /*allowResultNumber=*/false) ||
+ parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
+ return failure();
+ }
+
+ SmallVector<APInt> values;
+ unsigned bitWidth = selectorType.getIntOrFloatBitWidth();
+ while (succeeded(parser.parseOptionalComma())) {
+ int64_t value = 0;
+ if (failed(parser.parseInteger(value)))
+ return failure();
+ values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
+
+ Block *target;
+ SmallVector<OpAsmParser::UnresolvedOperand> operands;
+ SmallVector<Type> operandTypes;
+ if (failed(parser.parseColon()) || failed(parser.parseSuccessor(target)))
+ return failure();
+ if (succeeded(parser.parseOptionalLParen())) {
+ if (failed(parser.parseOperandList(operands,
+ OpAsmParser::Delimiter::None)) ||
+ failed(parser.parseColonTypeList(operandTypes)) ||
+ failed(parser.parseRParen()))
+ return failure();
+ }
+ targets.push_back(target);
+ targetOperands.emplace_back(operands);
+ targetOperandTypes.emplace_back(operandTypes);
+ }
+
+ if (!values.empty()) {
+ ShapedType literalType =
+ VectorType::get(static_cast<int64_t>(values.size()), selectorType);
+ literals = DenseIntElementsAttr::get(literalType, values);
+ }
+ return success();
+}
+
+static void
+printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type selectorType,
+ Block *defaultTarget, OperandRange defaultOperands,
+ TypeRange defaultOperandTypes, DenseIntElementsAttr literals,
+ SuccessorRange targets, OperandRangeRange targetOperands,
+ const TypeRangeRange &targetOperandTypes) {
+ p << " default: ";
+ p.printSuccessorAndUseList(defaultTarget, defaultOperands);
+
+ if (!literals)
+ return;
+
+ for (auto [index, literal] : llvm::enumerate(literals.getValues<APInt>())) {
+ p << ',';
+ p.printNewline();
+ p << " ";
+ p << literal.getLimitedValue();
+ p << ": ";
+ p.printSuccessorAndUseList(targets[index], targetOperands[index]);
+ }
+ p.printNewline();
+}
+
} // namespace mlir::spirv
// TablenGen'erated operation definitions.
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index cb9b7f6..f07307f 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -502,6 +502,11 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
<< type << " illegal: cannot handle zero-element tensors\n");
return nullptr;
}
+ if (arrayElemCount > std::numeric_limits<unsigned>::max()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << type << " illegal: cannot fit tensor into target type\n");
+ return nullptr;
+ }
Type arrayElemType = convertScalarType(targetEnv, options, scalarType);
if (!arrayElemType)
diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
index 645cbff..5941f7d 100644
--- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
+++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
@@ -476,38 +476,37 @@ void GridShapeOp::getAsmResultNames(
//===----------------------------------------------------------------------===//
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
- FlatSymbolRefAttr grid,
- ArrayRef<GridAxesAttr> split_axes,
- ArrayRef<int64_t> static_halos,
- ArrayRef<int64_t> static_offsets) {
+ FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> splitAxes,
+ ArrayRef<int64_t> staticHalos,
+ ArrayRef<int64_t> staticOffsets) {
return build(
- b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes),
- ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
- ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
+ b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), splitAxes),
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), {},
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticOffsets), {});
}
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
- llvm::StringRef grid, ArrayRef<GridAxesAttr> split_axes,
- ArrayRef<int64_t> static_halos,
- ArrayRef<int64_t> static_offsets) {
+ llvm::StringRef grid, ArrayRef<GridAxesAttr> splitAxes,
+ ArrayRef<int64_t> staticHalos,
+ ArrayRef<int64_t> staticOffsets) {
return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), grid),
- GridAxesArrayAttr::get(b.getContext(), split_axes),
- ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
- ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets),
+ GridAxesArrayAttr::get(b.getContext(), splitAxes),
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), {},
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticOffsets),
{});
}
void ShardingOp::build(
::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
- FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> split_axes,
- ::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes,
- ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) {
+ FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> splitAxes,
+ ::mlir::ArrayRef<::mlir::OpFoldResult> haloSizes,
+ ::mlir::ArrayRef<::mlir::OpFoldResult> shardedDimsOffsets) {
mlir::SmallVector<int64_t> staticHalos, staticDims;
mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
- dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
- dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
+ dispatchIndexOpFoldResults(haloSizes, dynamicHalos, staticHalos);
+ dispatchIndexOpFoldResults(shardedDimsOffsets, dynamicDims, staticDims);
return build(
- b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes),
+ b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), splitAxes),
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos,
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
}
@@ -576,7 +575,7 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return failure();
}
if (mlir::ShapedType::isDynamicShape(grid->getShape()) &&
- getStaticShardedDimsOffsets().size() > 0) {
+ !getStaticShardedDimsOffsets().empty()) {
return emitError() << "sharded dims offsets are not allowed for "
"device grids with dynamic shape.";
}
@@ -650,14 +649,14 @@ public:
if (dynamicOffs.empty() && !staticOffs.empty()) {
assert(staticOffs.size() >= 2);
auto diff = staticOffs[1] - staticOffs[0];
- bool all_same = staticOffs.size() > 2;
+ bool allSame = staticOffs.size() > 2;
for (auto i = 2u; i < staticOffs.size(); ++i) {
if (staticOffs[i] - staticOffs[i - 1] != diff) {
- all_same = false;
+ allSame = false;
break;
}
}
- if (all_same) {
+ if (allSame) {
staticOffs.clear();
modified = true;
}
@@ -749,7 +748,7 @@ bool Sharding::operator==(const Sharding &rhs) const {
bool Sharding::operator!=(const Sharding &rhs) const { return !(*this == rhs); }
-Sharding::Sharding(::mlir::FlatSymbolRefAttr grid_) : grid(grid_) {}
+Sharding::Sharding(::mlir::FlatSymbolRefAttr grid) : grid(grid) {}
Sharding::Sharding(Value rhs) {
auto shardingOp = rhs.getDefiningOp<ShardingOp>();
@@ -767,21 +766,20 @@ Sharding::Sharding(Value rhs) {
SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
}
-Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid_,
- ArrayRef<GridAxesAttr> split_axes_,
- ArrayRef<int64_t> static_halo_sizes_,
- ArrayRef<int64_t> static_sharded_dims_offsets_,
- ArrayRef<Value> dynamic_halo_sizes_,
- ArrayRef<Value> dynamic_sharded_dims_offsets_) {
- Sharding res(grid_);
- if (split_axes_.empty()) {
+Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid,
+ ArrayRef<GridAxesAttr> splitAxes,
+ ArrayRef<int64_t> staticHaloSizes,
+ ArrayRef<int64_t> staticShardedDimsOffsets,
+ ArrayRef<Value> dynamicHaloSizes,
+ ArrayRef<Value> dynamicShardedDimsOffsets) {
+ Sharding res(grid);
+ if (splitAxes.empty()) {
return res;
}
- res.split_axes.resize(split_axes_.size());
- for (auto [i, axis] : llvm::enumerate(split_axes_)) {
- res.split_axes[i] =
- GridAxesAttr::get(grid_.getContext(), axis.asArrayRef());
+ res.split_axes.resize(splitAxes.size());
+ for (auto [i, axis] : llvm::enumerate(splitAxes)) {
+ res.split_axes[i] = GridAxesAttr::get(grid.getContext(), axis.asArrayRef());
}
auto clone = [](const auto src, auto &dst) {
@@ -789,10 +787,10 @@ Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid_,
llvm::copy(src, dst.begin());
};
- clone(static_halo_sizes_, res.static_halo_sizes);
- clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
- clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
- clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
+ clone(staticHaloSizes, res.static_halo_sizes);
+ clone(staticShardedDimsOffsets, res.static_sharded_dims_offsets);
+ clone(dynamicHaloSizes, res.dynamic_halo_sizes);
+ clone(dynamicShardedDimsOffsets, res.dynamic_sharded_dims_offsets);
return res;
}
@@ -809,10 +807,10 @@ void ShardShapeOp::getAsmResultNames(
void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState,
::llvm::ArrayRef<int64_t> dims,
- ArrayRef<Value> dims_dyn, ::mlir::Value sharding,
+ ArrayRef<Value> dimsDyn, ::mlir::Value sharding,
::mlir::ValueRange device) {
SmallVector<mlir::Type> resType(dims.size(), odsBuilder.getIndexType());
- build(odsBuilder, odsState, resType, dims, dims_dyn, sharding,
+ build(odsBuilder, odsState, resType, dims, dimsDyn, sharding,
SmallVector<int64_t>(device.size(), ShapedType::kDynamic), device);
}
diff --git a/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
index 3bfbf373..f954131 100644
--- a/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
@@ -184,7 +184,7 @@ ReshardingRquirementKind getReshardingRquirementKind(
for (auto [result, sharding] :
llvm::zip_equal(op->getResults(), resultShardings)) {
- for (auto user : result.getUsers()) {
+ for (auto *user : result.getUsers()) {
ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
if (!shardOp) {
continue;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index ae7eef2..9db9814 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -1365,8 +1365,8 @@ public:
arith::SubIOp::create(rewriter, loc, capacity, newSize);
Value fillValue = constantZero(rewriter, loc, value.getType());
Value subBuffer = memref::SubViewOp::create(
- rewriter, loc, newBuffer, /*offset=*/ValueRange{newSize},
- /*size=*/ValueRange{fillSize},
+ rewriter, loc, newBuffer, /*offsets=*/ValueRange{newSize},
+ /*sizes=*/ValueRange{fillSize},
/*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
linalg::FillOp::create(rewriter, loc, fillValue, subBuffer);
}
@@ -1386,8 +1386,8 @@ public:
memref::StoreOp::create(rewriter, loc, value, buffer, size);
} else {
Value subBuffer = memref::SubViewOp::create(
- rewriter, loc, buffer, /*offset=*/ValueRange{size},
- /*size=*/ValueRange{n},
+ rewriter, loc, buffer, /*offsets=*/ValueRange{size},
+ /*sizes=*/ValueRange{n},
/*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
linalg::FillOp::create(rewriter, loc, value, subBuffer);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index febec6d..23436a6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -132,8 +132,8 @@ static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem,
SmallVector<Value> scalarArgs(idxs);
Value indexVec = idxs.back();
scalarArgs.back() = constantIndex(rewriter, loc, 0);
- vector::ScatterOp::create(rewriter, loc, mem, scalarArgs, indexVec, vmask,
- rhs);
+ vector::ScatterOp::create(rewriter, loc, /*resultType=*/nullptr, mem,
+ scalarArgs, indexVec, vmask, rhs);
return;
}
vector::MaskedStoreOp::create(rewriter, loc, mem, idxs, vmask, rhs);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
index ffa8b40..9904803 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
@@ -80,6 +80,53 @@ inline static bool includesDenseOutput(SortMask mask) {
return includesAny(mask, SortMask::kIncludeDenseOutput);
}
+/// Returns a sparsity rank for loop ordering: lower values indicate
+/// dimensions that should be placed in outer loops.
+/// 0 = Dense, 1 = Compressed, 2 = Singleton, 3 = Other/Unknown.
+static unsigned getLoopSparsityRank(unsigned loop, ArrayRef<Value> allTensors,
+ ArrayRef<AffineMap> allMaps) {
+ // Start with highest rank.
+ unsigned minRank = 3;
+
+ for (auto [tensor, map] : llvm::zip(allTensors, allMaps)) {
+ // Check if this loop accesses this tensor.
+ bool loopAccessesTensor = false;
+ unsigned tensorDim = 0;
+ for (AffineExpr expr : map.getResults()) {
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+ if (dimExpr.getPosition() == loop) {
+ loopAccessesTensor = true;
+ break;
+ }
+ }
+ tensorDim++;
+ }
+
+ if (loopAccessesTensor) {
+ const auto enc = getSparseTensorEncoding(tensor.getType());
+ if (!enc) {
+ // Dense tensor - lowest rank.
+ return 0;
+ } else {
+ // Sparse tensor - check the level type for this dimension.
+ auto lvlTypes = enc.getLvlTypes();
+ if (tensorDim < lvlTypes.size()) {
+ auto lvlType = lvlTypes[tensorDim];
+ if (isDenseLT(lvlType)) {
+ return 0; // Dense level.
+ } else if (isCompressedLT(lvlType)) {
+ minRank = std::min(minRank, 1u); // Compressed level.
+ } else if (isSingletonLT(lvlType)) {
+ minRank = std::min(minRank, 2u); // Singleton level.
+ }
+ }
+ }
+ }
+ }
+
+ return minRank;
+}
+
AffineMap IterationGraphSorter::topoSort() {
// The sorted result will put the first Reduction iterator to the
// latest possible position.
@@ -107,10 +154,33 @@ AffineMap IterationGraphSorter::topoSort() {
case sparse_tensor::LoopOrderingStrategy::kDefault:
src = it.back();
break;
+ case sparse_tensor::LoopOrderingStrategy::kDenseOuter: {
+ // Prefer dense, then compressed, then singleton dimensions outermost.
+ // Create combined tensor and map lists for analysis.
+ SmallVector<Value> allTensors = ins;
+ allTensors.push_back(out);
+ SmallVector<AffineMap> allMaps = loop2InsLvl;
+ allMaps.push_back(loop2OutLvl);
+
+ // Find loop with minimum (lowest) sparsity rank.
+ unsigned minLoop = it[0];
+ unsigned minRank = getLoopSparsityRank(minLoop, allTensors, allMaps);
+
+ for (auto candidateLoop : it) {
+ unsigned rank = getLoopSparsityRank(candidateLoop, allTensors, allMaps);
+ if (rank < minRank || (rank == minRank && candidateLoop < minLoop)) {
+ minLoop = candidateLoop;
+ minRank = rank;
+ }
+ }
+ src = minLoop;
+ break;
+ }
}
loopOrder.push_back(src);
- it.pop_back();
+ // Remove the selected loop from the worklist.
+ it.erase(std::find(it.begin(), it.end(), src));
// Update in-degree, and push 0-degree node into worklist.
for (unsigned dst = 0; dst < numLoops; dst++) {
if (itGraph[src][dst] && --inDegree[dst] == 0) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 3636f3f..46378b9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -197,7 +197,7 @@ public:
// Sets the iterate to the specified position.
void seek(ValueRange vals) {
assert(vals.size() == cursorValsCnt);
- std::copy(vals.begin(), vals.end(), cursorValsStorageRef.begin());
+ llvm::copy(vals, cursorValsStorageRef.begin());
// Now that the iterator is re-positioned, the coordinate becomes invalid.
crd = nullptr;
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index 4ec13e1..686f6ee 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -77,6 +77,9 @@ namespace {
struct ReifyExpandShapeOp
: public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
ExpandShapeOp> {
+ using Base =
+ ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
+ ExpandShapeOp>;
LogicalResult
reifyResultShapes(Operation *op, OpBuilder &b,
ReifiedRankedShapedTypeDims &reifyResultShapes) const {
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 110bfdc..204e9bb 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -551,9 +551,7 @@ void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
auto tensorTypes =
- llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
- return llvm::cast<RankedTensorType>(type);
- }));
+ llvm::map_to_vector<4>(inputTypes, llvm::CastTo<RankedTensorType>);
int64_t concatRank = tensorTypes[0].getRank();
// The concatenation dim must be in the range [0, rank).
@@ -2293,9 +2291,9 @@ void ExtractSliceOp::getAsmResultNames(
/// An extract_slice result type can be inferred, when it is not
/// rank-reduced, from the source type and the static representation of
/// offsets, sizes and strides. Special sentinels encode the dynamic case.
-RankedTensorType ExtractSliceOp::inferResultType(
- RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
+RankedTensorType
+ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
+ ArrayRef<int64_t> staticSizes) {
// An extract_slice op may specify only a leading subset of offset/sizes/
// strides in which case we complete with offset=0, sizes from memref type
// and strides=1.
@@ -2307,11 +2305,12 @@ RankedTensorType ExtractSliceOp::inferResultType(
}
// TODO: This uses neither offsets nor strides!
-RankedTensorType ExtractSliceOp::inferResultType(
- RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
+RankedTensorType
+ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
+ ArrayRef<OpFoldResult> sizes) {
SmallVector<int64_t> staticSizes;
std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes);
+
assert(static_cast<int64_t>(staticSizes.size()) ==
sourceTensorType.getRank() &&
"unexpected staticSizes not equal to rank of source");
@@ -2329,11 +2328,10 @@ RankedTensorType ExtractSliceOp::inferResultType(
/// To disambiguate, this function always drops the first 1 sizes occurrences.
RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
- ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
- ArrayRef<int64_t> strides) {
+ ArrayRef<int64_t> sizes) {
// Type inferred in the absence of rank-reducing behavior.
auto inferredType = llvm::cast<RankedTensorType>(
- inferResultType(sourceRankedTensorType, offsets, sizes, strides));
+ inferResultType(sourceRankedTensorType, sizes));
int rankDiff = inferredType.getRank() - desiredResultRank;
if (rankDiff > 0) {
auto shape = inferredType.getShape();
@@ -2352,16 +2350,12 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
- ArrayRef<OpFoldResult> strides) {
- SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
- SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
- dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+ ArrayRef<OpFoldResult> sizes) {
+ SmallVector<int64_t> staticSizes;
+ SmallVector<Value> dynamicSizes;
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
- dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
return ExtractSliceOp::inferCanonicalRankReducedResultType(
- desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
- staticStrides);
+ desiredResultRank, sourceRankedTensorType, staticSizes);
}
/// Build an ExtractSliceOp with mixed static and dynamic entries and custom
@@ -2380,8 +2374,8 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
// Structuring implementation this way avoids duplication between builders.
if (!resultType) {
- resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
- sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
+ resultType = llvm::cast<RankedTensorType>(
+ ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes));
}
result.addAttributes(attrs);
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
@@ -2451,13 +2445,26 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
}
}
+/// Build an ExtractSliceOp with mixed static and dynamic sizes, inferred
+/// result type, offsets set to 0 and strides set to 1.
+void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
+ RankedTensorType resultType, Value source,
+ ArrayRef<OpFoldResult> sizes,
+ ArrayRef<NamedAttribute> attrs) {
+ Attribute zeroIdxAttr = b.getIndexAttr(0);
+ Attribute oneIdxAttr = b.getIndexAttr(1);
+ SmallVector<OpFoldResult> readStrides(sizes.size(), oneIdxAttr);
+ SmallVector<OpFoldResult> readOffsets(sizes.size(), zeroIdxAttr);
+ build(b, result, resultType, source, readOffsets, sizes, readStrides, attrs);
+}
+
/// Verifier for ExtractSliceOp.
LogicalResult ExtractSliceOp::verify() {
RankedTensorType sourceType = getSourceType();
// Verify result type against inferred type.
- RankedTensorType expectedType = ExtractSliceOp::inferResultType(
- sourceType, getMixedOffsets(), getMixedSizes(), getMixedStrides());
+ RankedTensorType expectedType =
+ ExtractSliceOp::inferResultType(sourceType, getMixedSizes());
SliceVerificationResult result = isRankReducedType(expectedType, getType());
if (result != SliceVerificationResult::Success)
return produceSliceErrorMsg(result, *this, expectedType);
@@ -2697,8 +2704,7 @@ struct SliceReturnTypeCanonicalizer {
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
return ExtractSliceOp::inferCanonicalRankReducedResultType(
- op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
- mixedStrides);
+ op.getType().getRank(), op.getSourceType(), mixedSizes);
}
};
@@ -2839,8 +2845,8 @@ static SliceVerificationResult verifyInsertSliceOp(
ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
// insert_slice is the inverse of extract_slice, use the same type
// inference.
- RankedTensorType expected = ExtractSliceOp::inferResultType(
- dstType, staticOffsets, staticSizes, staticStrides);
+ RankedTensorType expected =
+ ExtractSliceOp::inferResultType(dstType, staticSizes);
if (expectedType)
*expectedType = expected;
return isRankReducedType(expected, srcType);
@@ -2968,7 +2974,7 @@ public:
// Create the new op in canonical form.
auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
- mixedOffsets, mixedSizes, mixedStrides);
+ mixedSizes);
Value toInsert = insertSliceOp.getSource();
if (sourceType != insertSliceOp.getSourceType()) {
OpBuilder::InsertionGuard g(rewriter);
@@ -3896,6 +3902,18 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
}
+// Build an InsertSliceOp with mixed static and dynamic sizes, offsets set
+// to 0, strides set to 1 and inferred result type.
+void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
+ Value dest, ArrayRef<OpFoldResult> sizes,
+ ArrayRef<NamedAttribute> attrs) {
+ Attribute zeroIdxAttr = b.getIndexAttr(0);
+ Attribute oneIdxAttr = b.getIndexAttr(1);
+ SmallVector<OpFoldResult> writeStrides(sizes.size(), oneIdxAttr);
+ SmallVector<OpFoldResult> writeOffsets(sizes.size(), zeroIdxAttr);
+ build(b, result, source, dest, writeOffsets, sizes, writeStrides, attrs);
+}
+
LogicalResult ParallelInsertSliceOp::verify() {
if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
return this->emitError("expected InParallelOpInterface parent, got:")
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index c607ece..310e725 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1132,35 +1132,22 @@ struct ConcatOpInterface
// Extract the dimension for the concat op
uint64_t concatDim = concatOp.getDim();
- bool dynamicConcatDim = false;
SmallVector<OpFoldResult> offsets(tensorType.getRank(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(tensorType.getRank(),
rewriter.getIndexAttr(1));
- SmallVector<OpFoldResult> sizes;
-
- for (const auto &[dimIdx, dimSize] :
- llvm::enumerate(tensorType.getShape())) {
- if (dimSize == ShapedType::kDynamic) {
- auto dimOp = memref::DimOp::create(rewriter, loc, dstBuffer, dimIdx);
- sizes.push_back(dimOp.getResult());
- if (dimIdx == concatDim)
- dynamicConcatDim = true;
- } else {
- sizes.push_back(rewriter.getIndexAttr(dimSize));
- }
- }
-
- int64_t concatDimOffset = 0;
- std::optional<Value> dynamicOffset;
- std::optional<Value> dynamicSize;
- if (dynamicConcatDim) {
- // One or more operands have dynamic size, so we must accumulate the
- // offset with arith ops.
- dynamicOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
- }
+ SmallVector<OpFoldResult> sizes =
+ memref::getMixedSizes(rewriter, loc, dstBuffer);
+
+ AffineExpr s0, s1;
+ bindSymbols(rewriter.getContext(), s0, s1);
+ auto sum = [&](OpFoldResult v1, OpFoldResult v2) {
+ return affine::makeComposedFoldedAffineApply(rewriter, loc, s0 + s1,
+ {v1, v2});
+ };
+ OpFoldResult concatDimOffset = rewriter.getIndexAttr(0);
for (auto operand : concatOp.getInputs()) {
// Get the buffer for the operand.
FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options, state);
@@ -1171,18 +1158,10 @@ struct ConcatOpInterface
// so the offset on that axis must accumulate through the loop, and the
// size must change to the size of the current operand.
auto operandTensorType = cast<RankedTensorType>(operand.getType());
- int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);
-
- if (dynamicConcatDim) {
- offsets[concatDim] = dynamicOffset.value();
- dynamicSize =
- memref::DimOp::create(rewriter, loc, *srcBuffer, concatDim)
- .getResult();
- sizes[concatDim] = dynamicSize.value();
- } else {
- sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
- offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
- }
+ offsets[concatDim] = concatDimOffset;
+ OpFoldResult concatDimSize =
+ memref::getMixedSize(rewriter, loc, *srcBuffer, concatDim);
+ sizes[concatDim] = concatDimSize;
// Create a subview of the destination buffer.
auto dstMemrefType = cast<MemRefType>(memrefType);
@@ -1197,12 +1176,7 @@ struct ConcatOpInterface
if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
return failure();
- if (dynamicConcatDim) {
- dynamicOffset = arith::AddIOp::create(
- rewriter, loc, dynamicOffset.value(), dynamicSize.value());
- } else {
- concatDimOffset += operandConcatDimSize;
- }
+ concatDimOffset = sum(concatDimOffset, concatDimSize);
}
replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index 7ec61c7..a53af98 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -37,8 +37,7 @@ struct FoldExpandOfRankReducingExtract
// supported. Moreover, only simple cases where the resulting ExtractSliceOp
// has no rank-reduction anymore are supported at the moment.
RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
- srcType, extractSliceOp.getStaticOffsets(),
- extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
+ srcType, extractSliceOp.getStaticSizes());
if (nonReducingExtractType != resultType)
return failure();
@@ -533,8 +532,8 @@ LogicalResult mlir::tensor::getCollapsedExtractSliceInfo(
getMixedSizes(b, loc, sliceOp.getSource());
// Helper variables and function for accumulating the size values.
- AffineExpr d0, d1, d2;
- bindDims(b.getContext(), d0, d1, d2);
+ AffineExpr d0, d1;
+ bindDims(b.getContext(), d0, d1);
// Multiply two integers.
auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
auto mulMap = AffineMap::get(2, 0, {d0 * d1});
diff --git a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
index 753cb95..d35f458 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
@@ -155,13 +155,15 @@ struct ExtractSliceOpInterface
RankedTensorType sourceType = extractSliceOp.getSource().getType();
// For each dimension, assert that:
- // 0 <= offset < dim_size
- // 0 <= offset + (size - 1) * stride < dim_size
+ // For empty slices (size == 0) : 0 <= offset <= dim_size
+ // For non-empty slices (size > 0): 0 <= offset < dim_size
+ // 0 <= offset + (size - 1) * stride <
+ // dim_size
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
Value one = arith::ConstantIndexOp::create(builder, loc, 1);
for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
- // Reset insertion point to before the operation for each dimension
+
builder.setInsertionPoint(extractSliceOp);
Value offset = getValueOrCreateConstantIndexOp(
@@ -170,46 +172,63 @@ struct ExtractSliceOpInterface
builder, loc, extractSliceOp.getMixedSizes()[i]);
Value stride = getValueOrCreateConstantIndexOp(
builder, loc, extractSliceOp.getMixedStrides()[i]);
-
- // Verify that offset is in-bounds.
Value dimSize = builder.createOrFold<tensor::DimOp>(
loc, extractSliceOp.getSource(), i);
- Value offsetInBounds =
- generateInBoundsCheck(builder, loc, offset, zero, dimSize);
- cf::AssertOp::create(builder, loc, offsetInBounds,
+
+ // Verify that offset is in-bounds (conditional on slice size).
+ Value sizeIsZero = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::eq, size, zero);
+ auto offsetCheckIf = scf::IfOp::create(
+ builder, loc, sizeIsZero,
+ [&](OpBuilder &b, Location loc) {
+ // For empty slices, offset can be at the boundary: 0 <= offset <=
+ // dimSize.
+ Value offsetGEZero = arith::CmpIOp::create(
+ b, loc, arith::CmpIPredicate::sge, offset, zero);
+ Value offsetLEDimSize = arith::CmpIOp::create(
+ b, loc, arith::CmpIPredicate::sle, offset, dimSize);
+ Value emptyOffsetValid =
+ arith::AndIOp::create(b, loc, offsetGEZero, offsetLEDimSize);
+ scf::YieldOp::create(b, loc, emptyOffsetValid);
+ },
+ [&](OpBuilder &b, Location loc) {
+ // For non-empty slices, offset must be a valid index: 0 <= offset <
+ // dimSize.
+ Value offsetInBounds =
+ generateInBoundsCheck(b, loc, offset, zero, dimSize);
+ scf::YieldOp::create(b, loc, offsetInBounds);
+ });
+
+ Value offsetCondition = offsetCheckIf.getResult(0);
+ cf::AssertOp::create(builder, loc, offsetCondition,
generateErrorMessage(op, "offset " +
std::to_string(i) +
" is out-of-bounds"));
- // Only verify if size > 0
+ // Verify that the slice endpoint is in-bounds (only for non-empty
+ // slices).
Value sizeIsNonZero = arith::CmpIOp::create(
builder, loc, arith::CmpIPredicate::sgt, size, zero);
+ auto ifOp = scf::IfOp::create(
+ builder, loc, sizeIsNonZero,
+ [&](OpBuilder &b, Location loc) {
+ // Verify that slice does not run out-of-bounds.
+ Value sizeMinusOne = arith::SubIOp::create(b, loc, size, one);
+ Value sizeMinusOneTimesStride =
+ arith::MulIOp::create(b, loc, sizeMinusOne, stride);
+ Value lastPos =
+ arith::AddIOp::create(b, loc, offset, sizeMinusOneTimesStride);
+ Value lastPosInBounds =
+ generateInBoundsCheck(b, loc, lastPos, zero, dimSize);
+ scf::YieldOp::create(b, loc, lastPosInBounds);
+ },
+ [&](OpBuilder &b, Location loc) {
+ Value trueVal =
+ arith::ConstantOp::create(b, loc, b.getBoolAttr(true));
+ scf::YieldOp::create(b, loc, trueVal);
+ });
- auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(),
- sizeIsNonZero, /*withElseRegion=*/true);
-
- // Populate the "then" region (for size > 0).
- builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-
- // Verify that slice does not run out-of-bounds.
- Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
- Value sizeMinusOneTimesStride =
- arith::MulIOp::create(builder, loc, sizeMinusOne, stride);
- Value lastPos =
- arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
- Value lastPosInBounds =
- generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
- scf::YieldOp::create(builder, loc, lastPosInBounds);
-
- // Populate the "else" region (for size == 0).
- builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- Value trueVal =
- arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true));
- scf::YieldOp::create(builder, loc, trueVal);
-
- builder.setInsertionPointAfter(ifOp);
Value finalCondition = ifOp.getResult(0);
-
cf::AssertOp::create(
builder, loc, finalCondition,
generateErrorMessage(
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 293c6af..c420a4c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
@@ -539,7 +540,7 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
- inputEType = quantType.getStorageType();
+ inputEType = getStorageElementTypeFromQuantized(quantType);
}
Attribute newMinValAttr, newMaxValAttr;
@@ -1485,7 +1486,24 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
return {};
}
+static bool
+mayRequireBroadcast(ValueTypeRange<mlir::OperandRange> operandTypes) {
+ const auto isDynamic = [](Type ty) {
+ const auto shapedTy = llvm::dyn_cast<ShapedType>(ty);
+ return !shapedTy || !shapedTy.hasStaticShape();
+ };
+
+ return llvm::any_of(operandTypes, isDynamic) ||
+ failed(verifyCompatibleShapes(operandTypes));
+}
+
OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
+ // Select allows operand shapes to be broadcast to the output shape. For
+ // now, don't support folding when we cannot prove no broadcasting is
+ // involved.
+ if (mayRequireBroadcast(getOperandTypes()))
+ return {};
+
if (getOnTrue() == getOnFalse())
return getOnTrue();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 65e0a59..1c175f9ab 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -563,7 +563,7 @@ static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
static Type getStorageElementTypeOrSelf(Type type) {
auto srcType = getElementTypeOrSelf(type);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
- srcType = quantType.getStorageType();
+ srcType = getStorageElementTypeFromQuantized(quantType);
return srcType;
}
@@ -631,16 +631,16 @@ static LogicalResult verifyConvOp(T op) {
bool resultIsFloat = llvm::isa<FloatType>(resultEType);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
- inputEType = quantType.getStorageType();
+ inputEType = getStorageElementTypeFromQuantized(quantType);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
- weightEType = quantType.getStorageType();
+ weightEType = getStorageElementTypeFromQuantized(quantType);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
- biasEType = quantType.getStorageType();
+ biasEType = getStorageElementTypeFromQuantized(quantType);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
- resultEType = quantType.getStorageType();
+ resultEType = getStorageElementTypeFromQuantized(quantType);
if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
// for now, only enforce bias element type == result element type for
@@ -709,7 +709,7 @@ LogicalResult tosa::ConstOp::verify() {
if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
outputType.getElementType())) {
- if (result.getStorageType() == attrType.getElementType())
+ if (getStorageElementTypeFromQuantized(result) == attrType.getElementType())
return success();
}
@@ -727,7 +727,7 @@ static LogicalResult verifyConvOpModes(T op) {
llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
- inputEType = quantType.getStorageType();
+ inputEType = getStorageElementTypeFromQuantized(quantType);
auto accType = op.getAccType();
if (inputEType.isInteger(8) && !accType.isInteger(32))
@@ -752,7 +752,7 @@ static LogicalResult verifyConvOpModes(T op) {
llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
- resultEType = quantType.getStorageType();
+ resultEType = getStorageElementTypeFromQuantized(quantType);
return success();
}
@@ -1179,13 +1179,13 @@ LogicalResult tosa::ClampOp::verify() {
llvm::cast<ShapedType>(getInput().getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
- inputETy = quantType.getStorageType();
+ inputETy = getStorageElementTypeFromQuantized(quantType);
}
mlir::Type outputETy =
llvm::cast<ShapedType>(getOutput().getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
- outputETy = quantType.getStorageType();
+ outputETy = getStorageElementTypeFromQuantized(quantType);
}
if (inputETy != outputETy)
return emitOpError("input/output element types are incompatible.");
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 41b338d..091b481 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRTosaTransforms
TosaAttachTarget.cpp
+ TosaArithConstantToConst.cpp
TosaConvertIntegerTypeToSignless.cpp
TosaDecomposeTransposeConv.cpp
TosaDecomposeDepthwise.cpp
@@ -12,6 +13,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
TosaTypeConverters.cpp
TosaProfileCompliance.cpp
TosaValidation.cpp
+ TosaNarrowI64ToI32.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms
@@ -21,7 +23,9 @@ add_mlir_dialect_library(MLIRTosaTransforms
LINK_LIBS PUBLIC
MLIRFuncDialect
+ MLIRFuncTransformOps
MLIRPass
MLIRTosaDialect
MLIRTransformUtils
+ MLIRFuncTransforms
)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp
new file mode 100644
index 0000000..73e1e2b
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp
@@ -0,0 +1,111 @@
+//===- TosaArithConstantToConst.cpp ---------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass that converts tensor-valued arith.constant ops
+// into tosa.const so that TOSA pipelines operate on a uniform constant form.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSAARITHCONSTANTTOTOSACONSTPASS
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+// NOTE: TOSA pipelines already lower their constants through shared Arith
+// folding passes, so tensor literals often come back as `arith.constant` even
+// after the IR is otherwise TOSA-only. Keep this normalization with the rest of
+// the TOSA transforms so any client can re-establish a canonical `tosa.const`
+// representation without needing a full Arith->TOSA conversion library.
+
+/// Returns true when `elementType` is natively representable by tosa.const.
+static bool isSupportedElementType(Type elementType) {
+ if (isa<FloatType>(elementType))
+ return true;
+
+ if (auto intType = dyn_cast<IntegerType>(elementType))
+ return intType.isSignless() || intType.isUnsigned();
+
+ if (isa<quant::QuantizedType>(elementType))
+ return true;
+
+ if (isa<tosa::mxint8Type>(elementType))
+ return true;
+
+ return false;
+}
+
+class ArithConstantToTosaConst : public OpRewritePattern<arith::ConstantOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arith::ConstantOp constOp,
+ PatternRewriter &rewriter) const override {
+ // TOSA constant verification requires a ranked, statically shaped tensor.
+ auto resultType = dyn_cast<RankedTensorType>(constOp.getResult().getType());
+ if (!resultType || !resultType.hasStaticShape())
+ return failure();
+
+ if (!isSupportedElementType(resultType.getElementType()))
+ return failure();
+
+ Attribute attr = constOp.getValueAttr();
+ auto elementsAttr = dyn_cast<ElementsAttr>(attr);
+ if (!elementsAttr)
+ return failure();
+
+ auto attrType = dyn_cast<RankedTensorType>(elementsAttr.getType());
+ if (!attrType || !attrType.hasStaticShape())
+ return failure();
+ if (attrType != resultType)
+ return failure();
+
+ auto newConst = tosa::ConstOp::create(rewriter, constOp.getLoc(),
+ resultType, elementsAttr);
+ rewriter.replaceOp(constOp, newConst.getResult());
+ return success();
+ }
+};
+
+struct TosaArithConstantToTosaConstPass
+ : public tosa::impl::TosaArithConstantToTosaConstPassBase<
+ TosaArithConstantToTosaConstPass> {
+ using Base::Base;
+
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<arith::ArithDialect, tosa::TosaDialect>();
+ }
+
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ RewritePatternSet patterns(ctx);
+ patterns.add<ArithConstantToTosaConst>(ctx);
+
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index 0bec0da..022476a2 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -33,8 +33,13 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
ShapedType weightType = cast<ShapedType>(weight.getType());
ShapedType resultType = cast<ShapedType>(op.getOutput().getType());
- if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
- resultType.hasStaticShape())) {
+ // Any dimensions other than batchSize cannot be dynamic for input/output
+ for (unsigned int i = 1; i < 4; ++i) {
+ if (inputType.isDynamicDim(i) || resultType.isDynamicDim(i))
+ return failure();
+ }
+
+ if (!weightType.hasStaticShape()) {
return failure();
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index dc5c51b..8b23fd1 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -49,8 +49,13 @@ public:
if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
return failure();
- if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
- !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+ // Any dimensions other than batchSize cannot be dynamic for input/output
+ for (unsigned int i = 1; i < 4; ++i) {
+ if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i))
+ return failure();
+ }
+
+ if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
return failure();
int64_t kernelHeight = weightTy.getDimSize(1);
@@ -113,8 +118,13 @@ public:
if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
return rewriter.notifyMatchFailure(op, "non-one stride found.");
- if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
- !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+ // Any dimensions other than batchSize cannot be dynamic for input/output
+ for (unsigned int i = 1; i < 4; ++i) {
+ if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i))
+ return failure();
+ }
+
+ if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
return failure();
int64_t batch = inputTy.getDimSize(0);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp
new file mode 100644
index 0000000..ddaf7d8a
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp
@@ -0,0 +1,310 @@
+//===- TosaNarrowI64ToI32.cpp ---------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass narrows TOSA operations with 64-bit integer tensor types to
+// 32-bit integer tensor types. This can be useful for backends that do not
+// support the EXT-INT64 extension of TOSA. The pass has two options:
+//
+// - aggressive-rewrite - If enabled, all TOSA operations are rewritten,
+// regardless or whether the narrowing is safe. This option may lead to
+// data loss if not used carefully.
+// - convert-function-boundaries - If enabled, the pass will convert function
+// I/O types as well. Otherwise casts will be inserted at the I/O
+// boundaries.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+LogicalResult convertGenericOp(Operation *op, ValueRange operands,
+ ConversionPatternRewriter &rewriter,
+ const TypeConverter *typeConverter) {
+ // Convert types of results
+ SmallVector<Type, 4> newResults;
+ if (failed(typeConverter->convertTypes(op->getResultTypes(), newResults)))
+ return failure();
+
+ // Create a new operation state
+ OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
+ newResults, {}, op->getSuccessors());
+
+ for (const NamedAttribute &namedAttribute : op->getAttrs()) {
+ const Attribute attribute = namedAttribute.getValue();
+
+ // Convert integer attribute type
+ if (const auto intAttr = dyn_cast<IntegerAttr>(attribute)) {
+ const std::optional<Attribute> convertedAttribute =
+ typeConverter->convertTypeAttribute(intAttr.getType(), attribute);
+ state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
+ continue;
+ }
+
+ if (const auto typeAttr = dyn_cast<TypeAttr>(attribute)) {
+ Type type = typeAttr.getValue();
+ const std::optional<Attribute> convertedAttribute =
+ typeConverter->convertTypeAttribute(type, attribute);
+ if (!convertedAttribute)
+ return rewriter.notifyMatchFailure(op,
+ "Failed to convert type attribute.");
+ state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
+ continue;
+ }
+
+ if (const auto denseElementsAttr = dyn_cast<DenseElementsAttr>(attribute)) {
+ const Type type = denseElementsAttr.getType();
+ const std::optional<Attribute> convertedAttribute =
+ typeConverter->convertTypeAttribute(type, denseElementsAttr);
+ if (!convertedAttribute)
+ return rewriter.notifyMatchFailure(
+ op, "Failed to convert dense elements attribute.");
+ state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
+ continue;
+ }
+
+ state.addAttribute(namedAttribute.getName(), attribute);
+ }
+
+ for (Region &region : op->getRegions()) {
+ Region *newRegion = state.addRegion();
+ rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
+ if (failed(rewriter.convertRegionTypes(newRegion, *typeConverter)))
+ return failure();
+ }
+
+ Operation *newOp = rewriter.create(state);
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+}
+
+// ===========================
+// Aggressive rewrite patterns
+// ===========================
+
+class ConvertGenericOp : public ConversionPattern {
+public:
+ ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context)
+ : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ if (!isa<tosa::TosaOp>(op))
+ return rewriter.notifyMatchFailure(
+ op,
+ "Support for operations other than TOSA has not been implemented.");
+
+ return convertGenericOp(op, operands, rewriter, typeConverter);
+ }
+};
+
+// ===============================
+// Bounds checked rewrite patterns
+// ===============================
+
+class ConvertArgMaxOpWithBoundsChecking
+ : public OpConversionPattern<tosa::ArgMaxOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(tosa::ArgMaxOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ // Output type can be narrowed based on the size of the axis dimension
+ const int32_t axis = op.getAxis();
+ const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
+ if (!inputType || !inputType.isStaticDim(axis))
+ return rewriter.notifyMatchFailure(
+ op, "Requires a static axis dimension for bounds checking.");
+ const int64_t axisDim = inputType.getDimSize(axis);
+ if (axisDim >= std::numeric_limits<int32_t>::max())
+ return rewriter.notifyMatchFailure(
+ op, "Axis dimension is too large to narrow safely.");
+
+ const Type resultType = op.getOutput().getType();
+ const Type newResultType = typeConverter->convertType(resultType);
+ rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(op, newResultType,
+ adaptor.getInput(), axis);
+ return success();
+ }
+};
+
+class ConvertCastOpWithBoundsChecking
+ : public OpConversionPattern<tosa::CastOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(tosa::CastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
+ const auto resultType = dyn_cast<ShapedType>(op.getResult().getType());
+ if (!inputType || !resultType)
+ return failure();
+
+ const auto elementInputIntType =
+ dyn_cast<IntegerType>(inputType.getElementType());
+ const auto elementResultIntType =
+ dyn_cast<IntegerType>(resultType.getElementType());
+ if (elementInputIntType && elementResultIntType &&
+ elementInputIntType.getWidth() > elementResultIntType.getWidth())
+ return rewriter.notifyMatchFailure(
+ op, "Narrowing cast may lead to data loss.");
+
+ rewriter.replaceOpWithNewOp<tosa::CastOp>(
+ op, typeConverter->convertType(resultType), adaptor.getInput());
+ return success();
+ }
+};
+
+template <typename OpTy>
+class ConvertTypedOp : public OpConversionPattern<OpTy> {
+ using OpConversionPattern<OpTy>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ return convertGenericOp(op, adaptor.getOperands(), rewriter,
+ this->getTypeConverter());
+ }
+};
+
+struct TosaNarrowI64ToI32
+ : public tosa::impl::TosaNarrowI64ToI32PassBase<TosaNarrowI64ToI32> {
+public:
+ explicit TosaNarrowI64ToI32() = default;
+ explicit TosaNarrowI64ToI32(const TosaNarrowI64ToI32PassOptions &options)
+ : TosaNarrowI64ToI32() {
+ this->aggressiveRewrite = options.aggressiveRewrite;
+ this->convertFunctionBoundaries = options.convertFunctionBoundaries;
+ }
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+
+ TypeConverter typeConverter;
+ typeConverter.addConversion([](Type type) -> Type { return type; });
+ typeConverter.addConversion([](IntegerType type) -> Type {
+ if (!type.isInteger(64))
+ return type;
+ return IntegerType::get(type.getContext(), 32);
+ });
+ typeConverter.addConversion(
+ [&typeConverter](RankedTensorType type) -> Type {
+ const Type elementType = type.getElementType();
+ if (!elementType.isInteger(64))
+ return type;
+ return RankedTensorType::get(type.getShape(),
+ typeConverter.convertType(elementType));
+ });
+
+ const auto materializeCast = [](OpBuilder &builder, Type resultType,
+ ValueRange inputs, Location loc) -> Value {
+ if (inputs.size() != 1)
+ return Value();
+ return tosa::CastOp::create(builder, loc, resultType, inputs.front());
+ };
+ typeConverter.addSourceMaterialization(materializeCast);
+ typeConverter.addTargetMaterialization(materializeCast);
+
+ typeConverter.addTypeAttributeConversion(
+ [](IntegerType type, IntegerAttr attribute) -> Attribute {
+ const APInt value = attribute.getValue().truncSSat(32);
+ return IntegerAttr::get(IntegerType::get(type.getContext(), 32),
+ value);
+ });
+ typeConverter.addTypeAttributeConversion(
+ [&typeConverter](ShapedType type,
+ DenseIntElementsAttr attr) -> Attribute {
+ const ShapedType newType =
+ cast<ShapedType>(typeConverter.convertType(type));
+ const auto oldElementType = cast<IntegerType>(type.getElementType());
+ const auto newElementType =
+ cast<IntegerType>(newType.getElementType());
+ if (oldElementType.getWidth() == newElementType.getWidth())
+ return attr;
+
+ DenseElementsAttr mapped =
+ attr.mapValues(newElementType, [&](const APInt &v) {
+ return v.truncSSat(newElementType.getWidth());
+ });
+ return mapped;
+ });
+
+ ConversionTarget target(*context);
+ target.addDynamicallyLegalDialect<tosa::TosaDialect>(
+ [&typeConverter](Operation *op) {
+ return typeConverter.isLegal(op->getResultTypes()) &&
+ typeConverter.isLegal(op->getOperandTypes());
+ });
+ if (convertFunctionBoundaries) {
+ target.addDynamicallyLegalOp<func::FuncOp>(
+ [&typeConverter](func::FuncOp op) {
+ return typeConverter.isSignatureLegal(op.getFunctionType()) &&
+ typeConverter.isLegal(&op.getBody());
+ });
+ target.addDynamicallyLegalOp<func::ReturnOp>([](func::ReturnOp op) {
+ const FunctionType funcType =
+ op->getParentOfType<func::FuncOp>().getFunctionType();
+ return llvm::equal(op.getOperandTypes(), funcType.getResults());
+ });
+ } else {
+ target.addDynamicallyLegalOp<func::FuncOp>(
+ [](func::FuncOp op) { return true; });
+ target.addDynamicallyLegalOp<func::ReturnOp>(
+ [](func::ReturnOp op) { return true; });
+ }
+
+ RewritePatternSet patterns(context);
+ if (convertFunctionBoundaries) {
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
+ patterns, typeConverter);
+ populateReturnOpTypeConversionPattern(patterns, typeConverter);
+ }
+ if (aggressiveRewrite) {
+ patterns.add<ConvertGenericOp>(typeConverter, context);
+ } else {
+ // Tensor
+ patterns.add<ConvertArgMaxOpWithBoundsChecking>(typeConverter, context);
+ // Data layout
+ patterns.add<ConvertTypedOp<tosa::ConcatOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::PadOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::ReshapeOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::ReverseOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::SliceOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::TileOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::TransposeOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::IdentityOp>>(typeConverter, context);
+ // Type conversion
+ patterns.add<ConvertCastOpWithBoundsChecking>(typeConverter, context);
+ // Controlflow
+ patterns.add<ConvertTypedOp<tosa::IfOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::WhileOp>>(typeConverter, context);
+ }
+
+ if (failed(
+ applyFullConversion(getOperation(), target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index ac5d620..36e8940 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -70,6 +70,8 @@ namespace {
// If lower=[a], higher=[a, a], [a] reshaped into [1, a].
// If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
// If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
+// If lower=[c], higher=[?, ?, c], [c] reshaped into [1, 1, c].
+// If lower=[?], higher=[?, ?, ?], [?] reshaped into [1, 1, ?].
LogicalResult
computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
ArrayRef<int64_t> lowerRankShape,
@@ -87,7 +89,12 @@ computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
higherRankDim = higherRankShape[i + rankDiff];
lowerRankDim = lowerRankShape[i];
- if (lowerRankDim != 1 && higherRankDim != 1 &&
+ auto isStaticDimAndNotEqualToOne = [](int64_t dim) {
+ return dim != 1 && dim != ShapedType::kDynamic;
+ };
+
+ if (isStaticDimAndNotEqualToOne(lowerRankDim) &&
+ isStaticDimAndNotEqualToOne(higherRankDim) &&
lowerRankDim != higherRankDim)
return failure();
@@ -216,22 +223,23 @@ mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) {
bool mlir::tosa::hasUniqueConstantScatterIndices(
ShapedType indicesType, DenseIntElementsAttr indicesAttr) {
- llvm::ArrayRef<int64_t> const indicesShape = indicesType.getShape();
+ const llvm::ArrayRef<int64_t> indicesShape = indicesType.getShape();
const unsigned int indicesRank = indicesShape.size();
const unsigned int lastDimSize = indicesShape[indicesRank - 1];
// check each batch of indices from the flat indicesAttr values
// for duplicates
- auto const indicesValues = indicesAttr.getValues<int32_t>();
+ auto const indicesValues = indicesAttr.getValues<APInt>();
assert(
(indicesValues.size() % lastDimSize == 0) &&
"Constant indices data length should be a multiple of indicesShape[-1]");
- std::vector<uint64_t> indices(lastDimSize);
+ std::vector<APInt> indices(lastDimSize);
for (auto beg = indicesValues.begin(); beg < indicesValues.end();
beg += lastDimSize) {
std::copy(beg, beg + lastDimSize, indices.begin());
- std::sort(indices.begin(), indices.end());
+ std::sort(indices.begin(), indices.end(),
+ [](const APInt &a, const APInt &b) { return a.slt(b); });
if (std::adjacent_find(indices.begin(), indices.end()) != indices.end()) {
// found duplicate values in indices in batch
return false;
diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
index 02c86a0..c55b13d 100644
--- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
@@ -395,3 +395,16 @@ mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype,
maxAttr, quantBits, filterQuantDim,
isSigned, narrowRange));
}
+
+Type mlir::tosa::getStorageElementTypeFromQuantized(
+ quant::QuantizedType quantType) {
+ auto quantEty = quantType.getStorageType();
+ // StorageType doesn't capture the sign information
+ // Explicitly create unsigned type if needed
+ if (!quantType.isSigned()) {
+ quantEty = IntegerType::get(quantEty.getContext(),
+ quantEty.getIntOrFloatBitWidth(),
+ IntegerType::Unsigned);
+ }
+ return quantEty;
+}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 062606e..86233b0 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -2062,6 +2062,10 @@ transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
DiagnosedSilenceableFailure result = applySequenceBlock(
callee.getBody().front(), getFailurePropagationMode(), state, results);
+
+ if (!result.succeeded())
+ return result;
+
mappings.clear();
detail::prepareValueMappings(
mappings, callee.getBody().front().getTerminator()->getOperands(), state);
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index 8859541..24b0487 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -1495,8 +1495,7 @@ transform::detail::checkApplyToOne(Operation *transformOp,
template <typename T>
static SmallVector<T> castVector(ArrayRef<transform::MappedValue> range) {
- return llvm::to_vector(llvm::map_range(
- range, [](transform::MappedValue value) { return cast<T>(value); }));
+ return llvm::map_to_vector(range, llvm::CastTo<T>);
}
void transform::detail::setApplyToOneResults(
diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
index f727118..2bd6205 100644
--- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
@@ -156,7 +156,7 @@ DiagnosedSilenceableFailure
transform::tune::AlternativesOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
- std::optional<size_t> selectedRegionIdx;
+ std::optional<int64_t> selectedRegionIdx;
if (auto selectedRegionAttr = getSelectedRegionAttr())
selectedRegionIdx = selectedRegionAttr->getSExtValue();
@@ -232,7 +232,7 @@ LogicalResult transform::tune::AlternativesOp::verify() {
}
if (auto selectedRegionAttr = getSelectedRegionAttr()) {
- size_t regionIdx = selectedRegionAttr->getSExtValue();
+ int64_t regionIdx = selectedRegionAttr->getSExtValue();
if (regionIdx < 0 || regionIdx >= getNumRegions())
return emitOpError()
<< "'selected_region' attribute specifies region at index "
diff --git a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
index a26edac..2986f4c 100644
--- a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
+++ b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
@@ -106,14 +106,12 @@ ScalableValueBoundsConstraintSet::computeScalableBound(
AffineMap bound = [&] {
if (boundType == BoundType::EQ && !invalidBound(lowerBound) &&
- lowerBound[0] == upperBound[0]) {
+ lowerBound[0] == upperBound[0])
return lowerBound[0];
- }
- if (boundType == BoundType::LB && !invalidBound(lowerBound)) {
+ if (boundType == BoundType::LB && !invalidBound(lowerBound))
return lowerBound[0];
- } else if (boundType == BoundType::UB && !invalidBound(upperBound)) {
+ if (boundType == BoundType::UB && !invalidBound(upperBound))
return upperBound[0];
- }
return AffineMap{};
}();
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index daef0ba..2789f63 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6066,19 +6066,21 @@ LogicalResult ScatterOp::verify() {
VectorType indVType = getIndexVectorType();
VectorType maskVType = getMaskVectorType();
VectorType valueVType = getVectorType();
- MemRefType memType = getMemRefType();
+ ShapedType baseType = getBaseType();
- if (valueVType.getElementType() != memType.getElementType())
+ if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
+ return emitOpError("requires base to be a memref or ranked tensor type");
+
+ if (valueVType.getElementType() != baseType.getElementType())
return emitOpError("base and valueToStore element type should match");
- if (llvm::size(getOffsets()) != memType.getRank())
- return emitOpError("requires ") << memType.getRank() << " indices";
+ if (llvm::size(getOffsets()) != baseType.getRank())
+ return emitOpError("requires ") << baseType.getRank() << " indices";
if (valueVType.getShape() != indVType.getShape())
return emitOpError("expected valueToStore dim to match indices dim");
if (valueVType.getShape() != maskVType.getShape())
return emitOpError("expected valueToStore dim to match mask dim");
return success();
}
-
namespace {
class ScatterFolder final : public OpRewritePattern<ScatterOp> {
public:
@@ -6241,6 +6243,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges.front());
}
+std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
+ return llvm::to_vector<4>(getResultVectorType().getShape());
+}
+
LogicalResult ShapeCastOp::verify() {
VectorType sourceType = getSourceVectorType();
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 546099c..352f477 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
using namespace mlir;
using namespace mlir::bufferization;
@@ -126,6 +127,54 @@ struct TransferWriteOpInterface
}
};
+/// Bufferization of vector.scatter. Replaced with a new vector.scatter that
+/// operates on a memref.
+struct ScatterOpInterface
+ : public BufferizableOpInterface::ExternalModel<ScatterOpInterface,
+ vector::ScatterOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ assert(isa<RankedTensorType>(opOperand.get().getType()) &&
+ "only tensor types expected");
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ assert(isa<RankedTensorType>(opOperand.get().getType()) &&
+ "only tensor types expected");
+ return true;
+ }
+
+ AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ assert(isa<RankedTensorType>(opOperand.get().getType()) &&
+ "only tensor types expected");
+ auto scatterOp = cast<vector::ScatterOp>(op);
+ if (&opOperand != &scatterOp.getBaseMutable())
+ return {};
+ return {{scatterOp.getResult(), BufferRelation::Equivalent}};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
+ auto scatterOp = cast<vector::ScatterOp>(op);
+ assert(isa<TensorType>(scatterOp.getBaseType()) &&
+ "only tensor types expected");
+ FailureOr<Value> buffer =
+ getBuffer(rewriter, scatterOp.getBase(), options, state);
+ if (failed(buffer))
+ return failure();
+ vector::ScatterOp::create(rewriter, scatterOp.getLoc(),
+ /*resultType=*/nullptr, *buffer,
+ scatterOp.getOffsets(), scatterOp.getIndices(),
+ scatterOp.getMask(), scatterOp.getValueToStore());
+ replaceOpWithBufferizedValues(rewriter, op, *buffer);
+ return success();
+ }
+};
+
/// Bufferization of vector.gather. Replaced with a new vector.gather that
/// operates on a memref.
struct GatherOpInterface
@@ -335,5 +384,6 @@ void mlir::vector::registerBufferizableOpInterfaceExternalModels(
GatherOp::attachInterface<GatherOpInterface>(*ctx);
MaskOp::attachInterface<MaskOpInterface>(*ctx);
YieldOp::attachInterface<YieldOpInterface>(*ctx);
+ ScatterOp::attachInterface<ScatterOpInterface>(*ctx);
});
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
index 258f2cb..1af5523 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
@@ -111,7 +111,7 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
if (!isValidKind(isInt, scanOp.getKind()))
return failure();
- VectorType resType = VectorType::get(destShape, elType);
+ VectorType resType = destType;
Value result = arith::ConstantOp::create(rewriter, loc, resType,
rewriter.getZeroAttr(resType));
int64_t reductionDim = scanOp.getReductionDim();
@@ -121,8 +121,18 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
int64_t initialValueRank = initialValueType.getRank();
SmallVector<int64_t> reductionShape(destShape);
+ SmallVector<bool> reductionScalableDims(destType.getScalableDims());
+
+ if (reductionScalableDims[reductionDim])
+ return rewriter.notifyMatchFailure(
+ scanOp, "Trying to reduce scalable dimension - not yet supported!");
+
+ // The reduction dimension, after reducing, becomes 1. It's a fixed-width
+ // dimension - no need to touch the scalability flag.
reductionShape[reductionDim] = 1;
- VectorType reductionType = VectorType::get(reductionShape, elType);
+ VectorType reductionType =
+ VectorType::get(reductionShape, elType, reductionScalableDims);
+
SmallVector<int64_t> offsets(destRank, 0);
SmallVector<int64_t> strides(destRank, 1);
SmallVector<int64_t> sizes(destShape);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 726da1e..ad16b80 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -453,6 +453,8 @@ struct ReorderCastOpsOnBroadcast
PatternRewriter &rewriter) const override {
if (op->getNumOperands() != 1)
return failure();
+ if (!isa<VectorType>(op->getResult(0).getType()))
+ return failure();
auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
if (!bcastOp)
return failure();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index fbae098..462bd8c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1003,6 +1003,286 @@ private:
vector::UnrollVectorOptions options;
};
+/// This pattern unrolls `vector.create_mask` operations into smaller mask
+/// operations based on the target unroll shape. Each unrolled slice computes
+/// its local mask size in each dimension (d) as:
+/// min(max(originalMaskSize[d] - offset[d], 0), unrolledDimSize[d]).
+/// Example:
+/// Given a create_mask operation:
+/// %0 = vector.create_mask %c6, %c10 : vector<8x16xi1> // mask first 6x10
+/// elements
+///
+/// and a target unroll shape of <4x8>, the pattern produces:
+///
+/// %false = arith.constant dense<false> : vector<8x16xi1>
+///
+/// Slice [0,0]:
+/// mask size = min(max(6-0, 0), 4) x min(max(10-0, 0), 8) = 4x8
+/// %mask00 = vector.create_mask %c4, %c8 : vector<4x8xi1>
+/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+/// Slice [0,8]:
+/// mask size = min(max(6-0, 0), 4) x min(max(10-8, 0), 8) = 4x2
+/// %mask01 = vector.create_mask %c4, %c2 : vector<4x8xi1>
+/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+/// Slice [4,0]:
+/// mask size = min(max(6-4, 0), 4) x min(max(10-0, 0), 8) = 2x8
+/// %mask10 = vector.create_mask %c2, %c8 : vector<4x8xi1>
+/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+/// Slice [4,8]:
+/// mask size = min(max(6-4, 0), 4) x min(max(10-8, 0), 8) = 2x2
+/// %mask11 = vector.create_mask %c2, %c2 : vector<4x8xi1>
+/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
+ UnrollCreateMaskPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::CreateMaskOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp,
+ PatternRewriter &rewriter) const override {
+ auto targetShape = getTargetShape(options, createMaskOp);
+ if (!targetShape)
+ return failure();
+
+ VectorType resultType = createMaskOp.getVectorType();
+ SmallVector<int64_t> originalSize = *createMaskOp.getShapeForUnroll();
+ Location loc = createMaskOp.getLoc();
+
+ Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+ rewriter.getZeroAttr(resultType));
+ VectorType targetVectorType =
+ VectorType::get(*targetShape, rewriter.getI1Type());
+ SmallVector<int64_t> strides(targetShape->size(), 1);
+
+ // In each dimension (d), each unrolled vector computes its mask size as:
+ // min(max(originalMaskOperands[d] - offset[d], 0), unrolledDimSize[d]).
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalSize, *targetShape)) {
+ SmallVector<Value> unrolledOperands;
+
+ for (auto [i, originalMaskOperand] :
+ llvm::enumerate(createMaskOp.getOperands())) {
+ Value offsetVal =
+ arith::ConstantIndexOp::create(rewriter, loc, offsets[i]);
+ Value adjustedMaskSize = rewriter.createOrFold<arith::SubIOp>(
+ loc, originalMaskOperand, offsetVal);
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value unrolledDimSize =
+ arith::ConstantIndexOp::create(rewriter, loc, (*targetShape)[i]);
+ Value nonNegative =
+ rewriter.createOrFold<arith::MaxSIOp>(loc, adjustedMaskSize, zero);
+ Value unrolledOperand = rewriter.createOrFold<arith::MinSIOp>(
+ loc, nonNegative, unrolledDimSize);
+ unrolledOperands.push_back(unrolledOperand);
+ }
+
+ auto unrolledMask = rewriter.createOrFold<vector::CreateMaskOp>(
+ loc, targetVectorType, unrolledOperands);
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, unrolledMask, result, offsets, strides);
+ }
+ rewriter.replaceOp(createMaskOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
+/// Checks whether extractShape is a contiguous slice of shape.
+/// For extractShape to be contiguous in shape:
+/// 1) All but the leading dimension of extractShape and shape must match
+/// exactly. 2) The total number of elements in shape must be evenly divisible
+/// by
+/// the total number of elements in extractShape.
+/// Examples:
+/// isContiguous([4, 4], [8, 4]) == true
+/// isContiguous([2, 4], [8, 4]) == true
+/// isContiguous([2, 2], [8, 4]) == false
+/// Removes leading unit dimensions to handle cases like:
+/// isContiguous([1, 16], [1, 32]) == true
+static bool isContiguous(ArrayRef<int64_t> extractShape,
+ ArrayRef<int64_t> shape) {
+
+ if (extractShape.size() > shape.size())
+ return false;
+
+ while (!extractShape.empty() && extractShape.front() == 1) {
+ extractShape = extractShape.drop_front();
+ }
+
+ while (!shape.empty() && shape.front() == 1) {
+ shape = shape.drop_front();
+ }
+
+ size_t rankDiff = shape.size() - extractShape.size();
+ if (!llvm::equal(extractShape.drop_front(), shape.drop_front(rankDiff + 1)))
+ return false;
+
+ int64_t extractElements = ShapedType::getNumElements(extractShape);
+ int64_t shapeElements = ShapedType::getNumElements(shape);
+ return shapeElements % extractElements == 0;
+}
+
+/// Determines what shape to use with `vector.extract_strided_slice` to extract
+/// a contiguous memory region from a source vector. The extraction must be
+/// contiguous and contain exactly the specified number of elements. If such an
+/// extraction shape cannot be determined, returns std::nullopt.
+/// EXAMPLE 1:
+/// sourceShape = [16], targetElements = 8
+/// Working right-to-left:
+/// - Take min(8, 16) = 8 from only dim → extractShape = [8],
+/// remaining = 8/8 = 1
+/// Result: [8]
+///
+/// EXAMPLE 2:
+/// sourceShape = [4, 4], targetElements = 8
+/// Working right-to-left:
+/// - Take min(8, 4) = 4 from last dim → extractShape = [4],
+/// remaining = 8/4 = 2
+/// - Take min(2, 4) = 2 from first dim → extractShape = [2, 4],
+/// remaining = 2/2 = 1
+/// Result: [2, 4]
+static std::optional<SmallVector<int64_t>>
+calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
+ int64_t targetElements) {
+ SmallVector<int64_t> extractShape;
+ int64_t remainingElements = targetElements;
+
+ // Build extract shape from innermost dimension outward to ensure contiguity.
+ for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
+ int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
+ extractShape.insert(extractShape.begin(), takeFromDim);
+
+ if (remainingElements % takeFromDim != 0)
+ return std::nullopt; // Not evenly divisible.
+ remainingElements /= takeFromDim;
+ }
+
+ // Fill remaining dimensions with 1.
+ while (extractShape.size() < sourceShape.size())
+ extractShape.insert(extractShape.begin(), 1);
+
+ if (ShapedType::getNumElements(extractShape) != targetElements)
+ return std::nullopt;
+
+ return extractShape;
+}
+
+// Convert result offsets to source offsets via linear position.
+static SmallVector<int64_t>
+calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
+ ArrayRef<int64_t> sourceShape,
+ ArrayRef<int64_t> resultShape) {
+ // Convert result offsets to linear position.
+ int64_t linearIndex = linearize(resultOffsets, computeStrides(resultShape));
+ // Convert linear position to source offsets.
+ return delinearize(linearIndex, computeStrides(sourceShape));
+}
+
+/// This pattern unrolls `vector.shape_cast` operations according to the
+/// provided target unroll shape. It unrolls a large shape cast into smaller
+/// shape casts by extracting contiguous slices from the source vector, casting
+/// each slice to the target shape, and assembling the result by inserting each
+/// computed segment into the appropriate offset of the result vector.
+///
+/// This pattern only applies when contiguous slices can be extracted from the
+/// source vector and inserted into the result vector such that each slice
+/// remains a valid vector (and not decompose to scalars). In these cases, the
+/// unrolling proceeds as:
+/// vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
+/// vector.insert_strided_slice.
+///
+/// Example:
+/// Given a shape cast operation:
+/// %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32>
+///
+/// and a target unroll shape of <2x4>, the pattern produces:
+///
+/// %zero = arith.constant dense<0.0> : vector<4x4xf32>
+/// %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1]
+/// : vector<8x2xf32> to vector<4x2xf32>
+/// %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32>
+/// %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1]
+/// : vector<2x4xf32> into vector<4x4xf32>
+/// %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1]
+/// : vector<8x2xf32> to vector<4x2xf32>
+/// %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32>
+/// %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1]
+/// : vector<2x4xf32> into vector<4x4xf32>
+///
+struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
+ UnrollShapeCastPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ std::optional<SmallVector<int64_t>> targetShape =
+ getTargetShape(options, shapeCastOp);
+ if (!targetShape)
+ return failure();
+
+ VectorType sourceType = shapeCastOp.getSourceVectorType();
+ VectorType resultType = shapeCastOp.getResultVectorType();
+ ArrayRef<int64_t> sourceShape = sourceType.getShape();
+ ArrayRef<int64_t> resultShape = resultType.getShape();
+
+ if (!isContiguous(*targetShape, resultShape))
+ return rewriter.notifyMatchFailure(
+ shapeCastOp, "Only supports cases where target shape is "
+ "contiguous in result vector shape");
+
+ int64_t targetElements = ShapedType::getNumElements(*targetShape);
+
+ // Calculate the shape to extract from source.
+ std::optional<SmallVector<int64_t>> extractShape =
+ calculateSourceExtractShape(sourceShape, targetElements);
+ if (!extractShape)
+ return rewriter.notifyMatchFailure(
+ shapeCastOp,
+ "cannot extract target number of elements contiguously from source");
+
+ Location loc = shapeCastOp.getLoc();
+
+ // Create result vector initialized to zero.
+ Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+ rewriter.getZeroAttr(resultType));
+
+ VectorType targetType =
+ VectorType::get(*targetShape, sourceType.getElementType());
+
+ SmallVector<int64_t> extractStrides(extractShape->size(), 1);
+ SmallVector<int64_t> insertStrides(targetShape->size(), 1);
+
+ for (SmallVector<int64_t> resultOffsets :
+ StaticTileOffsetRange(resultShape, *targetShape)) {
+ SmallVector<int64_t> sourceOffsets =
+ calculateSourceOffsets(resultOffsets, sourceShape, resultShape);
+ Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
+ extractStrides);
+ Value targetChunk = rewriter.createOrFold<vector::ShapeCastOp>(
+ loc, targetType, sourceChunk);
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, targetChunk, result, resultOffsets, insertStrides);
+ }
+
+ rewriter.replaceOp(shapeCastOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
} // namespace
void mlir::vector::populateVectorUnrollPatterns(
@@ -1013,8 +1293,9 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
- UnrollToElements, UnrollStepPattern>(patterns.getContext(),
- options, benefit);
+ UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
+ UnrollCreateMaskPattern>(patterns.getContext(), options,
+ benefit);
}
void mlir::vector::populateVectorToElementsUnrollPatterns(
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index c809c502..c307fb4 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -322,46 +322,61 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
std::optional<Value> padValue,
bool useInBoundsInsteadOfMasking,
ArrayRef<bool> inputScalableVecDims) {
- assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
+ VectorType vecToReadTy = VectorType::get(
+ inputVectorSizes, cast<ShapedType>(source.getType()).getElementType(),
+ inputScalableVecDims);
+
+ return createReadOrMaskedRead(builder, loc, source, vecToReadTy, padValue,
+ useInBoundsInsteadOfMasking);
+}
+
+Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
+ Value source,
+ const VectorType &vecToReadTy,
+ std::optional<Value> padValue,
+ bool useInBoundsInsteadOfMasking) {
+ assert(!llvm::is_contained(vecToReadTy.getScalableDims(),
+ ShapedType::kDynamic) &&
"invalid input vector sizes");
auto sourceShapedType = cast<ShapedType>(source.getType());
auto sourceShape = sourceShapedType.getShape();
- assert(sourceShape.size() == inputVectorSizes.size() &&
+
+ int64_t vecToReadRank = vecToReadTy.getRank();
+ auto vecToReadShape = vecToReadTy.getShape();
+
+ assert(sourceShape.size() == static_cast<size_t>(vecToReadRank) &&
"expected same ranks.");
- auto vectorType =
- VectorType::get(inputVectorSizes, sourceShapedType.getElementType(),
- inputScalableVecDims);
assert((!padValue.has_value() ||
padValue.value().getType() == sourceShapedType.getElementType()) &&
"expected same pad element type to match source element type");
- int64_t readRank = inputVectorSizes.size();
+
auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
- SmallVector<bool> inBoundsVal(readRank, true);
+ SmallVector<bool> inBoundsVal(vecToReadRank, true);
if (useInBoundsInsteadOfMasking) {
// Update the inBounds attribute.
// FIXME: This computation is too weak - it ignores the read indices.
- for (unsigned i = 0; i < readRank; i++)
- inBoundsVal[i] = (sourceShape[i] == inputVectorSizes[i]) &&
+ for (unsigned i = 0; i < vecToReadRank; i++)
+ inBoundsVal[i] = (sourceShape[i] == vecToReadShape[i]) &&
ShapedType::isStatic(sourceShape[i]);
}
auto transferReadOp = vector::TransferReadOp::create(
builder, loc,
- /*vectorType=*/vectorType,
+ /*vectorType=*/vecToReadTy,
/*source=*/source,
- /*indices=*/SmallVector<Value>(readRank, zero),
+ /*indices=*/SmallVector<Value>(vecToReadRank, zero),
/*padding=*/padValue,
/*inBounds=*/inBoundsVal);
- if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking)
+ if (llvm::equal(vecToReadTy.getShape(), sourceShape) ||
+ useInBoundsInsteadOfMasking)
return transferReadOp;
SmallVector<OpFoldResult> mixedSourceDims =
isa<MemRefType>(source.getType())
? memref::getMixedSizes(builder, loc, source)
: tensor::getMixedSizes(builder, loc, source);
- auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(),
- inputScalableVecDims);
+ auto maskType = vecToReadTy.cloneWith(/*shape=*/{}, builder.getI1Type());
Value mask =
vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims);
return mlir::vector::maskOperation(builder, transferReadOp, mask)
diff --git a/mlir/lib/Dialect/X86Vector/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
index 9f57627..cb1e9d0 100644
--- a/mlir/lib/Dialect/X86Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000..f4c9f8a
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRX86VectorTransformOps
+ X86VectorTransformOps.cpp
+
+ DEPENDS
+ MLIRX86VectorTransformOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRVectorDialect
+ MLIRSideEffectInterfaces
+ MLIRTransformDialect
+ MLIRTransformDialectUtils
+ MLIRX86VectorDialect
+ MLIRX86VectorTransforms
+ )
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
new file mode 100644
index 0000000..95db208
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -0,0 +1,64 @@
+//===- X86VectorTransformOps.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+
+using namespace mlir;
+using namespace mlir::x86vector;
+using namespace mlir::transform;
+
+void mlir::transform::ApplyVectorContractToFMAPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ x86vector::populateVectorContractToFMAPatterns(patterns);
+}
+
+void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp::
+ populatePatterns(RewritePatternSet &patterns) {
+ x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns);
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class X86VectorTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ X86VectorTransformDialectExtension> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ X86VectorTransformDialectExtension)
+
+ X86VectorTransformDialectExtension() {
+ declareGeneratedDialect<x86vector::X86VectorDialect>();
+ declareGeneratedDialect<LLVM::LLVMDialect>();
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
+ >();
+ }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
+
+void mlir::x86vector::registerTransformDialectExtension(
+ DialectRegistry &registry) {
+ registry.addExtensions<X86VectorTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index c51266a..2cab50f 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -1,11 +1,14 @@
add_mlir_dialect_library(MLIRX86VectorTransforms
AVXTranspose.cpp
LegalizeForLLVMExport.cpp
+ VectorContractToFMA.cpp
+ VectorContractToPackedTypeDotProduct.cpp
LINK_LIBS PUBLIC
MLIRArithDialect
MLIRX86VectorDialect
MLIRIR
+ MLIRLinalgDialect
MLIRLLVMCommonConversion
MLIRLLVMDialect
MLIRVectorDialect
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
new file mode 100644
index 0000000..f3af5ca
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
@@ -0,0 +1,143 @@
+//===- VectorContractToFMA.cpp --------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+namespace {
+
+// Implements outer product contraction as a sequence of broadcast and
+// FMA operations.
+//
+// For example - for F32 type:
+// ```
+// vector.contract <1x1xf32>, <1x16xf32> into <1x16xf32>
+// ```
+// to
+// ```
+// vector.broadcast %lhs to <16xf32>
+// vector.fma vector<16xf32>
+// ```
+struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind.");
+
+ VectorType lhsTy = contractOp.getLhsType();
+ if (!lhsTy.getElementType().isF32())
+ return rewriter.notifyMatchFailure(contractOp,
+ "Only F32 lowering is supported.");
+
+ ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimLhs;
+ llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
+ [](int64_t dim) { return dim != 1; });
+
+ VectorType rhsTy = contractOp.getRhsType();
+ ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimRhs;
+ llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
+ [](int64_t dim) { return dim != 1; });
+
+ if (nonUnitDimLhs.size() > 0 && nonUnitDimRhs.size() > 0)
+ return rewriter.notifyMatchFailure(
+ contractOp, "Excepts unit dimensions for either LHS or RHS shape.");
+
+ if (nonUnitDimLhs.size() != 1 && nonUnitDimRhs.size() != 1)
+ return rewriter.notifyMatchFailure(
+ contractOp,
+ "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
+
+ VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+ if (!accTy)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Accmulator is not a vector type");
+
+ if (!accTy.getElementType().isF32())
+ return rewriter.notifyMatchFailure(contractOp,
+ "Accmulator should be F32 type.");
+
+ ArrayRef<int64_t> accShape = accTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimAcc;
+ llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
+ [](int64_t dim) { return dim != 1; });
+ if (nonUnitDimAcc.size() != 1)
+ return rewriter.notifyMatchFailure(
+ contractOp, "A or B dimension should be non-unit.");
+
+ // Lowers vector.contract into a broadcast+FMA sequence.
+ auto loc = contractOp.getLoc();
+ auto castAcc = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
+ contractOp.getAcc());
+
+ vector::FMAOp fma;
+
+ // Broadcast the unit-dimension LHS or RHS to match the vector length of the
+ // corresponding non-unit dimension on the other operand. For example,
+ // if LHS has type vector<1x1xf32> and RHS has type vector<1x16xf32>, we
+ // broadcast the LHS to vector<1x16xf32>. In the opposite case (non-unit
+ // dimension on the LHS), we broadcast the RHS instead.
+ if (nonUnitDimRhs.size() > 0) {
+ auto castLhs = vector::ShapeCastOp::create(
+ rewriter, loc, VectorType::get(1, lhsTy.getElementType()),
+ contractOp.getLhs());
+ auto castRhs = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()),
+ contractOp.getRhs());
+ auto broadcastLhs = vector::BroadcastOp::create(
+ rewriter, loc, castRhs.getResult().getType(), castLhs);
+ fma =
+ vector::FMAOp::create(rewriter, loc, broadcastLhs, castRhs, castAcc);
+ } else {
+ auto castLhs = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()),
+ contractOp.getLhs());
+ auto castRhs = vector::ShapeCastOp::create(
+ rewriter, loc, VectorType::get(1, rhsTy.getElementType()),
+ contractOp.getRhs());
+ auto broadcastRhs = vector::BroadcastOp::create(
+ rewriter, loc, castLhs.getResult().getType(), castRhs);
+ fma =
+ vector::FMAOp::create(rewriter, loc, castLhs, broadcastRhs, castAcc);
+ }
+
+ auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma);
+ rewriter.replaceOp(contractOp, castFma);
+
+ return success();
+ }
+};
+
+} // namespace
+
+void x86vector::populateVectorContractToFMAPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<VectorContractToFMA>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
new file mode 100644
index 0000000..1e64811
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -0,0 +1,301 @@
+//===- VectorContractToPackedTypeDotProduct.cpp ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+namespace {
+
+static FailureOr<SmallVector<mlir::utils::IteratorType>>
+inferIteratorsFromOutMap(AffineMap map) {
+ if (!map.isProjectedPermutation())
+ return failure();
+ SmallVector<mlir::utils::IteratorType> iterators(
+ map.getNumDims(), mlir::utils::IteratorType::reduction);
+ for (auto expr : map.getResults())
+ if (auto dim = dyn_cast<AffineDimExpr>(expr))
+ iterators[dim.getPosition()] = mlir::utils::IteratorType::parallel;
+ return iterators;
+}
+
+// Returns true if the operation is in VNNI layout.
+// Optionally, the check can be constrained to a specific VNNI blocking factor.
+static bool isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps,
+ std::optional<unsigned> blockingFactor) {
+ // Narrow down type operations - VNNI only applies to contractions.
+ FailureOr<linalg::ContractionDimensions> dims =
+ linalg::inferContractionDims(indexingMaps);
+ if (failed(dims))
+ return false;
+
+ auto matA = op->getOperand(0);
+ auto matB = op->getOperand(1);
+ auto typeA = dyn_cast<ShapedType>(matA.getType());
+ auto typeB = dyn_cast<ShapedType>(matB.getType());
+ unsigned rankA = typeA.getRank();
+ unsigned rankB = typeB.getRank();
+ // VNNI format requires at least 1 parallel and 2 reduction dimensions.
+ if (rankA < 3 || rankB < 3)
+ return false;
+
+ // At least two reduction dimensions are expected:
+ // one for the VNNI factor and one for the K dimension
+ if (dims->k.size() < 2)
+ return false;
+
+ // Validate affine maps - VNNI computation should be defined by the two
+ // innermost reduction iterators.
+ // The input matrix dimensions layout must match the following:
+ // - matrix A - [...][K/vnniFactor][vnniFactor]
+ // - matrix B - [...][K/vnniFactor][N][vnniFactor]
+ auto maybeIters = inferIteratorsFromOutMap(indexingMaps[2]);
+ if (failed(maybeIters))
+ return false;
+ SmallVector<mlir::utils::IteratorType> iteratorTypes = *maybeIters;
+ AffineMap mapA = indexingMaps[0];
+ AffineMap mapB = indexingMaps[1];
+
+ auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 1));
+ auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 1));
+ if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
+ iteratorTypes[vnniDimA.getPosition()] !=
+ mlir::utils::IteratorType::reduction)
+ return false;
+ auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 2));
+ auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 3));
+ if (!redDimA || !redDimB || redDimA != redDimB ||
+ iteratorTypes[redDimA.getPosition()] !=
+ mlir::utils::IteratorType::reduction)
+ return false;
+ auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 2));
+ if (!parallelDimB || iteratorTypes[parallelDimB.getPosition()] !=
+ mlir::utils::IteratorType::parallel)
+ return false;
+
+ // VNNI factor must be:
+ // - the innermost inputs' dimension
+ // - statically known
+ // - multiple of 2 or equal to the specified factor
+ auto vnniDimSize = typeB.getShape().back();
+ if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 ||
+ vnniDimSize % 2 != 0)
+ return false;
+ if (typeA.getShape().back() != vnniDimSize)
+ return false;
+ if (blockingFactor && vnniDimSize != *blockingFactor)
+ return false;
+
+ // The split reduction dimension size should also match.
+ if (typeA.getShape().end()[-2] != typeB.getShape().end()[-3])
+ return false;
+
+ return true;
+}
+
+// Implements packed type outer product contraction as a sequence
+// of broadcast and packed dot-product operations.
+//
+// For example - for F32 type:
+// ```
+// vector.contract <1x1x2xbf16>, <1x16x2xbf16> into <1x16xf32>
+// ```
+// to
+// ```
+// vector.broadcast %lhs to <32xbf16>
+// x86vector.avx512.dot vector<32xbf16> -> vector<16xf32>
+// ```
+struct VectorContractToPackedTypeDotProduct
+ : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind.");
+
+ VectorType lhsTy = contractOp.getLhsType();
+ if (!lhsTy.getElementType().isBF16() &&
+ !lhsTy.getElementType().isSignlessInteger(8))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Only BF16/Int8 lowering is supported.");
+
+ unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4;
+ if (!isInVnniLayout(contractOp.getOperation(),
+ contractOp.getIndexingMapsArray(), blockingFactor))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Input matrices not in VNNI format.");
+
+ ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimLhs;
+ llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
+ [](int64_t dim) { return dim != 1; });
+
+ VectorType rhsTy = contractOp.getRhsType();
+ ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimRhs;
+ llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
+ [](int64_t dim) { return dim != 1; });
+
+ if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Excepts unit dimensions for either "
+ "LHS or RHS shape other than VNNI.");
+
+ if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
+ return rewriter.notifyMatchFailure(
+ contractOp,
+ "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
+
+ VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+ if (!accTy)
+ return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
+
+ if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
+ (lhsTy.getElementType().isSignlessInteger(8) &&
+ !accTy.getElementType().isSignlessInteger(32)))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Only F32 for BF16 or Int32 for Int8 "
+ "accumulation type is supported.");
+
+ ArrayRef<int64_t> accShape = accTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimAcc;
+ llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
+ [](int64_t dim) { return dim != 1; });
+ if (nonUnitDimAcc.size() != 1)
+ return rewriter.notifyMatchFailure(
+ contractOp, "A or B should be a non-unit dim in acc.");
+
+ // Non-unit dimensions should match the vector length of BF16 or Int8
+ // dot-product.
+ unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front()
+ : nonUnitDimRhs.front();
+ if (lhsTy.getElementType().isBF16() && nonUnitDim != 4 && nonUnitDim != 8 &&
+ nonUnitDim != 16 && nonUnitDimAcc.front() == nonUnitDim)
+ return rewriter.notifyMatchFailure(
+ contractOp, "BF16 dot-product operation expects non-unit (LHR or "
+ "RHS) dim and acc dim of size 4/8/16.");
+
+ if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDim != 4 &&
+ nonUnitDim != 8 && nonUnitDimAcc.front() == nonUnitDim)
+ return rewriter.notifyMatchFailure(
+ contractOp, "Int8 dot-product operation expects non-unit (LHR or "
+ "RHS) dim and acc dim of size 4/8.");
+
+ auto loc = contractOp.getLoc();
+ auto castAcc = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
+ contractOp.getAcc());
+
+ Value dp;
+
+ // Broadcast the unit-dimension LHS or RHS to match the vector length of the
+ // corresponding non-unit dimension on the other operand. For example,
+ // if LHS has type vector<1x1x2xbf16> and RHS has type vector<1x16x2xbf16>,
+ // we broadcast the LHS to vector<16x2xbf16>. In the opposite case (non-unit
+ // dimension on the LHS), we broadcast the RHS instead.
+ if ((nonUnitDimRhs.size() - 1) > 0) {
+ auto castRhs = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimRhs.front() * nonUnitDimRhs.back(),
+ rhsTy.getElementType()),
+ contractOp.getRhs());
+ auto castLhs = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()),
+ contractOp.getLhs());
+ auto bitcastLhs = vector::BitCastOp::create(
+ rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
+ castLhs);
+ auto broadcastLhs = vector::BroadcastOp::create(
+ rewriter, loc,
+ VectorType::get({nonUnitDimRhs.front()}, rewriter.getIntegerType(32)),
+ bitcastLhs);
+ auto bitcastLhsPkType = vector::BitCastOp::create(
+ rewriter, loc, castRhs.getResult().getType(), broadcastLhs);
+
+ if (lhsTy.getElementType().isBF16()) {
+ dp = x86vector::DotBF16Op::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimRhs.front(), rewriter.getF32Type()),
+ castAcc, bitcastLhsPkType, castRhs);
+ }
+
+ if (lhsTy.getElementType().isSignlessInteger(8)) {
+ dp = x86vector::DotInt8Op::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimRhs.front(), rewriter.getIntegerType(32)),
+ castAcc, bitcastLhsPkType, castRhs);
+ }
+ } else {
+ auto castLhs = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimLhs.front() * nonUnitDimLhs.back(),
+ lhsTy.getElementType()),
+ contractOp.getLhs());
+ auto castRhs = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()),
+ contractOp.getRhs());
+ auto bitcastRhs = vector::BitCastOp::create(
+ rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
+ castRhs);
+ auto broadcastRhs = vector::BroadcastOp::create(
+ rewriter, loc,
+ VectorType::get({nonUnitDimLhs.front()}, rewriter.getIntegerType(32)),
+ bitcastRhs);
+ auto bitcastRhsPkType = vector::BitCastOp::create(
+ rewriter, loc, castLhs.getResult().getType(), broadcastRhs);
+
+ if (lhsTy.getElementType().isBF16()) {
+ dp = x86vector::DotBF16Op::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimLhs.front(), rewriter.getF32Type()),
+ castAcc, castLhs, bitcastRhsPkType);
+ }
+
+ if (lhsTy.getElementType().isSignlessInteger(8)) {
+ dp = x86vector::DotInt8Op::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimLhs.front(), rewriter.getIntegerType(32)),
+ castAcc, castLhs, bitcastRhsPkType);
+ }
+ }
+
+ if (!dp)
+ return failure();
+
+ auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp);
+ rewriter.replaceOp(contractOp, castDp);
+ return success();
+ }
+};
+
+} // namespace
+
+void x86vector::populateVectorContractToPackedTypeDotProductPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<VectorContractToPackedTypeDotProduct>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index fb5d1e7..1a19ab5 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -8,7 +8,6 @@
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
@@ -61,7 +60,7 @@ genCoordinates(OpBuilder &builder, Location loc,
// Get the offset of `subShape` within a distribution unit.
SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector(
llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value {
- return builder.createOrFold<index::MulOp>(
+ return builder.createOrFold<arith::MulIOp>(
loc, std::get<0>(t),
builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
});
@@ -84,7 +83,7 @@ genCoordinates(OpBuilder &builder, Location loc,
// Do not go beyond `srcShape` bounds.
SmallVector<Value> mods = llvm::map_to_vector(
llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value {
- return builder.createOrFold<index::RemUOp>(
+ return builder.createOrFold<arith::RemUIOp>(
loc, std::get<0>(t),
arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
});
@@ -343,7 +342,7 @@ LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
/// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within
/// this dimension)
result[dimIdx] =
- builder.createOrFold<index::RemUOp>(loc, remaining, dimSizeVal);
+ builder.createOrFold<arith::RemUIOp>(loc, remaining, dimSizeVal);
/// Update remaining for the next dimension by removing what we've already
/// processed. Division tells us "how many complete groups of this dimension
@@ -352,7 +351,7 @@ LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
/// no next dimension to process
if (i < order.size() - 1) {
remaining =
- builder.createOrFold<index::DivUOp>(loc, remaining, dimSizeVal);
+ builder.createOrFold<arith::DivUIOp>(loc, remaining, dimSizeVal);
}
}
return result;
@@ -391,6 +390,86 @@ LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
return genCoordinates(builder, loc, ids, layout, subShape, shape);
}
+bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
+ if (dyn_cast<xegpu::SliceAttr>(other))
+ return false;
+
+ return *this == dyn_cast<xegpu::LayoutAttr>(other);
+}
+
+// set the layout for unit dims: sg_data, inst_data and lane_data to 1
+DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) {
+ auto sgDataOpt = getSgData();
+ auto instDataOpt = getInstData();
+ auto laneDataOpt = getLaneData();
+
+ SmallVector<int32_t> sgData;
+ SmallVector<int32_t> instData;
+ SmallVector<int32_t> laneData;
+
+ if (sgDataOpt) {
+ sgData = llvm::to_vector(sgDataOpt.asArrayRef());
+ }
+ if (instDataOpt) {
+ instData = llvm::to_vector(instDataOpt.asArrayRef());
+ }
+ if (laneDataOpt) {
+ laneData = llvm::to_vector(laneDataOpt.asArrayRef());
+ }
+
+ for (auto dim : unitDims) {
+ if (dim < static_cast<int64_t>(sgData.size()))
+ sgData[dim] = 1;
+ if (dim < static_cast<int64_t>(instData.size()))
+ instData[dim] = 1;
+ if (dim < static_cast<int64_t>(laneData.size()))
+ laneData[dim] = 1;
+ }
+
+ return LayoutAttr::get(
+ getContext(), getSgLayout(),
+ sgData.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), sgData),
+ instData.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), instData),
+ getLaneLayout(),
+ laneData.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), laneData),
+ getOrder());
+}
+
+// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
+DistributeLayoutAttr LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
+ auto sgLayoutOpt = getSgLayout();
+ auto laneLayoutOpt = getLaneLayout();
+
+ SmallVector<int32_t> sgLayout;
+ SmallVector<int32_t> laneLayout;
+
+ if (sgLayoutOpt) {
+ sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef());
+ }
+ if (laneLayoutOpt) {
+ laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef());
+ }
+
+ for (auto dim : unitDims) {
+ if (dim < static_cast<int64_t>(sgLayout.size()))
+ sgLayout[dim] = 1;
+ if (dim < static_cast<int64_t>(laneLayout.size()))
+ laneLayout[dim] = 1;
+ }
+
+ return LayoutAttr::get(
+ getContext(),
+ sgLayout.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), sgLayout),
+ getSgData(), getInstData(),
+ laneLayout.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), laneLayout),
+ getLaneData(), getOrder());
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_SliceAttr
//===----------------------------------------------------------------------===//
@@ -511,6 +590,69 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
[&](int64_t dim) { return thisDims.contains(dim); });
}
+bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
+ if (dyn_cast<xegpu::LayoutAttr>(other))
+ return false;
+
+ auto flattenedThis = flatten();
+ auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
+
+ return ((flattenedThis.getParent() == flattenedOther.getParent()) &&
+ (flattenedThis.getDims() == flattenedOther.getDims()));
+}
+
+// Helper function to adjust unit dimensions from sliced space to parent space
+static SetVector<int64_t>
+adjustUnitDimsWithSliceDims(const SetVector<int64_t> &unitDims,
+ ArrayRef<int64_t> sliceDims) {
+ // Reconstruct parent's non-sliced dimensions
+
+ int64_t parentRank = sliceDims.size() + unitDims.size();
+ llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(),
+ sliceDims.end());
+ SmallVector<int64_t> nonSlicedDims;
+ for (int64_t i = 0; i < parentRank; ++i) {
+ if (!slicedDimsSet.contains(i))
+ nonSlicedDims.push_back(i);
+ }
+
+ // Map unit dims from sliced space to parent space
+ SetVector<int64_t> adjustUnitDims;
+ for (auto dim : unitDims) {
+ if (dim < static_cast<int64_t>(nonSlicedDims.size())) {
+ adjustUnitDims.insert(nonSlicedDims[dim]);
+ }
+ }
+
+ return adjustUnitDims;
+}
+
+// set the layout for unit dims: sg_data, inst_data and lane_data to 1
+DistributeLayoutAttr SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) {
+ SliceAttr attr = flatten();
+ ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+
+ SetVector<int64_t> adjustUnitDims =
+ adjustUnitDimsWithSliceDims(unitDims, sliceDims);
+
+ return SliceAttr::get(getContext(), parent.setUnitDimData(adjustUnitDims),
+ attr.getDims());
+}
+
+// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
+DistributeLayoutAttr SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
+ SliceAttr attr = flatten();
+ ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+
+ SetVector<int64_t> adjustUnitDims =
+ adjustUnitDimsWithSliceDims(unitDims, sliceDims);
+
+ return SliceAttr::get(getContext(), parent.setUnitDimLayout(adjustUnitDims),
+ attr.getDims());
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_RangeAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 4dd10be..91ba07a 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -465,14 +465,15 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
xegpu::CachePolicyAttr l3_hint) {
return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(),
- l1_hint, l2_hint, l3_hint);
+ l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr);
}
void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
- xegpu::CachePolicyAttr l3_hint) {
+ xegpu::CachePolicyAttr l3_hint,
+ xegpu::DistributeLayoutAttr layout) {
SmallVector<Value> dynamicOffsets;
SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -480,7 +481,7 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
- l2_hint, l3_hint);
+ l2_hint, l3_hint, /*anchor_layout=*/layout);
}
LogicalResult PrefetchNdOp::verify() {
@@ -519,7 +520,7 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
return build(builder, state, retType, tensorDesc, ValueRange(),
DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint,
- l3_hint);
+ l3_hint, /*anchor_layout=*/nullptr);
}
void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
@@ -527,7 +528,8 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
UnitAttr packed, DenseI64ArrayAttr transpose,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
- xegpu::CachePolicyAttr l3_hint) {
+ xegpu::CachePolicyAttr l3_hint,
+ xegpu::DistributeLayoutAttr layout) {
SmallVector<Value> dynamicOffsets;
SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -535,7 +537,8 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
- packed, transpose, l1_hint, l2_hint, l3_hint);
+ packed, transpose, l1_hint, l2_hint, l3_hint,
+ /*anchor_layout=*/layout);
}
LogicalResult LoadNdOp::verify() {
@@ -638,14 +641,16 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
xegpu::CachePolicyAttr l3_hint) {
return build(builder, state, value, tensorDesc, ValueRange(),
- DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
+ DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint,
+ /*anchor_layout=*/nullptr);
}
void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
- xegpu::CachePolicyAttr l3_hint) {
+ xegpu::CachePolicyAttr l3_hint,
+ xegpu::DistributeLayoutAttr layout) {
SmallVector<Value> dynamicOffsets;
SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -653,7 +658,7 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
- l1_hint, l2_hint, l3_hint);
+ l1_hint, l2_hint, l3_hint, /*anchor_layout=*/layout);
}
LogicalResult StoreNdOp::verify() {
@@ -826,7 +831,7 @@ void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint,
- IntegerAttr{});
+ IntegerAttr{}, /*anchor_layout=*/nullptr);
}
//===----------------------------------------------------------------------===//
@@ -876,7 +881,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
- l1_hint, l2_hint, l3_hint, /*layout=*/nullptr);
+ l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr);
}
void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
@@ -892,7 +897,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
auto offset = vector::FromElementsOp::create(builder, loc, type, values);
build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
- l2_hint, l3_hint, /*layout=*/nullptr);
+ l2_hint, l3_hint, /*anchor_layout=*/nullptr);
}
void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
@@ -901,7 +906,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint,
- xegpu::LayoutAttr layout) {
+ DistributeLayoutAttr layout) {
auto loc = source.getLoc();
int64_t size = static_cast<int64_t>(offsets.size());
auto type = VectorType::get(size, builder.getIndexType());
@@ -960,7 +965,7 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
- l2_hint, l3_hint, /*layout=*/nullptr);
+ l2_hint, l3_hint, /*anchor_layout=*/nullptr);
}
void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
@@ -978,14 +983,14 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
// Call the correct builder overload that does not expect result types.
build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
- l3_hint, /*layout=*/nullptr);
+ l3_hint, /*anchor_layout=*/nullptr);
}
void StoreScatterOp::build(
OpBuilder &builder, OperationState &state, Value value, Value dest,
ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
- xegpu::CachePolicyAttr l3_hint, xegpu::LayoutAttr layout) {
+ xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
auto loc = dest.getLoc();
int64_t size = static_cast<int64_t>(offsets.size());
auto type = VectorType::get(size, builder.getIndexType());
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 8943ba0..e6009d5 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -7,12 +7,17 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include <optional>
+#include "llvm/Support/DebugLog.h"
+#define DEBUG_TYPE "xegpu-transforms"
+
using namespace mlir;
using namespace mlir::transform;
@@ -76,6 +81,45 @@ static DiagnosedSilenceableFailure convertMixedValuesToInt(
return DiagnosedSilenceableFailure::success();
}
+/// Find producer operation of type T for the given value.
+/// It's assumed that producer ops are chained through their first operand.
+/// Producer chain is traced trough loop block arguments (init values).
+template <typename T>
+static std::optional<T> findProducerOfType(Value val) {
+ Value currentValue = val;
+ if (!currentValue.getDefiningOp()) {
+ // Value may be a block argument initialized outside a loop.
+ if (val.getNumUses() == 0) {
+ LDBG() << "Failed to find producer op, value has no uses.";
+ return std::nullopt;
+ }
+ auto userOp = val.getUsers().begin();
+ auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
+ if (!parentLoop) {
+ LDBG() << "Failed to find producer op, not in a loop.";
+ return std::nullopt;
+ }
+ int64_t iterArgIdx;
+ if (auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) {
+ auto numInductionVars = parentLoop.getLoopInductionVars()->size();
+ iterArgIdx = iterArg.getArgNumber() - numInductionVars;
+ currentValue = parentLoop.getInits()[iterArgIdx];
+ } else {
+ LDBG() << "Failed to find producer op, value not in init values.";
+ return std::nullopt;
+ }
+ }
+ Operation *producerOp = currentValue.getDefiningOp();
+
+ if (auto matchingOp = dyn_cast<T>(producerOp))
+ return matchingOp;
+
+ if (producerOp->getNumOperands() == 0)
+ return std::nullopt;
+
+ return findProducerOfType<T>(producerOp->getOperand(0));
+}
+
/// Create a layout attribute from the given parameters.
static xegpu::LayoutAttr
createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
@@ -90,10 +134,41 @@ createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
/*order=*/nullptr);
}
+/// Generate `xegpu::LayoutAttr` from op mixed layout values.
+DiagnosedSilenceableFailure
+getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state,
+ TransformOpInterface transformOp,
+ ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
+ ArrayRef<::mlir::OpFoldResult> mixedSgData,
+ ArrayRef<::mlir::OpFoldResult> mixedInstData,
+ xegpu::LayoutAttr &layoutAttr) {
+ SmallVector<int32_t> sgLayout, sgData, instData;
+ auto status =
+ convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout);
+ if (!status.succeeded())
+ return status;
+
+ status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData);
+ if (!status.succeeded())
+ return status;
+
+ status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData);
+ if (!status.succeeded())
+ return status;
+ auto maybeInstData = instData.empty()
+ ? std::nullopt
+ : std::optional<ArrayRef<int32_t>>(instData);
+
+ layoutAttr = createLayoutAttr(ctx, sgLayout, sgData, maybeInstData);
+
+ return DiagnosedSilenceableFailure::success();
+}
+
/// Replace xegpu.create_nd_desc op with a new one with the given layout.
static xegpu::CreateNdDescOp
setDescLayout(transform::TransformRewriter &rewriter,
- xegpu::CreateNdDescOp descOp, xegpu::LayoutAttr layout) {
+ xegpu::CreateNdDescOp descOp,
+ xegpu::DistributeLayoutAttr layout) {
assert(descOp.getMixedOffsets().size() == 0 &&
"create desc op with offsets is not supported");
auto oldTensorDesc = descOp.getType();
@@ -111,11 +186,35 @@ setDescLayout(transform::TransformRewriter &rewriter,
return newDescOp;
}
+DiagnosedSilenceableFailure
+transform::GetDescOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto targetValues = state.getPayloadValues(getTarget());
+ if (!llvm::hasSingleElement(targetValues)) {
+ return emitDefiniteFailure()
+ << "requires exactly one target value handle (got "
+ << llvm::range_size(targetValues) << ")";
+ }
+
+ auto maybeDescOp =
+ findProducerOfType<xegpu::CreateNdDescOp>(*targetValues.begin());
+ if (!maybeDescOp) {
+ return emitSilenceableFailure(getLoc())
+ << "Could not find a matching descriptor op when walking the "
+ "producer chain of the first operand.";
+ }
+
+ results.set(llvm::cast<OpResult>(getResult()), {*maybeDescOp});
+ return DiagnosedSilenceableFailure::success();
+}
+
void transform::SetDescLayoutOp::build(OpBuilder &builder,
OperationState &result, Value target,
ArrayRef<OpFoldResult> mixedSgLayout,
ArrayRef<OpFoldResult> mixedSgData,
- ArrayRef<OpFoldResult> mixedInstData) {
+ ArrayRef<OpFoldResult> mixedInstData,
+ ArrayRef<int64_t> sliceDims) {
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
@@ -128,7 +227,8 @@ void transform::SetDescLayoutOp::build(OpBuilder &builder,
/*inst_data=*/dynamicInstData,
/*static_sg_layout=*/staticSgLayout,
/*static_sg_data=*/staticSgData,
- /*static_inst_data=*/staticInstData);
+ /*static_inst_data=*/staticInstData,
+ /*slice_dims=*/sliceDims);
}
DiagnosedSilenceableFailure
@@ -142,25 +242,20 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
}
Operation *target = *targetOps.begin();
- SmallVector<int32_t> sgLayout;
- DiagnosedSilenceableFailure status =
- convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout());
- if (!status.succeeded())
- return status;
-
- SmallVector<int32_t> sgData;
- status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData());
+ xegpu::LayoutAttr layoutAttr = nullptr;
+ auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
+ getMixedSgLayout(), getMixedSgData(),
+ getMixedInstData(), layoutAttr);
if (!status.succeeded())
return status;
- SmallVector<int32_t> instData;
- status =
- convertMixedValuesToInt(state, (*this), instData, getMixedInstData());
- if (!status.succeeded())
- return status;
- auto maybeInstData = instData.empty()
- ? std::nullopt
- : std::optional<ArrayRef<int32_t>>(instData);
+ xegpu::DistributeLayoutAttr layout = layoutAttr;
+ auto sliceDims = getSliceDims();
+ if (sliceDims.size() > 0) {
+ // Wrap layoutAttr in a slice attribute.
+ layout = xegpu::SliceAttr::get(
+ getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
+ }
// For now only create_nd_desc op is supported.
auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
@@ -173,9 +268,7 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
}
// Set layout attr in desc op's return type. Replaces old desc op.
- auto layoutAttr =
- createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
- auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
+ auto newdescOp = setDescLayout(rewriter, descOp, layout);
// Map result handles.
results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
@@ -193,6 +286,383 @@ void transform::SetDescLayoutOp::getEffects(
modifiesPayload(effects);
}
+void transform::SetOpLayoutAttrOp::build(
+ OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
+ ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
+ ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int64_t> sliceDims,
+ bool result) {
+ SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
+ SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
+ dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
+ dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
+ dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
+ build(builder, ostate, target.getType(),
+ /*target=*/target,
+ /*index=*/index,
+ /*sg_layout=*/dynamicSgLayout,
+ /*sg_data=*/dynamicSgData,
+ /*inst_data=*/dynamicInstData,
+ /*static_sg_layout=*/staticSgLayout,
+ /*static_sg_data=*/staticSgData,
+ /*static_inst_data=*/staticInstData,
+ /*slice_dims=*/sliceDims,
+ /*result=*/result);
+}
+
+DiagnosedSilenceableFailure
+transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto targetOps = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(targetOps)) {
+ return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
+ << llvm::range_size(targetOps) << ")";
+ }
+ Operation *target = *targetOps.begin();
+
+ bool resultTarget = getResult();
+
+ int64_t index = getIndex();
+ if (resultTarget && index >= target->getNumResults()) {
+ return emitSilenceableFailure(getLoc())
+ << "Index exceeds the number of op results";
+ }
+ if (!resultTarget && index >= target->getNumOperands()) {
+ return emitSilenceableFailure(getLoc())
+ << "Index exceeds the number of op operands";
+ }
+
+ xegpu::LayoutAttr layoutAttr = nullptr;
+ auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
+ getMixedSgLayout(), getMixedSgData(),
+ getMixedInstData(), layoutAttr);
+ if (!status.succeeded())
+ return status;
+
+ xegpu::DistributeLayoutAttr layout = layoutAttr;
+ auto sliceDims = getSliceDims();
+ if (sliceDims.size() > 0) {
+ // Wrap layoutAttr in a slice attribute.
+ layout = xegpu::SliceAttr::get(
+ getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
+ }
+
+ // Set layout attribute for the op result or operand
+ if (resultTarget)
+ xegpu::setDistributeLayoutAttr(target->getResult(index), layout);
+ else
+ xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layout);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetOpLayoutAttrOp::getEffects(
+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getSgLayoutMutable(), effects);
+ onlyReadsHandle(getSgDataMutable(), effects);
+ onlyReadsHandle(getInstDataMutable(), effects);
+ modifiesPayload(effects);
+}
+
+void transform::SetGPULaunchThreadsOp::build(
+ OpBuilder &builder, OperationState &ostate, Value target,
+ ArrayRef<OpFoldResult> mixedThreads) {
+ SmallVector<int64_t> staticThreads;
+ SmallVector<Value> dynamicThreads;
+ dispatchIndexOpFoldResults(mixedThreads, dynamicThreads, staticThreads);
+ build(builder, ostate, target.getType(),
+ /*target=*/target,
+ /*threads=*/dynamicThreads,
+ /*static_threads=*/staticThreads);
+}
+
+DiagnosedSilenceableFailure
+transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto targetOps = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(targetOps)) {
+ return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
+ << llvm::range_size(targetOps) << ")";
+ }
+ Operation *target = *targetOps.begin();
+
+ auto launchOp = dyn_cast<gpu::LaunchOp>(target);
+ if (!launchOp) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Expected a gpu.launch op, but got: " << target->getName();
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ SmallVector<int32_t> threads;
+ DiagnosedSilenceableFailure status =
+ convertMixedValuesToInt(state, (*this), threads, getMixedThreads());
+ if (!status.succeeded())
+ return status;
+
+ if (threads.size() != 3) {
+ return emitSilenceableFailure(getLoc())
+ << "Expected threads argument to consist of three values (got "
+ << threads.size() << ")";
+ }
+
+ rewriter.setInsertionPoint(launchOp);
+ auto createConstValue = [&](int value) {
+ return arith::ConstantIndexOp::create(rewriter, launchOp.getLoc(), value);
+ };
+
+ // Replace threads in-place.
+ launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0]));
+ launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1]));
+ launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2]));
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetGPULaunchThreadsOp::getEffects(
+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getThreadsMutable(), effects);
+ modifiesPayload(effects);
+}
+
+DiagnosedSilenceableFailure
+transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto targetValues = state.getPayloadValues(getTarget());
+ if (!llvm::hasSingleElement(targetValues))
+ return emitDefiniteFailure()
+ << "requires exactly one target value handle (got "
+ << llvm::range_size(targetValues) << ")";
+ auto value = *targetValues.begin();
+
+ int64_t nbPrefetch = getStaticNbPrefetch();
+ if (getDynamicNbPrefetch()) {
+ // Get dynamic prefetch count from transform param or handle.
+ SmallVector<int32_t> dynamicNbPrefetch;
+ auto status = convertMixedValuesToInt(state, (*this), dynamicNbPrefetch,
+ {getDynamicNbPrefetch()});
+ if (!status.succeeded())
+ return status;
+ if (dynamicNbPrefetch.size() != 1)
+ return emitDefiniteFailure()
+ << "requires exactly one value for dynamic_nb_prefetch";
+ nbPrefetch = dynamicNbPrefetch[0];
+ }
+ if (nbPrefetch <= 0)
+ return emitSilenceableFailure(getLoc())
+ << "nb_prefetch must be a positive integer.";
+
+ // Find load operation of the operand.
+ auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
+ if (!maybeLoadOp)
+ return emitSilenceableFailure(getLoc()) << "Could not find load op.";
+ auto loadOp = *maybeLoadOp;
+ if (loadOp.getMixedOffsets().size() == 0) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Load op must have offsets.";
+ diag.attachNote(loadOp.getLoc()) << "load op";
+ return diag;
+ }
+
+ // Find the parent scf.for loop.
+ auto forOp = loadOp->getParentOfType<scf::ForOp>();
+ if (!forOp) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Load op is not contained in a scf.for loop.";
+ diag.attachNote(loadOp.getLoc()) << "load op";
+ return diag;
+ }
+
+ // Find descriptor op.
+ auto maybeDescOp = findProducerOfType<xegpu::CreateNdDescOp>(value);
+ if (!maybeDescOp)
+ return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
+ auto descOp = *maybeDescOp;
+ if (descOp.getMixedOffsets().size() > 0) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "desc op with offsets is not supported.";
+ diag.attachNote(descOp.getLoc()) << "desc op";
+ }
+
+ // Clone desc op outside the loop.
+ rewriter.setInsertionPoint(forOp);
+ auto newDescOp =
+ cast<xegpu::CreateNdDescOp>(rewriter.clone(*descOp.getOperation()));
+
+ // Clone reduction loop to emit initial prefetches.
+ // Compute upper bound of the init loop: start + nbPrefetch * step.
+ auto nbPrefetchCst =
+ arith::ConstantIndexOp::create(rewriter, forOp.getLoc(), nbPrefetch);
+ auto nbStep = rewriter.createOrFold<arith::MulIOp>(
+ forOp.getLoc(), nbPrefetchCst, forOp.getStep());
+ auto initUpBound = rewriter.createOrFold<arith::AddIOp>(
+ forOp.getLoc(), forOp.getLowerBound(), nbStep);
+ auto initForOp =
+ scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
+ initUpBound, forOp.getStep());
+
+ auto ctx = rewriter.getContext();
+ auto readCacheHint =
+ xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED);
+
+ // Modify loadOp mixedOffsets by replacing the for loop induction variable
+ // with the given value.
+ auto getPrefetchOffsets =
+ [&](Value replacementVal) -> SmallVector<OpFoldResult> {
+ IRMapping mapping;
+ mapping.map(forOp.getInductionVar(), replacementVal);
+ SmallVector<Value> dynamicOffsets =
+ llvm::to_vector(llvm::map_range(loadOp.getOffsets(), [&](Value v) {
+ return mapping.lookupOrDefault(v);
+ }));
+ auto constOffsets = loadOp.getConstOffsets().value();
+ return getMixedValues(constOffsets, dynamicOffsets, ctx);
+ };
+
+ // Insert prefetch op in init loop.
+ // Replace induction var with the init loop induction var.
+ rewriter.setInsertionPointToStart(initForOp.getBody());
+ xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
+ newDescOp.getResult(),
+ getPrefetchOffsets(initForOp.getInductionVar()),
+ readCacheHint, readCacheHint, readCacheHint,
+ /*layout=*/nullptr);
+
+ // Insert prefetch op in main loop.
+ // Calculate prefetch offset after the init prefetches have been issued.
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(),
+ forOp.getInductionVar(), nbStep);
+ // Replace induction var with correct offset.
+ xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
+ newDescOp.getResult(),
+ getPrefetchOffsets(prefetchOffset), readCacheHint,
+ readCacheHint, readCacheHint, /*layout=*/nullptr);
+
+ // Unroll the init loop.
+ if (failed(loopUnrollFull(initForOp)))
+ return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop";
+
+ results.set(llvm::cast<OpResult>(getResult()), {newDescOp});
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::InsertPrefetchOp::getEffects(
+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getDynamicNbPrefetchMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
+}
+
+void transform::ConvertLayoutOp::build(
+ OpBuilder &builder, OperationState &ostate, Value target,
+ ArrayRef<OpFoldResult> mixedInputSgLayout,
+ ArrayRef<OpFoldResult> mixedInputSgData,
+ ArrayRef<OpFoldResult> mixedInputInstData,
+ ArrayRef<OpFoldResult> mixedTargetSgLayout,
+ ArrayRef<OpFoldResult> mixedTargetSgData,
+ ArrayRef<OpFoldResult> mixedTargetInstData) {
+ SmallVector<int64_t> staticInputSgLayout, staticInputSgData,
+ staticInputInstData;
+ SmallVector<Value> dynamicInputSgLayout, dynamicInputSgData,
+ dynamicInputInstData;
+ dispatchIndexOpFoldResults(mixedInputSgLayout, dynamicInputSgLayout,
+ staticInputSgLayout);
+ dispatchIndexOpFoldResults(mixedInputSgData, dynamicInputSgData,
+ staticInputSgData);
+ dispatchIndexOpFoldResults(mixedInputInstData, dynamicInputInstData,
+ staticInputInstData);
+ SmallVector<int64_t> staticTargetSgLayout, staticTargetSgData,
+ staticTargetInstData;
+ SmallVector<Value> dynamicTargetSgLayout, dynamicTargetSgData,
+ dynamicTargetInstData;
+ dispatchIndexOpFoldResults(mixedTargetSgLayout, dynamicTargetSgLayout,
+ staticTargetSgLayout);
+ dispatchIndexOpFoldResults(mixedTargetSgData, dynamicTargetSgData,
+ staticTargetSgData);
+ dispatchIndexOpFoldResults(mixedTargetInstData, dynamicTargetInstData,
+ staticTargetInstData);
+ build(builder, ostate, target.getType(),
+ /*target=*/target,
+ /*input_sg_layout=*/dynamicInputSgLayout,
+ /*input_sg_data=*/dynamicInputSgData,
+ /*input_inst_data=*/dynamicInputInstData,
+ /*target_sg_layout=*/dynamicTargetSgLayout,
+ /*target_sg_data=*/dynamicTargetSgData,
+ /*target_inst_data=*/dynamicTargetInstData,
+ /*static_input_sg_layout=*/staticInputSgLayout,
+ /*static_input_sg_data=*/staticInputSgData,
+ /*static_input_inst_data=*/staticInputInstData,
+ /*static_target_sg_layout=*/staticTargetSgLayout,
+ /*static_target_sg_data=*/staticTargetSgData,
+ /*static_target_inst_data=*/staticTargetInstData);
+}
+
+DiagnosedSilenceableFailure
+transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto targetValues = state.getPayloadValues(getTarget());
+ if (!llvm::hasSingleElement(targetValues))
+ return emitDefiniteFailure()
+ << "requires exactly one target value handle (got "
+ << llvm::range_size(targetValues) << ")";
+ auto value = *targetValues.begin();
+
+ // Construct layout attributes.
+ xegpu::LayoutAttr inputLayoutAttr = nullptr;
+ auto status = getLayoutAttrFromOperands(
+ getContext(), state, (*this), getMixedInputSgLayout(),
+ getMixedInputSgData(), getMixedInputInstData(), inputLayoutAttr);
+ if (!status.succeeded())
+ return status;
+
+ xegpu::LayoutAttr targetLayoutAttr = nullptr;
+ status = getLayoutAttrFromOperands(
+ getContext(), state, (*this), getMixedTargetSgLayout(),
+ getMixedTargetSgData(), getMixedTargetInstData(), targetLayoutAttr);
+ if (!status.succeeded())
+ return status;
+
+ // Find first user op to define insertion point for layout conversion.
+ if (value.use_empty())
+ return emitSilenceableFailure(getLoc())
+ << "Value has no users to insert layout conversion.";
+ Operation *userOp = *value.getUsers().begin();
+
+ // Emit convert_layout op.
+ rewriter.setInsertionPoint(userOp);
+ auto convLayoutOp =
+ xegpu::ConvertLayoutOp::create(rewriter, value.getLoc(), value.getType(),
+ value, inputLayoutAttr, targetLayoutAttr);
+ // Replace load op result with the converted layout.
+ rewriter.replaceUsesWithIf(
+ value, convLayoutOp.getResult(), [&](OpOperand &use) {
+ return use.getOwner() != convLayoutOp.getOperation();
+ });
+
+ results.set(llvm::cast<OpResult>(getResult()), {convLayoutOp});
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::ConvertLayoutOp::getEffects(
+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getInputSgLayoutMutable(), effects);
+ onlyReadsHandle(getInputSgDataMutable(), effects);
+ onlyReadsHandle(getInputInstDataMutable(), effects);
+ onlyReadsHandle(getTargetSgLayoutMutable(), effects);
+ onlyReadsHandle(getTargetSgDataMutable(), effects);
+ onlyReadsHandle(getTargetInstDataMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
+}
+
namespace {
class XeGPUTransformDialectExtension
: public transform::TransformDialectExtension<
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
index 4dc5ea4..ab41fe4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
@@ -214,7 +214,7 @@ static Value generateLoads(ConversionPatternRewriter &rewriter,
newTensorDesc, ArrayRef<OpFoldResult>{loadOffsetX, loadOffsetY},
origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
- origLoadOp.getL3HintAttr());
+ origLoadOp.getL3HintAttr(), origLoadOp.getLayoutAttr());
// Set the layout for the loadOp.
auto layoutAttr = newTensorDesc.getType().getLayoutAttr();
xegpu::setDistributeLayoutAttr(loadOp->getOpResult(0), layoutAttr);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 4e1a539..dc9eb96 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -53,6 +53,8 @@ using namespace mlir::dataflow;
namespace {
+enum class LayoutKind { Lane, InstData };
+
//===----------------------------------------------------------------------===//
// LayoutInfo
//===----------------------------------------------------------------------===//
@@ -166,7 +168,8 @@ LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
llvm_unreachable("Join should not be triggered by layout propagation.");
}
-/// Construct a new layout with the transposed lane layout and lane data.
+/// Construct a new layout with the transposed inst_data or lane_layout,
+/// lane_data.
LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
if (!isAssigned())
return {};
@@ -186,12 +189,20 @@ LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
SmallVector<int32_t> laneData;
SmallVector<int32_t> instData;
for (int64_t idx : permutation) {
- laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
- laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
- instData.push_back(static_cast<int32_t>(getInstData()[idx]));
+ if (getLaneLayout().size()) {
+ laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
+ laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
+ }
+ if (getInstData().size())
+ instData.push_back(static_cast<int32_t>(getInstData()[idx]));
}
- return LayoutInfo(xegpu::LayoutAttr::get(storage.getContext(), instData,
- laneLayout, laneData));
+ xegpu::LayoutAttr layoutAttr;
+ if (getLaneLayout().size())
+ layoutAttr =
+ xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
+ if (getInstData().size())
+ layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
+ return LayoutInfo(layoutAttr);
}
//===----------------------------------------------------------------------===//
@@ -213,15 +224,14 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
unsigned rank,
- const xegpu::uArch::uArch *uArch,
- ArrayRef<int> instData) {
+ const xegpu::uArch::uArch *uArch) {
assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
if (rank == 1) {
return LayoutInfo(
- xegpu::LayoutAttr::get(ctx, instData, {uArch->getSubgroupSize()}, {1}));
+ xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1}));
}
- return LayoutInfo(xegpu::LayoutAttr::get(
- ctx, instData, {1, uArch->getSubgroupSize()}, {1, 1}));
+ return LayoutInfo(
+ xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1}));
}
static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
@@ -236,7 +246,6 @@ static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
/// Helper to get the default layout for a vector type.
static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
const xegpu::uArch::uArch *uArch,
- ArrayRef<int> instData,
unsigned packingSize,
bool isScattered = false) {
// Expecting a 1D or 2D vector.
@@ -247,16 +256,16 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
"Expected int or float element type.");
// If the rank is 1, then return default layout for 1D vector.
if (vectorTy.getRank() == 1)
- return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch, instData);
+ return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
// Packing factor is determined by the element type bitwidth.
unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
if (isScattered) {
- return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData,
+ return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
{uArch->getSubgroupSize(), 1},
{1, packingFactor}));
}
- return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData,
+ return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
{1, uArch->getSubgroupSize()},
{1, packingFactor}));
}
@@ -264,7 +273,6 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
/// Helper to get the default layout for a vector type.
static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
const xegpu::uArch::uArch *uArch,
- ArrayRef<int> instData,
unsigned packingSize,
bool isScattered = false) {
// Expecting a 1D or 2D vector.
@@ -275,18 +283,18 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
"Expected int or float element type.");
// If the rank is 1, then return default layout for 1D vector.
if (tdescTy.getRank() == 1)
- return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch, instData);
+ return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch);
// Packing factor is determined by the element type bitwidth.
unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
int subgroupSize = uArch->getSubgroupSize();
int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
if (isScattered) {
return LayoutInfo(xegpu::LayoutAttr::get(
- tdescTy.getContext(), instData, {subgroupSize, 1}, {1, packingFactor}));
+ tdescTy.getContext(), {subgroupSize, 1}, {1, packingFactor}));
}
return LayoutInfo(xegpu::LayoutAttr::get(
- tdescTy.getContext(), instData, {1, subgroupSize}, {1, packingFactor}));
+ tdescTy.getContext(), {1, subgroupSize}, {1, packingFactor}));
}
/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
@@ -298,7 +306,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
static LayoutInfo
getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
const xegpu::uArch::uArch *uArch,
- ArrayRef<int> instData, unsigned packingSize) {
+ unsigned packingSize) {
Type elementTy = vectorTy.getElementType();
assert(elementTy.isIntOrFloat() &&
"Expected int or float type in DPAS operands");
@@ -310,10 +318,10 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
{static_cast<int32_t>(packingSize / elementTy.getIntOrFloatBitWidth()),
1});
return LayoutInfo(
- xegpu::LayoutAttr::get(vectorTy.getContext(), instData, layout, data));
+ xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
}
// Otherwise, return the default layout for the vector type.
- return getDefaultSIMTLayoutInfo(vectorTy, uArch, instData, packingSize);
+ return getDefaultSIMTLayoutInfo(vectorTy, uArch, packingSize);
}
//===----------------------------------------------------------------------===//
@@ -328,6 +336,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
class LayoutInfoPropagation
: public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
private:
+ LayoutKind layoutKind;
void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
@@ -378,10 +387,14 @@ private:
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
+ bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
+
public:
LayoutInfoPropagation(DataFlowSolver &solver,
- SymbolTableCollection &symbolTable)
- : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
+ SymbolTableCollection &symbolTable,
+ LayoutKind layoutKind)
+ : SparseBackwardDataFlowAnalysis(solver, symbolTable),
+ layoutKind(layoutKind) {}
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
LogicalResult
@@ -464,43 +477,71 @@ LogicalResult LayoutInfoPropagation::visitOperation(
return success();
}
+bool LayoutInfoPropagation::hasParamsOfLayoutKind(
+ xegpu::DistributeLayoutAttr anchorLayout) {
+ if (anchorLayout == nullptr) {
+ return false;
+ }
+ if (layoutKind == LayoutKind::InstData) {
+ return !(anchorLayout.getEffectiveInstDataAsInt().empty());
+ } else if (layoutKind == LayoutKind::Lane) {
+ return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
+ anchorLayout.getEffectiveLaneDataAsInt().empty());
+ }
+ return false;
+}
+
void LayoutInfoPropagation::visitPrefetchNdOp(
xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- // Here we assign the default layout to the tensor descriptor operand of
- // prefetch.
- auto tdescTy = prefetch.getTensorDescType();
-
- auto uArch = getUArch(getChipStr(prefetch).value_or(""));
- const auto *uArchInstruction =
- dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
- uArch->getInstruction(
- xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
-
- auto blockWHC =
- uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
- if (!blockWHC)
- prefetch.emitWarning("No known block params found for the element type.");
- auto [bWidth, bHeight, bCount] = blockWHC.value();
- SmallVector<int> instData;
- int instWidth = xegpu::getLargestDivisor(
- static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth,
- bCount);
- if (instWidth == -1)
- prefetch.emitWarning(
- "No suitable instruction multiple found for the given shape.");
- if (tdescTy.getRank() == 1)
- instData = {instWidth};
- else {
- int instHeight = xegpu::getLargestDivisor(
- static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
- if (instHeight == -1)
+
+ LayoutInfo prefetchLayout;
+ xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr();
+ if (hasParamsOfLayoutKind(anchorLayout)) {
+ prefetchLayout = LayoutInfo(anchorLayout);
+ } else {
+ // Here we assign the default layout to the tensor descriptor operand of
+ // prefetch.
+ auto tdescTy = prefetch.getTensorDescType();
+
+ auto uArch = getUArch(getChipStr(prefetch).value_or(""));
+ const auto *uArchInstruction =
+ dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
+ uArch->getInstruction(
+ xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
+
+ auto blockWHC =
+ uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
+ if (!blockWHC)
+ prefetch.emitWarning("No known block params found for the element type.");
+ auto [bWidth, bHeight, bCount] = blockWHC.value();
+ SmallVector<int> instData;
+ int instWidth = xegpu::getLargestDivisor(
+ static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
+ if (instWidth == -1)
prefetch.emitWarning(
"No suitable instruction multiple found for the given shape.");
- instData = {instHeight, instWidth};
+ if (tdescTy.getRank() == 1)
+ instData = {instWidth};
+ else {
+ int instHeight = xegpu::getLargestDivisor(
+ static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
+ if (instHeight == -1)
+ prefetch.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ instData = {instHeight, instWidth};
+ }
+
+ if (layoutKind == LayoutKind::InstData)
+ prefetchLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
+ else
+ prefetchLayout = getDefaultSIMTLayoutInfo(
+ tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
+
+ prefetch.setLayoutAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get()));
}
- auto prefetchLayout = getDefaultSIMTLayoutInfo(
- tdescTy, uArch, instData, uArchInstruction->getPackedFormatBitSize());
// Propagate the layout to the source tensor descriptor.
propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
}
@@ -539,23 +580,39 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
// Only consider vector to vector broadcasts for now.
VectorType resultTy = broadcast.getResultVectorType();
VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
- if (!sourceTy) {
- broadcast.emitWarning("Expecting source type to be a vector type.");
+ // skip layout propagation for non-vector source operand.
+ if (!sourceTy)
return;
- }
- // Only consider nD -> nD broadcast.
+ // Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
if (sourceTy.getRank() != resultTy.getRank()) {
- broadcast.emitWarning("Expecting source and result to have same rank.");
+ auto sourceDims = sourceTy.getShape();
+ auto resultDims = resultTy.getShape();
+ SmallVector<int64_t> bcastDims;
+ auto dimDiff = resultTy.getRank() - sourceTy.getRank();
+ // adding the missing leading dims
+ for (int i = 0; i < dimDiff; i++)
+ bcastDims.push_back(i);
+
+ // for the rest dims in the resultTy, if sourceTy dim is 1, then it's
+ // broadcasted dim
+ for (size_t i = 0; i < sourceDims.size(); i++)
+ if ((sourceDims[i] == 1) && (resultDims[i + dimDiff] != 1))
+ bcastDims.push_back(i + dimDiff);
+
+ // create a slice layout for the source
+ xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
+ broadcast->getContext(),
+ cast<xegpu::DistributeLayoutAttr>(resultLayout.get()),
+ DenseI64ArrayAttr::get(broadcast->getContext(), bcastDims));
+
+ propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
return;
}
+
SetVector<int64_t> broadcastUnitDims = broadcast.computeBroadcastedUnitDims();
- if (broadcastUnitDims.size() != 1) {
- broadcast.emitWarning("Expecting source type to be nD vector only with "
- "one broadcasted dimension.");
- return;
- }
- // Propagate the result layout to the source operand.
+ resultLayout = cast<xegpu::DistributeLayoutAttr>(resultLayout.get())
+ .setUnitDimData(broadcastUnitDims);
propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
}
@@ -600,55 +657,97 @@ void LayoutInfoPropagation::visitUpdateNdOffsetOp(
void LayoutInfoPropagation::visitDpasOp(
xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- VectorType aTy = dpas.getLhsType();
- VectorType bTy = dpas.getRhsType();
-
- auto uArch = getUArch(getChipStr(dpas).value_or(""));
- const int subgroupSize = uArch->getSubgroupSize();
- const auto *uArchInstruction =
- dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
- xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
-
- const unsigned dataALen = aTy.getShape().front();
- auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
- const int maxALen =
- xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
- if (maxALen == -1)
- dpas.emitWarning(
- "No suitable instruction multiple found for the given shape.");
-
- const unsigned dataBLen = bTy.getShape().back();
- auto supportedBLen = uArchInstruction->getSupportedK(bTy.getElementType());
- const int maxBLen =
- xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
- if (maxBLen == -1)
- dpas.emitWarning(
- "No suitable instruction multiple found for the given shape.");
- SmallVector<int> instDataA = {maxALen, subgroupSize};
- SmallVector<int> instDataB = {subgroupSize, maxBLen};
-
- propagateIfChanged(operands[0],
- operands[0]->meet(getSIMTLayoutInfoForDPASOperand(
- aTy, 0, uArch, instDataA,
- uArchInstruction->getPackedFormatBitSizeA())));
- propagateIfChanged(operands[1],
- operands[1]->meet(getSIMTLayoutInfoForDPASOperand(
- bTy, 1, uArch, instDataB,
- uArchInstruction->getPackedFormatBitSizeB())));
- if (operands.size() > 2) {
- VectorType cTy = dpas.getAccType();
- const unsigned dataCLen = bTy.getShape().back();
- auto supportedCLen = uArchInstruction->getSupportedN(bTy.getElementType());
- const int maxCLen =
- xegpu::getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
- if (maxCLen == -1)
+
+ LayoutInfo dpasALayout;
+ LayoutInfo dpasBLayout;
+ LayoutInfo dpasCDLayout;
+
+ xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr();
+ if (hasParamsOfLayoutKind(anchorLayoutCD)) {
+ xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr();
+ xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr();
+ assert(hasParamsOfLayoutKind(anchorLayoutA) &&
+ "Expected anchor layout for DPAS A operand.");
+ assert(hasParamsOfLayoutKind(anchorLayoutB) &&
+ "Expected anchor layout for DPAS B operand.");
+ dpasALayout = LayoutInfo(anchorLayoutA);
+ dpasBLayout = LayoutInfo(anchorLayoutB);
+ dpasCDLayout = LayoutInfo(anchorLayoutCD);
+
+ } else {
+
+ VectorType aTy = dpas.getLhsType();
+ VectorType bTy = dpas.getRhsType();
+
+ auto uArch = getUArch(getChipStr(dpas).value_or(""));
+ const int subgroupSize = uArch->getSubgroupSize();
+ const auto *uArchInstruction =
+ dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
+ xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
+
+ const unsigned dataALen = aTy.getShape().front();
+ auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
+ const int maxALen =
+ xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
+ if (maxALen == -1)
+ dpas.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+
+ const unsigned dataBLen = bTy.getShape().back();
+ auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
+
+ const int maxBLen =
+ xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
+
+ if (maxBLen == -1)
dpas.emitWarning(
"No suitable instruction multiple found for the given shape.");
- SmallVector<int> instDataC = {maxALen, maxCLen};
- propagateIfChanged(operands[2],
- operands[2]->meet(getSIMTLayoutInfoForDPASOperand(
- cTy, 2, uArch, instDataC,
- uArchInstruction->getPackedFormatBitSizeB())));
+ SmallVector<int> instDataA = {maxALen, subgroupSize};
+ SmallVector<int> instDataB = {subgroupSize, maxBLen};
+
+ if (layoutKind == LayoutKind::InstData) {
+ dpasALayout =
+ LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
+ dpasBLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
+ } else {
+ dpasALayout = getSIMTLayoutInfoForDPASOperand(
+ aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
+ dpasBLayout = getSIMTLayoutInfoForDPASOperand(
+ bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
+ }
+
+ if (operands.size() > 2) {
+ VectorType cTy = dpas.getAccType();
+ if (layoutKind == LayoutKind::InstData) {
+ const unsigned dataCLen = bTy.getShape().back();
+ auto supportedCLen =
+ uArchInstruction->getSupportedN(bTy.getElementType());
+ const int maxCLen = xegpu::getLargestDivisor(
+ dataCLen, ArrayRef<unsigned>(supportedCLen));
+ if (maxCLen == -1)
+ dpas.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ SmallVector<int> instDataC = {maxALen, maxCLen};
+ dpasCDLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
+ } else
+ dpasCDLayout = getSIMTLayoutInfoForDPASOperand(
+ cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
+
+ dpas.setLayoutCdAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get()));
+ }
+ dpas.setLayoutAAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(dpasALayout.get()));
+ dpas.setLayoutBAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(dpasBLayout.get()));
+ }
+
+ propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
+ propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
+ if (operands.size() > 2) {
+ propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
}
}
@@ -657,37 +756,50 @@ void LayoutInfoPropagation::visitStoreNdOp(
xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- auto uArch = getUArch(getChipStr(store).value_or(""));
- const auto *uArchInstruction =
- dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
- uArch->getInstruction(
- xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
- VectorType dataTy = store.getValueType();
- auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
- store.getValueType().getElementType());
- if (!blockWHC)
- store.emitWarning("No known block params found for the element type.");
- auto [bWidth, bHeight, bCount] = blockWHC.value();
- SmallVector<int> instData;
- int instWidth = xegpu::getLargestDivisor(
- static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth,
- bCount);
- if (instWidth == -1)
- store.emitWarning(
- "No suitable instruction multiple found for the given shape.");
- if (dataTy.getRank() == 1)
- instData = {instWidth};
- else {
- int instHeight = xegpu::getLargestDivisor(
- static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
- if (instHeight == -1)
+ LayoutInfo storeLayout;
+ xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
+ if (hasParamsOfLayoutKind(anchorLayout)) {
+ storeLayout = LayoutInfo(anchorLayout);
+ } else {
+ auto uArch = getUArch(getChipStr(store).value_or(""));
+ const auto *uArchInstruction =
+ dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
+ uArch->getInstruction(
+ xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
+ VectorType dataTy = store.getValueType();
+ auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
+ store.getValueType().getElementType());
+ if (!blockWHC)
+ store.emitWarning("No known block params found for the element type.");
+ auto [bWidth, bHeight, bCount] = blockWHC.value();
+ SmallVector<int> instData;
+ int instWidth = xegpu::getLargestDivisor(
+ static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
+ if (instWidth == -1)
store.emitWarning(
"No suitable instruction multiple found for the given shape.");
- instData = {instHeight, instWidth};
+ if (dataTy.getRank() == 1)
+ instData = {instWidth};
+ else {
+ int instHeight = xegpu::getLargestDivisor(
+ static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
+ if (instHeight == -1)
+ store.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ instData = {instHeight, instWidth};
+ }
+
+ if (layoutKind == LayoutKind::InstData)
+ storeLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
+ else
+ storeLayout =
+ getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
+ uArchInstruction->getPackedFormatBitSize());
+ store.setLayoutAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
}
- LayoutInfo storeLayout =
- getDefaultSIMTLayoutInfo(store.getValueType(), uArch, instData,
- uArchInstruction->getPackedFormatBitSize());
+ // Propagate the layout to the value operand.
// Both operands should have the same layout
for (LayoutInfoLattice *operand : operands)
propagateIfChanged(operand, operand->meet(storeLayout));
@@ -698,21 +810,30 @@ void LayoutInfoPropagation::visitStoreNdOp(
void LayoutInfoPropagation::visitLoadNdOp(
xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- LayoutInfo valueLayout = results[0]->getValue();
- // Need the layout of the value to propagate to the tensor descriptor.
- if (!valueLayout.isAssigned())
- return;
- LayoutInfo tensorDescLayout = valueLayout;
- // LoadNdOp has the transpose effect. However, at the stage of this analysis
- // this effect is not expected and should be abstracted away. Emit a
- // warning.
- if (auto transpose = load.getTranspose()) {
- load.emitWarning("Transpose effect is not expected for LoadNdOp at "
- "LayoutInfoPropagation stage.");
- tensorDescLayout = valueLayout.transpose(transpose.value());
+
+ LayoutInfo loadLayout;
+ xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
+ if (hasParamsOfLayoutKind(anchorLayout)) {
+ loadLayout = LayoutInfo(anchorLayout);
+ } else {
+
+ LayoutInfo valueLayout = results[0]->getValue();
+ // Need the layout of the value to propagate to the tensor descriptor.
+ if (!valueLayout.isAssigned())
+ return;
+ loadLayout = valueLayout;
+ // LoadNdOp has the transpose effect. However, at the stage of this analysis
+ // this effect is not expected and should be abstracted away. Emit a
+ // warning.
+ if (auto transpose = load.getTranspose()) {
+ load.emitWarning("Transpose effect is not expected for LoadNdOp at "
+ "LayoutInfoPropagation stage.");
+ loadLayout = valueLayout.transpose(transpose.value());
+ }
+ load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
}
// Propagate the new layout to the tensor descriptor operand.
- propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
+ propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
}
/// For vector::TransposeOp, the layout of the result is transposed and
@@ -802,33 +923,48 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
void LayoutInfoPropagation::visitLoadGatherOp(
xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- // The layout is strictly determined by the payload type.
- auto payloadTy = dyn_cast<VectorType>(load.getValueType());
- if (!payloadTy) {
- load.emitWarning("Not propagating, non-vector payload supplied.");
- return;
- }
- auto uArch = getUArch(getChipStr(load).value_or(""));
- const int subgroupSize = uArch->getSubgroupSize();
- SmallVector<int> instData{subgroupSize};
- if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1)
- instData.push_back(chunkSize);
- else if (auto srcTdescTy =
- dyn_cast<xegpu::TensorDescType>(load.getSourceType())) {
- if (srcTdescTy.getChunkSizeAsInt() > 1)
+
+ LayoutInfo loadLayout;
+ LayoutInfo maskLayout;
+ xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
+ if (hasParamsOfLayoutKind(anchorLayout)) {
+ loadLayout = LayoutInfo(anchorLayout);
+ maskLayout = loadLayout;
+ } else {
+
+ // The layout is strictly determined by the payload type.
+ VectorType payloadTy = load.getValueType();
+ if (!payloadTy) {
+ load.emitWarning("Not propagating, non-vector payload supplied.");
+ return;
+ }
+ auto uArch = getUArch(getChipStr(load).value_or(""));
+ const int subgroupSize = uArch->getSubgroupSize();
+ SmallVector<int> instData{subgroupSize};
+ if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1)
instData.push_back(chunkSize);
- }
- LayoutInfo layout = getDefaultSIMTLayoutInfo(
- payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(),
- /*scattered*/ true);
+ else if (auto srcTdescTy =
+ dyn_cast<xegpu::TensorDescType>(load.getSourceType())) {
+ if (srcTdescTy.getChunkSizeAsInt() > 1)
+ instData.push_back(chunkSize);
+ }
+
+ if (layoutKind == LayoutKind::InstData)
+ loadLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
+ else
+ loadLayout = getDefaultSIMTLayoutInfo(
+ payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
+ /*scattered*/ true);
- // Mask operand should have 1D default layout.
- LayoutInfo maskLayout =
- getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
+ // Mask operand should have 1D default layout.
+ maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
+ load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
+ }
// Propagate the new layout to the tensor descriptor operand.
if (isa<xegpu::TensorDescType>(load.getSourceType()))
- propagateIfChanged(operands[0], operands[0]->meet(layout));
+ propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
// Propagate the new layout to the mask and optional offset operand.
propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
if (load.getOffsets())
@@ -856,45 +992,56 @@ void LayoutInfoPropagation::visitCreateDescOp(
void LayoutInfoPropagation::visitStoreScatterOp(
xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- // Currently, for 2D StoreScatterOp we expect that the height dimension of
- // the tensor descriptor is equal to the subgroup size. This is ensured by
- // the op verifier.
- auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
- if (!payloadTy) {
- storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
- return;
- }
- auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
- const int subgroupSize = uArch->getSubgroupSize();
-
- auto payloadShape = payloadTy.getShape();
- if (payloadShape.size() > 1)
- assert(
- payloadShape[0] == subgroupSize &&
- "Expected the first dimension of 2D tensor descriptor to be equal to "
- "subgroup size.");
-
- SmallVector<int> instData{subgroupSize};
- if (auto chunkSize = storeScatter.getChunkSize().value_or(0); chunkSize > 1)
- instData.push_back(chunkSize);
- else if (auto dstTdescTy =
- dyn_cast<xegpu::TensorDescType>(storeScatter.getDestType())) {
- if (dstTdescTy.getChunkSizeAsInt() > 1)
- instData.push_back(chunkSize);
- }
LayoutInfo payloadLayout;
-
- if (auto layout = storeScatter.getLayoutAttr()) {
- payloadLayout = LayoutInfo(layout);
+ LayoutInfo maskLayout;
+ xegpu::DistributeLayoutAttr anchorLayout = storeScatter.getLayoutAttr();
+ if (hasParamsOfLayoutKind(anchorLayout)) {
+ payloadLayout = LayoutInfo(anchorLayout);
+ maskLayout = payloadLayout;
} else {
- payloadLayout = getDefaultSIMTLayoutInfo(
- payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(),
- /*scattered=*/true);
- }
+ // Currently, for 2D StoreScatterOp we expect that the height dimension of
+ // the tensor descriptor is equal to the subgroup size. This is ensured by
+ // the op verifier.
+ VectorType payloadTy = storeScatter.getValueType();
+ if (!payloadTy) {
+ storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
+ return;
+ }
- LayoutInfo maskLayout =
- getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
+ auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
+ const int subgroupSize = uArch->getSubgroupSize();
+
+ if (layoutKind == LayoutKind::InstData) {
+ SmallVector<int> instData{subgroupSize};
+ if (auto chunkSize = storeScatter.getChunkSize().value_or(0);
+ chunkSize > 1)
+ instData.push_back(chunkSize);
+ else if (auto dstTdescTy = dyn_cast<xegpu::TensorDescType>(
+ storeScatter.getDestType())) {
+ if (dstTdescTy.getChunkSizeAsInt() > 1)
+ instData.push_back(chunkSize);
+ }
+ payloadLayout = LayoutInfo(
+ xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
+ } else {
+ auto payloadShape = payloadTy.getShape();
+ if (payloadShape.size() > 1)
+ assert(payloadShape[0] == subgroupSize &&
+ "Expected the first dimension of 2D tensor descriptor to be "
+ "equal to "
+ "subgroup size.");
+ payloadLayout = getDefaultSIMTLayoutInfo(
+ payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
+ /*scattered=*/true);
+ }
+
+ maskLayout =
+ getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
+
+ storeScatter.setLayoutAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(payloadLayout.get()));
+ }
// Propagate the payload operand layout
propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
// Propagate the destination (if tdesc) operand layout
@@ -916,10 +1063,10 @@ class RunLayoutInfoPropagation {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
- RunLayoutInfoPropagation(Operation *op) : target(op) {
+ RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) : target(op) {
SymbolTableCollection symbolTable;
loadBaselineAnalyses(solver);
- solver.load<LayoutInfoPropagation>(symbolTable);
+ solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
(void)solver.initializeAndRun(op);
}
@@ -1159,7 +1306,18 @@ struct XeGPUPropagateLayoutPass final
} // namespace
void XeGPUPropagateLayoutPass::runOnOperation() {
- auto &analysis = getAnalysis<RunLayoutInfoPropagation>();
+ LayoutKind layoutKind;
+ if (this->layoutKind == "lane") {
+ layoutKind = LayoutKind::Lane;
+ } else if (this->layoutKind == "inst") {
+ layoutKind = LayoutKind::InstData;
+ } else {
+ getOperation()->emitError("Unsupported layout kind option: " +
+ this->layoutKind);
+ signalPassFailure();
+ return;
+ }
+ RunLayoutInfoPropagation analysis(getOperation(), layoutKind);
// Print the analysis result and exit. (for debugging purposes)
if (printOnly) {
auto &os = llvm::outs();
@@ -1173,8 +1331,6 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
return {};
xegpu::DistributeLayoutAttr layoutAttr =
cast<xegpu::DistributeLayoutAttr>(layout.get());
- if (this->layoutKind == "lane")
- layoutAttr = layoutAttr.dropInstData();
if (layout.isSliceLayout())
return cast<xegpu::SliceAttr>(layoutAttr);
return cast<xegpu::LayoutAttr>(layoutAttr);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index bbd7733..ca81c3c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -99,7 +99,6 @@ getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
if (i < distributionStart)
continue;
-
// Check if the dimension can be distributed evenly.
if (dim % effectiveLaneLayout[i - distributionStart] != 0)
return failure();
@@ -174,6 +173,21 @@ static bool requireTranspose(const xegpu::LayoutAttr layout,
return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1;
}
+/// Given a vector type and its distributed vector type, return the list of
+/// dimensions that are distributed.
+static SmallVector<int64_t> getDistributedDims(VectorType originalType,
+ VectorType distributedType) {
+ assert(originalType.getRank() == distributedType.getRank() &&
+ "sequential and distributed vector types must have the same rank");
+ SmallVector<int64_t> distributedDims;
+ for (int64_t i = 0; i < originalType.getRank(); ++i) {
+ if (distributedType.getDimSize(i) != originalType.getDimSize(i)) {
+ distributedDims.push_back(i);
+ }
+ }
+ return distributedDims;
+}
+
/// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
/// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
/// contained within a WarpExecuteOnLane0Op.
@@ -926,8 +940,7 @@ static SmallVector<Value> computeDistributedCoordinatesForMatrixOp(
SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
getAsOpFoldResult(origOffsets));
- newCoods = llvm::to_vector(llvm::map_range(
- ofrVec, [&](OpFoldResult ofr) -> Value { return cast<Value>(ofr); }));
+ newCoods = llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
return newCoods;
}
@@ -990,9 +1003,8 @@ struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern {
SmallVector<Value> newOperands = llvm::map_to_vector(
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
- SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()};
- std::fill(newConstOffsets.begin(), newConstOffsets.end(),
- ShapedType::kDynamic);
+ SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
+ ShapedType::kDynamic);
DenseI64ArrayAttr newConstOffsetsAttr =
rewriter.getDenseI64ArrayAttr(newConstOffsets);
ValueRange currentOffsets =
@@ -1067,9 +1079,8 @@ struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern {
SmallVector<Value> newOperands = llvm::map_to_vector(
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
- SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()};
- std::fill(newConstOffsets.begin(), newConstOffsets.end(),
- ShapedType::kDynamic);
+ SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(),
+ ShapedType::kDynamic);
DenseI64ArrayAttr newConstOffsetsAttr =
rewriter.getDenseI64ArrayAttr(newConstOffsets);
ValueRange currentOffsets =
@@ -1412,6 +1423,166 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
}
};
+/// This pattern distributes the `vector.broadcast` operation across lanes in a
+/// warp. The pattern supports three use cases:
+///
+/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input
+/// vector
+/// must have a slice layout of the result. If the distributed source and
+/// target vector types are identical, this lowers to a no-op; otherwise, it
+/// remains a broadcast but operates on distributed vectors.
+///
+/// 2) Broadcast a same-rank vector with identical layouts for source and
+/// target:
+/// The source vector must have unit dimensions, and lane_data must be unit
+/// size for those unit dims. This always lowers to a no-op.
+///
+/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast from
+/// scalar to distributed result type.
+///
+/// Example 1 (lowering to a broadcast with distributed types):
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
+/// %0 = "some_def"() {layout_result_0 =
+/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+/// dims = [0]> } : () -> (vector<32xf32>)
+/// %2 = vector.broadcast %0 {layout_result_0 =
+/// #xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>}
+/// : vector<32xf32> to vector<8x32xf32>
+/// gpu.yield %1 : vector<8x32xf32>
+/// }
+/// ```
+/// is lowered to:
+/// ```
+/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
+/// %0 = "some_def"() {layout_result_0 =
+/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+/// dims = [0]> } : () -> (vector<32xf32>)
+/// gpu.yield %0 : vector<32xf32>
+/// }
+/// %2 = vector.broadcast %r#0 : vector<1xf32> to vector<8x1xf32>
+///
+/// Example 2 (no-op):
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x32xf32>) {
+/// %0 = "some_def"() {layout_result_0 =
+/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+/// dims = [1]> } : () -> (vector<8xf32>)
+/// %1 = vector.shape_cast %0
+/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
+/// 1]>}: vector<8xf32> to vector<8x1xf32>
+/// %2 = vector.broadcast %1
+/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
+/// 1]>}: vector<8x1xf32> to vector<8x32xf32>
+/// gpu.yield %1 : vector<8x32xf32>
+/// }
+/// ```
+/// is lowered to:
+/// ```
+/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
+/// %0 = "some_def"() {layout_result_0 =
+/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+/// dims = [1]> } : () -> (vector<8xf32>)
+/// %1 = vector.shape_cast %0
+/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
+/// 1]>}: vector<8xf32> to vector<8x1xf32>
+/// gpu.yield %1 : vector<8x1xf32>
+/// }
+/// // The broadcast is implicit through layout transformation (no-op)
+/// "some_use"(%r#0)
+/// ```
+struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *yieldOperand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
+ if (!yieldOperand)
+ return failure();
+ auto broadcastOp =
+ cast<vector::BroadcastOp>(yieldOperand->get().getDefiningOp());
+ unsigned operandIdx = yieldOperand->getOperandNumber();
+
+ VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+ VectorType destType =
+ dyn_cast<VectorType>(broadcastOp.getResult().getType());
+
+ xegpu::DistributeLayoutAttr sourceLayout =
+ xegpu::getDistributeLayoutAttr(broadcastOp->getOpOperand(0));
+ xegpu::DistributeLayoutAttr resultLayout =
+ xegpu::getDistributeLayoutAttr(broadcastOp.getResult());
+
+ FailureOr<VectorType> sourceDistType;
+ Type sourceElemOrDistType;
+ if (sourceType) {
+
+ // Case 1 and 2: source is a vector type.
+ int64_t rankDiff = destType.getRank() - sourceType.getRank();
+ if (rankDiff > 0) {
+ // Case 1: source is lower-rank than result.
+ bool isSliceOf = sourceLayout.isSliceOf(resultLayout);
+ if (!isSliceOf)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "Broadcast input layout must be a slice of result layout.");
+ }
+ // case 2: source and result have same rank
+ if (rankDiff == 0) {
+ SetVector<int64_t> broadcastUnitDims =
+ broadcastOp.computeBroadcastedUnitDims();
+ resultLayout = resultLayout.setUnitDimData(broadcastUnitDims);
+ bool isEqualTo = sourceLayout.isEqualTo(resultLayout);
+ if (!isEqualTo)
+ return rewriter.notifyMatchFailure(
+ warpOp, "For same-rank broadcast, source must be identical to "
+ "adjusted result layouts with unit dims.");
+ sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
+ }
+
+ sourceDistType =
+ getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
+ if (failed(sourceDistType)) {
+ return rewriter.notifyMatchFailure(
+ warpOp, "Failed to distribute the source vector type.");
+ }
+ sourceElemOrDistType = sourceDistType.value();
+
+ } else {
+ // Case 3: source is a scalar type.
+ if (sourceLayout) {
+ return rewriter.notifyMatchFailure(
+ warpOp, "Broadcast from scalar must not have a layout attribute.");
+ }
+ sourceElemOrDistType = broadcastOp.getSourceType();
+ }
+ FailureOr<VectorType> destDistType =
+ getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
+ if (failed(destDistType)) {
+ return rewriter.notifyMatchFailure(
+ warpOp, "Failed to distribute the dest vector type.");
+ }
+
+ SmallVector<size_t> newRetIndices;
+ auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {broadcastOp.getSource()}, sourceElemOrDistType,
+ newRetIndices);
+
+ Value distributedSource = newWarpOp.getResult(newRetIndices[0]);
+
+ Value newBroadcast = distributedSource;
+
+ if (sourceElemOrDistType != destDistType.value()) {
+ rewriter.setInsertionPointAfter(newWarpOp);
+ newBroadcast =
+ vector::BroadcastOp::create(rewriter, newWarpOp.getLoc(),
+ destDistType.value(), distributedSource);
+ }
+
+ rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newBroadcast);
+ return success();
+ }
+};
+
/// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing
/// `gpu.warp_execute_on_lane_0` region.
struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
@@ -1472,6 +1643,226 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
}
};
+// Distribute a `vector.extract_strided_slice` op feeding into yield op of an
+// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
+// advanced cases where the distributed dimension is partially extracted and
+// currently not supported by the generic vector distribution patterns.
+struct VectorExtractStridedSliceDistribution
+ : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
+ if (!operand)
+ return failure();
+ auto extractOp =
+ cast<vector::ExtractStridedSliceOp>(operand->get().getDefiningOp());
+ unsigned operandIdx = operand->getOperandNumber();
+ auto distributedType =
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
+ // Find the distributed dimensions.
+ auto extractResultType = cast<VectorType>(operand->get().getType());
+ auto distributedDims =
+ getDistributedDims(extractResultType, distributedType);
+ // Collect updated source type, sizes and offsets. They may be adjusted
+ // later if the data is distributed to lanes (as opposed to being owned by
+ // all lanes uniformly).
+ VectorType updatedSourceType = extractOp.getSourceVectorType();
+ SmallVector<Attribute> updatedSizes = llvm::map_to_vector(
+ extractOp.getSizes(), [](Attribute attr) { return attr; });
+ SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
+ extractOp.getOffsets(), [](Attribute attr) { return attr; });
+ // If the result is distributed, it must be distributed in exactly one
+ // dimension. In this case, we adjust the sourceDistType, distributedSizes
+ // and distributedOffsets accordingly.
+ if (distributedDims.size() > 0) {
+ if (distributedDims.size() != 1)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Source can not be distributed in multiple dimensions.");
+ int64_t distributedDim = distributedDims[0];
+ int sourceDistrDimSize =
+ extractOp.getSourceVectorType().getShape()[distributedDim];
+ auto sourceLayout =
+ xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0));
+ if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
+ return rewriter.notifyMatchFailure(
+ warpOp, "the source of extract_strided_slice op lacks distribution "
+ "layout");
+ auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
+ // Because only single dimension distribution is supported, lane layout
+ // size at the distributed dim must be the subgroup size.
+ int subgroupSize = sourceLaneLayout[distributedDim];
+ // Check if the source size in the distributed dimension is a multiple of
+ // subgroup size.
+ if (sourceDistrDimSize % subgroupSize != 0)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "Source size along distributed dimension is not a multiple of "
+ "subgroup size.");
+ auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
+ // We expect lane data to be all ones in this case.
+ if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
+ return rewriter.notifyMatchFailure(
+ warpOp, "Expecting unit lane data in source layout");
+ // The offsets in the distributed dimention must be a multiple of subgroup
+ // size.
+ int64_t distrDimOffset =
+ cast<IntegerAttr>(extractOp.getOffsets()[distributedDim]).getInt();
+ if (distrDimOffset % subgroupSize != 0)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Offset along distributed dimension "
+ "is not a multiple of subgroup size.");
+ updatedSourceType = getDistVecTypeBasedOnLaneLayout(
+ sourceLayout, extractOp.getSourceVectorType())
+ .value();
+ // Update the distributed sizes to match the distributed type.
+ updatedSizes[distributedDim] = rewriter.getI64IntegerAttr(
+ distributedType.getDimSize(distributedDim));
+ // Update the distributed offsets to match round robin distribution (i.e.
+ // each lane owns data at `subgroupSize` stride given unit lane data).
+ updatedOffsets[distributedDim] =
+ rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
+ }
+ // Do the distribution by yielding the source of the extract op from
+ // the warp op and creating a new extract op outside the warp op.
+ SmallVector<size_t> newRetIndices;
+ auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType},
+ newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ Value source = newWarpOp.getResult(newRetIndices[0]);
+ // Create a new extract op outside the warp op.
+ Value newExtractOp = vector::ExtractStridedSliceOp::create(
+ rewriter, extractOp.getLoc(), distributedType, source,
+ ArrayAttr::get(rewriter.getContext(), updatedOffsets),
+ ArrayAttr::get(rewriter.getContext(), updatedSizes),
+ extractOp.getStrides());
+ rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newExtractOp);
+ return success();
+ }
+};
+
+/// Distribute a `vector.insert_strided_slice` op feeding into yield op of an
+/// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers
+/// advanced cases where the distributed dimension is partially inserted and
+/// currently not supported by the generic vector distribution patterns.
+struct VectorInsertStridedSliceDistribution
+ : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
+ if (!operand)
+ return failure();
+ unsigned int operandNumber = operand->getOperandNumber();
+ auto insertOp =
+ operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
+ auto distributedType =
+ cast<VectorType>(warpOp.getResult(operandNumber).getType());
+ // Find the distributed dimensions of the dest vector.
+ auto insertResultType = cast<VectorType>(operand->get().getType());
+ auto destDistributedDims =
+ getDistributedDims(insertResultType, distributedType);
+ // Collect updated offsets, source type and dest type. They may be adjusted
+ // later if the data is distributed to lanes (as opposed to being owned by
+ // all lanes uniformly).
+ SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
+ insertOp.getOffsets(), [](Attribute attr) { return attr; });
+ VectorType updatedSourceType = insertOp.getSourceVectorType();
+ VectorType updatedDestType = insertOp.getDestVectorType();
+ if (destDistributedDims.size() > 0) {
+ // Only single dimension distribution is supported.
+ if (destDistributedDims.size() != 1)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "Expecting source to be distributed in a single dimension.");
+ int64_t destDistributedDim = destDistributedDims[0];
+
+ VectorType srcType = insertOp.getSourceVectorType();
+ VectorType destType = insertOp.getDestVectorType();
+ // Currently we require that both source (kD) and dest (nD) vectors are
+ // distributed. This requires that distributedDim (d) is contained in the
+ // last k dims of the dest vector (d >= n - k).
+ int64_t sourceDistributedDim =
+ destDistributedDim - (destType.getRank() - srcType.getRank());
+ if (sourceDistributedDim < 0)
+ return rewriter.notifyMatchFailure(
+ insertOp,
+ "distributed dimension must be in the last k (i.e. source "
+ "rank) dims of dest vector");
+ int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
+ // Obtain the source and dest layouts.
+ auto destLayout =
+ xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(1));
+ auto sourceLayout =
+ xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(0));
+ if (!destLayout || !sourceLayout ||
+ destLayout.getEffectiveLaneLayoutAsInt().empty() ||
+ sourceLayout.getEffectiveLaneLayoutAsInt().empty())
+ return rewriter.notifyMatchFailure(
+ warpOp, "the source or dest of insert_strided_slice op lacks "
+ "distribution layout");
+ // Because only single dimension distribution is supported, lane layout
+ // size at the distributed dim must be the subgroup size.
+ int subgroupSize =
+ destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim];
+ // We require that source and dest lane data are all ones to ensure
+ // uniform round robin distribution.
+ auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
+ auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
+ if (!llvm::all_of(destLaneData, [](int64_t v) { return v == 1; }) ||
+ !llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
+ return rewriter.notifyMatchFailure(
+ warpOp, "Expecting unit lane data in source and dest layouts");
+ // Source distributed dim size must be multiples of subgroup size.
+ if (srcDistrDimSize % subgroupSize != 0)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Distributed dimension size in source is not a multiple of "
+ "subgroup size.");
+ // Offsets in the distributed dimension must be multiples of subgroup
+ // size.
+ int64_t destDistrDimOffset =
+ cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt();
+ if (destDistrDimOffset % subgroupSize != 0)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "Offset along distributed dimension in dest is not a multiple of "
+ "subgroup size.");
+ // Update the source and dest types based on their layouts.
+ updatedSourceType = getDistVecTypeBasedOnLaneLayout(
+ sourceLayout, insertOp.getSourceVectorType())
+ .value();
+ updatedDestType = getDistVecTypeBasedOnLaneLayout(
+ destLayout, insertOp.getDestVectorType())
+ .value();
+ // Update the distributed offsets to match round robin distribution (i.e.
+ // each lane owns data at `subgroupSize` stride given unit lane data).
+ updatedOffsets[destDistributedDim] =
+ rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
+ }
+ // Do the distribution by yielding the source and dest of the insert op
+ // from the warp op and creating a new insert op outside the warp op.
+ SmallVector<size_t> newRetIndices;
+ auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
+ {updatedSourceType, updatedDestType}, newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+
+ Value valueToStore = newWarpOp.getResult(newRetIndices[0]);
+ Value dest = newWarpOp.getResult(newRetIndices[1]);
+ // Create a new insert op outside the warp op.
+ Value newInsertOp = vector::InsertStridedSliceOp::create(
+ rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest,
+ ArrayAttr::get(rewriter.getContext(), updatedOffsets),
+ insertOp.getStrides());
+ rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
+ newInsertOp);
+ return success();
+ }
+};
+
/// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
/// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
/// outside of the warp op.
@@ -1629,9 +2020,13 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
MemrefExtractAlignedPointerAsIndexDistribution>(
patterns.getContext(),
/*pattern benefit=*/regularPatternBenefit);
- patterns.add<VectorShapeCastDistribution>(
- patterns.getContext(),
- /*pattern benefit=*/highPatternBenefit);
+ // For following patterns, we need to override the regular vector distribution
+ // patterns. Therefore, assign higher benefit.
+ patterns
+ .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
+ VectorInsertStridedSliceDistribution, VectorBroadcastDistribution>(
+ patterns.getContext(),
+ /*pattern benefit=*/highPatternBenefit);
}
void xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index c3bf960..af63f09 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -238,6 +238,9 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
if (!targetShape)
return failure();
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ if (layout)
+ layout = layout.dropInstData();
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
@@ -255,7 +258,7 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value {
xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
op.getL1HintAttr(), op.getL2HintAttr(),
- op.getL3HintAttr());
+ op.getL3HintAttr(), layout);
// return dummy Value to satisfy function's signature
return nullptr;
};
@@ -282,6 +285,9 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
if (!targetShape)
return failure();
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ if (layout)
+ layout = layout.dropInstData();
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
@@ -306,7 +312,7 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
return xegpu::LoadNdOp::create(
rewriter, loc, newValueTy, convertedTdescs[0], offsets,
op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
- op.getL2HintAttr(), op.getL3HintAttr());
+ op.getL2HintAttr(), op.getL3HintAttr(), layout);
};
newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
*targetShape, createLoad, loc, rewriter);
@@ -331,6 +337,9 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
if (!targetShape)
return failure();
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ if (layout)
+ layout = layout.dropInstData();
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
@@ -354,7 +363,7 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
convertedTdescs[0], offsets,
op.getL1HintAttr(), op.getL2HintAttr(),
- op.getL3HintAttr());
+ op.getL3HintAttr(), layout);
// return dummy Value to satisfy function's signature
return nullptr;
};
@@ -678,7 +687,7 @@ struct UnrollLoadGatherOpWithOffset
pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
}
- auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(op.getLayoutAttr());
+ auto layout = op.getLayoutAttr();
if (layout)
layout = layout.dropInstData();
@@ -778,7 +787,7 @@ struct UnrollStoreScatterOpWithOffsets
SmallVector<Value> convertedValues =
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
- auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(op.getLayoutAttr());
+ auto layout = op.getLayoutAttr();
if (layout)
layout = layout.dropInstData();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0a9ef0a..be82cda 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -86,8 +86,16 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
if (origOffsets.empty())
return failure();
+ // if op is xegpu::CreateNdDescOp, call op.getDescLayoutAttr()
+ xegpu::DistributeLayoutAttr layout;
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> ||
+ std::is_same_v<OpType, xegpu::StoreMatrixOp>) {
+ layout = op.getLayoutAttr();
+ } else {
+ layout = op.getDescLayoutAttr();
+ }
+
// not applicable to ops without workgroup layout attributes
- xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
if (!layout || !layout.isForWorkgroup())
return failure();
@@ -190,7 +198,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
xegpu::TensorDescType tdescTy = op.getType();
ArrayRef<int64_t> wgShape = tdescTy.getShape();
Type elemTy = tdescTy.getElementType();
- xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
auto newTdescTy =
xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
@@ -309,6 +317,9 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
if (failed(genOffsetsList(rewriter, op, offsetsList)))
return failure();
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ if (layout)
+ layout = layout.dropSgLayoutAndData();
SmallVector<Value> newOps;
for (auto [tdesc, offsets] :
llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
@@ -318,7 +329,7 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
auto newOp = xegpu::LoadNdOp::create(
rewriter, op.getLoc(), newResTy, tdesc, offsets,
/*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
- op.getL2HintAttr(), op.getL3HintAttr());
+ op.getL2HintAttr(), op.getL3HintAttr(), layout);
newOps.push_back(newOp);
}
rewriter.replaceOpWithMultiple(op, {newOps});
@@ -339,11 +350,14 @@ struct WgToSgStoreNdOpWithOffset
if (failed(genOffsetsList(rewriter, op, offsetsList)))
return failure();
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ if (layout)
+ layout = layout.dropSgLayoutAndData();
for (auto [v, tdesc, offsets] :
llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
op.getL1HintAttr(), op.getL2HintAttr(),
- op.getL3HintAttr());
+ op.getL3HintAttr(), layout);
}
rewriter.eraseOp(op);
@@ -363,11 +377,14 @@ struct WgToSgPrefetchNdOpWithOffset
if (failed(genOffsetsList(rewriter, op, offsetsList)))
return failure();
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ if (layout)
+ layout = layout.dropSgLayoutAndData();
for (auto [tdesc, offsets] :
llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
op.getL1HintAttr(), op.getL2HintAttr(),
- op.getL3HintAttr());
+ op.getL3HintAttr(), layout);
}
rewriter.eraseOp(op);
@@ -489,10 +506,8 @@ struct WgToSgVectorBroadcastOp
for (auto operand : adaptor.getOperands().front()) {
auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
newResultType, operand);
- if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
- !layout.getEffectiveInstDataAsInt().empty())
- xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
- layout.dropSgLayoutAndData());
+ xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
+ layout.dropSgLayoutAndData());
newBroadcastOps.push_back(newBroadcast.getResult());
}
@@ -738,12 +753,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
Location loc = op.getLoc();
auto eltType = vecType.getElementType();
- auto setLayoutIfNeeded = [&](Value val) {
- if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
- !layout.getEffectiveInstDataAsInt().empty()) {
- xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
- layout.dropSgLayoutAndData());
- }
+ auto setLayout = [&](Value val) {
+ xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
+ layout.dropSgLayoutAndData());
};
if (vecAttr.isSplat()) {
@@ -751,14 +763,14 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
Attribute singleVal = vecAttr.getSplatValue<Attribute>();
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
- setLayoutIfNeeded(cstOp->getResult(0));
+ setLayout(cstOp->getResult(0));
rewriter.replaceOp(op, cstOp);
return success();
} else if (sgShape == wgShape) { // if the entire vector is shared by all
// subgroups, don't distribute
auto newConstOp =
arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
- setLayoutIfNeeded(newConstOp->getResult(0));
+ setLayout(newConstOp->getResult(0));
rewriter.replaceOp(op, newConstOp);
return success();
} else {
@@ -860,9 +872,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
rewriter, loc, baseConstVec.getType(), mulOffset);
auto finalConst =
arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
- setLayoutIfNeeded(baseConstVec);
- setLayoutIfNeeded(bcastOffset);
- setLayoutIfNeeded(finalConst);
+ setLayout(baseConstVec);
+ setLayout(bcastOffset);
+ setLayout(finalConst);
newConstOps.push_back(finalConst);
}
rewriter.replaceOpWithMultiple(op, {newConstOps});
@@ -889,8 +901,8 @@ struct WgToSgLoadGatherOpWithOffset
return failure();
ArrayRef<int64_t> wgShape = resultType.getShape();
- xegpu::LayoutAttr layout = dyn_cast_if_present<xegpu::LayoutAttr>(
- xegpu::getDistributeLayoutAttr(op.getResult()));
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
if (!layout || !layout.isForWorkgroup())
return failure();
@@ -913,10 +925,12 @@ struct WgToSgLoadGatherOpWithOffset
VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
for (auto [offsets, mask] :
llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
+ auto newLayout = layout.dropSgLayoutAndData();
auto newLoadOp = xegpu::LoadGatherOp::create(
rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
- layout.dropSgLayoutAndData());
+ newLayout);
+ xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0), newLayout);
newLoadOps.push_back(newLoadOp);
}
rewriter.replaceOpWithMultiple(op, {newLoadOps});
@@ -941,8 +955,8 @@ struct WgToSgStoreScatterOpWithOffset
if (!valueType)
return failure();
- xegpu::LayoutAttr layout = dyn_cast_if_present<xegpu::LayoutAttr>(
- xegpu::getDistributeLayoutAttr(op.getOperand(0)));
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getOperand(0));
if (!layout || !layout.isForWorkgroup())
return failure();
@@ -967,14 +981,11 @@ struct WgToSgStoreScatterOpWithOffset
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
layout.dropSgLayoutAndData());
// Update the layout attribute to drop sg_layout and sg_data.
- if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
- !layout.getEffectiveInstDataAsInt().empty()) {
- for (OpOperand &operand : store->getOpOperands()) {
- // Skip for operand one (memref)
- if (operand.getOperandNumber() == 1)
- continue;
- xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData());
- }
+ for (OpOperand &operand : store->getOpOperands()) {
+ // Skip for operand one (memref)
+ if (operand.getOperandNumber() == 1)
+ continue;
+ xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData());
}
}
rewriter.eraseOp(op);
@@ -1067,15 +1078,12 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
auto finalSteps =
arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
- if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
- !layout.getEffectiveInstDataAsInt().empty()) {
- xegpu::setDistributeLayoutAttr(steps->getResult(0),
- layout.dropSgLayoutAndData());
- xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0),
- layout.dropSgLayoutAndData());
- xegpu::setDistributeLayoutAttr(finalSteps->getResult(0),
- layout.dropSgLayoutAndData());
- }
+ xegpu::setDistributeLayoutAttr(steps->getResult(0),
+ layout.dropSgLayoutAndData());
+ xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0),
+ layout.dropSgLayoutAndData());
+ xegpu::setDistributeLayoutAttr(finalSteps->getResult(0),
+ layout.dropSgLayoutAndData());
newOps.push_back(finalSteps);
}
@@ -1143,10 +1151,8 @@ struct WgToSgVectorShapeCastOp
for (auto src : adaptor.getSource()) {
auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
newResultType, src);
- if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
- !layout.getEffectiveInstDataAsInt().empty())
- xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
- layout.dropSgLayoutAndData());
+ xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
+ layout.dropSgLayoutAndData());
newShapeCastOps.push_back(newShapeCast.getResult());
}
@@ -1207,10 +1213,8 @@ struct WgToSgMultiDimReductionOp
auto newOp = vector::MultiDimReductionOp::create(
rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc,
adaptor.getAcc()[0], op.getReductionDims());
- if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
- !layout.getEffectiveInstDataAsInt().empty())
- xegpu::setDistributeLayoutAttr(newOp->getResult(0),
- layout.dropSgLayoutAndData());
+ xegpu::setDistributeLayoutAttr(newOp->getResult(0),
+ layout.dropSgLayoutAndData());
newReductions.push_back(newOp.getResult());
}
@@ -1283,6 +1287,78 @@ struct WgToSgVectorTransposeOp
}
};
+// Distribute vector mask ops to work at subgroup level.
+template <typename MaskOpType>
+struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
+ using OpConversionPattern<MaskOpType>::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ MaskOpType op,
+ typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ Location loc = op.getLoc();
+ VectorType type = op.getResult().getType();
+ auto wgShape = type.getShape();
+
+ SmallVector<Value> wgMaskDimSizes;
+ if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
+ for (int64_t maskSize : op.getMaskDimSizes()) {
+ wgMaskDimSizes.push_back(
+ arith::ConstantIndexOp::create(rewriter, loc, maskSize));
+ }
+ } else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
+ wgMaskDimSizes = llvm::to_vector(op.getOperands());
+ }
+
+ Value sgId =
+ gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+ auto sgOffsets =
+ layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
+ if (failed(sgOffsets))
+ return failure();
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ VectorType resultType = VectorType::get(sgShape, type.getElementType());
+
+ // In each dimension, each subgroup computes its local mask size as:
+ // min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d])
+ SmallVector<Value> newCreateMaskOps;
+ for (auto offsetSet : *sgOffsets) {
+ SmallVector<Value> maskOperands;
+
+ for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
+ Value dimSizeVal =
+ arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
+ Value offset = offsetSet[i];
+ Value adjustedMaskSize =
+ arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value nonNegative =
+ arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
+ Value sgMaskSize =
+ arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
+ maskOperands.push_back(sgMaskSize);
+ }
+
+ auto newCreateMaskOp =
+ vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
+ xegpu::setDistributeLayoutAttr(newCreateMaskOp->getResult(0),
+ layout.dropSgLayoutAndData());
+ newCreateMaskOps.push_back(newCreateMaskOp.getResult());
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
+ return success();
+ }
+};
+
+using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
+using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
} // namespace
namespace mlir {
@@ -1297,7 +1373,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
- WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>(
+ WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
+ WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
patterns.getContext());
}
} // namespace xegpu
@@ -1427,7 +1504,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
vector::TransposeOp, vector::BroadcastOp,
- vector::MultiDimReductionOp>(
+ vector::MultiDimReductionOp,
+ vector::ConstantMaskOp, vector::CreateMaskOp>(
[=](Operation *op) -> bool {
// Check for either a SliceAttr or LayoutAttr on the result.
auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index de9e09d..9f126fe 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -12,7 +12,6 @@
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -140,7 +139,6 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
// for StoreMatrixOp, the layout is attached to the property of the op
if (auto storeOp = dyn_cast<xegpu::StoreMatrixOp>(defOp))
return storeOp.getLayoutAttr();
-
std::string layoutName = getLayoutName(result);
if (defOp->hasAttr(layoutName))
return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
@@ -308,7 +306,7 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
int64_t rankDiff = srcShapeRank - targetShapeRank;
std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff,
1);
- std::copy(shape.begin(), shape.end(), adjustedTargetShape.begin() + rankDiff);
+ llvm::copy(shape, adjustedTargetShape.begin() + rankDiff);
SmallVector<Value> result;
for (SmallVector<int64_t> offsets :
@@ -528,7 +526,7 @@ SmallVector<OpFoldResult> xegpu::addElementwise(OpBuilder &builder,
for (auto [l, r] : llvm::zip_equal(lhs, rhs)) {
auto lval = getValueOrCreateConstantIndexOp(builder, loc, l);
auto rval = getValueOrCreateConstantIndexOp(builder, loc, r);
- results.push_back(builder.createOrFold<index::AddOp>(loc, lval, rval));
+ results.push_back(builder.createOrFold<arith::AddIOp>(loc, lval, rval));
}
return results;
}
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
new file mode 100644
index 0000000..f3e38eb
--- /dev/null
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -0,0 +1,174 @@
+//===- APFloatWrappers.cpp - Software Implementation of FP Arithmetics --- ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file exposes the APFloat infrastructure to MLIR programs as a runtime
+// library. APFloat is a software implementation of floating point arithmetics.
+//
+// On the MLIR side, floating-point values must be bitcasted to 64-bit integers
+// before calling a runtime function. If a floating-point type has less than
+// 64 bits, it must be zero-extended to 64 bits after bitcasting it to an
+// integer.
+//
+// Runtime functions receive the floating-point operands of the arithmeic
+// operation in the form of 64-bit integers, along with the APFloat semantics
+// in the form of a 32-bit integer, which will be interpreted as an
+// APFloatBase::Semantics enum value.
+//
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APSInt.h"
+
+#ifdef _WIN32
+#ifndef MLIR_APFLOAT_WRAPPERS_EXPORT
+#ifdef mlir_apfloat_wrappers_EXPORTS
+// We are building this library
+#define MLIR_APFLOAT_WRAPPERS_EXPORT __declspec(dllexport)
+#else
+// We are using this library
+#define MLIR_APFLOAT_WRAPPERS_EXPORT __declspec(dllimport)
+#endif // mlir_apfloat_wrappers_EXPORTS
+#endif // MLIR_APFLOAT_WRAPPERS_EXPORT
+#else
+// Non-windows: use visibility attributes.
+#define MLIR_APFLOAT_WRAPPERS_EXPORT __attribute__((visibility("default")))
+#endif // _WIN32
+
+/// Binary operations without rounding mode.
+#define APFLOAT_BINARY_OP(OP) \
+ MLIR_APFLOAT_WRAPPERS_EXPORT int64_t _mlir_apfloat_##OP( \
+ int32_t semantics, uint64_t a, uint64_t b) { \
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \
+ static_cast<llvm::APFloatBase::Semantics>(semantics)); \
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \
+ llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \
+ llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \
+ lhs.OP(rhs); \
+ return lhs.bitcastToAPInt().getZExtValue(); \
+ }
+
+/// Binary operations with rounding mode.
+#define APFLOAT_BINARY_OP_ROUNDING_MODE(OP, ROUNDING_MODE) \
+ MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_##OP( \
+ int32_t semantics, uint64_t a, uint64_t b) { \
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \
+ static_cast<llvm::APFloatBase::Semantics>(semantics)); \
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \
+ llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \
+ llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \
+ lhs.OP(rhs, ROUNDING_MODE); \
+ return lhs.bitcastToAPInt().getZExtValue(); \
+ }
+
+extern "C" {
+
+#define BIN_OPS_WITH_ROUNDING(X) \
+ X(add, llvm::RoundingMode::NearestTiesToEven) \
+ X(subtract, llvm::RoundingMode::NearestTiesToEven) \
+ X(multiply, llvm::RoundingMode::NearestTiesToEven) \
+ X(divide, llvm::RoundingMode::NearestTiesToEven)
+
+BIN_OPS_WITH_ROUNDING(APFLOAT_BINARY_OP_ROUNDING_MODE)
+#undef BIN_OPS_WITH_ROUNDING
+#undef APFLOAT_BINARY_OP_ROUNDING_MODE
+
+APFLOAT_BINARY_OP(remainder)
+
+#undef APFLOAT_BINARY_OP
+
+MLIR_APFLOAT_WRAPPERS_EXPORT void printApFloat(int32_t semantics, uint64_t a) {
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+ static_cast<llvm::APFloatBase::Semantics>(semantics));
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+ llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+ double d = x.convertToDouble();
+ fprintf(stdout, "%lg", d);
+}
+
+MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t
+_mlir_apfloat_convert(int32_t inSemantics, int32_t outSemantics, uint64_t a) {
+ const llvm::fltSemantics &inSem = llvm::APFloatBase::EnumToSemantics(
+ static_cast<llvm::APFloatBase::Semantics>(inSemantics));
+ const llvm::fltSemantics &outSem = llvm::APFloatBase::EnumToSemantics(
+ static_cast<llvm::APFloatBase::Semantics>(outSemantics));
+ unsigned bitWidthIn = llvm::APFloatBase::semanticsSizeInBits(inSem);
+ llvm::APFloat val(inSem, llvm::APInt(bitWidthIn, a));
+ // TODO: Custom rounding modes are not supported yet.
+ bool losesInfo;
+ val.convert(outSem, llvm::RoundingMode::NearestTiesToEven, &losesInfo);
+ llvm::APInt result = val.bitcastToAPInt();
+ return result.getZExtValue();
+}
+
+MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_to_int(
+ int32_t semantics, int32_t resultWidth, bool isUnsigned, uint64_t a) {
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+ static_cast<llvm::APFloatBase::Semantics>(semantics));
+ unsigned inputWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+ llvm::APFloat val(sem, llvm::APInt(inputWidth, a));
+ llvm::APSInt result(resultWidth, isUnsigned);
+ bool isExact;
+ // TODO: Custom rounding modes are not supported yet.
+ val.convertToInteger(result, llvm::RoundingMode::NearestTiesToEven, &isExact);
+ // This function always returns uint64_t, regardless of the desired result
+ // width. It does not matter whether we zero-extend or sign-extend the APSInt
+ // to 64 bits because the generated IR in arith-to-apfloat will truncate the
+ // result to the desired result width.
+ return result.getZExtValue();
+}
+
+MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_from_int(
+ int32_t semantics, int32_t inputWidth, bool isUnsigned, uint64_t a) {
+ llvm::APInt val(inputWidth, a, /*isSigned=*/!isUnsigned);
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+ static_cast<llvm::APFloatBase::Semantics>(semantics));
+ llvm::APFloat result(sem);
+ // TODO: Custom rounding modes are not supported yet.
+ result.convertFromAPInt(val, /*IsSigned=*/!isUnsigned,
+ llvm::RoundingMode::NearestTiesToEven);
+ return result.bitcastToAPInt().getZExtValue();
+}
+
+MLIR_APFLOAT_WRAPPERS_EXPORT int8_t _mlir_apfloat_compare(int32_t semantics,
+ uint64_t a,
+ uint64_t b) {
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+ static_cast<llvm::APFloatBase::Semantics>(semantics));
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+ llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+ llvm::APFloat y(sem, llvm::APInt(bitWidth, b));
+ return static_cast<int8_t>(x.compare(y));
+}
+
+MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_neg(int32_t semantics, uint64_t a) {
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+ static_cast<llvm::APFloatBase::Semantics>(semantics));
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+ llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+ x.changeSign();
+ return x.bitcastToAPInt().getZExtValue();
+}
+
+/// Min/max operations.
+#define APFLOAT_MIN_MAX_OP(OP) \
+ MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_##OP( \
+ int32_t semantics, uint64_t a, uint64_t b) { \
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \
+ static_cast<llvm::APFloatBase::Semantics>(semantics)); \
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \
+ llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \
+ llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \
+ llvm::APFloat result = llvm::OP(lhs, rhs); \
+ return result.bitcastToAPInt().getZExtValue(); \
+ }
+
+APFLOAT_MIN_MAX_OP(minimum)
+APFLOAT_MIN_MAX_OP(maximum)
+APFLOAT_MIN_MAX_OP(minnum)
+APFLOAT_MIN_MAX_OP(maxnum)
+
+#undef APFLOAT_MIN_MAX_OP
+}
diff --git a/mlir/lib/ExecutionEngine/ArmRunnerUtils.cpp b/mlir/lib/ExecutionEngine/ArmRunnerUtils.cpp
index 9868ffa..9b1c39e 100644
--- a/mlir/lib/ExecutionEngine/ArmRunnerUtils.cpp
+++ b/mlir/lib/ExecutionEngine/ArmRunnerUtils.cpp
@@ -49,7 +49,7 @@ extern "C" {
/// The recommended strategy is to call `setArmVectorLength` only from functions
/// that do not access SVE registers, either by themselves or by inlining other
/// functions.
-static void setArmVectorLength(std::string_view helper_name, int option,
+static void setArmVectorLength(std::string_view helperName, int option,
uint32_t bits) {
#if defined(__linux__) && defined(__aarch64__)
if (bits < 128 || bits > 2048 || !llvm::isPowerOf2_32(bits)) {
@@ -63,7 +63,7 @@ static void setArmVectorLength(std::string_view helper_name, int option,
abort();
}
#else
- std::cerr << "[error] " << helper_name << " is unsupported" << std::endl;
+ std::cerr << "[error] " << helperName << " is unsupported" << std::endl;
abort();
#endif
}
diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt
index fdeb4dac..a615352 100644
--- a/mlir/lib/ExecutionEngine/CMakeLists.txt
+++ b/mlir/lib/ExecutionEngine/CMakeLists.txt
@@ -2,6 +2,7 @@
# is a big dependency which most don't need.
set(LLVM_OPTIONAL_SOURCES
+ APFloatWrappers.cpp
ArmRunnerUtils.cpp
ArmSMEStubs.cpp
AsyncRuntime.cpp
@@ -167,6 +168,26 @@ if(LLVM_ENABLE_PIC)
set_property(TARGET mlir_float16_utils PROPERTY CXX_STANDARD 17)
target_compile_definitions(mlir_float16_utils PRIVATE mlir_float16_utils_EXPORTS)
+ if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
+ # TODO: This support library is only used on Linux builds until we figure
+ # out how to hide LLVM symbols in a way that works for all platforms.
+ add_mlir_library(mlir_apfloat_wrappers
+ SHARED
+ APFloatWrappers.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+ )
+ set_target_properties(
+ mlir_apfloat_wrappers
+ PROPERTIES CXX_STANDARD 17
+ CXX_VISIBILITY_PRESET hidden
+ VISIBILITY_INLINES_HIDDEN ON
+ )
+ target_compile_definitions(mlir_apfloat_wrappers PRIVATE mlir_apfloat_wrappers_EXPORTS)
+ # Hide LLVM symbols to avoid ODR violations.
+ target_link_options(mlir_apfloat_wrappers PRIVATE "-Wl,--exclude-libs,ALL")
+ endif()
+
add_subdirectory(SparseTensor)
add_mlir_library(mlir_c_runner_utils
@@ -184,6 +205,11 @@ if(LLVM_ENABLE_PIC)
set_property(TARGET mlir_c_runner_utils PROPERTY CXX_STANDARD 17)
target_compile_definitions(mlir_c_runner_utils PRIVATE mlir_c_runner_utils_EXPORTS)
+ # Conditionally link apfloat wrappers only on Linux.
+ if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
+ target_link_libraries(mlir_c_runner_utils PUBLIC mlir_apfloat_wrappers)
+ endif()
+
add_mlir_library(mlir_runner_utils
SHARED
RunnerUtils.cpp
@@ -195,6 +221,11 @@ if(LLVM_ENABLE_PIC)
)
target_compile_definitions(mlir_runner_utils PRIVATE mlir_runner_utils_EXPORTS)
+ # Conditionally link apfloat wrappers only on Linux.
+ if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
+ target_link_libraries(mlir_runner_utils PUBLIC mlir_apfloat_wrappers)
+ endif()
+
add_mlir_library(mlir_async_runtime
SHARED
AsyncRuntime.cpp
@@ -323,7 +354,6 @@ if(LLVM_ENABLE_PIC)
endif()
string(STRIP AGENTS_STRING ${AGENTS_STRING})
string(REPLACE "\n" ";" AGENTS_LIST ${AGENTS_STRING})
- list(FILTER AGENTS_LIST EXCLUDE REGEX "gfx000")
if (AGENTS_LIST STREQUAL "")
message(SEND_ERROR "No non-CPU ROCm agents found on the system, and ROCM_TEST_CHIPSET is not defined")
else()
diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index 6cc2b7fd..f203363 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -57,7 +57,7 @@
thread_local static int32_t defaultDevice = 0;
/// Helper method that checks environment value for debugging.
-bool isDebugEnabled() {
+static bool isDebugEnabled() {
const char *kDebugEnvironmentVariable = "MLIR_CUDA_DEBUG";
static bool isEnabled = getenv(kDebugEnvironmentVariable) != nullptr;
return isEnabled;
@@ -71,7 +71,7 @@ bool isDebugEnabled() {
} while (0)
// Returns default CUdevice
-CUdevice getDefaultCuDevice() {
+static CUdevice getDefaultCuDevice() {
CUdevice device;
CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice));
return device;
diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
index 2255633..287c52a 100644
--- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
+++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
@@ -146,12 +146,10 @@ static void packFunctionArguments(Module *module) {
llvm::IRBuilder<> builder(ctx);
DenseSet<llvm::Function *> interfaceFunctions;
for (auto &func : module->getFunctionList()) {
- if (func.isDeclaration()) {
+ if (func.isDeclaration() || func.hasLocalLinkage())
continue;
- }
- if (interfaceFunctions.count(&func)) {
+ if (interfaceFunctions.count(&func))
continue;
- }
// Given a function `foo(<...>)`, define the interface function
// `mlir_foo(i8**)`.
diff --git a/mlir/lib/ExecutionEngine/VulkanRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/VulkanRuntimeWrappers.cpp
index ddea230..ff0dd54 100644
--- a/mlir/lib/ExecutionEngine/VulkanRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/VulkanRuntimeWrappers.cpp
@@ -156,7 +156,7 @@ mgpuLaunchKernel(void *vkKernel, size_t gridX, size_t gridY, size_t gridZ,
size_t /*blockX*/, size_t /*blockY*/, size_t /*blockZ*/,
size_t /*smem*/, void *vkRuntimeManager, void **params,
void ** /*extra*/, size_t paramsCount) {
- auto manager = static_cast<VulkanRuntimeManager *>(vkRuntimeManager);
+ auto *manager = static_cast<VulkanRuntimeManager *>(vkRuntimeManager);
// GpuToLLVMConversionPass with the kernelBarePtrCallConv and
// kernelIntersperseSizeCallConv options will set up the params array like:
@@ -180,7 +180,7 @@ mgpuLaunchKernel(void *vkKernel, size_t gridX, size_t gridY, size_t gridZ,
static_cast<uint32_t>(gridY),
static_cast<uint32_t>(gridZ)});
- auto function = static_cast<VulkanFunction *>(vkKernel);
+ auto *function = static_cast<VulkanFunction *>(vkKernel);
// Expected size should be in bytes.
manager->setShaderModule(
function->module->blobData(),
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 9b23dd6..fd846e4 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2032,7 +2032,7 @@ private:
};
template <typename Range>
-void printDimensionList(raw_ostream &stream, Range &&shape) {
+static void printDimensionList(raw_ostream &stream, Range &&shape) {
llvm::interleave(
shape, stream,
[&stream](const auto &dimSize) {
diff --git a/mlir/lib/IR/Remarks.cpp b/mlir/lib/IR/Remarks.cpp
index 031eae2..4cce16b 100644
--- a/mlir/lib/IR/Remarks.cpp
+++ b/mlir/lib/IR/Remarks.cpp
@@ -31,6 +31,11 @@ Remark::Arg::Arg(llvm::StringRef k, Type t) : key(k) {
os << t;
}
+Remark::Arg::Arg(llvm::StringRef k, Attribute a) : key(k), attr(a) {
+ llvm::raw_string_ostream os(val);
+ os << a;
+}
+
void Remark::insert(llvm::StringRef s) { args.emplace_back(s); }
void Remark::insert(Arg a) { args.push_back(std::move(a)); }
diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp
index e438631..199744d2 100644
--- a/mlir/lib/IR/TypeUtilities.cpp
+++ b/mlir/lib/IR/TypeUtilities.cpp
@@ -118,8 +118,7 @@ LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) {
/// have compatible dimensions. Dimensions are compatible if all non-dynamic
/// dims are equal. The element type does not matter.
LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
- auto shapedTypes = llvm::map_to_vector<8>(
- types, [](auto type) { return llvm::dyn_cast<ShapedType>(type); });
+ auto shapedTypes = llvm::map_to_vector<8>(types, llvm::DynCastTo<ShapedType>);
// Return failure if some, but not all are not shaped. Return early if none
// are shaped also.
if (llvm::none_of(shapedTypes, [](auto t) { return t; }))
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index 9f4f672..c31e0ae7 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -58,6 +58,22 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op,
return status;
}
+FailureOr<SmallVector<OpFoldResult>>
+mlir::reifyShapeOfResult(OpBuilder &b, Operation *op, int resultIndex) {
+ auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
+ if (!reifiableOp)
+ return failure();
+ return reifiableOp.reifyShapeOfResult(b, resultIndex);
+}
+
+FailureOr<OpFoldResult> mlir::reifyDimOfResult(OpBuilder &b, Operation *op,
+ int resultIndex, int dim) {
+ auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
+ if (!reifiableOp)
+ return failure();
+ return reifiableOp.reifyDimOfResult(b, resultIndex, dim);
+}
+
bool ShapeAdaptor::hasRank() const {
if (val.isNull())
return false;
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index a5bfde1..cfe808b 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -129,7 +129,7 @@ ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
assert(var.map.getNumDims() == 0 && "expected only symbols");
SmallVector<AffineExpr> symReplacements;
for (auto valueDim : var.mapOperands) {
- auto it = llvm::find(this->mapOperands, valueDim);
+ auto *it = llvm::find(this->mapOperands, valueDim);
if (it != this->mapOperands.end()) {
// There is already a symbol for this operand.
symReplacements.push_back(b.getAffineSymbolExpr(
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 521c7c6..75f8826 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -559,9 +559,9 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
return op->emitOpError() << "trying to schedule a pass on an operation not "
"marked as 'IsolatedFromAbove'";
}
- if (!pass->canScheduleOn(*op->getName().getRegisteredInfo())) {
- return op->emitOpError()
- << "trying to schedule a pass on an unsupported operation";
+ if (!pass->canScheduleOn(op)) {
+ return op->emitOpError() << "trying to schedule pass '" << pass->getName()
+ << "' on an unsupported operation";
}
// Initialize the pass state with a callback for the pass to dynamically
diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index e392a88..7bfe03d 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -27,7 +27,7 @@ struct Parser::TokenInfo {
}
// Known identifiers.
- static const char *const ID_Extract;
+ static const char *const idExtract;
llvm::StringRef text;
TokenKind kind = TokenKind::Eof;
@@ -35,7 +35,7 @@ struct Parser::TokenInfo {
VariantValue value;
};
-const char *const Parser::TokenInfo::ID_Extract = "extract";
+const char *const Parser::TokenInfo::idExtract = "extract";
class Parser::CodeTokenizer {
public:
@@ -452,13 +452,13 @@ bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken,
}
if (chainCallToken.kind != TokenKind::Ident ||
- chainCallToken.text != TokenInfo::ID_Extract) {
+ chainCallToken.text != TokenInfo::idExtract) {
error->addError(chainCallToken.range,
ErrorType::ParserMalformedChainedExpr);
return false;
}
- if (chainCallToken.text == TokenInfo::ID_Extract &&
+ if (chainCallToken.text == TokenInfo::idExtract &&
!parseChainedExpression(functionName))
return false;
}
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 5b49204..1e00ed6 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -175,9 +175,12 @@ public:
using Base::Base;
// Collect the reduce patterns defined by each dialect.
- void populateReductionPatterns(RewritePatternSet &pattern) const {
- for (const DialectReductionPatternInterface &interface : *this)
+ void populateReductionPatterns(RewritePatternSet &pattern,
+ Tester &tester) const {
+ for (const DialectReductionPatternInterface &interface : *this) {
interface.populateReductionPatterns(pattern);
+ interface.populateReductionPatternsWithTester(pattern, tester);
+ }
}
};
@@ -201,15 +204,21 @@ public:
private:
LogicalResult reduceOp(ModuleOp module, Region &region);
+ Tester tester;
FrozenRewritePatternSet reducerPatterns;
};
} // namespace
LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
+ tester.setTestScript(testerName);
+ tester.setTestScriptArgs(testerArgs);
+
RewritePatternSet patterns(context);
+
ReductionPatternInterfaceCollection reducePatternCollection(context);
- reducePatternCollection.populateReductionPatterns(patterns);
+ reducePatternCollection.populateReductionPatterns(patterns, tester);
+
reducerPatterns = std::move(patterns);
return success();
}
@@ -244,11 +253,10 @@ void ReductionTreePass::runOnOperation() {
}
LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
- Tester test(testerName, testerArgs);
switch (traversalModeId) {
case TraversalMode::SinglePath:
return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
- module, region, reducerPatterns, test);
+ module, region, reducerPatterns, tester);
default:
return module.emitError() << "unsupported traversal mode detected";
}
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index c857c38..4312100 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -56,6 +56,7 @@
#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
@@ -113,6 +114,7 @@ void mlir::registerAllExtensions(DialectRegistry &registry) {
transform::registerSMTExtension(registry);
transform::registerTuneExtension(registry);
vector::registerTransformDialectExtension(registry);
+ x86vector::registerTransformDialectExtension(registry);
xegpu::registerTransformDialectExtension(registry);
arm_neon::registerTransformDialectExtension(registry);
arm_sve::registerTransformDialectExtension(registry);
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 42843ea..159aa54 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -1099,12 +1099,12 @@ public:
MutableArrayRef<PDLValue> getResults() { return results; }
/// Return the type ranges allocated by this list.
- MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
+ MutableArrayRef<std::vector<Type>> getAllocatedTypeRanges() {
return allocatedTypeRanges;
}
/// Return the value ranges allocated by this list.
- MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
+ MutableArrayRef<std::vector<Value>> getAllocatedValueRanges() {
return allocatedValueRanges;
}
};
@@ -1112,19 +1112,20 @@ public:
/// This class provides support for executing a bytecode stream.
class ByteCodeExecutor {
public:
- ByteCodeExecutor(
- const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
- MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
- MutableArrayRef<TypeRange> typeRangeMemory,
- std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
- MutableArrayRef<ValueRange> valueRangeMemory,
- std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
- MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
- ArrayRef<ByteCodeField> code,
- ArrayRef<PatternBenefit> currentPatternBenefits,
- ArrayRef<PDLByteCodePattern> patterns,
- ArrayRef<PDLConstraintFunction> constraintFunctions,
- ArrayRef<PDLRewriteFunction> rewriteFunctions)
+ ByteCodeExecutor(const ByteCodeField *curCodeIt,
+ MutableArrayRef<const void *> memory,
+ MutableArrayRef<std::vector<Operation *>> opRangeMemory,
+ MutableArrayRef<TypeRange> typeRangeMemory,
+ std::vector<std::vector<Type>> &allocatedTypeRangeMemory,
+ MutableArrayRef<ValueRange> valueRangeMemory,
+ std::vector<std::vector<Value>> &allocatedValueRangeMemory,
+ MutableArrayRef<unsigned> loopIndex,
+ ArrayRef<const void *> uniquedMemory,
+ ArrayRef<ByteCodeField> code,
+ ArrayRef<PatternBenefit> currentPatternBenefits,
+ ArrayRef<PDLByteCodePattern> patterns,
+ ArrayRef<PDLConstraintFunction> constraintFunctions,
+ ArrayRef<PDLRewriteFunction> rewriteFunctions)
: curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
typeRangeMemory(typeRangeMemory),
allocatedTypeRangeMemory(allocatedTypeRangeMemory),
@@ -1367,13 +1368,9 @@ private:
if (range.empty()) {
rangeMemory[rangeIndex] = {};
} else {
- // Allocate a buffer for this type range.
- llvm::OwningArrayRef<T> storage(llvm::size(range));
- llvm::copy(range, storage.begin());
-
// Assign this to the range slot and use the range as the value for the
// memory index.
- allocatedRangeMemory.emplace_back(std::move(storage));
+ allocatedRangeMemory.emplace_back(range.begin(), range.end());
rangeMemory[rangeIndex] = allocatedRangeMemory.back();
}
memory[memIndex] = &rangeMemory[rangeIndex];
@@ -1397,11 +1394,11 @@ private:
/// The current execution memory.
MutableArrayRef<const void *> memory;
- MutableArrayRef<OwningOpRange> opRangeMemory;
+ MutableArrayRef<std::vector<Operation *>> opRangeMemory;
MutableArrayRef<TypeRange> typeRangeMemory;
- std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
+ std::vector<std::vector<Type>> &allocatedTypeRangeMemory;
MutableArrayRef<ValueRange> valueRangeMemory;
- std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
+ std::vector<std::vector<Value>> &allocatedValueRangeMemory;
/// The current loop indices.
MutableArrayRef<unsigned> loopIndex;
@@ -1907,10 +1904,10 @@ void ByteCodeExecutor::executeGetUsers() {
LDBG() << "Executing GetUsers:";
unsigned memIndex = read();
unsigned rangeIndex = read();
- OwningOpRange &range = opRangeMemory[rangeIndex];
+ std::vector<Operation *> &range = opRangeMemory[rangeIndex];
memory[memIndex] = &range;
- range = OwningOpRange();
+ range.clear();
if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
// Read the value.
Value value = read<Value>();
@@ -1918,9 +1915,7 @@ void ByteCodeExecutor::executeGetUsers() {
return;
LDBG() << " * Value: " << value;
- // Extract the users of a single value.
- range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
- llvm::copy(value.getUsers(), range.begin());
+ range.assign(value.user_begin(), value.user_end());
} else {
// Read a range of values.
ValueRange *values = read<ValueRange *>();
@@ -1929,12 +1924,8 @@ void ByteCodeExecutor::executeGetUsers() {
LDBG() << " * Values (" << values->size()
<< "): " << llvm::interleaved(*values);
- // Extract all the users of a range of values.
- SmallVector<Operation *> users;
for (Value value : *values)
- users.append(value.user_begin(), value.user_end());
- range = OwningOpRange(users.size());
- llvm::copy(users, range.begin());
+ range.insert(range.end(), value.user_begin(), value.user_end());
}
LDBG() << " * Result: " << range.size() << " operations";
@@ -2174,7 +2165,8 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter,
executeEraseOp(rewriter);
break;
case ExtractOp:
- executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
+ executeExtract<Operation *, std::vector<Operation *>,
+ PDLValue::Kind::Operation>();
break;
case ExtractType:
executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
diff --git a/mlir/lib/Rewrite/ByteCode.h b/mlir/lib/Rewrite/ByteCode.h
index 4aceac7..566c1cb 100644
--- a/mlir/lib/Rewrite/ByteCode.h
+++ b/mlir/lib/Rewrite/ByteCode.h
@@ -30,7 +30,6 @@ class PDLByteCode;
/// entries. ByteCodeAddr refers to size of indices into the bytecode.
using ByteCodeField = uint16_t;
using ByteCodeAddr = uint32_t;
-using OwningOpRange = llvm::OwningArrayRef<Operation *>;
//===----------------------------------------------------------------------===//
// PDLByteCodePattern
@@ -94,21 +93,21 @@ private:
/// the bytecode to store ranges of operations. These are always stored by
/// owning references, because at no point in the execution of the byte code
/// we get an indexed range (view) of operations.
- std::vector<OwningOpRange> opRangeMemory;
+ std::vector<std::vector<Operation *>> opRangeMemory;
/// A mutable block of memory used during the matching and rewriting phase of
/// the bytecode to store ranges of types.
std::vector<TypeRange> typeRangeMemory;
/// A set of type ranges that have been allocated by the byte code interpreter
/// to provide a guaranteed lifetime.
- std::vector<llvm::OwningArrayRef<Type>> allocatedTypeRangeMemory;
+ std::vector<std::vector<Type>> allocatedTypeRangeMemory;
/// A mutable block of memory used during the matching and rewriting phase of
/// the bytecode to store ranges of values.
std::vector<ValueRange> valueRangeMemory;
/// A set of value ranges that have been allocated by the byte code
/// interpreter to provide a guaranteed lifetime.
- std::vector<llvm::OwningArrayRef<Value>> allocatedValueRangeMemory;
+ std::vector<std::vector<Value>> allocatedValueRangeMemory;
/// The current index of ranges being iterated over for each level of nesting.
/// These are always maintained at 0 for the loops that are not active, so we
diff --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp
index b0ad3ee..77a6cec 100644
--- a/mlir/lib/TableGen/Interfaces.cpp
+++ b/mlir/lib/TableGen/Interfaces.cpp
@@ -208,3 +208,11 @@ bool OpInterface::classof(const Interface *interface) {
bool TypeInterface::classof(const Interface *interface) {
return interface->getDef().isSubClassOf("TypeInterface");
}
+
+//===----------------------------------------------------------------------===//
+// DialectInterface
+//===----------------------------------------------------------------------===//
+
+bool DialectInterface::classof(const Interface *interface) {
+ return interface->getDef().isSubClassOf("DialectInterface");
+}
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index 1a1a58a..ce09f5c 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -18,6 +18,7 @@
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/Path.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
@@ -771,15 +772,27 @@ int Pattern::getBenefit() const {
return initBenefit + dyn_cast<IntInit>(delta->getArg(0))->getValue();
}
-std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
+std::vector<Pattern::IdentifierLine>
+Pattern::getLocation(bool forSourceOutput) const {
std::vector<std::pair<StringRef, unsigned>> result;
result.reserve(def.getLoc().size());
for (auto loc : def.getLoc()) {
unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
assert(buf && "invalid source location");
- result.emplace_back(
- llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
- llvm::SrcMgr.getLineAndColumn(loc, buf).first);
+
+ StringRef bufferName =
+ llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier();
+ // If we're emitting a generated file, we'd like to have some indication of
+ // where our patterns came from. However, LLVM's build rules use absolute
+ // paths as arguments to TableGen, and naively echoing such paths makes the
+ // contents of the generated source file depend on the build location,
+ // making MLIR builds substantially less reproducable. As a compromise, we
+ // trim absolute paths back to only the filename component.
+ if (forSourceOutput && llvm::sys::path::is_absolute(bufferName))
+ bufferName = llvm::sys::path::filename(bufferName);
+
+ result.emplace_back(bufferName,
+ llvm::SrcMgr.getLineAndColumn(loc, buf).first);
}
return result;
}
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 1243511..15c23c6 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -70,6 +70,7 @@ static inline LogicalResult interleaveCommaWithError(const Container &c,
/// imply higher precedence.
static FailureOr<int> getOperatorPrecedence(Operation *operation) {
return llvm::TypeSwitch<Operation *, FailureOr<int>>(operation)
+ .Case<emitc::AddressOfOp>([&](auto op) { return 15; })
.Case<emitc::AddOp>([&](auto op) { return 12; })
.Case<emitc::ApplyOp>([&](auto op) { return 15; })
.Case<emitc::BitwiseAndOp>([&](auto op) { return 7; })
@@ -111,6 +112,8 @@ static FailureOr<int> getOperatorPrecedence(Operation *operation) {
.Default([](auto op) { return op->emitError("unsupported operation"); });
}
+static bool shouldBeInlined(Operation *op);
+
namespace {
/// Emitter that uses dialect specific emitters to emit C++ code.
struct CppEmitter {
@@ -173,8 +176,11 @@ struct CppEmitter {
/// Emits the operands of the operation. All operands are emitted in order.
LogicalResult emitOperands(Operation &op);
- /// Emits value as an operands of an operation
- LogicalResult emitOperand(Value value);
+ /// Emits value as an operand of some operation. Unless \p isInBrackets is
+ /// true, operands emitted as sub-expressions will be parenthesized if needed
+ /// in order to enforce correct evaluation based on precedence and
+ /// associativity.
+ LogicalResult emitOperand(Value value, bool isInBrackets = false);
/// Emit an expression as a C expression.
LogicalResult emitExpression(ExpressionOp expressionOp);
@@ -189,15 +195,6 @@ struct CppEmitter {
/// emitc::ForOp.
StringRef getOrCreateInductionVarName(Value val);
- // Returns the textual representation of a subscript operation.
- std::string getSubscriptName(emitc::SubscriptOp op);
-
- // Returns the textual representation of a member (of object) operation.
- std::string createMemberAccess(emitc::MemberOp op);
-
- // Returns the textual representation of a member of pointer operation.
- std::string createMemberAccess(emitc::MemberOfPtrOp op);
-
/// Return the existing or a new label of a Block.
StringRef getOrCreateName(Block &block);
@@ -259,25 +256,20 @@ struct CppEmitter {
return !fileId.empty() && file.getId() == fileId;
}
- /// Get expression currently being emitted.
- ExpressionOp getEmittedExpression() { return emittedExpression; }
+ /// Is expression currently being emitted.
+ bool isEmittingExpression() { return !emittedExpressionPrecedence.empty(); }
/// Determine whether given value is part of the expression potentially being
/// emitted.
bool isPartOfCurrentExpression(Value value) {
- if (!emittedExpression)
- return false;
Operation *def = value.getDefiningOp();
- if (!def)
- return false;
- return isPartOfCurrentExpression(def);
+ return def ? isPartOfCurrentExpression(def) : false;
}
/// Determine whether given operation is part of the expression potentially
/// being emitted.
bool isPartOfCurrentExpression(Operation *def) {
- auto operandExpression = dyn_cast<ExpressionOp>(def->getParentOp());
- return operandExpression && operandExpression == emittedExpression;
+ return isEmittingExpression() && shouldBeInlined(def);
};
// Resets the value counter to 0.
@@ -324,7 +316,6 @@ private:
unsigned int valueCount{0};
/// State of the current expression being emitted.
- ExpressionOp emittedExpression;
SmallVector<int> emittedExpressionPrecedence;
void pushExpressionPrecedence(int precedence) {
@@ -342,17 +333,28 @@ private:
/// Determine whether expression \p op should be emitted in a deferred way.
static bool hasDeferredEmission(Operation *op) {
- return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp, emitc::MemberOp,
+ return isa_and_nonnull<emitc::DereferenceOp, emitc::GetGlobalOp,
+ emitc::LiteralOp, emitc::MemberOp,
emitc::MemberOfPtrOp, emitc::SubscriptOp,
emitc::GetFieldOp>(op);
}
-/// Determine whether expression \p expressionOp should be emitted inline, i.e.
+/// Determine whether operation \p op should be emitted inline, i.e.
/// as part of its user. This function recommends inlining of any expressions
/// that can be inlined unless it is used by another expression, under the
/// assumption that any expression fusion/re-materialization was taken care of
/// by transformations run by the backend.
-static bool shouldBeInlined(ExpressionOp expressionOp) {
+static bool shouldBeInlined(Operation *op) {
+ // CExpression operations are inlined if and only if they reside within an
+ // ExpressionOp.
+ if (isa<CExpressionInterface>(op))
+ return isa<ExpressionOp>(op->getParentOp());
+
+ // Only other inlinable operation is ExpressionOp itself.
+ ExpressionOp expressionOp = dyn_cast<ExpressionOp>(op);
+ if (!expressionOp)
+ return false;
+
// Do not inline if expression is marked as such.
if (expressionOp.getDoNotInline())
return false;
@@ -402,6 +404,66 @@ static bool shouldBeInlined(ExpressionOp expressionOp) {
return false;
}
+static LogicalResult printOperation(CppEmitter &emitter,
+ emitc::DereferenceOp dereferenceOp) {
+ std::string out;
+ llvm::raw_string_ostream ss(out);
+ ss << "*" << emitter.getOrCreateName(dereferenceOp.getPointer());
+ emitter.cacheDeferredOpResult(dereferenceOp.getResult(), out);
+ return success();
+}
+
+static LogicalResult printOperation(CppEmitter &emitter,
+ emitc::GetFieldOp getFieldOp) {
+ emitter.cacheDeferredOpResult(getFieldOp.getResult(),
+ getFieldOp.getFieldName());
+ return success();
+}
+
+static LogicalResult printOperation(CppEmitter &emitter,
+ emitc::GetGlobalOp getGlobalOp) {
+ emitter.cacheDeferredOpResult(getGlobalOp.getResult(), getGlobalOp.getName());
+ return success();
+}
+
+static LogicalResult printOperation(CppEmitter &emitter,
+ emitc::LiteralOp literalOp) {
+ emitter.cacheDeferredOpResult(literalOp.getResult(), literalOp.getValue());
+ return success();
+}
+
+static LogicalResult printOperation(CppEmitter &emitter,
+ emitc::MemberOp memberOp) {
+ std::string out;
+ llvm::raw_string_ostream ss(out);
+ ss << emitter.getOrCreateName(memberOp.getOperand());
+ ss << "." << memberOp.getMember();
+ emitter.cacheDeferredOpResult(memberOp.getResult(), out);
+ return success();
+}
+
+static LogicalResult printOperation(CppEmitter &emitter,
+ emitc::MemberOfPtrOp memberOfPtrOp) {
+ std::string out;
+ llvm::raw_string_ostream ss(out);
+ ss << emitter.getOrCreateName(memberOfPtrOp.getOperand());
+ ss << "->" << memberOfPtrOp.getMember();
+ emitter.cacheDeferredOpResult(memberOfPtrOp.getResult(), out);
+ return success();
+}
+
+static LogicalResult printOperation(CppEmitter &emitter,
+ emitc::SubscriptOp subscriptOp) {
+ std::string out;
+ llvm::raw_string_ostream ss(out);
+ ss << emitter.getOrCreateName(subscriptOp.getValue());
+ for (auto index : subscriptOp.getIndices()) {
+ ss << "[" << emitter.getOrCreateName(index) << "]";
+ }
+ emitter.cacheDeferredOpResult(subscriptOp.getResult(), out);
+ return success();
+}
+
static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
Attribute value) {
OpResult result = operation->getResult(0);
@@ -435,6 +497,17 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
}
static LogicalResult printOperation(CppEmitter &emitter,
+ emitc::AddressOfOp addressOfOp) {
+ raw_ostream &os = emitter.ostream();
+ Operation &op = *addressOfOp.getOperation();
+
+ if (failed(emitter.emitAssignPrefix(op)))
+ return failure();
+ os << "&";
+ return emitter.emitOperand(addressOfOp.getReference());
+}
+
+static LogicalResult printOperation(CppEmitter &emitter,
emitc::ConstantOp constantOp) {
Operation *operation = constantOp.getOperation();
Attribute value = constantOp.getValue();
@@ -1336,32 +1409,6 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
labelInScopeCount.push(0);
}
-std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
- std::string out;
- llvm::raw_string_ostream ss(out);
- ss << getOrCreateName(op.getValue());
- for (auto index : op.getIndices()) {
- ss << "[" << getOrCreateName(index) << "]";
- }
- return out;
-}
-
-std::string CppEmitter::createMemberAccess(emitc::MemberOp op) {
- std::string out;
- llvm::raw_string_ostream ss(out);
- ss << getOrCreateName(op.getOperand());
- ss << "." << op.getMember();
- return out;
-}
-
-std::string CppEmitter::createMemberAccess(emitc::MemberOfPtrOp op) {
- std::string out;
- llvm::raw_string_ostream ss(out);
- ss << getOrCreateName(op.getOperand());
- ss << "->" << op.getMember();
- return out;
-}
-
void CppEmitter::cacheDeferredOpResult(Value value, StringRef str) {
if (!valueMapper.count(value))
valueMapper.insert(value, str.str());
@@ -1545,7 +1592,6 @@ LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) {
"Expected precedence stack to be empty");
Operation *rootOp = expressionOp.getRootOp();
- emittedExpression = expressionOp;
FailureOr<int> precedence = getOperatorPrecedence(rootOp);
if (failed(precedence))
return failure();
@@ -1557,12 +1603,11 @@ LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) {
popExpressionPrecedence();
assert(emittedExpressionPrecedence.empty() &&
"Expected precedence stack to be empty");
- emittedExpression = nullptr;
return success();
}
-LogicalResult CppEmitter::emitOperand(Value value) {
+LogicalResult CppEmitter::emitOperand(Value value, bool isInBrackets) {
if (isPartOfCurrentExpression(value)) {
Operation *def = value.getDefiningOp();
assert(def && "Expected operand to be defined by an operation");
@@ -1570,10 +1615,12 @@ LogicalResult CppEmitter::emitOperand(Value value) {
if (failed(precedence))
return failure();
- // Sub-expressions with equal or lower precedence need to be parenthesized,
- // as they might be evaluated in the wrong order depending on the shape of
- // the expression tree.
- bool encloseInParenthesis = precedence.value() <= getExpressionPrecedence();
+ // Unless already in brackets, sub-expressions with equal or lower
+ // precedence need to be parenthesized as they might be evaluated in the
+ // wrong order depending on the shape of the expression tree.
+ bool encloseInParenthesis =
+ !isInBrackets && precedence.value() <= getExpressionPrecedence();
+
if (encloseInParenthesis)
os << "(";
pushExpressionPrecedence(precedence.value());
@@ -1596,14 +1643,8 @@ LogicalResult CppEmitter::emitOperand(Value value) {
// If this operand is a block argument of an expression, emit instead the
// matching expression parameter.
Operation *argOp = arg.getParentBlock()->getParentOp();
- if (auto expressionOp = dyn_cast<ExpressionOp>(argOp)) {
- // This scenario is only expected when one of the operations within the
- // expression being emitted references one of the expression's block
- // arguments.
- assert(expressionOp == emittedExpression &&
- "Expected expression being emitted");
- value = expressionOp->getOperand(arg.getArgNumber());
- }
+ if (auto expressionOp = dyn_cast<ExpressionOp>(argOp))
+ return emitOperand(expressionOp->getOperand(arg.getArgNumber()));
}
os << getOrCreateName(value);
@@ -1612,15 +1653,9 @@ LogicalResult CppEmitter::emitOperand(Value value) {
LogicalResult CppEmitter::emitOperands(Operation &op) {
return interleaveCommaWithError(op.getOperands(), os, [&](Value operand) {
- // If an expression is being emitted, push lowest precedence as these
- // operands are either wrapped by parenthesis.
- if (getEmittedExpression())
- pushExpressionPrecedence(lowestPrecedence());
- if (failed(emitOperand(operand)))
- return failure();
- if (getEmittedExpression())
- popExpressionPrecedence();
- return success();
+ // Emit operand under guarantee that if it's part of an expression then it
+ // is being emitted within brackets.
+ return emitOperand(operand, /*isInBrackets=*/true);
});
}
@@ -1702,7 +1737,7 @@ LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) {
LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
// If op is being emitted as part of an expression, bail out.
- if (getEmittedExpression())
+ if (isEmittingExpression())
return success();
switch (op.getNumResults()) {
@@ -1753,49 +1788,27 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
.Case<cf::BranchOp, cf::CondBranchOp>(
[&](auto op) { return printOperation(*this, op); })
// EmitC ops.
- .Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp,
- emitc::BitwiseAndOp, emitc::BitwiseLeftShiftOp,
+ .Case<emitc::AddressOfOp, emitc::AddOp, emitc::ApplyOp,
+ emitc::AssignOp, emitc::BitwiseAndOp, emitc::BitwiseLeftShiftOp,
emitc::BitwiseNotOp, emitc::BitwiseOrOp,
emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp,
emitc::CallOpaqueOp, emitc::CastOp, emitc::ClassOp,
emitc::CmpOp, emitc::ConditionalOp, emitc::ConstantOp,
- emitc::DeclareFuncOp, emitc::DivOp, emitc::DoOp,
- emitc::ExpressionOp, emitc::FieldOp, emitc::FileOp,
- emitc::ForOp, emitc::FuncOp, emitc::GlobalOp, emitc::IfOp,
- emitc::IncludeOp, emitc::LoadOp, emitc::LogicalAndOp,
- emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
- emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SwitchOp,
- emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
- emitc::VerbatimOp>(
+ emitc::DeclareFuncOp, emitc::DereferenceOp, emitc::DivOp,
+ emitc::DoOp, emitc::ExpressionOp, emitc::FieldOp, emitc::FileOp,
+ emitc::ForOp, emitc::FuncOp, emitc::GetFieldOp,
+ emitc::GetGlobalOp, emitc::GlobalOp, emitc::IfOp,
+ emitc::IncludeOp, emitc::LiteralOp, emitc::LoadOp,
+ emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp,
+ emitc::MemberOfPtrOp, emitc::MemberOp, emitc::MulOp,
+ emitc::RemOp, emitc::ReturnOp, emitc::SubscriptOp, emitc::SubOp,
+ emitc::SwitchOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp,
+ emitc::VariableOp, emitc::VerbatimOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
.Case<func::CallOp, func::FuncOp, func::ReturnOp>(
[&](auto op) { return printOperation(*this, op); })
- .Case<emitc::GetGlobalOp>([&](auto op) {
- cacheDeferredOpResult(op.getResult(), op.getName());
- return success();
- })
- .Case<emitc::GetFieldOp>([&](auto op) {
- cacheDeferredOpResult(op.getResult(), op.getFieldName());
- return success();
- })
- .Case<emitc::LiteralOp>([&](auto op) {
- cacheDeferredOpResult(op.getResult(), op.getValue());
- return success();
- })
- .Case<emitc::MemberOp>([&](auto op) {
- cacheDeferredOpResult(op.getResult(), createMemberAccess(op));
- return success();
- })
- .Case<emitc::MemberOfPtrOp>([&](auto op) {
- cacheDeferredOpResult(op.getResult(), createMemberAccess(op));
- return success();
- })
- .Case<emitc::SubscriptOp>([&](auto op) {
- cacheDeferredOpResult(op.getResult(), getSubscriptName(op));
- return success();
- })
.Default([&](Operation *) {
return op.emitOpError("unable to find printer for op");
});
@@ -1806,7 +1819,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
if (hasDeferredEmission(&op))
return success();
- if (getEmittedExpression() ||
+ if (isEmittingExpression() ||
(isa<emitc::ExpressionOp>(op) &&
shouldBeInlined(cast<emitc::ExpressionOp>(op))))
return success();
diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
index 2dd0640..5be33c4 100644
--- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
@@ -30,6 +30,14 @@ void registerFromLLVMIRTranslation() {
llvm::cl::desc("Emit expensive warnings during LLVM IR import "
"(discouraged: testing only!)"),
llvm::cl::init(false));
+ static llvm::cl::opt<bool> convertDebugRecToIntrinsics(
+ "convert-debug-rec-to-intrinsics",
+ llvm::cl::desc("Change the input LLVM module to use old debug intrinsics "
+ "instead of records "
+ "via convertFromNewDbgValues, this happens "
+ "before importing the debug information"
+ "(discouraged: to be removed soon!)"),
+ llvm::cl::init(false));
static llvm::cl::opt<bool> dropDICompositeTypeElements(
"drop-di-composite-type-elements",
llvm::cl::desc(
@@ -69,8 +77,10 @@ void registerFromLLVMIRTranslation() {
if (llvm::verifyModule(*llvmModule, &llvm::errs()))
return nullptr;
- // Debug records are not currently supported in the LLVM IR translator.
- llvmModule->convertFromNewDbgValues();
+ // Now that the translation supports importing debug records directly,
+ // make it the default, but allow the user to override to old behavior.
+ if (convertDebugRecToIntrinsics)
+ llvmModule->convertFromNewDbgValues();
return translateLLVMIRToModule(
std::move(llvmModule), context, emitExpensiveWarnings,
diff --git a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
index d3216d9..d9bfe65 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
@@ -124,10 +124,10 @@ static LogicalResult embedBinaryImpl(StringRef moduleName,
}
IRBuilder<> builder(module.getContext());
- auto i32Ty = builder.getInt32Ty();
- auto i64Ty = builder.getInt64Ty();
- auto ptrTy = builder.getPtrTy(0);
- auto voidTy = builder.getVoidTy();
+ auto *i32Ty = builder.getInt32Ty();
+ auto *i64Ty = builder.getInt64Ty();
+ auto *ptrTy = builder.getPtrTy(0);
+ auto *voidTy = builder.getVoidTy();
// Embed the module as a global object.
auto *modulePtr = new GlobalVariable(
@@ -147,13 +147,12 @@ static LogicalResult embedBinaryImpl(StringRef moduleName,
"mgpuModuleLoadJIT", FunctionType::get(ptrTy, {ptrTy, i32Ty}, false));
Constant *optValue = ConstantInt::get(i32Ty, optLevel);
return builder.CreateCall(moduleLoadFn, {serializedObj, optValue});
- } else {
- FunctionCallee moduleLoadFn = module.getOrInsertFunction(
- "mgpuModuleLoad", FunctionType::get(ptrTy, {ptrTy, i64Ty}, false));
- Constant *binarySize =
- ConstantInt::get(i64Ty, serializedStr.size() + (addNull ? 1 : 0));
- return builder.CreateCall(moduleLoadFn, {serializedObj, binarySize});
}
+ FunctionCallee moduleLoadFn = module.getOrInsertFunction(
+ "mgpuModuleLoad", FunctionType::get(ptrTy, {ptrTy, i64Ty}, false));
+ Constant *binarySize =
+ ConstantInt::get(i64Ty, serializedStr.size() + (addNull ? 1 : 0));
+ return builder.CreateCall(moduleLoadFn, {serializedObj, binarySize});
}();
builder.CreateStore(moduleObj, modulePtr);
builder.CreateRetVoid();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index 44732d5..2d4a18c 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -80,8 +80,9 @@ static LogicalResult convertIntrinsicImpl(OpBuilder &odsBuilder,
/// Returns the list of LLVM IR metadata kinds that are convertible to MLIR LLVM
/// dialect attributes.
-static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) {
- static const SmallVector<unsigned> convertibleMetadata = {
+static SmallVector<unsigned>
+getSupportedMetadataImpl(llvm::LLVMContext &llvmContext) {
+ SmallVector<unsigned> convertibleMetadata = {
llvm::LLVMContext::MD_prof,
llvm::LLVMContext::MD_tbaa,
llvm::LLVMContext::MD_access_group,
@@ -91,10 +92,10 @@ static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) {
llvm::LLVMContext::MD_dereferenceable,
llvm::LLVMContext::MD_dereferenceable_or_null,
llvm::LLVMContext::MD_mmra,
- context.getMDKindID(vecTypeHintMDName),
- context.getMDKindID(workGroupSizeHintMDName),
- context.getMDKindID(reqdWorkGroupSizeMDName),
- context.getMDKindID(intelReqdSubGroupSizeMDName)};
+ llvmContext.getMDKindID(vecTypeHintMDName),
+ llvmContext.getMDKindID(workGroupSizeHintMDName),
+ llvmContext.getMDKindID(reqdWorkGroupSizeMDName),
+ llvmContext.getMDKindID(intelReqdSubGroupSizeMDName)};
return convertibleMetadata;
}
@@ -113,7 +114,7 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
return failure();
// Handle function entry count metadata.
- if (name->getString() == "function_entry_count") {
+ if (name->getString() == llvm::MDProfLabels::FunctionEntryCount) {
// TODO support function entry count metadata with GUID fields.
if (node->getNumOperands() != 2)
@@ -131,15 +132,28 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
<< "expected function_entry_count to be attached to a function";
}
- if (name->getString() != "branch_weights")
+ if (name->getString() != llvm::MDProfLabels::BranchWeights)
return failure();
+ // The branch_weights metadata must have at least 2 operands.
+ if (node->getNumOperands() < 2)
+ return failure();
+
+ ArrayRef<llvm::MDOperand> branchWeightOperands =
+ node->operands().drop_front();
+ if (auto *mdString = dyn_cast<llvm::MDString>(node->getOperand(1))) {
+ if (mdString->getString() != llvm::MDProfLabels::ExpectedBranchWeights)
+ return failure();
+ // The MLIR WeightedBranchOpInterface does not support the
+ // ExpectedBranchWeights field, so it is dropped.
+ branchWeightOperands = branchWeightOperands.drop_front();
+ }
// Handle branch weights metadata.
SmallVector<int32_t> branchWeights;
- branchWeights.reserve(node->getNumOperands() - 1);
- for (unsigned i = 1, e = node->getNumOperands(); i != e; ++i) {
+ branchWeights.reserve(branchWeightOperands.size());
+ for (const llvm::MDOperand &operand : branchWeightOperands) {
llvm::ConstantInt *branchWeight =
- llvm::mdconst::dyn_extract<llvm::ConstantInt>(node->getOperand(i));
+ llvm::mdconst::dyn_extract<llvm::ConstantInt>(operand);
if (!branchWeight)
return failure();
branchWeights.push_back(branchWeight->getZExtValue());
@@ -492,9 +506,9 @@ public:
/// Returns the list of LLVM IR metadata kinds that are convertible to MLIR
/// LLVM dialect attributes.
- ArrayRef<unsigned>
- getSupportedMetadata(llvm::LLVMContext &context) const final {
- return getSupportedMetadataImpl(context);
+ SmallVector<unsigned>
+ getSupportedMetadata(llvm::LLVMContext &llvmContext) const final {
+ return getSupportedMetadataImpl(llvmContext);
}
};
} // namespace
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index eaf1d20..b6ea4ba 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -222,14 +222,14 @@ static void convertLinkerOptionsOp(ArrayAttr options,
llvm::LLVMContext &context = llvmModule->getContext();
llvm::NamedMDNode *linkerMDNode =
llvmModule->getOrInsertNamedMetadata("llvm.linker.options");
- SmallVector<llvm::Metadata *> MDNodes;
- MDNodes.reserve(options.size());
+ SmallVector<llvm::Metadata *> mdNodes;
+ mdNodes.reserve(options.size());
for (auto s : options.getAsRange<StringAttr>()) {
- auto *MDNode = llvm::MDString::get(context, s.getValue());
- MDNodes.push_back(MDNode);
+ auto *mdNode = llvm::MDString::get(context, s.getValue());
+ mdNodes.push_back(mdNode);
}
- auto *listMDNode = llvm::MDTuple::get(context, MDNodes);
+ auto *listMDNode = llvm::MDTuple::get(context, mdNodes);
linkerMDNode->addOperand(listMDNode);
}
@@ -243,16 +243,16 @@ convertModuleFlagValue(StringRef key, ArrayAttr arrayAttr,
if (key == LLVMDialect::getModuleFlagKeyCGProfileName()) {
for (auto entry : arrayAttr.getAsRange<ModuleFlagCGProfileEntryAttr>()) {
- llvm::Metadata *fromMetadata =
- entry.getFrom()
- ? llvm::ValueAsMetadata::get(moduleTranslation.lookupFunction(
- entry.getFrom().getValue()))
- : nullptr;
- llvm::Metadata *toMetadata =
- entry.getTo()
- ? llvm::ValueAsMetadata::get(
- moduleTranslation.lookupFunction(entry.getTo().getValue()))
- : nullptr;
+ auto getFuncMetadata = [&](FlatSymbolRefAttr sym) -> llvm::Metadata * {
+ if (!sym)
+ return nullptr;
+ if (llvm::Function *fn =
+ moduleTranslation.lookupFunction(sym.getValue()))
+ return llvm::ValueAsMetadata::get(fn);
+ return nullptr;
+ };
+ llvm::Metadata *fromMetadata = getFuncMetadata(entry.getFrom());
+ llvm::Metadata *toMetadata = getFuncMetadata(entry.getTo());
llvm::Metadata *vals[] = {
fromMetadata, toMetadata,
@@ -439,7 +439,14 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::MemoryEffects::Location::InaccessibleMem,
convertModRefInfoToLLVM(memAttr.getInaccessibleMem())) |
llvm::MemoryEffects(llvm::MemoryEffects::Location::Other,
- convertModRefInfoToLLVM(memAttr.getOther()));
+ convertModRefInfoToLLVM(memAttr.getOther())) |
+ llvm::MemoryEffects(llvm::MemoryEffects::Location::ErrnoMem,
+ convertModRefInfoToLLVM(memAttr.getErrnoMem())) |
+ llvm::MemoryEffects(
+ llvm::MemoryEffects::Location::TargetMem0,
+ convertModRefInfoToLLVM(memAttr.getTargetMem0())) |
+ llvm::MemoryEffects(llvm::MemoryEffects::Location::TargetMem1,
+ convertModRefInfoToLLVM(memAttr.getTargetMem1()));
call->setMemoryEffects(memEffects);
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index cecff51..b7427a5 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -411,6 +411,41 @@ getTcgen05StIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num) {
llvm_unreachable("unhandled tcgen05.st lowering");
}
+static llvm::Intrinsic::ID getFenceSyncRestrictID(NVVM::MemOrderKind order) {
+ return order == NVVM::MemOrderKind::ACQUIRE
+ ? llvm::Intrinsic::
+ nvvm_fence_acquire_sync_restrict_space_cluster_scope_cluster
+ : llvm::Intrinsic::
+ nvvm_fence_release_sync_restrict_space_cta_scope_cluster;
+}
+
+static llvm::Intrinsic::ID
+getFenceProxyID(NVVM::ProxyKind kind, std::optional<NVVM::SharedSpace> space) {
+ switch (kind) {
+ case NVVM::ProxyKind::alias:
+ return llvm::Intrinsic::nvvm_fence_proxy_alias;
+ case NVVM::ProxyKind::async:
+ return llvm::Intrinsic::nvvm_fence_proxy_async;
+ case NVVM::ProxyKind::async_global:
+ return llvm::Intrinsic::nvvm_fence_proxy_async_global;
+ case NVVM::ProxyKind::async_shared:
+ return *space == NVVM::SharedSpace::shared_cta
+ ? llvm::Intrinsic::nvvm_fence_proxy_async_shared_cta
+ : llvm::Intrinsic::nvvm_fence_proxy_async_shared_cluster;
+ default:
+ llvm_unreachable("unsupported proxy kind");
+ }
+}
+
+static llvm::Intrinsic::ID
+getFenceProxySyncRestrictID(NVVM::MemOrderKind order) {
+ return order == NVVM::MemOrderKind::ACQUIRE
+ ? llvm::Intrinsic::
+ nvvm_fence_proxy_async_generic_acquire_sync_restrict_space_cluster_scope_cluster
+ : llvm::Intrinsic::
+ nvvm_fence_proxy_async_generic_release_sync_restrict_space_cta_scope_cluster;
+}
+
namespace {
/// Implementation of the dialect interface that converts operations belonging
/// to the NVVM dialect to LLVM IR.
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 8edec99..03d67a5 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -61,6 +61,8 @@ convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
return llvm::omp::OMP_SCHEDULE_Auto;
case omp::ClauseScheduleKind::Runtime:
return llvm::omp::OMP_SCHEDULE_Runtime;
+ case omp::ClauseScheduleKind::Distribute:
+ return llvm::omp::OMP_SCHEDULE_Distribute;
}
llvm_unreachable("unhandled schedule clause argument");
}
@@ -135,28 +137,31 @@ class LinearClauseProcessor {
private:
SmallVector<llvm::Value *> linearPreconditionVars;
SmallVector<llvm::Value *> linearLoopBodyTemps;
- SmallVector<llvm::AllocaInst *> linearOrigVars;
SmallVector<llvm::Value *> linearOrigVal;
SmallVector<llvm::Value *> linearSteps;
+ SmallVector<llvm::Type *> linearVarTypes;
llvm::BasicBlock *linearFinalizationBB;
llvm::BasicBlock *linearExitBB;
llvm::BasicBlock *linearLastIterExitBB;
public:
+ // Register type for the linear variables
+ void registerType(LLVM::ModuleTranslation &moduleTranslation,
+ mlir::Attribute &ty) {
+ linearVarTypes.push_back(moduleTranslation.convertType(
+ mlir::cast<mlir::TypeAttr>(ty).getValue()));
+ }
+
// Allocate space for linear variabes
void createLinearVar(llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
- mlir::Value &linearVar) {
- if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>(
- moduleTranslation.lookupValue(linearVar))) {
- linearPreconditionVars.push_back(builder.CreateAlloca(
- linearVarAlloca->getAllocatedType(), nullptr, ".linear_var"));
- llvm::Value *linearLoopBodyTemp = builder.CreateAlloca(
- linearVarAlloca->getAllocatedType(), nullptr, ".linear_result");
- linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar));
- linearLoopBodyTemps.push_back(linearLoopBodyTemp);
- linearOrigVars.push_back(linearVarAlloca);
- }
+ mlir::Value &linearVar, int idx) {
+ linearPreconditionVars.push_back(
+ builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_var"));
+ llvm::Value *linearLoopBodyTemp =
+ builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_result");
+ linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar));
+ linearLoopBodyTemps.push_back(linearLoopBodyTemp);
}
// Initialize linear step
@@ -166,20 +171,15 @@ public:
}
// Emit IR for initialization of linear variables
- llvm::OpenMPIRBuilder::InsertPointOrErrorTy
- initLinearVar(llvm::IRBuilderBase &builder,
- LLVM::ModuleTranslation &moduleTranslation,
- llvm::BasicBlock *loopPreHeader) {
+ void initLinearVar(llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation,
+ llvm::BasicBlock *loopPreHeader) {
builder.SetInsertPoint(loopPreHeader->getTerminator());
- for (size_t index = 0; index < linearOrigVars.size(); index++) {
- llvm::LoadInst *linearVarLoad = builder.CreateLoad(
- linearOrigVars[index]->getAllocatedType(), linearOrigVars[index]);
+ for (size_t index = 0; index < linearOrigVal.size(); index++) {
+ llvm::LoadInst *linearVarLoad =
+ builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]);
builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
}
- llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
- moduleTranslation.getOpenMPBuilder()->createBarrier(
- builder.saveIP(), llvm::omp::OMPD_barrier);
- return afterBarrierIP;
}
// Emit IR for updating Linear variables
@@ -188,20 +188,24 @@ public:
builder.SetInsertPoint(loopBody->getTerminator());
for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
// Emit increments for linear vars
- llvm::LoadInst *linearVarStart =
- builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
-
- linearPreconditionVars[index]);
+ llvm::LoadInst *linearVarStart = builder.CreateLoad(
+ linearVarTypes[index], linearPreconditionVars[index]);
auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]);
- auto addInst = builder.CreateAdd(linearVarStart, mulInst);
- builder.CreateStore(addInst, linearLoopBodyTemps[index]);
+ if (linearVarTypes[index]->isIntegerTy()) {
+ auto addInst = builder.CreateAdd(linearVarStart, mulInst);
+ builder.CreateStore(addInst, linearLoopBodyTemps[index]);
+ } else if (linearVarTypes[index]->isFloatingPointTy()) {
+ auto cvt = builder.CreateSIToFP(mulInst, linearVarTypes[index]);
+ auto addInst = builder.CreateFAdd(linearVarStart, cvt);
+ builder.CreateStore(addInst, linearLoopBodyTemps[index]);
+ }
}
}
// Linear variable finalization is conditional on the last logical iteration.
// Create BB splits to manage the same.
- void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
- llvm::BasicBlock *loopExit) {
+ void splitLinearFiniBB(llvm::IRBuilderBase &builder,
+ llvm::BasicBlock *loopExit) {
linearFinalizationBB = loopExit->splitBasicBlock(
loopExit->getTerminator(), "omp_loop.linear_finalization");
linearExitBB = linearFinalizationBB->splitBasicBlock(
@@ -225,11 +229,10 @@ public:
llvm::Type::getInt32Ty(builder.getContext()), 0));
// Store the linear variable values to original variables.
builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
- for (size_t index = 0; index < linearOrigVars.size(); index++) {
+ for (size_t index = 0; index < linearOrigVal.size(); index++) {
llvm::LoadInst *linearVarTemp =
- builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
- linearLoopBodyTemps[index]);
- builder.CreateStore(linearVarTemp, linearOrigVars[index]);
+ builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]);
+ builder.CreateStore(linearVarTemp, linearOrigVal[index]);
}
// Create conditional branch such that the linear variable
@@ -253,7 +256,8 @@ public:
users.push_back(user);
for (auto *user : users) {
if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
- if (userInst->getParent()->getName().str() == BBName)
+ if (userInst->getParent()->getName().str().find(BBName) !=
+ std::string::npos)
user->replaceUsesOfWith(linearOrigVal[varIndex],
linearLoopBodyTemps[varIndex]);
}
@@ -319,10 +323,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
if (op.getDevice())
result = todo("device");
};
- auto checkDistSchedule = [&todo](auto op, LogicalResult &result) {
- if (op.getDistScheduleChunkSize())
- result = todo("dist_schedule with chunk_size");
- };
auto checkHint = [](auto op, LogicalResult &) {
if (op.getHint())
op.emitWarning("hint clause discarded");
@@ -332,14 +332,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
op.getInReductionSyms())
result = todo("in_reduction");
};
- auto checkIsDevicePtr = [&todo](auto op, LogicalResult &result) {
- if (!op.getIsDevicePtrVars().empty())
- result = todo("is_device_ptr");
- };
- auto checkLinear = [&todo](auto op, LogicalResult &result) {
- if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
- result = todo("linear");
- };
auto checkNowait = [&todo](auto op, LogicalResult &result) {
if (op.getNowait())
result = todo("nowait");
@@ -387,7 +379,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
})
.Case([&](omp::DistributeOp op) {
checkAllocate(op, result);
- checkDistSchedule(op, result);
checkOrder(op, result);
})
.Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
@@ -423,7 +414,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
})
.Case([&](omp::WsloopOp op) {
checkAllocate(op, result);
- checkLinear(op, result);
checkOrder(op, result);
checkReduction(op, result);
})
@@ -431,10 +421,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkAllocate(op, result);
checkReduction(op, result);
})
- .Case([&](omp::SimdOp op) {
- checkLinear(op, result);
- checkReduction(op, result);
- })
+ .Case([&](omp::SimdOp op) { checkReduction(op, result); })
.Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); })
.Case<omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp>(
@@ -444,7 +431,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkBare(op, result);
checkDevice(op, result);
checkInReduction(op, result);
- checkIsDevicePtr(op, result);
})
.Default([](Operation &) {
// Assume all clauses for an operation can be translated unless they are
@@ -953,6 +939,9 @@ using OwningAtomicReductionGen =
std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
llvm::Value *)>;
+using OwningDataPtrPtrReductionGen =
+ std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
+ llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>;
} // namespace
/// Create an OpenMPIRBuilder-compatible reduction generator for the given
@@ -1017,6 +1006,35 @@ makeAtomicReductionGen(omp::DeclareReductionOp decl,
return atomicGen;
}
+/// Create an OpenMPIRBuilder-compatible `data_ptr_ptr` reduction generator for
+/// the given reduction declaration. The generator uses `builder` but ignores
+/// its insertion point. Returns null if there is no `data_ptr_ptr` region
+/// available in the reduction declaration.
+static OwningDataPtrPtrReductionGen
+makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation, bool isByRef) {
+ if (!isByRef)
+ return OwningDataPtrPtrReductionGen();
+
+ OwningDataPtrPtrReductionGen refDataPtrGen =
+ [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
+ llvm::Value *byRefVal, llvm::Value *&result) mutable
+ -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
+ moduleTranslation.mapValue(decl.getDataPtrPtrRegionArg(), byRefVal);
+ builder.restoreIP(insertPoint);
+ SmallVector<llvm::Value *> phis;
+ if (failed(inlineConvertOmpRegions(decl.getDataPtrPtrRegion(),
+ "omp.data_ptr_ptr.body", builder,
+ moduleTranslation, &phis)))
+ return llvm::createStringError(
+ "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`");
+ result = llvm::getSingleElement(phis);
+ return builder.saveIP();
+ };
+
+ return refDataPtrGen;
+}
+
/// Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult
convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -1170,6 +1188,7 @@ allocReductionVars(T loop, ArrayRef<BlockArgument> reductionArgs,
template <typename T>
static void
mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation,
+ llvm::IRBuilderBase &builder,
SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
DenseMap<Value, llvm::Value *> &reductionVariableMap,
unsigned i) {
@@ -1180,8 +1199,17 @@ mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation,
mlir::Value mlirSource = loop.getReductionVars()[i];
llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
- assert(llvmSource && "lookup reduction var");
- moduleTranslation.mapValue(reduction.getInitializerMoldArg(), llvmSource);
+ llvm::Value *origVal = llvmSource;
+ // If a non-pointer value is expected, load the value from the source pointer.
+ if (!isa<LLVM::LLVMPointerType>(
+ reduction.getInitializerMoldArg().getType()) &&
+ isa<LLVM::LLVMPointerType>(mlirSource.getType())) {
+ origVal =
+ builder.CreateLoad(moduleTranslation.convertType(
+ reduction.getInitializerMoldArg().getType()),
+ llvmSource, "omp_orig");
+ }
+ moduleTranslation.mapValue(reduction.getInitializerMoldArg(), origVal);
if (entry.getNumArguments() > 1) {
llvm::Value *allocation =
@@ -1254,7 +1282,7 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
SmallVector<llvm::Value *, 1> phis;
// map block argument to initializer region
- mapInitializationArgs(op, moduleTranslation, reductionDecls,
+ mapInitializationArgs(op, moduleTranslation, builder, reductionDecls,
reductionVariableMap, i);
// TODO In some cases (specially on the GPU), the init regions may
@@ -1310,8 +1338,10 @@ static void collectReductionInfo(
SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
SmallVectorImpl<OwningReductionGen> &owningReductionGens,
SmallVectorImpl<OwningAtomicReductionGen> &owningAtomicReductionGens,
+ SmallVector<OwningDataPtrPtrReductionGen> &owningDataPtrPtrReductionGens,
const ArrayRef<llvm::Value *> privateReductionVariables,
- SmallVectorImpl<llvm::OpenMPIRBuilder::ReductionInfo> &reductionInfos) {
+ SmallVectorImpl<llvm::OpenMPIRBuilder::ReductionInfo> &reductionInfos,
+ ArrayRef<bool> isByRef) {
unsigned numReductions = loop.getNumReductionVars();
for (unsigned i = 0; i < numReductions; ++i) {
@@ -1319,6 +1349,8 @@ static void collectReductionInfo(
makeReductionGen(reductionDecls[i], builder, moduleTranslation));
owningAtomicReductionGens.push_back(
makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation));
+ owningDataPtrPtrReductionGens.push_back(makeRefDataPtrGen(
+ reductionDecls[i], builder, moduleTranslation, isByRef[i]));
}
// Collect the reduction information.
@@ -1329,12 +1361,28 @@ static void collectReductionInfo(
atomicGen = owningAtomicReductionGens[i];
llvm::Value *variable =
moduleTranslation.lookupValue(loop.getReductionVars()[i]);
+ mlir::Type allocatedType;
+ reductionDecls[i].getAllocRegion().walk([&](mlir::Operation *op) {
+ if (auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) {
+ allocatedType = alloca.getElemType();
+ return mlir::WalkResult::interrupt();
+ }
+
+ return mlir::WalkResult::advance();
+ });
+
reductionInfos.push_back(
{moduleTranslation.convertType(reductionDecls[i].getType()), variable,
privateReductionVariables[i],
/*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar,
owningReductionGens[i],
- /*ReductionGenClang=*/nullptr, atomicGen});
+ /*ReductionGenClang=*/nullptr, atomicGen,
+ owningDataPtrPtrReductionGens[i],
+ allocatedType ? moduleTranslation.convertType(allocatedType) : nullptr,
+ reductionDecls[i].getByrefElementType()
+ ? moduleTranslation.convertType(
+ *reductionDecls[i].getByrefElementType())
+ : nullptr});
}
}
@@ -1392,7 +1440,8 @@ static LogicalResult createReductionsAndCleanup(
SmallVector<OwningReductionGen> owningReductionGens;
SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
- SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
+ SmallVector<OwningDataPtrPtrReductionGen> owningReductionGenRefDataPtrGens;
+ SmallVector<llvm::OpenMPIRBuilder::ReductionInfo, 2> reductionInfos;
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
@@ -1400,7 +1449,8 @@ static LogicalResult createReductionsAndCleanup(
// ReductionInfo only accepts references to the generators.
collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
owningReductionGens, owningAtomicReductionGens,
- privateReductionVariables, reductionInfos);
+ owningReductionGenRefDataPtrGens,
+ privateReductionVariables, reductionInfos, isByRef);
// The call to createReductions below expects the block to have a
// terminator. Create an unreachable instruction to serve as terminator
@@ -1907,7 +1957,7 @@ static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp) {
// If we are going to use distribute reduction then remove any debug uses of
// the reduction parameters in teamsOp. Otherwise they will be left without
// any mapped value in moduleTranslation and will eventually error out.
- for (auto use : debugUses)
+ for (auto *use : debugUses)
use->erase();
return true;
}
@@ -2484,6 +2534,19 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
}
+ omp::DistributeOp distributeOp = nullptr;
+ llvm::Value *distScheduleChunk = nullptr;
+ bool hasDistSchedule = false;
+ if (llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())) {
+ distributeOp = cast<omp::DistributeOp>(opInst.getParentOp());
+ hasDistSchedule = distributeOp.getDistScheduleStatic();
+ if (distributeOp.getDistScheduleChunkSize()) {
+ llvm::Value *chunkVar = moduleTranslation.lookupValue(
+ distributeOp.getDistScheduleChunkSize());
+ distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
+ }
+ }
+
PrivateVarsInfo privateVarsInfo(wsloopOp);
SmallVector<omp::DeclareReductionOp> reductionDecls;
@@ -2553,10 +2616,15 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
// Initialize linear variables and linear step
LinearClauseProcessor linearClauseProcessor;
+
if (!wsloopOp.getLinearVars().empty()) {
- for (mlir::Value linearVar : wsloopOp.getLinearVars())
+ auto linearVarTypes = wsloopOp.getLinearVarTypes().value();
+ for (mlir::Attribute linearVarType : linearVarTypes)
+ linearClauseProcessor.registerType(moduleTranslation, linearVarType);
+
+ for (auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars()))
linearClauseProcessor.createLinearVar(builder, moduleTranslation,
- linearVar);
+ linearVar, idx);
for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
}
@@ -2571,16 +2639,17 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
// Emit Initialization and Update IR for linear variables
if (!wsloopOp.getLinearVars().empty()) {
+ linearClauseProcessor.initLinearVar(builder, moduleTranslation,
+ loopInfo->getPreheader());
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
- linearClauseProcessor.initLinearVar(builder, moduleTranslation,
- loopInfo->getPreheader());
+ moduleTranslation.getOpenMPBuilder()->createBarrier(
+ builder.saveIP(), llvm::omp::OMPD_barrier);
if (failed(handleError(afterBarrierIP, *loopOp)))
return failure();
builder.restoreIP(*afterBarrierIP);
linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
loopInfo->getIndVar());
- linearClauseProcessor.outlineLinearFinalizationBB(builder,
- loopInfo->getExit());
+ linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
}
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
@@ -2611,7 +2680,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
convertToScheduleKind(schedule), chunk, isSimd,
scheduleMod == omp::ScheduleModifier::monotonic,
scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
- workshareLoopType, noLoopMode);
+ workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
if (failed(handleError(wsloopIP, opInst)))
return failure();
@@ -2655,6 +2724,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
ArrayRef<bool> isByRef = getIsByRef(opInst.getReductionByref());
assert(isByRef.size() == opInst.getNumReductionVars());
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+ bool isCancellable = constructIsCancellable(opInst);
if (failed(checkImplementationStatus(*opInst)))
return failure();
@@ -2729,10 +2799,13 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
// Collect reduction info
SmallVector<OwningReductionGen> owningReductionGens;
SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
- SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
+ SmallVector<OwningDataPtrPtrReductionGen>
+ owningReductionGenRefDataPtrGens;
+ SmallVector<llvm::OpenMPIRBuilder::ReductionInfo, 2> reductionInfos;
collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
owningReductionGens, owningAtomicReductionGens,
- privateReductionVariables, reductionInfos);
+ owningReductionGenRefDataPtrGens,
+ privateReductionVariables, reductionInfos, isByRef);
// Move to region cont block
builder.SetInsertPoint((*regionBlock)->getTerminator());
@@ -2790,6 +2863,18 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
privateVarsInfo.privatizers)))
return llvm::make_error<PreviouslyReportedError>();
+ // If we could be performing cancellation, add the cancellation barrier on
+ // the way out of the outlined region.
+ if (isCancellable) {
+ auto IPOrErr = ompBuilder->createBarrier(
+ llvm::OpenMPIRBuilder::LocationDescription(builder),
+ llvm::omp::Directive::OMPD_unknown,
+ /* ForceSimpleCall */ false,
+ /* CheckCancelFlag */ false);
+ if (!IPOrErr)
+ return IPOrErr.takeError();
+ }
+
builder.restoreIP(oldIP);
return llvm::Error::success();
};
@@ -2803,7 +2888,6 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
if (auto bind = opInst.getProcBindKind())
pbKind = getProcBindKind(*bind);
- bool isCancellable = constructIsCancellable(opInst);
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
@@ -2858,6 +2942,20 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
+ // Initialize linear variables and linear step
+ LinearClauseProcessor linearClauseProcessor;
+
+ if (!simdOp.getLinearVars().empty()) {
+ auto linearVarTypes = simdOp.getLinearVarTypes().value();
+ for (mlir::Attribute linearVarType : linearVarTypes)
+ linearClauseProcessor.registerType(moduleTranslation, linearVarType);
+ for (auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars()))
+ linearClauseProcessor.createLinearVar(builder, moduleTranslation,
+ linearVar, idx);
+ for (mlir::Value linearStep : simdOp.getLinearStepVars())
+ linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
+ }
+
llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
builder, moduleTranslation, privateVarsInfo, allocaIP);
if (handleError(afterAllocas, opInst).failed())
@@ -2927,14 +3025,27 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
if (failed(handleError(regionBlock, opInst)))
return failure();
- builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
+ // Emit Initialization for linear variables
+ if (simdOp.getLinearVars().size()) {
+ linearClauseProcessor.initLinearVar(builder, moduleTranslation,
+ loopInfo->getPreheader());
+
+ linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
+ loopInfo->getIndVar());
+ }
+ builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+
ompBuilder->applySimd(loopInfo, alignedVars,
simdOp.getIfExpr()
? moduleTranslation.lookupValue(simdOp.getIfExpr())
: nullptr,
order, simdlen, safelen);
+ for (size_t index = 0; index < simdOp.getLinearVars().size(); index++)
+ linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
+ index);
+
// We now need to reduce the per-simd-lane reduction variable into the
// original variable. This works a bit differently to other reductions (e.g.
// wsloop) because we don't need to call into the OpenMP runtime to handle
@@ -3632,10 +3743,23 @@ convertToCaptureClauseKind(
return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
case mlir::omp::DeclareTargetCaptureClause::enter:
return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
+ case mlir::omp::DeclareTargetCaptureClause::none:
+ return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone;
}
llvm_unreachable("unhandled capture clause");
}
+static Operation *getGlobalOpFromValue(Value value) {
+ Operation *op = value.getDefiningOp();
+ if (auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
+ op = addrCast->getOperand(0).getDefiningOp();
+ if (auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
+ auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
+ return modOp.lookupSymbol(addressOfOp.getGlobalName());
+ }
+ return nullptr;
+}
+
static llvm::SmallString<64>
getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp,
llvm::OpenMPIRBuilder &ompBuilder) {
@@ -3658,62 +3782,58 @@ getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp,
return suffix;
}
-static bool isDeclareTargetLink(mlir::Value value) {
- if (auto addressOfOp = value.getDefiningOp<LLVM::AddressOfOp>()) {
- auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
- Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName());
- if (auto declareTargetGlobal =
- llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(gOp))
- if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
- mlir::omp::DeclareTargetCaptureClause::link)
- return true;
- }
+static bool isDeclareTargetLink(Value value) {
+ if (auto declareTargetGlobal =
+ dyn_cast_if_present<omp::DeclareTargetInterface>(
+ getGlobalOpFromValue(value)))
+ if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
+ omp::DeclareTargetCaptureClause::link)
+ return true;
+ return false;
+}
+
+static bool isDeclareTargetTo(Value value) {
+ if (auto declareTargetGlobal =
+ dyn_cast_if_present<omp::DeclareTargetInterface>(
+ getGlobalOpFromValue(value)))
+ if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
+ omp::DeclareTargetCaptureClause::to ||
+ declareTargetGlobal.getDeclareTargetCaptureClause() ==
+ omp::DeclareTargetCaptureClause::enter)
+ return true;
return false;
}
-// Returns the reference pointer generated by the lowering of the declare target
-// operation in cases where the link clause is used or the to clause is used in
-// USM mode.
+// Returns the reference pointer generated by the lowering of the declare
+// target operation in cases where the link clause is used or the to clause is
+// used in USM mode.
static llvm::Value *
-getRefPtrIfDeclareTarget(mlir::Value value,
+getRefPtrIfDeclareTarget(Value value,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
- Operation *op = value.getDefiningOp();
- if (auto addrCast = llvm::dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op))
- op = addrCast->getOperand(0).getDefiningOp();
-
- // An easier way to do this may just be to keep track of any pointer
- // references and their mapping to their respective operation
- if (auto addressOfOp = llvm::dyn_cast_if_present<LLVM::AddressOfOp>(op)) {
- if (auto gOp = llvm::dyn_cast_or_null<LLVM::GlobalOp>(
- addressOfOp->getParentOfType<mlir::ModuleOp>().lookupSymbol(
- addressOfOp.getGlobalName()))) {
-
- if (auto declareTargetGlobal =
- llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
- gOp.getOperation())) {
-
- // In this case, we must utilise the reference pointer generated by the
- // declare target operation, similar to Clang
- if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
- mlir::omp::DeclareTargetCaptureClause::link) ||
- (declareTargetGlobal.getDeclareTargetCaptureClause() ==
- mlir::omp::DeclareTargetCaptureClause::to &&
- ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
- llvm::SmallString<64> suffix =
- getDeclareTargetRefPtrSuffix(gOp, *ompBuilder);
-
- if (gOp.getSymName().contains(suffix))
- return moduleTranslation.getLLVMModule()->getNamedValue(
- gOp.getSymName());
+ if (auto gOp =
+ dyn_cast_or_null<LLVM::GlobalOp>(getGlobalOpFromValue(value))) {
+ if (auto declareTargetGlobal =
+ dyn_cast<omp::DeclareTargetInterface>(gOp.getOperation())) {
+ // In this case, we must utilise the reference pointer generated by
+ // the declare target operation, similar to Clang
+ if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
+ omp::DeclareTargetCaptureClause::link) ||
+ (declareTargetGlobal.getDeclareTargetCaptureClause() ==
+ omp::DeclareTargetCaptureClause::to &&
+ ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
+ llvm::SmallString<64> suffix =
+ getDeclareTargetRefPtrSuffix(gOp, *ompBuilder);
+ if (gOp.getSymName().contains(suffix))
return moduleTranslation.getLLVMModule()->getNamedValue(
- (gOp.getSymName().str() + suffix.str()).str());
- }
+ gOp.getSymName());
+
+ return moduleTranslation.getLLVMModule()->getNamedValue(
+ (gOp.getSymName().str() + suffix.str()).str());
}
}
}
-
return nullptr;
}
@@ -3756,6 +3876,32 @@ struct MapInfoData : MapInfosTy {
MapInfosTy::append(CurInfo);
}
};
+
+enum class TargetDirectiveEnumTy : uint32_t {
+ None = 0,
+ Target = 1,
+ TargetData = 2,
+ TargetEnterData = 3,
+ TargetExitData = 4,
+ TargetUpdate = 5
+};
+
+static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) {
+ return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op)
+ .Case([](omp::TargetDataOp) { return TargetDirectiveEnumTy::TargetData; })
+ .Case([](omp::TargetEnterDataOp) {
+ return TargetDirectiveEnumTy::TargetEnterData;
+ })
+ .Case([&](omp::TargetExitDataOp) {
+ return TargetDirectiveEnumTy::TargetExitData;
+ })
+ .Case([&](omp::TargetUpdateOp) {
+ return TargetDirectiveEnumTy::TargetUpdate;
+ })
+ .Case([&](omp::TargetOp) { return TargetDirectiveEnumTy::Target; })
+ .Default([&](Operation *op) { return TargetDirectiveEnumTy::None; });
+}
+
} // namespace
static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy,
@@ -3787,7 +3933,7 @@ static llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
// This calculates the size to transfer based on bounds and the underlying
// element type, provided bounds have been specified (Fortran
// pointers/allocatables/target and arrays that have sections specified fall
- // into this as well).
+ // into this as well)
if (!memberClause.getBounds().empty()) {
llvm::Value *elementCount = builder.getInt64(1);
for (auto bounds : memberClause.getBounds()) {
@@ -3835,6 +3981,9 @@ convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) {
auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
return (mlirFlags & flag) == flag;
};
+ const bool hasExplicitMap =
+ (mlirFlags & ~omp::ClauseMapFlags::is_device_ptr) !=
+ omp::ClauseMapFlags::none;
llvm::omp::OpenMPOffloadMappingFlags mapType =
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
@@ -3875,6 +4024,12 @@ convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) {
if (mapTypeToBool(omp::ClauseMapFlags::attach))
mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
+ if (mapTypeToBool(omp::ClauseMapFlags::is_device_ptr)) {
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
+ if (!hasExplicitMap)
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
+ }
+
return mapType;
}
@@ -3910,10 +4065,12 @@ static void collectMapDataFromMapOperands(
mapData.Pointers.push_back(mapData.OriginalValue.back());
if (llvm::Value *refPtr =
- getRefPtrIfDeclareTarget(offloadPtr,
- moduleTranslation)) { // declare target
+ getRefPtrIfDeclareTarget(offloadPtr, moduleTranslation)) {
mapData.IsDeclareTarget.push_back(true);
mapData.BasePointers.push_back(refPtr);
+ } else if (isDeclareTargetTo(offloadPtr)) {
+ mapData.IsDeclareTarget.push_back(true);
+ mapData.BasePointers.push_back(mapData.OriginalValue.back());
} else { // regular mapped variable
mapData.IsDeclareTarget.push_back(false);
mapData.BasePointers.push_back(mapData.OriginalValue.back());
@@ -3996,6 +4153,9 @@ static void collectMapDataFromMapOperands(
llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
auto mapType = convertClauseMapFlags(mapOp.getMapType());
auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
+ bool isDevicePtr =
+ (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
+ omp::ClauseMapFlags::none;
mapData.OriginalValue.push_back(origValue);
mapData.BasePointers.push_back(origValue);
@@ -4022,14 +4182,18 @@ static void collectMapDataFromMapOperands(
mapData.Mappers.push_back(nullptr);
}
} else {
+ // For is_device_ptr we need the map type to propagate so the runtime
+ // can materialize the device-side copy of the pointer container.
mapData.Types.push_back(
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
+ isDevicePtr ? mapType
+ : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
mapData.Mappers.push_back(nullptr);
}
mapData.Names.push_back(LLVM::createMappingInformation(
mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
mapData.DevicePointers.push_back(
- llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
+ isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
+ : llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
mapData.IsAMapping.push_back(false);
mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
}
@@ -4042,41 +4206,66 @@ static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
return std::distance(mapData.MapClause.begin(), res);
}
+static void sortMapIndices(llvm::SmallVectorImpl<size_t> &indices,
+ omp::MapInfoOp mapInfo) {
+ ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
+ llvm::SmallVector<size_t> occludedChildren;
+ llvm::sort(
+ indices.begin(), indices.end(), [&](const size_t a, const size_t b) {
+ // Bail early if we are asked to look at the same index. If we do not
+ // bail early, we can end up mistakenly adding indices to
+ // occludedChildren. This can occur with some types of libc++ hardening.
+ if (a == b)
+ return false;
+
+ auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
+ auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
+
+ for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
+ int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt();
+ int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt();
+
+ if (aIndex == bIndex)
+ continue;
+
+ if (aIndex < bIndex)
+ return true;
+
+ if (aIndex > bIndex)
+ return false;
+ }
+
+ // Iterated up until the end of the smallest member and
+ // they were found to be equal up to that point, so select
+ // the member with the lowest index count, so the "parent"
+ bool memberAParent = memberIndicesA.size() < memberIndicesB.size();
+ if (memberAParent)
+ occludedChildren.push_back(b);
+ else
+ occludedChildren.push_back(a);
+ return memberAParent;
+ });
+
+ // We remove children from the index list that are overshadowed by
+ // a parent, this prevents us retrieving these as the first or last
+ // element when the parent is the correct element in these cases.
+ for (auto v : occludedChildren)
+ indices.erase(std::remove(indices.begin(), indices.end(), v),
+ indices.end());
+}
+
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
bool first) {
ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
// Only 1 member has been mapped, we can return it.
if (indexAttr.size() == 1)
return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
-
llvm::SmallVector<size_t> indices(indexAttr.size());
std::iota(indices.begin(), indices.end(), 0);
-
- llvm::sort(indices, [&](const size_t a, const size_t b) {
- auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
- auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
- for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
- int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
- int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
-
- if (aIndex == bIndex)
- continue;
-
- if (aIndex < bIndex)
- return first;
-
- if (aIndex > bIndex)
- return !first;
- }
-
- // Iterated the up until the end of the smallest member and
- // they were found to be equal up to that point, so select
- // the member with the lowest index count, so the "parent"
- return memberIndicesA.size() < memberIndicesB.size();
- });
-
+ sortMapIndices(indices, mapInfo);
return llvm::cast<omp::MapInfoOp>(
- mapInfo.getMembers()[indices.front()].getDefiningOp());
+ mapInfo.getMembers()[first ? indices.front() : indices.back()]
+ .getDefiningOp());
}
/// This function calculates the array/pointer offset for map data provided
@@ -4155,6 +4344,86 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
return idx;
}
+static void getAsIntegers(ArrayAttr values, llvm::SmallVector<int64_t> &ints) {
+ llvm::transform(values, std::back_inserter(ints), [](Attribute value) {
+ return cast<IntegerAttr>(value).getInt();
+ });
+}
+
+// Gathers members that are overlapping in the parent, excluding members that
+// themselves overlap, keeping the top-most (closest to parents level) map.
+static void
+getOverlappedMembers(llvm::SmallVectorImpl<size_t> &overlapMapDataIdxs,
+ omp::MapInfoOp parentOp) {
+ // No members mapped, no overlaps.
+ if (parentOp.getMembers().empty())
+ return;
+
+ // Single member, we can insert and return early.
+ if (parentOp.getMembers().size() == 1) {
+ overlapMapDataIdxs.push_back(0);
+ return;
+ }
+
+ // 1) collect list of top-level overlapping members from MemberOp
+ llvm::SmallVector<std::pair<int, ArrayAttr>> memberByIndex;
+ ArrayAttr indexAttr = parentOp.getMembersIndexAttr();
+ for (auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr))
+ memberByIndex.push_back(
+ std::make_pair(memIndex, cast<ArrayAttr>(indicesAttr)));
+
+ // Sort the smallest first (higher up the parent -> member chain), so that
+ // when we remove members, we remove as much as we can in the initial
+ // iterations, shortening the number of passes required.
+ llvm::sort(memberByIndex.begin(), memberByIndex.end(),
+ [&](auto a, auto b) { return a.second.size() < b.second.size(); });
+
+ // Remove elements from the vector if there is a parent element that
+ // supersedes it. i.e. if member [0] is mapped, we can remove members [0,1],
+ // [0,2].. etc.
+ llvm::SmallVector<std::pair<int, ArrayAttr>> skipList;
+ for (auto v : memberByIndex) {
+ llvm::SmallVector<int64_t> vArr(v.second.size());
+ getAsIntegers(v.second, vArr);
+ skipList.push_back(
+ *std::find_if(memberByIndex.begin(), memberByIndex.end(), [&](auto x) {
+ if (v == x)
+ return false;
+ llvm::SmallVector<int64_t> xArr(x.second.size());
+ getAsIntegers(x.second, xArr);
+ return std::equal(vArr.begin(), vArr.end(), xArr.begin()) &&
+ xArr.size() >= vArr.size();
+ }));
+ }
+
+ // Collect the indices, as we need the base pointer etc. from the MapData
+ // structure which is primarily accessible via index at the moment.
+ for (auto v : memberByIndex)
+ if (find(skipList.begin(), skipList.end(), v) == skipList.end())
+ overlapMapDataIdxs.push_back(v.first);
+}
+
+// The intent is to verify if the mapped data being passed is a
+// pointer -> pointee that requires special handling in certain cases,
+// e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
+//
+// There may be a better way to verify this, but unfortunately with
+// opaque pointers we lose the ability to easily check if something is
+// a pointer whilst maintaining access to the underlying type.
+static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
+ // If we have a varPtrPtr field assigned then the underlying type is a pointer
+ if (mapOp.getVarPtrPtr())
+ return true;
+
+ // If the map data is declare target with a link clause, then it's represented
+ // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
+ // no relation to pointers.
+ if (isDeclareTargetLink(mapOp.getVarPtr()))
+ return true;
+
+ return false;
+}
+
// This creates two insertions into the MapInfosTy data structure for the
// "parent" of a set of members, (usually a container e.g.
// class/structure/derived type) when subsequent members have also been
@@ -4173,7 +4442,8 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
- MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams) {
+ MapInfoData &mapData, uint64_t mapDataIndex,
+ TargetDirectiveEnumTy targetDirective) {
assert(!ompBuilder.Config.isTargetDevice() &&
"function only supported for host device codegen");
@@ -4182,7 +4452,8 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
// base entry so the mapper receives correct copy semantics via its 'type'
// parameter. Also keep TARGET_PARAM when required for kernel arguments.
llvm::omp::OpenMPOffloadMappingFlags baseFlag =
- isTargetParams
+ (targetDirective == TargetDirectiveEnumTy::Target &&
+ !mapData.IsDeclareTarget[mapDataIndex])
? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
@@ -4217,7 +4488,6 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
// runtime information on the dynamically allocated data).
auto parentClause =
llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
-
llvm::Value *lowAddr, *highAddr;
if (!parentClause.getPartialMap()) {
lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
@@ -4263,39 +4533,85 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
// further case specific flag modifications). For the moment, it handles
// what we support as expected.
llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
+ bool hasMapClose = (llvm::omp::OpenMPOffloadMappingFlags(mapFlag) &
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE) ==
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
- combinedInfo.Types.emplace_back(mapFlag);
- combinedInfo.DevicePointers.emplace_back(
- llvm::OpenMPIRBuilder::DeviceInfoTy::None);
- combinedInfo.Mappers.emplace_back(nullptr);
- combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
- mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
- combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
- combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
- combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
- }
- return memberOfFlag;
-}
-
-// The intent is to verify if the mapped data being passed is a
-// pointer -> pointee that requires special handling in certain cases,
-// e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
-//
-// There may be a better way to verify this, but unfortunately with
-// opaque pointers we lose the ability to easily check if something is
-// a pointer whilst maintaining access to the underlying type.
-static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
- // If we have a varPtrPtr field assigned then the underlying type is a pointer
- if (mapOp.getVarPtrPtr())
- return true;
- // If the map data is declare target with a link clause, then it's represented
- // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
- // no relation to pointers.
- if (isDeclareTargetLink(mapOp.getVarPtr()))
- return true;
+ if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose) {
+ combinedInfo.Types.emplace_back(mapFlag);
+ combinedInfo.DevicePointers.emplace_back(
+ mapData.DevicePointers[mapDataIndex]);
+ combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
+ mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
+ combinedInfo.BasePointers.emplace_back(
+ mapData.BasePointers[mapDataIndex]);
+ combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
+ combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
+ combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
+ } else {
+ llvm::SmallVector<size_t> overlapIdxs;
+ // Find all of the members that "overlap", i.e. occlude other members that
+ // were mapped alongside the parent, e.g. member [0], occludes [0,1] and
+ // [0,2], but not [1,0].
+ getOverlappedMembers(overlapIdxs, parentClause);
+ // We need to make sure the overlapped members are sorted in order of
+ // lowest address to highest address.
+ sortMapIndices(overlapIdxs, parentClause);
+
+ lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
+ builder.getPtrTy());
+ highAddr = builder.CreatePointerCast(
+ builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
+ mapData.Pointers[mapDataIndex], 1),
+ builder.getPtrTy());
+
+ // TODO: We may want to skip arrays/array sections in this as Clang does.
+ // It appears to be an optimisation rather than a necessity though,
+ // but this requires further investigation. However, we would have to make
+ // sure to not exclude maps with bounds that ARE pointers, as these are
+ // processed as separate components, i.e. pointer + data.
+ for (auto v : overlapIdxs) {
+ auto mapDataOverlapIdx = getMapDataMemberIdx(
+ mapData,
+ cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp()));
+ combinedInfo.Types.emplace_back(mapFlag);
+ combinedInfo.DevicePointers.emplace_back(
+ mapData.DevicePointers[mapDataOverlapIdx]);
+ combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
+ mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
+ combinedInfo.BasePointers.emplace_back(
+ mapData.BasePointers[mapDataIndex]);
+ combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
+ combinedInfo.Pointers.emplace_back(lowAddr);
+ combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
+ builder.CreatePtrDiff(builder.getInt8Ty(),
+ mapData.OriginalValue[mapDataOverlapIdx],
+ lowAddr),
+ builder.getInt64Ty(), /*isSigned=*/true));
+ lowAddr = builder.CreateConstGEP1_32(
+ checkIfPointerMap(llvm::cast<omp::MapInfoOp>(
+ mapData.MapClause[mapDataOverlapIdx]))
+ ? builder.getPtrTy()
+ : mapData.BaseType[mapDataOverlapIdx],
+ mapData.BasePointers[mapDataOverlapIdx], 1);
+ }
- return false;
+ combinedInfo.Types.emplace_back(mapFlag);
+ combinedInfo.DevicePointers.emplace_back(
+ mapData.DevicePointers[mapDataIndex]);
+ combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
+ mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
+ combinedInfo.BasePointers.emplace_back(
+ mapData.BasePointers[mapDataIndex]);
+ combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]);
+ combinedInfo.Pointers.emplace_back(lowAddr);
+ combinedInfo.Sizes.emplace_back(builder.CreateIntCast(
+ builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
+ builder.getInt64Ty(), true));
+ }
+ }
+ return memberOfFlag;
}
// This function is intended to add explicit mappings of members
@@ -4303,7 +4619,8 @@ static void processMapMembersWithParent(
LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
MapInfoData &mapData, uint64_t mapDataIndex,
- llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {
+ llvm::omp::OpenMPOffloadMappingFlags memberOfFlag,
+ TargetDirectiveEnumTy targetDirective) {
assert(!ompBuilder.Config.isTargetDevice() &&
"function only supported for host device codegen");
@@ -4348,8 +4665,15 @@ static void processMapMembersWithParent(
mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
- if (checkIfPointerMap(memberClause))
+ bool isDeclTargetTo = isDeclareTargetTo(parentClause.getVarPtr()
+ ? parentClause.getVarPtr()
+ : parentClause.getVarPtrPtr());
+ if (checkIfPointerMap(memberClause) &&
+ (!isDeclTargetTo ||
+ (targetDirective != TargetDirectiveEnumTy::TargetUpdate &&
+ targetDirective != TargetDirectiveEnumTy::TargetData))) {
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
+ }
combinedInfo.Types.emplace_back(mapFlag);
combinedInfo.DevicePointers.emplace_back(
@@ -4375,7 +4699,8 @@ static void processMapMembersWithParent(
}
static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
- MapInfosTy &combinedInfo, bool isTargetParams,
+ MapInfosTy &combinedInfo,
+ TargetDirectiveEnumTy targetDirective,
int mapDataParentIdx = -1) {
// Declare Target Mappings are excluded from being marked as
// OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're
@@ -4387,7 +4712,8 @@ static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
if (isPtrTy)
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
- if (isTargetParams && !mapData.IsDeclareTarget[mapDataIdx])
+ if (targetDirective == TargetDirectiveEnumTy::Target &&
+ !mapData.IsDeclareTarget[mapDataIdx])
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
@@ -4416,7 +4742,7 @@ static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation,
llvm::OpenMPIRBuilder &ompBuilder,
DataLayout &dl, MapInfosTy &combinedInfo,
MapInfoData &mapData, uint64_t mapDataIndex,
- bool isTargetParams) {
+ TargetDirectiveEnumTy targetDirective) {
assert(!ompBuilder.Config.isTargetDevice() &&
"function only supported for host device codegen");
@@ -4440,17 +4766,18 @@ static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation,
// Clang maps array without bounds as pointers (which we do not
// currently do), whereas we treat them as arrays in all cases
// currently.
- processIndividualMap(mapData, memberDataIdx, combinedInfo, isTargetParams,
+ processIndividualMap(mapData, memberDataIdx, combinedInfo, targetDirective,
mapDataIndex);
return;
}
llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl,
- combinedInfo, mapData, mapDataIndex, isTargetParams);
+ combinedInfo, mapData, mapDataIndex,
+ targetDirective);
processMapMembersWithParent(moduleTranslation, builder, ompBuilder, dl,
combinedInfo, mapData, mapDataIndex,
- memberOfParentFlag);
+ memberOfParentFlag, targetDirective);
}
// This is a variation on Clang's GenerateOpenMPCapturedVars, which
@@ -4528,10 +4855,10 @@ createAlteredByCaptureMap(MapInfoData &mapData,
static void genMapInfos(llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
DataLayout &dl, MapInfosTy &combinedInfo,
- MapInfoData &mapData, bool isTargetParams = false) {
+ MapInfoData &mapData,
+ TargetDirectiveEnumTy targetDirective) {
assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
"function only supported for host device codegen");
-
// We wish to modify some of the methods in which arguments are
// passed based on their capture type by the target region, this can
// involve generating new loads and stores, which changes the
@@ -4561,22 +4888,24 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
if (!mapInfoOp.getMembers().empty()) {
processMapWithMembersOf(moduleTranslation, builder, *ompBuilder, dl,
- combinedInfo, mapData, i, isTargetParams);
+ combinedInfo, mapData, i, targetDirective);
continue;
}
- processIndividualMap(mapData, i, combinedInfo, isTargetParams);
+ processIndividualMap(mapData, i, combinedInfo, targetDirective);
}
}
static llvm::Expected<llvm::Function *>
emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
- llvm::StringRef mapperFuncName);
+ llvm::StringRef mapperFuncName,
+ TargetDirectiveEnumTy targetDirective);
static llvm::Expected<llvm::Function *>
getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder,
- LLVM::ModuleTranslation &moduleTranslation) {
+ LLVM::ModuleTranslation &moduleTranslation,
+ TargetDirectiveEnumTy targetDirective) {
assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
"function only supported for host device codegen");
auto declMapperOp = cast<omp::DeclareMapperOp>(op);
@@ -4588,13 +4917,14 @@ getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder,
return lookupFunc;
return emitUserDefinedMapper(declMapperOp, builder, moduleTranslation,
- mapperFuncName);
+ mapperFuncName, targetDirective);
}
static llvm::Expected<llvm::Function *>
emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
- llvm::StringRef mapperFuncName) {
+ llvm::StringRef mapperFuncName,
+ TargetDirectiveEnumTy targetDirective) {
assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
"function only supported for host device codegen");
auto declMapperOp = cast<omp::DeclareMapperOp>(op);
@@ -4622,10 +4952,11 @@ emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder,
MapInfoData mapData;
collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
builder);
- genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData);
+ genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData,
+ targetDirective);
- // Drop the mapping that is no longer necessary so that the same region can
- // be processed multiple times.
+ // Drop the mapping that is no longer necessary so that the same region
+ // can be processed multiple times.
moduleTranslation.forgetMapping(declMapperOp.getRegion());
return combinedInfo;
};
@@ -4634,7 +4965,7 @@ emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder,
if (!combinedInfo.Mappers[i])
return nullptr;
return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder,
- moduleTranslation);
+ moduleTranslation, targetDirective);
};
llvm::Expected<llvm::Function *> newFn = ompBuilder->emitUserDefinedMapper(
@@ -4655,10 +4986,12 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
SmallVector<Value> useDeviceAddrVars;
llvm::omp::RuntimeFunction RTLFn;
DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>());
+ TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op);
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
- llvm::OpenMPIRBuilder::TargetDataInfo info(/*RequiresDevicePointerInfo=*/true,
- /*SeparateBeginEndCalls=*/true);
+ llvm::OpenMPIRBuilder::TargetDataInfo info(
+ /*RequiresDevicePointerInfo=*/true,
+ /*SeparateBeginEndCalls=*/true);
bool isTargetDevice = ompBuilder->Config.isTargetDevice();
bool isOffloadEntry =
isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
@@ -4757,7 +5090,8 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
MapInfosTy combinedInfo;
auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
builder.restoreIP(codeGenIP);
- genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData);
+ genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
+ targetDirective);
return combinedInfo;
};
@@ -4873,7 +5207,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
return nullptr;
info.HasMapper = true;
return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder,
- moduleTranslation);
+ moduleTranslation, targetDirective);
};
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
@@ -4980,15 +5314,18 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
// TODO: Add support for clauses which are valid for DISTRIBUTE
// constructs. Static schedule is the default.
- auto schedule = omp::ClauseScheduleKind::Static;
- bool isOrdered = false;
+ bool hasDistSchedule = distributeOp.getDistScheduleStatic();
+ auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
+ : omp::ClauseScheduleKind::Static;
+ // dist_schedule clauses are ordered - otherise this should be false
+ bool isOrdered = hasDistSchedule;
std::optional<omp::ScheduleModifier> scheduleMod;
bool isSimd = false;
llvm::omp::WorksharingLoopType workshareLoopType =
llvm::omp::WorksharingLoopType::DistributeStaticLoop;
bool loopNeedsBarrier = false;
- llvm::Value *chunk = nullptr;
-
+ llvm::Value *chunk = moduleTranslation.lookupValue(
+ distributeOp.getDistScheduleChunkSize());
llvm::CanonicalLoopInfo *loopInfo =
findCurrentLoopInfo(moduleTranslation);
llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
@@ -4997,12 +5334,11 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
convertToScheduleKind(schedule), chunk, isSimd,
scheduleMod == omp::ScheduleModifier::monotonic,
scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
- workshareLoopType);
+ workshareLoopType, false, hasDistSchedule, chunk);
if (!wsloopIP)
return wsloopIP.takeError();
}
-
if (failed(cleanupPrivateVars(builder, moduleTranslation,
distributeOp.getLoc(), privVarsInfo.llvmVars,
privVarsInfo.privatizers)))
@@ -5135,11 +5471,16 @@ handleDeclareTargetMapVar(MapInfoData &mapData,
for (llvm::User *user : userVec) {
if (auto *insn = dyn_cast<llvm::Instruction>(user)) {
if (insn->getFunction() == func) {
- builder.SetCurrentDebugLocation(insn->getDebugLoc());
- auto *load = builder.CreateLoad(mapData.BasePointers[i]->getType(),
- mapData.BasePointers[i]);
- load->moveBefore(insn->getIterator());
- user->replaceUsesOfWith(mapData.OriginalValue[i], load);
+ auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
+ llvm::Value *substitute = mapData.BasePointers[i];
+ if (isDeclareTargetLink(mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr()
+ : mapOp.getVarPtr())) {
+ builder.SetCurrentDebugLocation(insn->getDebugLoc());
+ substitute = builder.CreateLoad(
+ mapData.BasePointers[i]->getType(), mapData.BasePointers[i]);
+ cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator());
+ }
+ user->replaceUsesOfWith(mapData.OriginalValue[i], substitute);
}
}
}
@@ -5431,8 +5772,8 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
int32_t minTeamsVal = 1, maxTeamsVal = -1;
if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
- // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now, match
- // clang and set min and max to the same value.
+ // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now,
+ // match clang and set min and max to the same value.
if (numTeamsUpper) {
if (auto val = extractConstInteger(numTeamsUpper))
minTeamsVal = maxTeamsVal = *val;
@@ -5624,9 +5965,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
auto &targetRegion = targetOp.getRegion();
- // Holds the private vars that have been mapped along with the block argument
- // that corresponds to the MapInfoOp corresponding to the private var in
- // question. So, for instance:
+ // Holds the private vars that have been mapped along with the block
+ // argument that corresponds to the MapInfoOp corresponding to the private
+ // var in question. So, for instance:
//
// %10 = omp.map.info var_ptr(%6#0 : !fir.ref<!fir.box<!fir.heap<i32>>>, ..)
// omp.target map_entries(%10 -> %arg0) private(@box.privatizer %6#0-> %arg1)
@@ -5641,6 +5982,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
ArrayRef<BlockArgument> mapBlockArgs = argIface.getMapBlockArgs();
ArrayRef<BlockArgument> hdaBlockArgs = argIface.getHasDeviceAddrBlockArgs();
llvm::Function *llvmOutlinedFn = nullptr;
+ TargetDirectiveEnumTy targetDirective =
+ getTargetDirectiveEnumTyFromOp(&opInst);
// TODO: It can also be false if a compile-time constant `false` IF clause is
// specified.
@@ -5802,7 +6145,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
auto genMapInfoCB =
[&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
builder.restoreIP(codeGenIP);
- genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData, true);
+ genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData,
+ targetDirective);
return combinedInfos;
};
@@ -5882,7 +6226,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
return nullptr;
info.HasMapper = true;
return getOrCreateUserDefinedMapperFunc(combinedInfos.Mappers[i], builder,
- moduleTranslation);
+ moduleTranslation, targetDirective);
};
llvm::Value *ifCond = nullptr;
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index d9891e3..d7d215b 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -34,12 +34,14 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Comdat.h"
#include "llvm/IR/Constants.h"
+#include "llvm/IR/DebugProgramInstruction.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Operator.h"
+#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/ModRef.h"
#include <optional>
@@ -522,6 +524,11 @@ void ModuleImport::addDebugIntrinsic(llvm::CallInst *intrinsic) {
debugIntrinsics.insert(intrinsic);
}
+void ModuleImport::addDebugRecord(llvm::DbgVariableRecord *dbgRecord) {
+ if (!dbgRecords.contains(dbgRecord))
+ dbgRecords.insert(dbgRecord);
+}
+
static Attribute convertCGProfileModuleFlagValue(ModuleOp mlirModule,
llvm::MDTuple *mdTuple) {
auto getLLVMFunction =
@@ -1214,7 +1221,7 @@ static TypedAttr getScalarConstantAsAttr(OpBuilder &builder,
llvm::Constant *constScalar) {
MLIRContext *context = builder.getContext();
- // Convert scalar intergers.
+ // Convert scalar integers.
if (auto *constInt = dyn_cast<llvm::ConstantInt>(constScalar)) {
return builder.getIntegerAttr(
IntegerType::get(context, constInt->getBitWidth()),
@@ -2003,9 +2010,15 @@ FloatAttr ModuleImport::matchFloatAttr(llvm::Value *value) {
return floatAttr;
}
-DILocalVariableAttr ModuleImport::matchLocalVariableAttr(llvm::Value *value) {
- auto *nodeAsVal = cast<llvm::MetadataAsValue>(value);
- auto *node = cast<llvm::DILocalVariable>(nodeAsVal->getMetadata());
+DILocalVariableAttr ModuleImport::matchLocalVariableAttr(
+ llvm::PointerUnion<llvm::Value *, llvm::DILocalVariable *> valOrVariable) {
+ llvm::DILocalVariable *node = nullptr;
+ if (auto *value = dyn_cast<llvm::Value *>(valOrVariable)) {
+ auto *nodeAsVal = cast<llvm::MetadataAsValue>(value);
+ node = cast<llvm::DILocalVariable>(nodeAsVal->getMetadata());
+ } else {
+ node = cast<llvm::DILocalVariable *>(valOrVariable);
+ }
return debugImporter->translate(node);
}
@@ -2544,6 +2557,41 @@ LogicalResult ModuleImport::processInstruction(llvm::Instruction *inst) {
if (auto *intrinsic = dyn_cast<llvm::IntrinsicInst>(inst))
return convertIntrinsic(intrinsic);
+ // Process debug records attached to this instruction. Debug variable records
+ // are stored for later processing after all SSA values are converted, while
+ // debug label records can be converted immediately.
+ if (inst->DebugMarker) {
+ for (llvm::DbgRecord &dbgRecord : inst->DebugMarker->getDbgRecordRange()) {
+ // Store debug variable records for later processing.
+ if (auto *dbgVariableRecord =
+ dyn_cast<llvm::DbgVariableRecord>(&dbgRecord)) {
+ addDebugRecord(dbgVariableRecord);
+ continue;
+ }
+ Location loc = translateLoc(dbgRecord.getDebugLoc());
+ auto emitUnsupportedWarning = [&]() -> LogicalResult {
+ if (!emitExpensiveWarnings)
+ return success();
+ std::string options;
+ llvm::raw_string_ostream optionsStream(options);
+ dbgRecord.print(optionsStream);
+ emitWarning(loc) << "unhandled debug record " << optionsStream.str();
+ return success();
+ };
+ // Convert the debug label records in-place.
+ if (auto *dbgLabelRecord = dyn_cast<llvm::DbgLabelRecord>(&dbgRecord)) {
+ DILabelAttr labelAttr =
+ debugImporter->translate(dbgLabelRecord->getLabel());
+ if (!labelAttr)
+ return emitUnsupportedWarning();
+ LLVM::DbgLabelOp::create(builder, loc, labelAttr);
+ continue;
+ }
+ // Warn if an unsupported debug record is encountered.
+ return emitUnsupportedWarning();
+ }
+ }
+
// Convert all remaining LLVM instructions to MLIR operations.
return convertInstruction(inst);
}
@@ -2579,8 +2627,15 @@ static void processMemoryEffects(llvm::Function *func, LLVMFuncOp funcOp) {
memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
auto inaccessibleMem = convertModRefInfoFromLLVM(
memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
- auto memAttr = MemoryEffectsAttr::get(funcOp.getContext(), othermem, argMem,
- inaccessibleMem);
+ auto errnoMem = convertModRefInfoFromLLVM(
+ memEffects.getModRef(llvm::MemoryEffects::Location::ErrnoMem));
+ auto targetMem0 = convertModRefInfoFromLLVM(
+ memEffects.getModRef(llvm::MemoryEffects::Location::TargetMem0));
+ auto targetMem1 = convertModRefInfoFromLLVM(
+ memEffects.getModRef(llvm::MemoryEffects::Location::TargetMem1));
+ auto memAttr =
+ MemoryEffectsAttr::get(funcOp.getContext(), othermem, argMem,
+ inaccessibleMem, errnoMem, targetMem0, targetMem1);
// Only set the attr when it does not match the default value.
if (memAttr.isReadWrite())
return;
@@ -2885,8 +2940,15 @@ LogicalResult ModuleImport::convertCallAttributes(llvm::CallInst *inst,
memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
- auto memAttr = MemoryEffectsAttr::get(op.getContext(), othermem, argMem,
- inaccessibleMem);
+ ModRefInfo errnoMem = convertModRefInfoFromLLVM(
+ memEffects.getModRef(llvm::MemoryEffects::Location::ErrnoMem));
+ ModRefInfo targetMem0 = convertModRefInfoFromLLVM(
+ memEffects.getModRef(llvm::MemoryEffects::Location::TargetMem0));
+ ModRefInfo targetMem1 = convertModRefInfoFromLLVM(
+ memEffects.getModRef(llvm::MemoryEffects::Location::TargetMem1));
+ auto memAttr =
+ MemoryEffectsAttr::get(op.getContext(), othermem, argMem, inaccessibleMem,
+ errnoMem, targetMem0, targetMem1);
// Only set the attribute when it does not match the default value.
if (!memAttr.isReadWrite())
op.setMemoryEffectsAttr(memAttr);
@@ -3007,6 +3069,11 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) {
if (failed(processDebugIntrinsics()))
return failure();
+ // Process the debug records that require a delayed conversion after
+ // everything else was converted.
+ if (failed(processDebugRecords()))
+ return failure();
+
return success();
}
@@ -3022,61 +3089,32 @@ static bool isMetadataKillLocation(llvm::DbgVariableIntrinsic *dbgIntr) {
return !isa<llvm::ValueAsMetadata>(nodeAsVal->getMetadata());
}
-LogicalResult
-ModuleImport::processDebugIntrinsic(llvm::DbgVariableIntrinsic *dbgIntr,
- DominanceInfo &domInfo) {
- Location loc = translateLoc(dbgIntr->getDebugLoc());
- auto emitUnsupportedWarning = [&]() {
- if (emitExpensiveWarnings)
- emitWarning(loc) << "dropped intrinsic: " << diag(*dbgIntr);
- return success();
- };
- // Drop debug intrinsics with arg lists.
- // TODO: Support debug intrinsics that have arg lists.
- if (dbgIntr->hasArgList())
- return emitUnsupportedWarning();
- // Kill locations can have metadata nodes as location operand. This
- // cannot be converted to poison as the type cannot be reconstructed.
- // TODO: find a way to support this case.
- if (isMetadataKillLocation(dbgIntr))
- return emitUnsupportedWarning();
- // Drop debug intrinsics if the associated variable information cannot be
- // translated due to cyclic debug metadata.
- // TODO: Support cyclic debug metadata.
- DILocalVariableAttr localVariableAttr =
- matchLocalVariableAttr(dbgIntr->getArgOperand(1));
- if (!localVariableAttr)
- return emitUnsupportedWarning();
- FailureOr<Value> argOperand = convertMetadataValue(dbgIntr->getArgOperand(0));
- if (failed(argOperand))
- return emitError(loc) << "failed to convert a debug intrinsic operand: "
- << diag(*dbgIntr);
-
- // Ensure that the debug intrinsic is inserted right after its operand is
- // defined. Otherwise, the operand might not necessarily dominate the
- // intrinsic. If the defining operation is a terminator, insert the intrinsic
- // into a dominated block.
- OpBuilder::InsertionGuard guard(builder);
- if (Operation *op = argOperand->getDefiningOp();
+/// Ensure that the debug intrinsic is inserted right after the operand
+/// definition. Otherwise, the operand might not necessarily dominate the
+/// intrinsic. If the defining operation is a terminator, insert the intrinsic
+/// into a dominated block.
+static LogicalResult setDebugIntrinsicBuilderInsertionPoint(
+ mlir::OpBuilder &builder, DominanceInfo &domInfo, Value argOperand) {
+ if (Operation *op = argOperand.getDefiningOp();
op && op->hasTrait<OpTrait::IsTerminator>()) {
// Find a dominated block that can hold the debug intrinsic.
auto dominatedBlocks = domInfo.getNode(op->getBlock())->children();
// If no block is dominated by the terminator, this intrinisc cannot be
// converted.
if (dominatedBlocks.empty())
- return emitUnsupportedWarning();
+ return failure();
// Set insertion point before the terminator, to avoid inserting something
// before landingpads.
Block *dominatedBlock = (*dominatedBlocks.begin())->getBlock();
builder.setInsertionPoint(dominatedBlock->getTerminator());
} else {
- Value insertPt = *argOperand;
- if (auto blockArg = dyn_cast<BlockArgument>(*argOperand)) {
+ Value insertPt = argOperand;
+ if (auto blockArg = dyn_cast<BlockArgument>(argOperand)) {
// The value might be coming from a phi node and is now a block argument,
// which means the insertion point is set to the start of the block. If
// this block is a target destination of an invoke, the insertion point
// must happen after the landing pad operation.
- Block *insertionBlock = argOperand->getParentBlock();
+ Block *insertionBlock = argOperand.getParentBlock();
if (!insertionBlock->empty() &&
isa<LandingpadOp>(insertionBlock->front()))
insertPt = cast<LandingpadOp>(insertionBlock->front()).getRes();
@@ -3084,23 +3122,152 @@ ModuleImport::processDebugIntrinsic(llvm::DbgVariableIntrinsic *dbgIntr,
builder.setInsertionPointAfterValue(insertPt);
}
- auto locationExprAttr =
- debugImporter->translateExpression(dbgIntr->getExpression());
- Operation *op =
- llvm::TypeSwitch<llvm::DbgVariableIntrinsic *, Operation *>(dbgIntr)
- .Case([&](llvm::DbgDeclareInst *) {
- return LLVM::DbgDeclareOp::create(
- builder, loc, *argOperand, localVariableAttr, locationExprAttr);
- })
- .Case([&](llvm::DbgValueInst *) {
- return LLVM::DbgValueOp::create(
- builder, loc, *argOperand, localVariableAttr, locationExprAttr);
- });
+ return success();
+}
+
+std::tuple<DILocalVariableAttr, DIExpressionAttr, Value>
+ModuleImport::processDebugOpArgumentsAndInsertionPt(
+ Location loc,
+ llvm::function_ref<FailureOr<Value>()> convertArgOperandToValue,
+ llvm::Value *address,
+ llvm::PointerUnion<llvm::Value *, llvm::DILocalVariable *> variable,
+ llvm::DIExpression *expression, DominanceInfo &domInfo) {
+ // Drop debug intrinsics if the associated debug information cannot be
+ // translated due to an unsupported construct.
+ DILocalVariableAttr localVarAttr = matchLocalVariableAttr(variable);
+ if (!localVarAttr)
+ return {};
+ FailureOr<Value> argOperand = convertArgOperandToValue();
+ if (failed(argOperand)) {
+ emitError(loc) << "failed to convert a debug operand: " << diag(*address);
+ return {};
+ }
+
+ if (setDebugIntrinsicBuilderInsertionPoint(builder, domInfo, *argOperand)
+ .failed())
+ return {};
+
+ return {localVarAttr, debugImporter->translateExpression(expression),
+ *argOperand};
+}
+
+LogicalResult
+ModuleImport::processDebugIntrinsic(llvm::DbgVariableIntrinsic *dbgIntr,
+ DominanceInfo &domInfo) {
+ Location loc = translateLoc(dbgIntr->getDebugLoc());
+ auto emitUnsupportedWarning = [&]() {
+ if (emitExpensiveWarnings)
+ emitWarning(loc) << "dropped intrinsic: " << diag(*dbgIntr);
+ return success();
+ };
+
+ OpBuilder::InsertionGuard guard(builder);
+ auto convertArgOperandToValue = [&]() {
+ return convertMetadataValue(dbgIntr->getArgOperand(0));
+ };
+
+ // Drop debug intrinsics with an argument list.
+ // TODO: Support this case.
+ if (dbgIntr->hasArgList())
+ return emitUnsupportedWarning();
+
+ // Drop debug intrinsics with kill locations that have metadata nodes as
+ // location operand, which cannot be converted to poison as the type cannot be
+ // reconstructed.
+ // TODO: Support this case.
+ if (isMetadataKillLocation(dbgIntr))
+ return emitUnsupportedWarning();
+
+ auto [localVariableAttr, locationExprAttr, locVal] =
+ processDebugOpArgumentsAndInsertionPt(
+ loc, convertArgOperandToValue, dbgIntr->getArgOperand(0),
+ dbgIntr->getArgOperand(1), dbgIntr->getExpression(), domInfo);
+
+ if (!localVariableAttr)
+ return emitUnsupportedWarning();
+
+ if (!locVal) // Expected if localVariableAttr is present.
+ return failure();
+
+ Operation *op = nullptr;
+ if (isa<llvm::DbgDeclareInst>(dbgIntr))
+ op = LLVM::DbgDeclareOp::create(builder, loc, locVal, localVariableAttr,
+ locationExprAttr);
+ else if (isa<llvm::DbgValueInst>(dbgIntr))
+ op = LLVM::DbgValueOp::create(builder, loc, locVal, localVariableAttr,
+ locationExprAttr);
+ else
+ return emitUnsupportedWarning();
+
mapNoResultOp(dbgIntr, op);
setNonDebugMetadataAttrs(dbgIntr, op);
return success();
}
+LogicalResult
+ModuleImport::processDebugRecord(llvm::DbgVariableRecord &dbgRecord,
+ DominanceInfo &domInfo) {
+ OpBuilder::InsertionGuard guard(builder);
+ Location loc = translateLoc(dbgRecord.getDebugLoc());
+ auto emitUnsupportedWarning = [&]() -> LogicalResult {
+ if (!emitExpensiveWarnings)
+ return success();
+ std::string options;
+ llvm::raw_string_ostream optionsStream(options);
+ dbgRecord.print(optionsStream);
+ emitWarning(loc) << "unhandled debug variable record "
+ << optionsStream.str();
+ return success();
+ };
+
+ // Drop debug records with an argument list.
+ // TODO: Support this case.
+ if (dbgRecord.hasArgList())
+ return emitUnsupportedWarning();
+
+ // Drop all other debug records with a address operand that cannot be
+ // converted to an SSA value such as an empty metadata node.
+ // TODO: Support this case.
+ if (!dbgRecord.getAddress())
+ return emitUnsupportedWarning();
+
+ auto convertArgOperandToValue = [&]() -> FailureOr<Value> {
+ llvm::Value *value = dbgRecord.getAddress();
+
+ // Return the mapped value if it has been converted before.
+ auto it = valueMapping.find(value);
+ if (it != valueMapping.end())
+ return it->getSecond();
+
+ // Convert constants such as immediate values that have no mapping yet.
+ if (auto *constant = dyn_cast<llvm::Constant>(value))
+ return convertConstantExpr(constant);
+ return failure();
+ };
+
+ auto [localVariableAttr, locationExprAttr, locVal] =
+ processDebugOpArgumentsAndInsertionPt(
+ loc, convertArgOperandToValue, dbgRecord.getAddress(),
+ dbgRecord.getVariable(), dbgRecord.getExpression(), domInfo);
+
+ if (!localVariableAttr)
+ return emitUnsupportedWarning();
+
+ if (!locVal) // Expected if localVariableAttr is present.
+ return failure();
+
+ if (dbgRecord.isDbgDeclare())
+ LLVM::DbgDeclareOp::create(builder, loc, locVal, localVariableAttr,
+ locationExprAttr);
+ else if (dbgRecord.isDbgValue())
+ LLVM::DbgValueOp::create(builder, loc, locVal, localVariableAttr,
+ locationExprAttr);
+ else // isDbgAssign
+ return emitUnsupportedWarning();
+
+ return success();
+}
+
LogicalResult ModuleImport::processDebugIntrinsics() {
DominanceInfo domInfo;
for (llvm::Instruction *inst : debugIntrinsics) {
@@ -3111,6 +3278,15 @@ LogicalResult ModuleImport::processDebugIntrinsics() {
return success();
}
+LogicalResult ModuleImport::processDebugRecords() {
+ DominanceInfo domInfo;
+ for (llvm::DbgVariableRecord *dbgRecord : dbgRecords)
+ if (failed(processDebugRecord(*dbgRecord, domInfo)))
+ return failure();
+ dbgRecords.clear();
+ return success();
+}
+
LogicalResult ModuleImport::processBasicBlock(llvm::BasicBlock *bb,
Block *block) {
builder.setInsertionPointToStart(block);
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 64e3c5f..fad9bd6b7 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -588,10 +588,17 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
}
// For integer types, we allow a mismatch in sizes as the index type in
// MLIR might have a different size than the index type in the LLVM module.
- if (auto intAttr = dyn_cast<IntegerAttr>(attr))
- return llvm::ConstantInt::get(
- llvmType,
- intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth()));
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+ // If the attribute is an unsigned integer or a 1-bit integer, zero-extend
+ // the value to the bit width of the LLVM type. Otherwise, sign-extend.
+ auto intTy = dyn_cast<IntegerType>(intAttr.getType());
+ APInt value;
+ if (intTy && (intTy.isUnsigned() || intTy.getWidth() == 1))
+ value = intAttr.getValue().zextOrTrunc(llvmType->getIntegerBitWidth());
+ else
+ value = intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth());
+ return llvm::ConstantInt::get(llvmType, value);
+ }
if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
// Special case for 8-bit floats, which are represented by integers due to
@@ -677,10 +684,10 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
}
}
}
- // std::vector is used here to accomodate large number of elements that
- // exceed SmallVector capacity.
- std::vector<llvm::Constant *> constants(numElements, child);
- return llvm::ConstantArray::get(arrayType, constants);
+ // std::vector is used here to accomodate large number of elements that
+ // exceed SmallVector capacity.
+ std::vector<llvm::Constant *> constants(numElements, child);
+ return llvm::ConstantArray::get(arrayType, constants);
}
}
@@ -892,10 +899,13 @@ void mlir::LLVM::detail::connectPHINodes(Region &region,
llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic,
ArrayRef<llvm::Value *> args, ArrayRef<llvm::Type *> tys) {
- llvm::Module *module = builder.GetInsertBlock()->getModule();
- llvm::Function *fn =
- llvm::Intrinsic::getOrInsertDeclaration(module, intrinsic, tys);
- return builder.CreateCall(fn, args);
+ return builder.CreateIntrinsic(intrinsic, tys, args);
+}
+
+llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
+ llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic,
+ llvm::Type *retTy, ArrayRef<llvm::Value *> args) {
+ return builder.CreateIntrinsic(retTy, intrinsic, args);
}
llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
@@ -1637,6 +1647,15 @@ static void convertFunctionMemoryAttributes(LLVMFuncOp func,
newMemEffects |=
llvm::MemoryEffects(llvm::MemoryEffects::Location::Other,
convertModRefInfoToLLVM(memEffects.getOther()));
+ newMemEffects |=
+ llvm::MemoryEffects(llvm::MemoryEffects::Location::ErrnoMem,
+ convertModRefInfoToLLVM(memEffects.getErrnoMem()));
+ newMemEffects |=
+ llvm::MemoryEffects(llvm::MemoryEffects::Location::TargetMem0,
+ convertModRefInfoToLLVM(memEffects.getTargetMem0()));
+ newMemEffects |=
+ llvm::MemoryEffects(llvm::MemoryEffects::Location::TargetMem1,
+ convertModRefInfoToLLVM(memEffects.getTargetMem1()));
llvmFunc->setMemoryEffects(newMemEffects);
}
@@ -2122,8 +2141,16 @@ LogicalResult ModuleTranslation::createTBAAMetadata() {
// LLVM metadata instances.
AttrTypeWalker walker;
walker.addWalk([&](TBAARootAttr root) {
- tbaaMetadataMapping.insert(
- {root, llvm::MDNode::get(ctx, llvm::MDString::get(ctx, root.getId()))});
+ llvm::MDNode *node;
+ if (StringAttr id = root.getId()) {
+ node = llvm::MDNode::get(ctx, llvm::MDString::get(ctx, id));
+ } else {
+ // Anonymous root nodes are self-referencing.
+ auto selfRef = llvm::MDNode::getTemporary(ctx, {});
+ node = llvm::MDNode::get(ctx, {selfRef.get()});
+ node->replaceOperandWith(0, node);
+ }
+ tbaaMetadataMapping.insert({root, node});
});
walker.addWalk([&](TBAATypeDescriptorAttr descriptor) {
@@ -2254,8 +2281,11 @@ llvm::OpenMPIRBuilder *ModuleTranslation::getOpenMPBuilder() {
/* HasRequiresUnifiedSharedMemory = */ false,
/* HasRequiresDynamicAllocators = */ false);
unsigned int defaultAS =
- getLLVMModule()->getDataLayout().getProgramAddressSpace();
+ llvmModule->getDataLayout().getProgramAddressSpace();
config.setDefaultTargetAS(defaultAS);
+ config.setRuntimeCC(llvmModule->getTargetTriple().isSPIRV()
+ ? llvm::CallingConv::SPIR_FUNC
+ : llvm::CallingConv::C);
ompBuilder->setConfig(std::move(config));
ompBuilder->initialize();
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index c27f9aa..5b04a14 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -248,6 +248,8 @@ LogicalResult spirv::Deserializer::processInstruction(
return processLoopMerge(operands);
case spirv::Opcode::OpPhi:
return processPhi(operands);
+ case spirv::Opcode::OpSwitch:
+ return processSwitch(operands);
case spirv::Opcode::OpUndef:
return processUndef(operands);
default:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 6492708..50883d9 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -346,6 +346,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
case spirv::Decoration::Constant:
case spirv::Decoration::Invariant:
case spirv::Decoration::Patch:
+ case spirv::Decoration::Coherent:
if (words.size() != 2) {
return emitError(unknownLoc, "OpDecoration with ")
<< decorationName << "needs a single target <id>";
@@ -2292,6 +2293,38 @@ LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult spirv::Deserializer::processSwitch(ArrayRef<uint32_t> operands) {
+ if (!curBlock)
+ return emitError(unknownLoc, "OpSwitch must appear in a block");
+
+ if (operands.size() < 2)
+ return emitError(unknownLoc, "OpSwitch must at least specify selector and "
+ "a default target");
+
+ if (operands.size() % 2)
+ return emitError(unknownLoc,
+ "OpSwitch must at have an even number of operands: "
+ "selector, default target and any number of literal and "
+ "label <id> pairs");
+
+ Value selector = getValue(operands[0]);
+ Block *defaultBlock = getOrCreateBlock(operands[1]);
+ Location loc = createFileLineColLoc(opBuilder);
+
+ SmallVector<int32_t> literals;
+ SmallVector<Block *> blocks;
+ for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
+ literals.push_back(operands[i]);
+ blocks.push_back(getOrCreateBlock(operands[i + 1]));
+ }
+
+ SmallVector<ValueRange> targetOperands(blocks.size(), {});
+ spirv::SwitchOp::create(opBuilder, loc, selector, defaultBlock,
+ ArrayRef<Value>(), literals, blocks, targetOperands);
+
+ return success();
+}
+
namespace {
/// A class for putting all blocks in a structured selection/loop in a
/// spirv.mlir.selection/spirv.mlir.loop op.
@@ -2799,6 +2832,23 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
branchCondOp.getFalseBlock());
branchCondOp.erase();
+ } else if (auto switchOp = dyn_cast<spirv::SwitchOp>(op)) {
+ if (target == switchOp.getDefaultTarget()) {
+ SmallVector<ValueRange> targetOperands(switchOp.getTargetOperands());
+ DenseIntElementsAttr literals =
+ switchOp.getLiterals().value_or(DenseIntElementsAttr());
+ spirv::SwitchOp::create(
+ opBuilder, switchOp.getLoc(), switchOp.getSelector(),
+ switchOp.getDefaultTarget(), blockArgs, literals,
+ switchOp.getTargets(), targetOperands);
+ switchOp.erase();
+ } else {
+ SuccessorRange targets = switchOp.getTargets();
+ auto it = llvm::find(targets, target);
+ assert(it != targets.end());
+ size_t index = std::distance(targets.begin(), it);
+ switchOp.getTargetOperandsMutable(index).assign(blockArgs);
+ }
} else {
return emitError(unknownLoc, "unimplemented terminator for Phi creation");
}
@@ -2819,7 +2869,7 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
return success();
}
-LogicalResult spirv::Deserializer::splitConditionalBlocks() {
+LogicalResult spirv::Deserializer::splitSelectionHeader() {
// Create a copy, so we can modify keys in the original.
BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo;
for (auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
@@ -2836,7 +2886,7 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() {
Operation *terminator = block->getTerminator();
assert(terminator);
- if (!isa<spirv::BranchConditionalOp>(terminator))
+ if (!isa<spirv::BranchConditionalOp, spirv::SwitchOp>(terminator))
continue;
// Check if the current header block is a merge block of another construct.
@@ -2846,10 +2896,10 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() {
splitHeaderMergeBlock = true;
}
- // Do not split a block that only contains a conditional branch, unless it
- // is also a merge block of another construct - in that case we want to
- // split the block. We do not want two constructs to share header / merge
- // block.
+ // Do not split a block that only contains a conditional branch / switch,
+ // unless it is also a merge block of another construct - in that case we
+ // want to split the block. We do not want two constructs to share header /
+ // merge block.
if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
Block *newBlock = block->splitBlock(terminator);
OpBuilder builder(block, block->end());
@@ -2887,13 +2937,10 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() {
logger.startLine() << "\n";
});
- if (failed(splitConditionalBlocks())) {
+ if (failed(splitSelectionHeader())) {
return failure();
}
- // TODO: This loop is non-deterministic. Iteration order may vary between runs
- // for the same shader as the key to the map is a pointer. See:
- // https://github.com/llvm/llvm-project/issues/128547
while (!blockMergeInfo.empty()) {
Block *headerBlock = blockMergeInfo.begin()->first;
BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 6027f1a..50c9350 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -58,7 +58,9 @@ struct DebugLine {
};
/// Map from a selection/loop's header block to its merge (and continue) target.
-using BlockMergeInfoMap = DenseMap<Block *, BlockMergeInfo>;
+/// Use `MapVector<>` to ensure a deterministic iteration order with a pointer
+/// key.
+using BlockMergeInfoMap = llvm::MapVector<Block *, BlockMergeInfo>;
/// A "deferred struct type" is a struct type with one or more member types not
/// known when the Deserializer first encounters the struct. This happens, for
@@ -278,11 +280,11 @@ private:
return opBuilder.getStringAttr(attrName);
}
- /// Move a conditional branch into a separate basic block to avoid unnecessary
- /// sinking of defs that may be required outside a selection region. This
- /// function also ensures that a single block cannot be a header block of one
- /// selection construct and the merge block of another.
- LogicalResult splitConditionalBlocks();
+ /// Move a conditional branch or a switch into a separate basic block to avoid
+ /// unnecessary sinking of defs that may be required outside a selection
+ /// region. This function also ensures that a single block cannot be a header
+ /// block of one selection construct and the merge block of another.
+ LogicalResult splitSelectionHeader();
//===--------------------------------------------------------------------===//
// Type
@@ -472,6 +474,9 @@ private:
/// Processes a SPIR-V OpPhi instruction with the given `operands`.
LogicalResult processPhi(ArrayRef<uint32_t> operands);
+ /// Processes a SPIR-V OpSwitch instruction with the given `operands`.
+ LogicalResult processSwitch(ArrayRef<uint32_t> operands);
+
/// Creates block arguments on predecessors previously recorded when handling
/// OpPhi instructions.
LogicalResult wireUpBlockArgument();
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 85e92c7..6397d2c 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -775,6 +775,27 @@ LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
return success();
}
+LogicalResult Serializer::processSwitchOp(spirv::SwitchOp switchOp) {
+ uint32_t selectorID = getValueID(switchOp.getSelector());
+ uint32_t defaultLabelID = getOrCreateBlockID(switchOp.getDefaultTarget());
+ SmallVector<uint32_t> arguments{selectorID, defaultLabelID};
+
+ std::optional<mlir::DenseIntElementsAttr> literals = switchOp.getLiterals();
+ BlockRange targets = switchOp.getTargets();
+ if (literals) {
+ for (auto [literal, target] : llvm::zip_equal(*literals, targets)) {
+ arguments.push_back(literal.getLimitedValue());
+ uint32_t targetLabelID = getOrCreateBlockID(target);
+ arguments.push_back(targetLabelID);
+ }
+ }
+
+ if (failed(emitDebugLine(functionBody, switchOp.getLoc())))
+ return failure();
+ encodeInstructionInto(functionBody, spirv::Opcode::OpSwitch, arguments);
+ return success();
+}
+
LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
auto varName = addressOfOp.getVariable();
auto variableID = getVariableID(varName);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 29ed5a4..c879a2b 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -373,6 +373,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
case spirv::Decoration::Block:
case spirv::Decoration::Invariant:
case spirv::Decoration::Patch:
+ case spirv::Decoration::Coherent:
// For unit attributes and decoration attributes, the args list
// has no values so we do nothing.
if (isa<UnitAttr, DecorationAttr>(attr))
@@ -1443,7 +1444,20 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
assert(branchCondOp.getFalseTarget() == block);
blockOperands = branchCondOp.getFalseTargetOperands();
}
-
+ assert(!blockOperands->empty() &&
+ "expected non-empty block operand range");
+ predecessors.emplace_back(spirvPredecessor, *blockOperands);
+ } else if (auto switchOp = dyn_cast<spirv::SwitchOp>(terminator)) {
+ std::optional<OperandRange> blockOperands;
+ if (block == switchOp.getDefaultTarget()) {
+ blockOperands = switchOp.getDefaultOperands();
+ } else {
+ SuccessorRange targets = switchOp.getTargets();
+ auto it = llvm::find(targets, block);
+ assert(it != targets.end());
+ size_t index = std::distance(targets.begin(), it);
+ blockOperands = switchOp.getTargetOperands(index);
+ }
assert(!blockOperands->empty() &&
"expected non-empty block operand range");
predecessors.emplace_back(spirvPredecessor, *blockOperands);
@@ -1579,6 +1593,7 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
.Case([&](spirv::SpecConstantOperationOp op) {
return processSpecConstantOperationOp(op);
})
+ .Case([&](spirv::SwitchOp op) { return processSwitchOp(op); })
.Case([&](spirv::UndefOp op) { return processUndefOp(op); })
.Case([&](spirv::VariableOp op) { return processVariableOp(op); })
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index add372b..6e79c13 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -304,6 +304,8 @@ private:
LogicalResult processBranchOp(spirv::BranchOp branchOp);
+ LogicalResult processSwitchOp(spirv::SwitchOp switchOp);
+
//===--------------------------------------------------------------------===//
// Operations
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Tools/PDLL/AST/Context.cpp b/mlir/lib/Tools/PDLL/AST/Context.cpp
index 6f2e4cd..e82807f 100644
--- a/mlir/lib/Tools/PDLL/AST/Context.cpp
+++ b/mlir/lib/Tools/PDLL/AST/Context.cpp
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Tools/PDLL/AST/Context.h"
-#include "TypeDetail.h"
+#include "mlir/Tools/PDLL/AST/Types.h"
using namespace mlir;
using namespace mlir::pdll::ast;
diff --git a/mlir/lib/Tools/PDLL/AST/Nodes.cpp b/mlir/lib/Tools/PDLL/AST/Nodes.cpp
index 5aa0937..4358ceb 100644
--- a/mlir/lib/Tools/PDLL/AST/Nodes.cpp
+++ b/mlir/lib/Tools/PDLL/AST/Nodes.cpp
@@ -21,7 +21,7 @@ static StringRef copyStringWithNull(Context &ctx, StringRef str) {
return str;
char *data = ctx.getAllocator().Allocate<char>(str.size() + 1);
- std::copy(str.begin(), str.end(), data);
+ llvm::copy(str, data);
data[str.size()] = 0;
return StringRef(data, str.size());
}
diff --git a/mlir/lib/Tools/PDLL/AST/TypeDetail.h b/mlir/lib/Tools/PDLL/AST/TypeDetail.h
deleted file mode 100644
index a0bd84e..0000000
--- a/mlir/lib/Tools/PDLL/AST/TypeDetail.h
+++ /dev/null
@@ -1,141 +0,0 @@
-//===- TypeDetail.h ---------------------------------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef LIB_MLIR_TOOLS_PDLL_AST_TYPEDETAIL_H_
-#define LIB_MLIR_TOOLS_PDLL_AST_TYPEDETAIL_H_
-
-#include "mlir/Tools/PDLL/AST/Types.h"
-
-namespace mlir {
-namespace pdll {
-namespace ast {
-//===----------------------------------------------------------------------===//
-// Type
-//===----------------------------------------------------------------------===//
-
-struct Type::Storage : public StorageUniquer::BaseStorage {
- Storage(TypeID typeID) : typeID(typeID) {}
-
- /// The type identifier for the derived type class.
- TypeID typeID;
-};
-
-namespace detail {
-
-/// A utility CRTP base class that defines many of the necessary utilities for
-/// defining a PDLL AST Type.
-template <typename ConcreteT, typename KeyT = void>
-struct TypeStorageBase : public Type::Storage {
- using KeyTy = KeyT;
- using Base = TypeStorageBase<ConcreteT, KeyT>;
- TypeStorageBase(KeyTy key)
- : Type::Storage(TypeID::get<ConcreteT>()), key(key) {}
-
- /// Construct an instance with the given storage allocator.
- static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc,
- const KeyTy &key) {
- return new (alloc.allocate<ConcreteT>()) ConcreteT(key);
- }
-
- /// Utility methods required by the storage allocator.
- bool operator==(const KeyTy &key) const { return this->key == key; }
-
- /// Return the key value of this storage class.
- const KeyTy &getValue() const { return key; }
-
-protected:
- KeyTy key;
-};
-/// A specialization of the storage base for singleton types.
-template <typename ConcreteT>
-struct TypeStorageBase<ConcreteT, void> : public Type::Storage {
- using Base = TypeStorageBase<ConcreteT, void>;
- TypeStorageBase() : Type::Storage(TypeID::get<ConcreteT>()) {}
-};
-
-//===----------------------------------------------------------------------===//
-// AttributeType
-//===----------------------------------------------------------------------===//
-
-struct AttributeTypeStorage : public TypeStorageBase<AttributeTypeStorage> {};
-
-//===----------------------------------------------------------------------===//
-// ConstraintType
-//===----------------------------------------------------------------------===//
-
-struct ConstraintTypeStorage : public TypeStorageBase<ConstraintTypeStorage> {};
-
-//===----------------------------------------------------------------------===//
-// OperationType
-//===----------------------------------------------------------------------===//
-
-struct OperationTypeStorage
- : public TypeStorageBase<OperationTypeStorage,
- std::pair<StringRef, const ods::Operation *>> {
- using Base::Base;
-
- static OperationTypeStorage *
- construct(StorageUniquer::StorageAllocator &alloc,
- const std::pair<StringRef, const ods::Operation *> &key) {
- return new (alloc.allocate<OperationTypeStorage>()) OperationTypeStorage(
- std::make_pair(alloc.copyInto(key.first), key.second));
- }
-};
-
-//===----------------------------------------------------------------------===//
-// RangeType
-//===----------------------------------------------------------------------===//
-
-struct RangeTypeStorage : public TypeStorageBase<RangeTypeStorage, Type> {
- using Base::Base;
-};
-
-//===----------------------------------------------------------------------===//
-// RewriteType
-//===----------------------------------------------------------------------===//
-
-struct RewriteTypeStorage : public TypeStorageBase<RewriteTypeStorage> {};
-
-//===----------------------------------------------------------------------===//
-// TupleType
-//===----------------------------------------------------------------------===//
-
-struct TupleTypeStorage
- : public TypeStorageBase<TupleTypeStorage,
- std::pair<ArrayRef<Type>, ArrayRef<StringRef>>> {
- using Base::Base;
-
- static TupleTypeStorage *
- construct(StorageUniquer::StorageAllocator &alloc,
- std::pair<ArrayRef<Type>, ArrayRef<StringRef>> key) {
- SmallVector<StringRef> names = llvm::to_vector(llvm::map_range(
- key.second, [&](StringRef name) { return alloc.copyInto(name); }));
- return new (alloc.allocate<TupleTypeStorage>())
- TupleTypeStorage(std::make_pair(alloc.copyInto(key.first),
- alloc.copyInto(llvm::ArrayRef(names))));
- }
-};
-
-//===----------------------------------------------------------------------===//
-// TypeType
-//===----------------------------------------------------------------------===//
-
-struct TypeTypeStorage : public TypeStorageBase<TypeTypeStorage> {};
-
-//===----------------------------------------------------------------------===//
-// ValueType
-//===----------------------------------------------------------------------===//
-
-struct ValueTypeStorage : public TypeStorageBase<ValueTypeStorage> {};
-
-} // namespace detail
-} // namespace ast
-} // namespace pdll
-} // namespace mlir
-
-#endif // LIB_MLIR_TOOLS_PDLL_AST_TYPEDETAIL_H_
diff --git a/mlir/lib/Tools/PDLL/AST/Types.cpp b/mlir/lib/Tools/PDLL/AST/Types.cpp
index 1468ac2..d5497b0 100644
--- a/mlir/lib/Tools/PDLL/AST/Types.cpp
+++ b/mlir/lib/Tools/PDLL/AST/Types.cpp
@@ -7,7 +7,6 @@
//===----------------------------------------------------------------------===//
#include "mlir/Tools/PDLL/AST/Types.h"
-#include "TypeDetail.h"
#include "mlir/Tools/PDLL/AST/Context.h"
#include <optional>
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 9ef405d..018a188 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -681,17 +681,8 @@ processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
return success();
}
-std::pair<std::string, std::string>
-mlir::registerAndParseCLIOptions(int argc, char **argv,
- llvm::StringRef toolName,
- DialectRegistry &registry) {
- static cl::opt<std::string> inputFilename(
- cl::Positional, cl::desc("<input file>"), cl::init("-"));
-
- static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"),
- cl::value_desc("filename"),
- cl::init("-"));
- // Register any command line options.
+std::string mlir::registerCLIOptions(llvm::StringRef toolName,
+ DialectRegistry &registry) {
MlirOptMainConfig::registerCLOptions(registry);
registerAsmPrinterCLOptions();
registerMLIRContextCLOptions();
@@ -706,11 +697,29 @@ mlir::registerAndParseCLIOptions(int argc, char **argv,
interleaveComma(registry.getDialectNames(), os,
[&](auto name) { os << name; });
}
- // Parse pass names in main to ensure static initialization completed.
+ return helpHeader;
+}
+
+std::pair<std::string, std::string>
+mlir::parseCLIOptions(int argc, char **argv, llvm::StringRef helpHeader) {
+ static cl::opt<std::string> inputFilename(
+ cl::Positional, cl::desc("<input file>"), cl::init("-"));
+
+ static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"),
+ cl::value_desc("filename"),
+ cl::init("-"));
cl::ParseCommandLineOptions(argc, argv, helpHeader);
return std::make_pair(inputFilename.getValue(), outputFilename.getValue());
}
+std::pair<std::string, std::string>
+mlir::registerAndParseCLIOptions(int argc, char **argv,
+ llvm::StringRef toolName,
+ DialectRegistry &registry) {
+ auto helpHeader = registerCLIOptions(toolName, registry);
+ return parseCLIOptions(argc, argv, helpHeader);
+}
+
static LogicalResult printRegisteredDialects(DialectRegistry &registry) {
llvm::outs() << "Available Dialects: ";
interleave(registry.getDialectNames(), llvm::outs(), ",");
diff --git a/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp b/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp
index 685e794..64e86f2 100644
--- a/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp
+++ b/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp
@@ -153,5 +153,12 @@ int mlir::MlirTblgenMain(int argc, char **argv) {
cl::ParseCommandLineOptions(argc, argv);
- return TableGenMain(argv[0], &mlirTableGenMain);
+ return TableGenMain(
+ argv[0], [](TableGenOutputFiles &OutFiles, const RecordKeeper &RK) {
+ std::string S;
+ raw_string_ostream OS(S);
+ bool Res = mlirTableGenMain(OS, RK);
+ OutFiles = {S, {}};
+ return Res;
+ });
}
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 54b67f5..8907724 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -27,6 +27,7 @@ add_mlir_library(MLIRTransforms
DEPENDS
MLIRTransformsPassIncGen
+ MLIRTransformsDialectInterfaceIncGen
LINK_LIBS PUBLIC
MLIRAnalysis
@@ -39,4 +40,5 @@ add_mlir_library(MLIRTransforms
MLIRSideEffectInterfaces
MLIRSupport
MLIRTransformUtils
+ MLIRUBDialect
)
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 41f3f9d..e9ced064 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -33,6 +33,7 @@
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/LivenessAnalysis.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dialect.h"
@@ -260,6 +261,22 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
+ // Operations that have dead operands can be erased regardless of their
+ // side effects. The liveness analysis would not have marked an SSA value as
+ // "dead" if it had a side-effecting user that is reachable.
+ bool hasDeadOperand =
+ markLives(op->getOperands(), nonLiveSet, la).flip().any();
+ if (hasDeadOperand) {
+ LDBG() << "Simple op has dead operands, so the op must be dead: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
+ assert(!hasLive(op->getResults(), nonLiveSet, la) &&
+ "expected the op to have no live results");
+ cl.operations.push_back(op);
+ collectNonLiveValues(nonLiveSet, op->getResults(),
+ BitVector(op->getNumResults(), true));
+ return;
+ }
+
if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
LDBG() << "Simple op is not memory effect free or has live results, "
"preserving it: "
@@ -361,6 +378,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
// block other than the entry block, because every block has a terminator.
for (Block &block : funcOp.getBlocks()) {
Operation *returnOp = block.getTerminator();
+ if (!returnOp->hasTrait<OpTrait::ReturnLike>())
+ continue;
if (returnOp && returnOp->getNumOperands() == numReturns)
cl.operands.push_back({returnOp, nonLiveRets});
}
@@ -700,7 +719,11 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
}
/// Steps to process a `BranchOpInterface` operation:
-/// Iterate through each successor block of `branchOp`.
+///
+/// When a non-forwarded operand is dead (e.g., the condition value of a
+/// conditional branch op), the entire operation is dead.
+///
+/// Otherwise, iterate through each successor block of `branchOp`.
/// (1) For each successor block, gather all operands from all successors.
/// (2) Fetch their associated liveness analysis data and collect for future
/// removal.
@@ -711,7 +734,22 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
LDBG() << "Processing branch op: " << *branchOp;
+
+ // Check for dead non-forwarded operands.
+ BitVector deadNonForwardedOperands =
+ markLives(branchOp->getOperands(), nonLiveSet, la).flip();
unsigned numSuccessors = branchOp->getNumSuccessors();
+ for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
+ SuccessorOperands successorOperands =
+ branchOp.getSuccessorOperands(succIdx);
+ // Remove all non-forwarded operands from the bit vector.
+ for (OpOperand &opOperand : successorOperands.getMutableForwardedOperands())
+ deadNonForwardedOperands[opOperand.getOperandNumber()] = false;
+ }
+ if (deadNonForwardedOperands.any()) {
+ cl.operations.push_back(branchOp.getOperation());
+ return;
+ }
for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
Block *successorBlock = branchOp->getSuccessor(succIdx);
@@ -742,23 +780,70 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
static void cleanUpDeadVals(RDVFinalCleanupList &list) {
LDBG() << "Starting cleanup of dead values...";
- // 1. Operations
+ // 1. Blocks, We must remove the block arguments and successor operands before
+ // deleting the operation, as they may reside in the region operation.
+ LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists";
+ for (auto &b : list.blocks) {
+ // blocks that are accessed via multiple codepaths processed once
+ if (b.b->getNumArguments() != b.nonLiveArgs.size())
+ continue;
+ LDBG() << "Erasing " << b.nonLiveArgs.count()
+ << " non-live arguments from block: " << b.b;
+ // it iterates backwards because erase invalidates all successor indexes
+ for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
+ if (!b.nonLiveArgs[i])
+ continue;
+ LDBG() << " Erasing block argument " << i << ": " << b.b->getArgument(i);
+ b.b->getArgument(i).dropAllUses();
+ b.b->eraseArgument(i);
+ }
+ }
+
+ // 2. Successor Operands
+ LDBG() << "Cleaning up " << list.successorOperands.size()
+ << " successor operand lists";
+ for (auto &op : list.successorOperands) {
+ SuccessorOperands successorOperands =
+ op.branch.getSuccessorOperands(op.successorIndex);
+ // blocks that are accessed via multiple codepaths processed once
+ if (successorOperands.size() != op.nonLiveOperands.size())
+ continue;
+ LDBG() << "Erasing " << op.nonLiveOperands.count()
+ << " non-live successor operands from successor "
+ << op.successorIndex << " of branch: "
+ << OpWithFlags(op.branch, OpPrintingFlags().skipRegions());
+ // it iterates backwards because erase invalidates all successor indexes
+ for (int i = successorOperands.size() - 1; i >= 0; --i) {
+ if (!op.nonLiveOperands[i])
+ continue;
+ LDBG() << " Erasing successor operand " << i << ": "
+ << successorOperands[i];
+ successorOperands.erase(i);
+ }
+ }
+
+ // 3. Operations
LDBG() << "Cleaning up " << list.operations.size() << " operations";
- for (auto &op : list.operations) {
+ for (Operation *op : list.operations) {
LDBG() << "Erasing operation: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
+ if (op->hasTrait<OpTrait::IsTerminator>()) {
+ // When erasing a terminator, insert an unreachable op in its place.
+ OpBuilder b(op);
+ ub::UnreachableOp::create(b, op->getLoc());
+ }
op->dropAllUses();
op->erase();
}
- // 2. Values
+ // 4. Values
LDBG() << "Cleaning up " << list.values.size() << " values";
for (auto &v : list.values) {
LDBG() << "Dropping all uses of value: " << v;
v.dropAllUses();
}
- // 3. Functions
+ // 5. Functions
LDBG() << "Cleaning up " << list.functions.size() << " functions";
// Record which function arguments were erased so we can shrink call-site
// argument segments for CallOpInterface operations (e.g. ops using
@@ -780,7 +865,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
(void)f.funcOp.eraseResults(f.nonLiveRets);
}
- // 4. Operands
+ // 6. Operands
LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
for (OperationToCleanup &o : list.operands) {
// Handle call-specific cleanup only when we have a cached callee reference.
@@ -822,7 +907,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
}
}
- // 5. Results
+ // 7. Results
LDBG() << "Cleaning up " << list.results.size() << " result lists";
for (auto &r : list.results) {
LDBG() << "Erasing " << r.nonLive.count()
@@ -830,48 +915,6 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
<< OpWithFlags(r.op, OpPrintingFlags().skipRegions());
dropUsesAndEraseResults(r.op, r.nonLive);
}
-
- // 6. Blocks
- LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists";
- for (auto &b : list.blocks) {
- // blocks that are accessed via multiple codepaths processed once
- if (b.b->getNumArguments() != b.nonLiveArgs.size())
- continue;
- LDBG() << "Erasing " << b.nonLiveArgs.count()
- << " non-live arguments from block: " << b.b;
- // it iterates backwards because erase invalidates all successor indexes
- for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
- if (!b.nonLiveArgs[i])
- continue;
- LDBG() << " Erasing block argument " << i << ": " << b.b->getArgument(i);
- b.b->getArgument(i).dropAllUses();
- b.b->eraseArgument(i);
- }
- }
-
- // 7. Successor Operands
- LDBG() << "Cleaning up " << list.successorOperands.size()
- << " successor operand lists";
- for (auto &op : list.successorOperands) {
- SuccessorOperands successorOperands =
- op.branch.getSuccessorOperands(op.successorIndex);
- // blocks that are accessed via multiple codepaths processed once
- if (successorOperands.size() != op.nonLiveOperands.size())
- continue;
- LDBG() << "Erasing " << op.nonLiveOperands.count()
- << " non-live successor operands from successor "
- << op.successorIndex << " of branch: "
- << OpWithFlags(op.branch, OpPrintingFlags().skipRegions());
- // it iterates backwards because erase invalidates all successor indexes
- for (int i = successorOperands.size() - 1; i >= 0; --i) {
- if (!op.nonLiveOperands[i])
- continue;
- LDBG() << " Erasing successor operand " << i << ": "
- << successorOperands[i];
- successorOperands.erase(i);
- }
- }
-
LDBG() << "Finished cleanup of dead values";
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index f8c38fa..09ad423 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -25,6 +25,7 @@
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/ScopedPrinter.h"
#include <optional>
+#include <utility>
using namespace mlir;
using namespace mlir::detail;
@@ -975,9 +976,12 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
/// Replace the uses of the given value with the given values. The specified
- /// converter is used to build materializations (if necessary).
- void replaceAllUsesWith(Value from, ValueRange to,
- const TypeConverter *converter);
+ /// converter is used to build materializations (if necessary). If `functor`
+ /// is specified, only the uses that the functor returns "true" for are
+ /// replaced.
+ void replaceValueUses(Value from, ValueRange to,
+ const TypeConverter *converter,
+ function_ref<bool(OpOperand &)> functor = nullptr);
/// Erase the given block and its contents.
void eraseBlock(Block *block);
@@ -1051,7 +1055,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
MLIRContext *context,
std::function<void(Operation *)> opErasedCallback = nullptr)
: RewriterBase(context, /*listener=*/this),
- opErasedCallback(opErasedCallback) {}
+ opErasedCallback(std::move(opErasedCallback)) {}
/// Erase the given op (unless it was already erased).
void eraseOp(Operation *op) override {
@@ -1202,11 +1206,16 @@ void BlockTypeConversionRewrite::rollback() {
}
/// Replace all uses of `from` with `repl`.
-static void performReplaceValue(RewriterBase &rewriter, Value from,
- Value repl) {
+static void
+performReplaceValue(RewriterBase &rewriter, Value from, Value repl,
+ function_ref<bool(OpOperand &)> functor = nullptr) {
if (isa<BlockArgument>(repl)) {
// `repl` is a block argument. Directly replace all uses.
- rewriter.replaceAllUsesWith(from, repl);
+ if (functor) {
+ rewriter.replaceUsesWithIf(from, repl, functor);
+ } else {
+ rewriter.replaceAllUsesWith(from, repl);
+ }
return;
}
@@ -1237,7 +1246,11 @@ static void performReplaceValue(RewriterBase &rewriter, Value from,
Block *replBlock = replOp->getBlock();
rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
Operation *user = operand.getOwner();
- return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
+ bool result =
+ user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
+ if (result && functor)
+ result &= functor(operand);
+ return result;
});
}
@@ -1645,7 +1658,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
/*isPureTypeConversion=*/false)
.front();
- replaceAllUsesWith(origArg, mat, converter);
+ replaceValueUses(origArg, mat, converter);
continue;
}
@@ -1654,14 +1667,14 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
assert(inputMap->size == 0 &&
"invalid to provide a replacement value when the argument isn't "
"dropped");
- replaceAllUsesWith(origArg, inputMap->replacementValues, converter);
+ replaceValueUses(origArg, inputMap->replacementValues, converter);
continue;
}
// This is a 1->1+ mapping.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
- replaceAllUsesWith(origArg, replArgs, converter);
+ replaceValueUses(origArg, replArgs, converter);
}
if (config.allowPatternRollback)
@@ -1961,8 +1974,24 @@ void ConversionPatternRewriterImpl::replaceOp(
op->walk([&](Operation *op) { replacedOps.insert(op); });
}
-void ConversionPatternRewriterImpl::replaceAllUsesWith(
- Value from, ValueRange to, const TypeConverter *converter) {
+void ConversionPatternRewriterImpl::replaceValueUses(
+ Value from, ValueRange to, const TypeConverter *converter,
+ function_ref<bool(OpOperand &)> functor) {
+ LLVM_DEBUG({
+ logger.startLine() << "** Replace Value : '" << from << "'";
+ if (auto blockArg = dyn_cast<BlockArgument>(from)) {
+ if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
+ logger.getOStream() << " (in region of '" << parentOp->getName()
+ << "' (" << parentOp << ")";
+ } else {
+ logger.getOStream() << " (unlinked block)";
+ }
+ }
+ if (functor) {
+ logger.getOStream() << ", conditional replacement";
+ }
+ });
+
if (!config.allowPatternRollback) {
SmallVector<Value> toConv = llvm::to_vector(to);
SmallVector<Value> repls =
@@ -1972,7 +2001,7 @@ void ConversionPatternRewriterImpl::replaceAllUsesWith(
if (!repl)
return;
- performReplaceValue(r, from, repl);
+ performReplaceValue(r, from, repl, functor);
return;
}
@@ -1991,6 +2020,9 @@ void ConversionPatternRewriterImpl::replaceAllUsesWith(
replacedValues.insert(from);
#endif // NDEBUG
+ if (functor)
+ llvm::report_fatal_error(
+ "conditional value replacement is not supported in rollback mode");
mapping.map(from, to);
appendRewrite<ReplaceValueRewrite>(from, converter);
}
@@ -2189,18 +2221,15 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
}
void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) {
- LLVM_DEBUG({
- impl->logger.startLine() << "** Replace Value : '" << from << "'";
- if (auto blockArg = dyn_cast<BlockArgument>(from)) {
- if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
- impl->logger.getOStream() << " (in region of '" << parentOp->getName()
- << "' (" << parentOp << ")\n";
- } else {
- impl->logger.getOStream() << " (unlinked block)\n";
- }
- }
- });
- impl->replaceAllUsesWith(from, to, impl->currentTypeConverter);
+ impl->replaceValueUses(from, to, impl->currentTypeConverter);
+}
+
+void ConversionPatternRewriter::replaceUsesWithIf(
+ Value from, ValueRange to, function_ref<bool(OpOperand &)> functor,
+ bool *allUsesReplaced) {
+ assert(!allUsesReplaced &&
+ "allUsesReplaced is not supported in a dialect conversion");
+ impl->replaceValueUses(from, to, impl->currentTypeConverter, functor);
}
Value ConversionPatternRewriter::getRemappedValue(Value key) {
@@ -2765,7 +2794,7 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
rewriterImpl.patternMaterializations.clear();
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Expensive pattern check that can detect API violations.
- if (checkOp) {
+ if (checkOp && topLevelFingerPrint) {
OperationFingerPrint fingerPrintAfterPattern(checkOp);
if (fingerPrintAfterPattern != *topLevelFingerPrint)
llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
@@ -2856,17 +2885,19 @@ LogicalResult OperationLegalizer::legalizePatternResult(
assert(impl.pendingRootUpdates.empty() && "dangling root updates");
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- // Check that the root was either replaced or updated in place.
- auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
- auto replacedRoot = [&] {
- return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
- };
- auto updatedRootInPlace = [&] {
- return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
- };
- if (!replacedRoot() && !updatedRootInPlace())
- llvm::report_fatal_error(
- "expected pattern to replace the root operation or modify it in place");
+ if (impl.config.allowPatternRollback) {
+ // Check that the root was either replaced or updated in place.
+ auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
+ auto replacedRoot = [&] {
+ return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
+ };
+ auto updatedRootInPlace = [&] {
+ return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
+ };
+ if (!replacedRoot() && !updatedRootInPlace())
+ llvm::report_fatal_error("expected pattern to replace the root operation "
+ "or modify it in place");
+ }
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Legalize each of the actions registered during application.
diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp
index 26c965c..4095031 100644
--- a/mlir/lib/Transforms/Utils/Inliner.cpp
+++ b/mlir/lib/Transforms/Utils/Inliner.cpp
@@ -613,8 +613,8 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
LLVM_DEBUG({
LDBG() << "* Inliner: Initial calls in SCC are: {";
- for (unsigned i = 0, e = calls.size(); i < e; ++i)
- LDBG() << " " << i << ". " << calls[i].call << ",";
+ for (unsigned I = 0, E = calls.size(); I < E; ++I)
+ LDBG() << " " << I << ". " << calls[I].call << ",";
LDBG() << "}";
});
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 31ae1d1..330a2d3 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -1149,9 +1149,8 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
// Remove the values that already dominate the insertion point.
SmallVector<Value> prunedValues;
for (auto value : values) {
- if (dominance.properlyDominates(value, insertionPoint)) {
+ if (dominance.properlyDominates(value, insertionPoint))
continue;
- }
// Block arguments are not supported.
if (isa<BlockArgument>(value)) {
return rewriter.notifyMatchFailure(
@@ -1178,8 +1177,13 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
// Since current support is to only move within a same basic block,
// the slices dont need to look past block arguments.
options.omitBlockArguments = true;
+ bool dependsOnSideEffectingOp = false;
options.filter = [&](Operation *sliceBoundaryOp) {
- return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
+ bool mustMove =
+ !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
+ if (mustMove && !isPure(sliceBoundaryOp))
+ dependsOnSideEffectingOp = true;
+ return mustMove;
};
llvm::SetVector<Operation *> slice;
for (auto value : prunedValues) {
@@ -1188,6 +1192,10 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
(void)result;
}
+ // Check if any operation in the slice is side-effecting.
+ if (dependsOnSideEffectingOp)
+ return failure();
+
// If the slice contains `insertionPoint` cannot move the dependencies.
if (slice.contains(insertionPoint)) {
return rewriter.notifyMatchFailure(
@@ -1198,9 +1206,8 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
// Sort operations topologically before moving.
mlir::topologicalSort(slice);
- for (Operation *op : slice) {
+ for (Operation *op : slice)
rewriter.moveOpBefore(op, insertionPoint);
- }
return success();
}