aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp9
-rw-r--r--mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp111
-rw-r--r--mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp17
-rw-r--r--mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp29
-rw-r--r--mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp46
-rw-r--r--mlir/lib/Analysis/DataFlowFramework.cpp21
-rw-r--r--mlir/lib/Analysis/FlatLinearValueConstraints.cpp2
-rw-r--r--mlir/lib/Analysis/Presburger/Barvinok.cpp14
-rw-r--r--mlir/lib/Analysis/Presburger/IntegerRelation.cpp29
-rw-r--r--mlir/lib/Analysis/Presburger/Matrix.cpp6
-rw-r--r--mlir/lib/Analysis/Presburger/QuasiPolynomial.cpp8
-rw-r--r--mlir/lib/Analysis/Presburger/Simplex.cpp8
-rw-r--r--mlir/lib/Analysis/TopologicalSortUtils.cpp7
-rw-r--r--mlir/lib/Bindings/Python/DialectGPU.cpp9
-rw-r--r--mlir/lib/Bindings/Python/DialectLLVM.cpp16
-rw-r--r--mlir/lib/Bindings/Python/DialectNVGPU.cpp6
-rw-r--r--mlir/lib/Bindings/Python/DialectPDL.cpp14
-rw-r--r--mlir/lib/Bindings/Python/DialectQuant.cpp11
-rw-r--r--mlir/lib/Bindings/Python/DialectSMT.cpp2
-rw-r--r--mlir/lib/Bindings/Python/DialectSparseTensor.cpp7
-rw-r--r--mlir/lib/Bindings/Python/DialectTransform.cpp15
-rw-r--r--mlir/lib/Bindings/Python/ExecutionEngineModule.cpp17
-rw-r--r--mlir/lib/Bindings/Python/Globals.h39
-rw-r--r--mlir/lib/Bindings/Python/IRAffine.cpp25
-rw-r--r--mlir/lib/Bindings/Python/IRAttributes.cpp54
-rw-r--r--mlir/lib/Bindings/Python/IRCore.cpp396
-rw-r--r--mlir/lib/Bindings/Python/IRModule.cpp70
-rw-r--r--mlir/lib/Bindings/Python/IRModule.h85
-rw-r--r--mlir/lib/Bindings/Python/IRTypes.cpp2
-rw-r--r--mlir/lib/Bindings/Python/MainModule.cpp23
-rw-r--r--mlir/lib/Bindings/Python/Pass.cpp10
-rw-r--r--mlir/lib/Bindings/Python/RegisterEverything.cpp2
-rw-r--r--mlir/lib/Bindings/Python/TransformInterpreter.cpp1
-rw-r--r--mlir/lib/Bytecode/Writer/BytecodeWriter.cpp15
-rw-r--r--mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp9
-rw-r--r--mlir/lib/CAPI/IR/BuiltinTypes.cpp4
-rw-r--r--mlir/lib/CAPI/IR/IR.cpp4
-rw-r--r--mlir/lib/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp51
-rw-r--r--mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp17
-rw-r--r--mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp37
-rw-r--r--mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp2
-rw-r--r--mlir/lib/Conversion/CMakeLists.txt3
-rw-r--r--mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp36
-rw-r--r--mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp41
-rw-r--r--mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp33
-rw-r--r--mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp139
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp3
-rw-r--r--mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp6
-rw-r--r--mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp85
-rw-r--r--mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp62
-rw-r--r--mlir/lib/Conversion/LLVMCommon/Pattern.cpp121
-rw-r--r--mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp115
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp116
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp56
-rw-r--r--mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp15
-rw-r--r--mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp43
-rw-r--r--mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp10
-rw-r--r--mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp54
-rw-r--r--mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt17
-rw-r--r--mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp440
-rw-r--r--mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp7
-rw-r--r--mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp4
-rw-r--r--mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp4
-rw-r--r--mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp17
-rw-r--r--mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp10
-rw-r--r--mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp4
-rw-r--r--mlir/lib/Conversion/TosaToArith/TosaToArith.cpp14
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp90
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp16
-rw-r--r--mlir/lib/Conversion/VectorToAMX/CMakeLists.txt19
-rw-r--r--mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp429
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp22
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp1
-rw-r--r--mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp40
-rw-r--r--mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp313
-rw-r--r--mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt27
-rw-r--r--mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp1026
-rw-r--r--mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp12
-rw-r--r--mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/AMX/IR/AMXDialect.cpp4
-rw-r--r--mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp4
-rw-r--r--mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp21
-rw-r--r--mlir/lib/Dialect/Affine/Analysis/Utils.cpp78
-rw-r--r--mlir/lib/Dialect/Affine/IR/AffineOps.cpp32
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp199
-rw-r--r--mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp21
-rw-r--r--mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp93
-rw-r--r--mlir/lib/Dialect/Arith/IR/ArithOps.cpp5
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp12
-rw-r--r--mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp60
-rw-r--r--mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp2
-rw-r--r--mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp6
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp6
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp8
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp15
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp82
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp14
-rw-r--r--mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/DLTI/Traits.cpp2
-rw-r--r--mlir/lib/Dialect/EmitC/IR/EmitC.cpp62
-rw-r--r--mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp150
-rw-r--r--mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp4
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp47
-rw-r--r--mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp13
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp57
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp3
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp48
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp1
-rw-r--r--mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp38
-rw-r--r--mlir/lib/Dialect/LLVMIR/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp487
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp120
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp71
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp6
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp4
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp31
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp420
-rw-r--r--mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp7
-rw-r--r--mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp6
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp508
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp7
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp12
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp256
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp4
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp275
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp22
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp65
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp14
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp10
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp (renamed from mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp)15
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp11
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp132
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp8
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp9
-rw-r--r--mlir/lib/Dialect/Math/Transforms/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp (renamed from mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp)139
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp1
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp6
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp2
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp70
-rw-r--r--mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp6
-rw-r--r--mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp6
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp78
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp153
-rw-r--r--mlir/lib/Dialect/Ptr/IR/CMakeLists.txt20
-rw-r--r--mlir/lib/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp15
-rw-r--r--mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp12
-rw-r--r--mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp122
-rw-r--r--mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp52
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp3
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp9
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp5
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp4
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp4
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp15
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp63
-rw-r--r--mlir/lib/Dialect/SCF/Utils/Utils.cpp46
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp2
-rw-r--r--mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp5
-rw-r--r--mlir/lib/Dialect/Shard/Transforms/Partition.cpp10
-rw-r--r--mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp4
-rw-r--r--mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp7
-rw-r--r--mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp11
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp5
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp10
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorOps.cpp16
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp568
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp23
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp238
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp12
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformOps.cpp40
-rw-r--r--mlir/lib/Dialect/Transform/IR/Utils.cpp33
-rw-r--r--mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp4
-rw-r--r--mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp7
-rw-r--r--mlir/lib/Dialect/Utils/StaticValueUtils.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp179
-rw-r--r--mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp5
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp65
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp45
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp63
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp39
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp10
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp4
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp4
-rw-r--r--mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp26
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp214
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp385
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp39
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp4
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp20
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp585
-rw-r--r--mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp98
-rw-r--r--mlir/lib/ExecutionEngine/ExecutionEngine.cpp20
-rw-r--r--mlir/lib/ExecutionEngine/JitRunner.cpp2
-rw-r--r--mlir/lib/ExecutionEngine/VulkanRuntime.cpp2
-rw-r--r--mlir/lib/IR/Block.cpp10
-rw-r--r--mlir/lib/IR/BuiltinAttributes.cpp17
-rw-r--r--mlir/lib/IR/CMakeLists.txt1
-rw-r--r--mlir/lib/IR/Dialect.cpp12
-rw-r--r--mlir/lib/IR/Location.cpp4
-rw-r--r--mlir/lib/IR/MLIRContext.cpp32
-rw-r--r--mlir/lib/IR/ODSSupport.cpp4
-rw-r--r--mlir/lib/IR/Remarks.cpp279
-rw-r--r--mlir/lib/Interfaces/SideEffectInterfaces.cpp5
-rw-r--r--mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp3
-rw-r--r--mlir/lib/Interfaces/ValueBoundsOpInterface.cpp2
-rw-r--r--mlir/lib/Query/Matcher/VariantValue.cpp5
-rw-r--r--mlir/lib/Query/Query.cpp2
-rw-r--r--mlir/lib/RegisterAllDialects.cpp2
-rw-r--r--mlir/lib/RegisterAllExtensions.cpp3
-rw-r--r--mlir/lib/RegisterAllPasses.cpp2
-rw-r--r--mlir/lib/Remark/CMakeLists.txt14
-rw-r--r--mlir/lib/Remark/RemarkStreamer.cpp69
-rw-r--r--mlir/lib/Rewrite/ByteCode.cpp300
-rw-r--r--mlir/lib/Rewrite/PatternApplicator.cpp6
-rw-r--r--mlir/lib/Target/CMakeLists.txt1
-rw-r--r--mlir/lib/Target/Cpp/TranslateToCpp.cpp52
-rw-r--r--mlir/lib/Target/LLVM/CMakeLists.txt24
-rw-r--r--mlir/lib/Target/LLVM/NVVM/Target.cpp48
-rw-r--r--mlir/lib/Target/LLVM/XeVM/Target.cpp418
-rw-r--r--mlir/lib/Target/LLVMIR/CMakeLists.txt2
-rw-r--r--mlir/lib/Target/LLVMIR/DataLayoutImporter.cpp63
-rw-r--r--mlir/lib/Target/LLVMIR/DataLayoutImporter.h132
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt1
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp12
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp139
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp100
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/Ptr/CMakeLists.txt12
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp66
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleImport.cpp140
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleTranslation.cpp180
-rw-r--r--mlir/lib/Target/LLVMIR/Transforms/CMakeLists.txt23
-rw-r--r--mlir/lib/Target/LLVMIR/Transforms/TargetToDataLayout.cpp62
-rw-r--r--mlir/lib/Target/LLVMIR/Transforms/TargetToTargetFeatures.cpp78
-rw-r--r--mlir/lib/Target/LLVMIR/Transforms/TargetUtils.cpp71
-rw-r--r--mlir/lib/Target/LLVMIR/TypeToLLVM.cpp11
-rw-r--r--mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp2
-rw-r--r--mlir/lib/Target/SPIRV/Serialization/Serializer.cpp20
-rw-r--r--mlir/lib/Target/SPIRV/Serialization/Serializer.h2
-rw-r--r--mlir/lib/Target/SPIRV/TranslateRegistration.cpp59
-rw-r--r--mlir/lib/Target/Wasm/CMakeLists.txt13
-rw-r--r--mlir/lib/Target/Wasm/TranslateFromWasm.cpp1522
-rw-r--r--mlir/lib/Target/Wasm/TranslateRegistration.cpp28
-rw-r--r--mlir/lib/Tools/PDLL/Parser/Parser.cpp3
-rw-r--r--mlir/lib/Tools/mlir-query/MlirQueryMain.cpp11
-rw-r--r--mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp13
-rw-r--r--mlir/lib/Transforms/CSE.cpp12
-rw-r--r--mlir/lib/Transforms/InlinerPass.cpp7
-rw-r--r--mlir/lib/Transforms/Mem2Reg.cpp4
-rw-r--r--mlir/lib/Transforms/RemoveDeadValues.cpp49
-rw-r--r--mlir/lib/Transforms/SROA.cpp4
-rw-r--r--mlir/lib/Transforms/SymbolDCE.cpp30
-rw-r--r--mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp9
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp703
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp42
-rw-r--r--mlir/lib/Transforms/Utils/InliningUtils.cpp13
-rw-r--r--mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp12
-rw-r--r--mlir/lib/Transforms/Utils/RegionUtils.cpp33
-rw-r--r--mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp126
271 files changed, 13316 insertions, 3761 deletions
diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
index 6cece46..8062b474 100644
--- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
+++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
@@ -127,9 +127,12 @@ static void collectUnderlyingAddressValues(OpResult result, unsigned maxDepth,
Operation *op = result.getOwner();
// If this is a view, unwrap to the source.
- if (ViewLikeOpInterface view = dyn_cast<ViewLikeOpInterface>(op))
- return collectUnderlyingAddressValues(view.getViewSource(), maxDepth,
- visited, output);
+ if (ViewLikeOpInterface view = dyn_cast<ViewLikeOpInterface>(op)) {
+ if (result == view.getViewDest()) {
+ return collectUnderlyingAddressValues(view.getViewSource(), maxDepth,
+ visited, output);
+ }
+ }
// Check to see if we can reason about the control flow of this op.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
return collectUnderlyingAddressValues(branch, /*region=*/nullptr, result,
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index 10874fd..9424eff 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -15,6 +15,7 @@
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
@@ -78,9 +79,17 @@ void Executable::onUpdate(DataFlowSolver *solver) const {
void PredecessorState::print(raw_ostream &os) const {
if (allPredecessorsKnown())
os << "(all) ";
- os << "predecessors:\n";
- for (Operation *op : getKnownPredecessors())
- os << " " << *op << "\n";
+ os << "predecessors:";
+ if (getKnownPredecessors().empty())
+ os << " (none)";
+ else
+ os << "\n";
+ llvm::interleave(
+ getKnownPredecessors(), os,
+ [&](Operation *op) {
+ os << " " << OpWithFlags(op, OpPrintingFlags().skipRegions());
+ },
+ "\n");
}
ChangeResult PredecessorState::join(Operation *predecessor) {
@@ -127,7 +136,7 @@ DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver)
LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
LDBG() << "Initializing DeadCodeAnalysis for top-level op: "
- << top->getName();
+ << OpWithFlags(top, OpPrintingFlags().skipRegions());
// Mark the top-level blocks as executable.
for (Region &region : top->getRegions()) {
if (region.empty())
@@ -135,7 +144,8 @@ LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
auto *state =
getOrCreate<Executable>(getProgramPointBefore(&region.front()));
propagateIfChanged(state, state->setToLive());
- LDBG() << "Marked entry block live for region in op: " << top->getName();
+ LDBG() << "Marked entry block live for region in op: "
+ << OpWithFlags(top, OpPrintingFlags().skipRegions());
}
// Mark as overdefined the predecessors of symbol callables with potentially
@@ -147,17 +157,19 @@ LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
LDBG() << "[init] Entering initializeSymbolCallables for top-level op: "
- << top->getName();
+ << OpWithFlags(top, OpPrintingFlags().skipRegions());
analysisScope = top;
auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
- LDBG() << "[init] Processing symbol table op: " << symTable->getName();
+ LDBG() << "[init] Processing symbol table op: "
+ << OpWithFlags(symTable, OpPrintingFlags().skipRegions());
Region &symbolTableRegion = symTable->getRegion(0);
Block *symbolTableBlock = &symbolTableRegion.front();
bool foundSymbolCallable = false;
for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
LDBG() << "[init] Found CallableOpInterface: "
- << callable.getOperation()->getName();
+ << OpWithFlags(callable.getOperation(),
+ OpPrintingFlags().skipRegions());
Region *callableRegion = callable.getCallableRegion();
if (!callableRegion)
continue;
@@ -172,7 +184,8 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
getOrCreate<PredecessorState>(getProgramPointAfter(callable));
propagateIfChanged(state, state->setHasUnknownPredecessors());
LDBG() << "[init] Marked callable as having unknown predecessors: "
- << callable.getOperation()->getName();
+ << OpWithFlags(callable.getOperation(),
+ OpPrintingFlags().skipRegions());
}
foundSymbolCallable = true;
}
@@ -195,7 +208,8 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
propagateIfChanged(state, state->setHasUnknownPredecessors());
LDBG() << "[init] Marked nested callable as "
"having unknown predecessors: "
- << callable.getOperation()->getName();
+ << OpWithFlags(callable.getOperation(),
+ OpPrintingFlags().skipRegions());
});
}
@@ -211,13 +225,13 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
propagateIfChanged(state, state->setHasUnknownPredecessors());
LDBG() << "[init] Found non-call use for symbol, "
"marked as having unknown predecessors: "
- << symbol->getName();
+ << OpWithFlags(symbol, OpPrintingFlags().skipRegions());
}
};
SymbolTable::walkSymbolTables(top, /*allSymUsesVisible=*/!top->getBlock(),
walkFn);
LDBG() << "[init] Finished initializeSymbolCallables for top-level op: "
- << top->getName();
+ << OpWithFlags(top, OpPrintingFlags().skipRegions());
}
/// Returns true if the operation is a returning terminator in region
@@ -229,12 +243,13 @@ static bool isRegionOrCallableReturn(Operation *op) {
}
LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
- LDBG() << "[init] Entering initializeRecursively for op: " << op->getName()
- << " at " << op;
+ LDBG() << "[init] Entering initializeRecursively for op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
// Initialize the analysis by visiting every op with control-flow semantics.
if (op->getNumRegions() || op->getNumSuccessors() ||
isRegionOrCallableReturn(op) || isa<CallOpInterface>(op)) {
- LDBG() << "[init] Visiting op with control-flow semantics: " << *op;
+ LDBG() << "[init] Visiting op with control-flow semantics: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
// When the liveness of the parent block changes, make sure to
// re-invoke the analysis on the op.
if (op->getBlock())
@@ -246,16 +261,17 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
}
// Recurse on nested operations.
for (Region &region : op->getRegions()) {
- LDBG() << "[init] Recursing into region of op: " << op->getName();
+ LDBG() << "[init] Recursing into region of op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
for (Operation &nestedOp : region.getOps()) {
- LDBG() << "[init] Recursing into nested op: " << nestedOp.getName()
- << " at " << &nestedOp;
+ LDBG() << "[init] Recursing into nested op: "
+ << OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions());
if (failed(initializeRecursively(&nestedOp)))
return failure();
}
}
- LDBG() << "[init] Finished initializeRecursively for op: " << op->getName()
- << " at " << op;
+ LDBG() << "[init] Finished initializeRecursively for op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
return success();
}
@@ -269,35 +285,40 @@ void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
}
void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
- LDBG() << "Marking entry blocks live for op: " << op->getName();
+ LDBG() << "Marking entry blocks live for op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
for (Region &region : op->getRegions()) {
if (region.empty())
continue;
auto *state =
getOrCreate<Executable>(getProgramPointBefore(&region.front()));
propagateIfChanged(state, state->setToLive());
- LDBG() << "Marked entry block live for region in op: " << op->getName();
+ LDBG() << "Marked entry block live for region in op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
}
}
LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
- LDBG() << "Visiting program point: " << point << " " << *point;
+ LDBG() << "Visiting program point: " << *point;
if (point->isBlockStart())
return success();
Operation *op = point->getPrevOp();
- LDBG() << "Visiting operation: " << *op;
+ LDBG() << "Visiting operation: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
// If the parent block is not executable, there is nothing to do.
if (op->getBlock() != nullptr &&
!getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))
->isLive()) {
- LDBG() << "Parent block not live, skipping op: " << *op;
+ LDBG() << "Parent block not live, skipping op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
return success();
}
// We have a live call op. Add this as a live predecessor of the callee.
if (auto call = dyn_cast<CallOpInterface>(op)) {
- LDBG() << "Visiting call operation: " << *op;
+ LDBG() << "Visiting call operation: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
visitCallOperation(call);
}
@@ -305,12 +326,14 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
if (op->getNumRegions()) {
// Check if we can reason about the region control-flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
- LDBG() << "Visiting region branch operation: " << *op;
+ LDBG() << "Visiting region branch operation: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
visitRegionBranchOperation(branch);
// Check if this is a callable operation.
} else if (auto callable = dyn_cast<CallableOpInterface>(op)) {
- LDBG() << "Visiting callable operation: " << *op;
+ LDBG() << "Visiting callable operation: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
const auto *callsites = getOrCreateFor<PredecessorState>(
getProgramPointAfter(op), getProgramPointAfter(callable));
@@ -322,19 +345,22 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
// Otherwise, conservatively mark all entry blocks as executable.
} else {
- LDBG() << "Marking all entry blocks live for op: " << *op;
+ LDBG() << "Marking all entry blocks live for op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
markEntryBlocksLive(op);
}
}
if (isRegionOrCallableReturn(op)) {
if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
- LDBG() << "Visiting region terminator: " << *op;
+ LDBG() << "Visiting region terminator: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
// Visit the exiting terminator of a region.
visitRegionTerminator(op, branch);
} else if (auto callable =
dyn_cast<CallableOpInterface>(op->getParentOp())) {
- LDBG() << "Visiting callable terminator: " << *op;
+ LDBG() << "Visiting callable terminator: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
// Visit the exiting terminator of a callable.
visitCallableTerminator(op, callable);
}
@@ -343,12 +369,14 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
if (op->getNumSuccessors()) {
// Check if we can reason about the control-flow.
if (auto branch = dyn_cast<BranchOpInterface>(op)) {
- LDBG() << "Visiting branch operation: " << *op;
+ LDBG() << "Visiting branch operation: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
visitBranchOperation(branch);
// Otherwise, conservatively mark all successors as exectuable.
} else {
- LDBG() << "Marking all successors live for op: " << *op;
+ LDBG() << "Marking all successors live for op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
for (Block *successor : op->getSuccessors())
markEdgeLive(op->getBlock(), successor);
}
@@ -358,7 +386,8 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
}
void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
- LDBG() << "visitCallOperation: " << call.getOperation()->getName();
+ LDBG() << "visitCallOperation: "
+ << OpWithFlags(call.getOperation(), OpPrintingFlags().skipRegions());
Operation *callableOp = call.resolveCallableInTable(&symbolTable);
// A call to a externally-defined callable has unknown predecessors.
@@ -382,14 +411,14 @@ void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
getOrCreate<PredecessorState>(getProgramPointAfter(callableOp));
propagateIfChanged(callsites, callsites->join(call));
LDBG() << "Added callsite as predecessor for callable: "
- << callableOp->getName();
+ << OpWithFlags(callableOp, OpPrintingFlags().skipRegions());
} else {
// Mark this call op's predecessors as overdefined.
auto *predecessors =
getOrCreate<PredecessorState>(getProgramPointAfter(call));
propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors());
LDBG() << "Marked call op's predecessors as unknown for: "
- << call.getOperation()->getName();
+ << OpWithFlags(call.getOperation(), OpPrintingFlags().skipRegions());
}
}
@@ -421,7 +450,8 @@ DeadCodeAnalysis::getOperandValues(Operation *op) {
}
void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
- LDBG() << "visitBranchOperation: " << branch.getOperation()->getName();
+ LDBG() << "visitBranchOperation: "
+ << OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions());
// Try to deduce a single successor for the branch.
std::optional<SmallVector<Attribute>> operands = getOperandValues(branch);
if (!operands)
@@ -440,7 +470,8 @@ void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
void DeadCodeAnalysis::visitRegionBranchOperation(
RegionBranchOpInterface branch) {
- LDBG() << "visitRegionBranchOperation: " << branch.getOperation()->getName();
+ LDBG() << "visitRegionBranchOperation: "
+ << OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions());
// Try to deduce which regions are executable.
std::optional<SmallVector<Attribute>> operands = getOperandValues(branch);
if (!operands)
@@ -517,14 +548,14 @@ void DeadCodeAnalysis::visitCallableTerminator(Operation *op,
if (canResolve) {
propagateIfChanged(predecessors, predecessors->join(op));
LDBG() << "Added callable terminator as predecessor for callsite: "
- << predecessor->getName();
+ << OpWithFlags(predecessor, OpPrintingFlags().skipRegions());
} else {
// If the terminator is not a return-like, then conservatively assume we
// can't resolve the predecessor.
propagateIfChanged(predecessors,
predecessors->setHasUnknownPredecessors());
LDBG() << "Could not resolve callable terminator for callsite: "
- << predecessor->getName();
+ << OpWithFlags(predecessor, OpPrintingFlags().skipRegions());
}
}
}
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index c7a950d..e79f6a8 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -19,6 +19,8 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
@@ -28,6 +30,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include <cassert>
#include <optional>
#include <utility>
@@ -87,7 +90,8 @@ LogicalResult IntegerRangeAnalysis::visitOperation(
return success();
}
- LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
+ LDBG() << "Inferring ranges for "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
auto argRanges = llvm::map_to_vector(
operands, [](const IntegerValueRangeLattice *lattice) {
return lattice->getValue();
@@ -99,7 +103,7 @@ LogicalResult IntegerRangeAnalysis::visitOperation(
return;
assert(llvm::is_contained(op->getResults(), result));
- LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
+ LDBG() << "Inferred range " << attrs;
IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
IntegerValueRange oldRange = lattice->getValue();
@@ -114,7 +118,7 @@ LogicalResult IntegerRangeAnalysis::visitOperation(
});
if (isYieldedResult && !oldRange.isUninitialized() &&
!(lattice->getValue() == oldRange)) {
- LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
+ LDBG() << "Loop variant loop result detected";
changed |= lattice->join(IntegerValueRange::getMaxRange(v));
}
propagateIfChanged(lattice, changed);
@@ -128,7 +132,8 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
Operation *op, const RegionSuccessor &successor,
ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
- LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
+ LDBG() << "Inferring ranges for "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) {
return getLatticeElementFor(getProgramPointAfter(op), value)->getValue();
@@ -141,7 +146,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg))
return;
- LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
+ LDBG() << "Inferred range " << attrs;
IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
IntegerValueRange oldRange = lattice->getValue();
@@ -156,7 +161,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
});
if (isYieldedValue && !oldRange.isUninitialized() &&
!(lattice->getValue() == oldRange)) {
- LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
+ LDBG() << "Loop variant loop result detected";
changed |= lattice->join(IntegerValueRange::getMaxRange(v));
}
propagateIfChanged(lattice, changed);
diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
index 509f520..65df355 100644
--- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
@@ -294,7 +294,34 @@ RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) {
solver.load<LivenessAnalysis>(symbolTable);
LDBG() << "Initializing and running solver";
(void)solver.initializeAndRun(op);
- LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName();
+ LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName()
+ << " check on unreachable code now:";
+ // The framework doesn't visit operations in dead blocks, so we need to
+ // explicitly mark them as dead.
+ op->walk([&](Operation *op) {
+ if (op->getNumResults() == 0)
+ return;
+ for (auto result : llvm::enumerate(op->getResults())) {
+ if (getLiveness(result.value()))
+ continue;
+ LDBG() << "Result: " << result.index() << " of "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " has no liveness info (unreachable), mark dead";
+ solver.getOrCreateState<Liveness>(result.value());
+ }
+ for (auto &region : op->getRegions()) {
+ for (auto &block : region) {
+ for (auto blockArg : llvm::enumerate(block.getArguments())) {
+ if (getLiveness(blockArg.value()))
+ continue;
+ LDBG() << "Block argument: " << blockArg.index() << " of "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " has no liveness info, mark dead";
+ solver.getOrCreateState<Liveness>(blockArg.value());
+ }
+ }
+ }
+ });
}
const Liveness *RunLivenessAnalysis::getLiveness(Value val) {
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index e625f62..13a3e14 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -19,12 +19,15 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
#include <cassert>
#include <optional>
using namespace mlir;
using namespace mlir::dataflow;
+#define DEBUG_TYPE "dataflow"
+
//===----------------------------------------------------------------------===//
// AbstractSparseLattice
//===----------------------------------------------------------------------===//
@@ -64,22 +67,36 @@ AbstractSparseForwardDataFlowAnalysis::initialize(Operation *top) {
LogicalResult
AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
+ LDBG() << "Initializing recursively for operation: " << op->getName();
+
// Initialize the analysis by visiting every owner of an SSA value (all
// operations and blocks).
- if (failed(visitOperation(op)))
+ if (failed(visitOperation(op))) {
+ LDBG() << "Failed to visit operation: " << op->getName();
return failure();
+ }
for (Region &region : op->getRegions()) {
+ LDBG() << "Processing region with " << region.getBlocks().size()
+ << " blocks";
for (Block &block : region) {
+ LDBG() << "Processing block with " << block.getNumArguments()
+ << " arguments";
getOrCreate<Executable>(getProgramPointBefore(&block))
->blockContentSubscribe(this);
visitBlock(&block);
- for (Operation &op : block)
- if (failed(initializeRecursively(&op)))
+ for (Operation &op : block) {
+ LDBG() << "Recursively initializing nested operation: " << op.getName();
+ if (failed(initializeRecursively(&op))) {
+ LDBG() << "Failed to initialize nested operation: " << op.getName();
return failure();
+ }
+ }
}
}
+ LDBG() << "Successfully completed recursive initialization for operation: "
+ << op->getName();
return success();
}
@@ -409,11 +426,20 @@ static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {
LogicalResult
AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
+ LDBG() << "Visiting operation: " << op->getName() << " with "
+ << op->getNumOperands() << " operands and " << op->getNumResults()
+ << " results";
+
// If we're in a dead block, bail out.
if (op->getBlock() != nullptr &&
- !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive())
+ !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))
+ ->isLive()) {
+ LDBG() << "Operation is in dead block, bailing out";
return success();
+ }
+ LDBG() << "Creating lattice elements for " << op->getNumOperands()
+ << " operands and " << op->getNumResults() << " results";
SmallVector<AbstractSparseLattice *> operandLattices =
getLatticeElements(op->getOperands());
SmallVector<const AbstractSparseLattice *> resultLattices =
@@ -422,11 +448,15 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
// Block arguments of region branch operations flow back into the operands
// of the parent op
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
+ LDBG() << "Processing RegionBranchOpInterface operation";
visitRegionSuccessors(branch, operandLattices);
return success();
}
if (auto branch = dyn_cast<BranchOpInterface>(op)) {
+ LDBG() << "Processing BranchOpInterface operation with "
+ << op->getNumSuccessors() << " successors";
+
// Block arguments of successor blocks flow back into our operands.
// We remember all operands not forwarded to any block in a BitVector.
@@ -463,6 +493,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
// For function calls, connect the arguments of the entry blocks to the
// operands of the call op that are forwarded to these arguments.
if (auto call = dyn_cast<CallOpInterface>(op)) {
+ LDBG() << "Processing CallOpInterface operation";
Operation *callableOp = call.resolveCallableInTable(&symbolTable);
if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
// Not all operands of a call op forward to arguments. Such operands are
@@ -513,6 +544,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
// of this op itself and the operands of the terminators of the regions of
// this op.
if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
+ LDBG() << "Processing RegionBranchTerminatorOpInterface operation";
if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
visitRegionSuccessorsFromTerminator(terminator, branch);
return success();
@@ -520,12 +552,16 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
}
if (op->hasTrait<OpTrait::ReturnLike>()) {
+ LDBG() << "Processing ReturnLike operation";
// Going backwards, the operands of the return are derived from the
// results of all CallOps calling this CallableOp.
- if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp()))
+ if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
+ LDBG() << "Callable parent found, visiting callable operation";
return visitCallableOperation(op, callable, operandLattices);
+ }
}
+ LDBG() << "Using default visitOperationImpl for operation: " << op->getName();
return visitOperationImpl(op, operandLattices, resultLattices);
}
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index 16f7033..7e1b405 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -45,7 +45,7 @@ void AnalysisState::addDependency(ProgramPoint *dependent,
DATAFLOW_DEBUG({
if (inserted) {
LDBG() << "Creating dependency between " << debugName << " of " << anchor
- << "\nand " << debugName << " on " << dependent;
+ << "\nand " << debugName << " on " << *dependent;
}
});
}
@@ -62,11 +62,12 @@ void ProgramPoint::print(raw_ostream &os) const {
return;
}
if (!isBlockStart()) {
- os << "<after operation>:";
- return getPrevOp()->print(os, OpPrintingFlags().skipRegions());
+ os << "<after operation>:"
+ << OpWithFlags(getPrevOp(), OpPrintingFlags().skipRegions());
+ return;
}
- os << "<before operation>:";
- return getNextOp()->print(os, OpPrintingFlags().skipRegions());
+ os << "<before operation>:"
+ << OpWithFlags(getNextOp(), OpPrintingFlags().skipRegions());
}
//===----------------------------------------------------------------------===//
@@ -78,8 +79,8 @@ void LatticeAnchor::print(raw_ostream &os) const {
os << "<NULL POINT>";
return;
}
- if (auto *LatticeAnchor = llvm::dyn_cast<GenericLatticeAnchor *>(*this))
- return LatticeAnchor->print(os);
+ if (auto *latticeAnchor = llvm::dyn_cast<GenericLatticeAnchor *>(*this))
+ return latticeAnchor->print(os);
if (auto value = llvm::dyn_cast<Value>(*this)) {
return value.print(os, OpPrintingFlags().skipRegions());
}
@@ -88,8 +89,8 @@ void LatticeAnchor::print(raw_ostream &os) const {
}
Location LatticeAnchor::getLoc() const {
- if (auto *LatticeAnchor = llvm::dyn_cast<GenericLatticeAnchor *>(*this))
- return LatticeAnchor->getLoc();
+ if (auto *latticeAnchor = llvm::dyn_cast<GenericLatticeAnchor *>(*this))
+ return latticeAnchor->getLoc();
if (auto value = llvm::dyn_cast<Value>(*this))
return value.getLoc();
@@ -128,7 +129,7 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
worklist.pop();
DATAFLOW_DEBUG(LDBG() << "Invoking '" << analysis->debugName
- << "' on: " << point);
+ << "' on: " << *point);
if (failed(analysis->visit(point)))
return failure();
}
diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index f4b02b4..30ce1fb 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -60,7 +60,7 @@ private:
AffineExpr localExpr) override {
SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr);
// Update localVarCst.
- localVarCst.addLocalFloorDiv(dividend, divisor);
+ (void)localVarCst.addLocalFloorDiv(dividend, divisor);
}
LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs,
diff --git a/mlir/lib/Analysis/Presburger/Barvinok.cpp b/mlir/lib/Analysis/Presburger/Barvinok.cpp
index 4546e49..75d592e 100644
--- a/mlir/lib/Analysis/Presburger/Barvinok.cpp
+++ b/mlir/lib/Analysis/Presburger/Barvinok.cpp
@@ -554,7 +554,7 @@ QuasiPolynomial mlir::presburger::detail::getCoefficientInRationalFunction(
/// t^num / \prod_j (1 - t^dens[j]).
/// v represents the affine functions whose floors are multiplied by the
/// generators, and ds represents the list of generators.
-std::pair<QuasiPolynomial, std::vector<Fraction>>
+static std::pair<QuasiPolynomial, std::vector<Fraction>>
substituteMuInTerm(unsigned numParams, const ParamPoint &v,
const std::vector<Point> &ds, const Point &mu) {
unsigned numDims = mu.size();
@@ -606,8 +606,8 @@ substituteMuInTerm(unsigned numParams, const ParamPoint &v,
/// Here, sign = ± 1,
/// num is a QuasiPolynomial, and
/// each dens[j] is a Fraction.
-void normalizeDenominatorExponents(int &sign, QuasiPolynomial &num,
- std::vector<Fraction> &dens) {
+static void normalizeDenominatorExponents(int &sign, QuasiPolynomial &num,
+ std::vector<Fraction> &dens) {
// We track the number of exponents that are negative in the
// denominator, and convert them to their absolute values.
unsigned numNegExps = 0;
@@ -634,8 +634,8 @@ void normalizeDenominatorExponents(int &sign, QuasiPolynomial &num,
/// Compute the binomial coefficients nCi for 0 ≤ i ≤ r,
/// where n is a QuasiPolynomial.
-std::vector<QuasiPolynomial> getBinomialCoefficients(const QuasiPolynomial &n,
- unsigned r) {
+static std::vector<QuasiPolynomial>
+getBinomialCoefficients(const QuasiPolynomial &n, unsigned r) {
unsigned numParams = n.getNumInputs();
std::vector<QuasiPolynomial> coefficients;
coefficients.reserve(r + 1);
@@ -651,8 +651,8 @@ std::vector<QuasiPolynomial> getBinomialCoefficients(const QuasiPolynomial &n,
/// Compute the binomial coefficients nCi for 0 ≤ i ≤ r,
/// where n is a QuasiPolynomial.
-std::vector<Fraction> getBinomialCoefficients(const Fraction &n,
- const Fraction &r) {
+static std::vector<Fraction> getBinomialCoefficients(const Fraction &n,
+ const Fraction &r) {
std::vector<Fraction> coefficients;
coefficients.reserve((int64_t)floor(r));
coefficients.emplace_back(1);
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 5c4d4d1..0dcdd5b 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -1500,12 +1500,13 @@ void IntegerRelation::addBound(BoundType type, ArrayRef<DynamicAPInt> expr,
/// respect to a positive constant 'divisor'. Two constraints are added to the
/// system to capture equivalence with the floordiv.
/// q = expr floordiv c <=> c*q <= expr <= c*q + c - 1.
-void IntegerRelation::addLocalFloorDiv(ArrayRef<DynamicAPInt> dividend,
- const DynamicAPInt &divisor) {
+/// Returns the column position of the new local variable.
+unsigned IntegerRelation::addLocalFloorDiv(ArrayRef<DynamicAPInt> dividend,
+ const DynamicAPInt &divisor) {
assert(dividend.size() == getNumCols() && "incorrect dividend size");
assert(divisor > 0 && "positive divisor expected");
- appendVar(VarKind::Local);
+ unsigned newVar = appendVar(VarKind::Local);
SmallVector<DynamicAPInt, 8> dividendCopy(dividend);
dividendCopy.insert(dividendCopy.end() - 1, DynamicAPInt(0));
@@ -1513,6 +1514,28 @@ void IntegerRelation::addLocalFloorDiv(ArrayRef<DynamicAPInt> dividend,
getDivLowerBound(dividendCopy, divisor, dividendCopy.size() - 2));
addInequality(
getDivUpperBound(dividendCopy, divisor, dividendCopy.size() - 2));
+ return newVar;
+}
+
+unsigned IntegerRelation::addLocalModulo(ArrayRef<DynamicAPInt> exprs,
+ const DynamicAPInt &modulus) {
+ assert(exprs.size() == getNumCols() && "incorrect exprs size");
+ assert(modulus > 0 && "positive modulus expected");
+
+ /// Add a local variable for q = expr floordiv modulus
+ addLocalFloorDiv(exprs, modulus);
+
+ /// Add a local var to represent the result
+ auto resultIndex = appendVar(VarKind::Local);
+
+ SmallVector<DynamicAPInt, 8> exprsCopy(exprs);
+ /// Insert the two new locals before the constant
+ /// Add locals that correspond to `q` and `result` to compute
+ /// 0 = (expr - modulus * q) - result
+ exprsCopy.insert(exprsCopy.end() - 1,
+ {DynamicAPInt(-modulus), DynamicAPInt(-1)});
+ addEquality(exprsCopy);
+ return resultIndex;
}
int IntegerRelation::findEqualityToConstant(unsigned pos, bool symbolic) const {
diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp
index 9fc6205..bb60564 100644
--- a/mlir/lib/Analysis/Presburger/Matrix.cpp
+++ b/mlir/lib/Analysis/Presburger/Matrix.cpp
@@ -402,10 +402,10 @@ void Matrix<T>::print(raw_ostream &os) const {
for (unsigned row = 0; row < nRows; ++row)
for (unsigned column = 0; column < nColumns; ++column)
updatePrintMetrics<T>(at(row, column), ptm);
- unsigned MIN_SPACING = 1;
+ unsigned minSpacing = 1;
for (unsigned row = 0; row < nRows; ++row) {
for (unsigned column = 0; column < nColumns; ++column) {
- printWithPrintMetrics<T>(os, at(row, column), MIN_SPACING, ptm);
+ printWithPrintMetrics<T>(os, at(row, column), minSpacing, ptm);
}
os << "\n";
}
@@ -721,7 +721,7 @@ FracMatrix FracMatrix::gramSchmidt() const {
// Otherwise, we swap b_k and b_{k-1} and decrement k.
//
// We repeat this until k = n and return.
-void FracMatrix::LLL(Fraction delta) {
+void FracMatrix::LLL(const Fraction &delta) {
DynamicAPInt nearest;
Fraction mu;
diff --git a/mlir/lib/Analysis/Presburger/QuasiPolynomial.cpp b/mlir/lib/Analysis/Presburger/QuasiPolynomial.cpp
index 84d885f..4e374d0 100644
--- a/mlir/lib/Analysis/Presburger/QuasiPolynomial.cpp
+++ b/mlir/lib/Analysis/Presburger/QuasiPolynomial.cpp
@@ -112,14 +112,14 @@ QuasiPolynomial QuasiPolynomial::simplify() {
// A term is zero if its coefficient is zero, or
if (coefficients[i] == Fraction(0, 1))
continue;
- bool product_is_zero =
+ bool productIsZero =
// if any of the affine functions in the product
- llvm::any_of(affine[i], [](const SmallVector<Fraction> &affine_ij) {
+ llvm::any_of(affine[i], [](const SmallVector<Fraction> &affineIj) {
// has all its coefficients as zero.
- return llvm::all_of(affine_ij,
+ return llvm::all_of(affineIj,
[](const Fraction &f) { return f == 0; });
});
- if (product_is_zero)
+ if (productIsZero)
continue;
// Now, we know the term is nonzero.
diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp
index 08290db..a1cbe29 100644
--- a/mlir/lib/Analysis/Presburger/Simplex.cpp
+++ b/mlir/lib/Analysis/Presburger/Simplex.cpp
@@ -433,7 +433,7 @@ LogicalResult SymbolicLexSimplex::addSymbolicCut(unsigned row) {
normalizeDiv(divCoeffs, divDenom);
domainSimplex.addDivisionVariable(divCoeffs, divDenom);
- domainPoly.addLocalFloorDiv(divCoeffs, divDenom);
+ (void)domainPoly.addLocalFloorDiv(divCoeffs, divDenom);
// Update `this` to account for the additional symbol we just added.
appendSymbol();
@@ -1663,7 +1663,7 @@ public:
/// First pushes a snapshot for the current simplex state to the stack so
/// that this can be rolled back later.
void addEqualityForDirection(ArrayRef<DynamicAPInt> dir) {
- assert(llvm::any_of(dir, [](const DynamicAPInt &x) { return x != 0; }) &&
+ assert(llvm::any_of(dir, [](const DynamicAPInt &X) { return X != 0; }) &&
"Direction passed is the zero vector!");
snapshotStack.emplace_back(simplex.getSnapshot());
simplex.addEquality(getCoeffsForDirection(dir));
@@ -2156,10 +2156,10 @@ void SimplexBase::print(raw_ostream &os) const {
for (unsigned row = 0, numRows = getNumRows(); row < numRows; ++row)
for (unsigned col = 0, numCols = getNumColumns(); col < numCols; ++col)
updatePrintMetrics<DynamicAPInt>(tableau(row, col), ptm);
- unsigned MIN_SPACING = 1;
+ unsigned minSpacing = 1;
for (unsigned row = 0, numRows = getNumRows(); row < numRows; ++row) {
for (unsigned col = 0, numCols = getNumColumns(); col < numCols; ++col) {
- printWithPrintMetrics<DynamicAPInt>(os, tableau(row, col), MIN_SPACING,
+ printWithPrintMetrics<DynamicAPInt>(os, tableau(row, col), minSpacing,
ptm);
}
os << '\n';
diff --git a/mlir/lib/Analysis/TopologicalSortUtils.cpp b/mlir/lib/Analysis/TopologicalSortUtils.cpp
index a2fd149..99546e7 100644
--- a/mlir/lib/Analysis/TopologicalSortUtils.cpp
+++ b/mlir/lib/Analysis/TopologicalSortUtils.cpp
@@ -101,12 +101,7 @@ bool mlir::sortTopologically(
bool mlir::sortTopologically(
Block *block, function_ref<bool(Value, Operation *)> isOperandReady) {
- if (block->empty())
- return true;
- if (block->back().hasTrait<OpTrait::IsTerminator>())
- return sortTopologically(block, block->without_terminator(),
- isOperandReady);
- return sortTopologically(block, *block, isOperandReady);
+ return sortTopologically(block, block->without_terminator(), isOperandReady);
}
bool mlir::computeTopologicalSorting(
diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp
index e5045cf..a21176fff 100644
--- a/mlir/lib/Bindings/Python/DialectGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectGPU.cpp
@@ -9,8 +9,8 @@
#include "mlir-c/Dialect/GPU.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace nanobind::literals;
@@ -34,7 +34,7 @@ NB_MODULE(_mlirDialectsGPU, m) {
mlirGPUAsyncTokenType.def_classmethod(
"get",
- [](nb::object cls, MlirContext ctx) {
+ [](const nb::object &cls, MlirContext ctx) {
return cls(mlirGPUAsyncTokenTypeGet(ctx));
},
"Gets an instance of AsyncTokenType in the same context", nb::arg("cls"),
@@ -47,8 +47,9 @@ NB_MODULE(_mlirDialectsGPU, m) {
mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr)
.def_classmethod(
"get",
- [](nb::object cls, MlirAttribute target, uint32_t format,
- nb::bytes object, std::optional<MlirAttribute> mlirObjectProps,
+ [](const nb::object &cls, MlirAttribute target, uint32_t format,
+ const nb::bytes &object,
+ std::optional<MlirAttribute> mlirObjectProps,
std::optional<MlirAttribute> mlirKernelsAttr) {
MlirStringRef objectStrRef = mlirStringRefCreate(
static_cast<char *>(const_cast<void *>(object.data())),
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index f211e76..ee106c0 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -12,8 +12,8 @@
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Diagnostics.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
@@ -24,7 +24,7 @@ using namespace mlir;
using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
-void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
+static void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
//===--------------------------------------------------------------------===//
// StructType
@@ -35,8 +35,8 @@ void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
llvmStructType.def_classmethod(
"get_literal",
- [](nb::object cls, const std::vector<MlirType> &elements, bool packed,
- MlirLocation loc) {
+ [](const nb::object &cls, const std::vector<MlirType> &elements,
+ bool packed, MlirLocation loc) {
CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
MlirType type = mlirLLVMStructTypeLiteralGetChecked(
@@ -51,7 +51,7 @@ void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
llvmStructType.def_classmethod(
"get_identified",
- [](nb::object cls, const std::string &name, MlirContext context) {
+ [](const nb::object &cls, const std::string &name, MlirContext context) {
return cls(mlirLLVMStructTypeIdentifiedGet(
context, mlirStringRefCreate(name.data(), name.size())));
},
@@ -59,7 +59,7 @@ void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
llvmStructType.def_classmethod(
"get_opaque",
- [](nb::object cls, const std::string &name, MlirContext context) {
+ [](const nb::object &cls, const std::string &name, MlirContext context) {
return cls(mlirLLVMStructTypeOpaqueGet(
context, mlirStringRefCreate(name.data(), name.size())));
},
@@ -79,7 +79,7 @@ void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
llvmStructType.def_classmethod(
"new_identified",
- [](nb::object cls, const std::string &name,
+ [](const nb::object &cls, const std::string &name,
const std::vector<MlirType> &elements, bool packed, MlirContext ctx) {
return cls(mlirLLVMStructTypeIdentifiedNewGet(
ctx, mlirStringRefCreate(name.data(), name.length()),
@@ -123,7 +123,7 @@ void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType)
.def_classmethod(
"get",
- [](nb::object cls, std::optional<unsigned> addressSpace,
+ [](const nb::object &cls, std::optional<unsigned> addressSpace,
MlirContext context) {
CollectDiagnosticsToStringScope scope(context);
MlirType type = mlirLLVMPointerTypeGet(
diff --git a/mlir/lib/Bindings/Python/DialectNVGPU.cpp b/mlir/lib/Bindings/Python/DialectNVGPU.cpp
index a0d6a4b..bb3f519c 100644
--- a/mlir/lib/Bindings/Python/DialectNVGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectNVGPU.cpp
@@ -8,8 +8,8 @@
#include "mlir-c/Dialect/NVGPU.h"
#include "mlir-c/IR.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace llvm;
@@ -23,8 +23,8 @@ static void populateDialectNVGPUSubmodule(const nb::module_ &m) {
nvgpuTensorMapDescriptorType.def_classmethod(
"get",
- [](nb::object cls, MlirType tensorMemrefType, int swizzle, int l2promo,
- int oobFill, int interleave, MlirContext ctx) {
+ [](const nb::object &cls, MlirType tensorMemrefType, int swizzle,
+ int l2promo, int oobFill, int interleave, MlirContext ctx) {
return cls(mlirNVGPUTensorMapDescriptorTypeGet(
ctx, tensorMemrefType, swizzle, l2promo, oobFill, interleave));
},
diff --git a/mlir/lib/Bindings/Python/DialectPDL.cpp b/mlir/lib/Bindings/Python/DialectPDL.cpp
index bcc6ff4..2acedbc 100644
--- a/mlir/lib/Bindings/Python/DialectPDL.cpp
+++ b/mlir/lib/Bindings/Python/DialectPDL.cpp
@@ -8,8 +8,8 @@
#include "mlir-c/Dialect/PDL.h"
#include "mlir-c/IR.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace llvm;
@@ -17,7 +17,7 @@ using namespace mlir;
using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
-void populateDialectPDLSubmodule(const nanobind::module_ &m) {
+static void populateDialectPDLSubmodule(const nanobind::module_ &m) {
//===-------------------------------------------------------------------===//
// PDLType
//===-------------------------------------------------------------------===//
@@ -32,7 +32,7 @@ void populateDialectPDLSubmodule(const nanobind::module_ &m) {
mlir_type_subclass(m, "AttributeType", mlirTypeIsAPDLAttributeType);
attributeType.def_classmethod(
"get",
- [](nb::object cls, MlirContext ctx) {
+ [](const nb::object &cls, MlirContext ctx) {
return cls(mlirPDLAttributeTypeGet(ctx));
},
"Get an instance of AttributeType in given context.", nb::arg("cls"),
@@ -46,7 +46,7 @@ void populateDialectPDLSubmodule(const nanobind::module_ &m) {
mlir_type_subclass(m, "OperationType", mlirTypeIsAPDLOperationType);
operationType.def_classmethod(
"get",
- [](nb::object cls, MlirContext ctx) {
+ [](const nb::object &cls, MlirContext ctx) {
return cls(mlirPDLOperationTypeGet(ctx));
},
"Get an instance of OperationType in given context.", nb::arg("cls"),
@@ -59,7 +59,7 @@ void populateDialectPDLSubmodule(const nanobind::module_ &m) {
auto rangeType = mlir_type_subclass(m, "RangeType", mlirTypeIsAPDLRangeType);
rangeType.def_classmethod(
"get",
- [](nb::object cls, MlirType elementType) {
+ [](const nb::object &cls, MlirType elementType) {
return cls(mlirPDLRangeTypeGet(elementType));
},
"Gets an instance of RangeType in the same context as the provided "
@@ -77,7 +77,7 @@ void populateDialectPDLSubmodule(const nanobind::module_ &m) {
auto typeType = mlir_type_subclass(m, "TypeType", mlirTypeIsAPDLTypeType);
typeType.def_classmethod(
"get",
- [](nb::object cls, MlirContext ctx) {
+ [](const nb::object &cls, MlirContext ctx) {
return cls(mlirPDLTypeTypeGet(ctx));
},
"Get an instance of TypeType in given context.", nb::arg("cls"),
@@ -90,7 +90,7 @@ void populateDialectPDLSubmodule(const nanobind::module_ &m) {
auto valueType = mlir_type_subclass(m, "ValueType", mlirTypeIsAPDLValueType);
valueType.def_classmethod(
"get",
- [](nb::object cls, MlirContext ctx) {
+ [](const nb::object &cls, MlirContext ctx) {
return cls(mlirPDLValueTypeGet(ctx));
},
"Get an instance of TypeType in given context.", nb::arg("cls"),
diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp
index 55571cd..a5220fc 100644
--- a/mlir/lib/Bindings/Python/DialectQuant.cpp
+++ b/mlir/lib/Bindings/Python/DialectQuant.cpp
@@ -165,7 +165,7 @@ static void populateDialectQuantSubmodule(const nb::module_ &m) {
quantizedType.get_class());
anyQuantizedType.def_classmethod(
"get",
- [](nb::object cls, unsigned flags, MlirType storageType,
+ [](const nb::object &cls, unsigned flags, MlirType storageType,
MlirType expressedType, int64_t storageTypeMin,
int64_t storageTypeMax) {
return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType,
@@ -186,7 +186,7 @@ static void populateDialectQuantSubmodule(const nb::module_ &m) {
quantizedType.get_class());
uniformQuantizedType.def_classmethod(
"get",
- [](nb::object cls, unsigned flags, MlirType storageType,
+ [](const nb::object &cls, unsigned flags, MlirType storageType,
MlirType expressedType, double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax) {
return cls(mlirUniformQuantizedTypeGet(flags, storageType,
@@ -221,7 +221,7 @@ static void populateDialectQuantSubmodule(const nb::module_ &m) {
quantizedType.get_class());
uniformQuantizedPerAxisType.def_classmethod(
"get",
- [](nb::object cls, unsigned flags, MlirType storageType,
+ [](const nb::object &cls, unsigned flags, MlirType storageType,
MlirType expressedType, std::vector<double> scales,
std::vector<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax) {
@@ -293,7 +293,7 @@ static void populateDialectQuantSubmodule(const nb::module_ &m) {
mlirTypeIsAUniformQuantizedSubChannelType, quantizedType.get_class());
uniformQuantizedSubChannelType.def_classmethod(
"get",
- [](nb::object cls, unsigned flags, MlirType storageType,
+ [](const nb::object &cls, unsigned flags, MlirType storageType,
MlirType expressedType, MlirAttribute scales, MlirAttribute zeroPoints,
std::vector<int32_t> quantizedDimensions,
std::vector<int64_t> blockSizes, int64_t storageTypeMin,
@@ -367,7 +367,8 @@ static void populateDialectQuantSubmodule(const nb::module_ &m) {
quantizedType.get_class());
calibratedQuantizedType.def_classmethod(
"get",
- [](nb::object cls, MlirType expressedType, double min, double max) {
+ [](const nb::object &cls, MlirType expressedType, double min,
+ double max) {
return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max));
},
"Gets an instance of CalibratedQuantizedType in the same context as the "
diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp
index 4e76477..cab4219 100644
--- a/mlir/lib/Bindings/Python/DialectSMT.cpp
+++ b/mlir/lib/Bindings/Python/DialectSMT.cpp
@@ -24,7 +24,7 @@ using namespace mlir;
using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
-void populateDialectSMTSubmodule(nanobind::module_ &m) {
+static void populateDialectSMTSubmodule(nanobind::module_ &m) {
auto smtBoolType = mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool)
.def_classmethod(
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 97cebcc..9d7dc11 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -12,8 +12,8 @@
#include "mlir-c/AffineMap.h"
#include "mlir-c/Dialect/SparseTensor.h"
#include "mlir-c/IR.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace llvm;
@@ -38,7 +38,8 @@ static void populateDialectSparseTensorSubmodule(const nb::module_ &m) {
mlirAttributeIsASparseTensorEncodingAttr)
.def_classmethod(
"get",
- [](nb::object cls, std::vector<MlirSparseTensorLevelType> lvlTypes,
+ [](const nb::object &cls,
+ std::vector<MlirSparseTensorLevelType> lvlTypes,
std::optional<MlirAffineMap> dimToLvl,
std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
std::optional<MlirAttribute> explicitVal,
@@ -58,7 +59,7 @@ static void populateDialectSparseTensorSubmodule(const nb::module_ &m) {
"Gets a sparse_tensor.encoding from parameters.")
.def_classmethod(
"build_level_type",
- [](nb::object cls, MlirSparseTensorLevelFormat lvlFmt,
+ [](const nb::object &cls, MlirSparseTensorLevelFormat lvlFmt,
const std::vector<MlirSparseTensorLevelPropertyNondefault>
&properties,
unsigned n, unsigned m) {
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 59a030a..1a62b06 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -11,15 +11,15 @@
#include "mlir-c/Dialect/Transform.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace mlir;
using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
-void populateDialectTransformSubmodule(const nb::module_ &m) {
+static void populateDialectTransformSubmodule(const nb::module_ &m) {
//===-------------------------------------------------------------------===//
// AnyOpType
//===-------------------------------------------------------------------===//
@@ -29,7 +29,7 @@ void populateDialectTransformSubmodule(const nb::module_ &m) {
mlirTransformAnyOpTypeGetTypeID);
anyOpType.def_classmethod(
"get",
- [](nb::object cls, MlirContext ctx) {
+ [](const nb::object &cls, MlirContext ctx) {
return cls(mlirTransformAnyOpTypeGet(ctx));
},
"Get an instance of AnyOpType in the given context.", nb::arg("cls"),
@@ -44,7 +44,7 @@ void populateDialectTransformSubmodule(const nb::module_ &m) {
mlirTransformAnyParamTypeGetTypeID);
anyParamType.def_classmethod(
"get",
- [](nb::object cls, MlirContext ctx) {
+ [](const nb::object &cls, MlirContext ctx) {
return cls(mlirTransformAnyParamTypeGet(ctx));
},
"Get an instance of AnyParamType in the given context.", nb::arg("cls"),
@@ -59,7 +59,7 @@ void populateDialectTransformSubmodule(const nb::module_ &m) {
mlirTransformAnyValueTypeGetTypeID);
anyValueType.def_classmethod(
"get",
- [](nb::object cls, MlirContext ctx) {
+ [](const nb::object &cls, MlirContext ctx) {
return cls(mlirTransformAnyValueTypeGet(ctx));
},
"Get an instance of AnyValueType in the given context.", nb::arg("cls"),
@@ -74,7 +74,8 @@ void populateDialectTransformSubmodule(const nb::module_ &m) {
mlirTransformOperationTypeGetTypeID);
operationType.def_classmethod(
"get",
- [](nb::object cls, const std::string &operationName, MlirContext ctx) {
+ [](const nb::object &cls, const std::string &operationName,
+ MlirContext ctx) {
MlirStringRef cOperationName =
mlirStringRefCreate(operationName.data(), operationName.size());
return cls(mlirTransformOperationTypeGet(ctx, cOperationName));
@@ -101,7 +102,7 @@ void populateDialectTransformSubmodule(const nb::module_ &m) {
mlirTransformParamTypeGetTypeID);
paramType.def_classmethod(
"get",
- [](nb::object cls, MlirType type, MlirContext ctx) {
+ [](const nb::object &cls, MlirType type, MlirContext ctx) {
return cls(mlirTransformParamTypeGet(ctx, type));
},
"Get an instance of ParamType for the given type in the given context.",
diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
index 81dada3..8bb493e 100644
--- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
+++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
@@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir-c/ExecutionEngine.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace mlir;
@@ -45,7 +45,7 @@ public:
referencedObjects.push_back(obj);
}
- static nb::object createFromCapsule(nb::object capsule) {
+ static nb::object createFromCapsule(const nb::object &capsule) {
MlirExecutionEngine rawPm =
mlirPythonCapsuleToExecutionEngine(capsule.ptr());
if (mlirExecutionEngineIsNull(rawPm))
@@ -113,7 +113,7 @@ NB_MODULE(_mlirExecutionEngine, m) {
.def(
"raw_register_runtime",
[](PyExecutionEngine &executionEngine, const std::string &name,
- nb::object callbackObj) {
+ const nb::object &callbackObj) {
executionEngine.addReferencedObject(callbackObj);
uintptr_t rawSym =
nb::cast<uintptr_t>(nb::getattr(callbackObj, "value"));
@@ -125,6 +125,17 @@ NB_MODULE(_mlirExecutionEngine, m) {
nb::arg("name"), nb::arg("callback"),
"Register `callback` as the runtime symbol `name`.")
.def(
+ "initialize",
+ [](PyExecutionEngine &executionEngine) {
+ mlirExecutionEngineInitialize(executionEngine.get());
+ },
+ "Initialize the ExecutionEngine. Global constructors specified by "
+ "`llvm.mlir.global_ctors` will be run. One common scenario is that "
+ "kernel binary compiled from `gpu.module` gets loaded during "
+ "initialization. Make sure all symbols are resolvable before "
+ "initialization by calling `register_runtime` or including "
+ "shared libraries.")
+ .def(
"dump_to_object_file",
[](PyExecutionEngine &executionEngine, const std::string &fileName) {
mlirExecutionEngineDumpToObjectFile(
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 826a34a..71a051c 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -10,15 +10,19 @@
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
#include <optional>
+#include <regex>
#include <string>
+#include <unordered_set>
#include <vector>
#include "NanobindUtils.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/Support.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/Regex.h"
namespace mlir {
namespace python {
@@ -114,6 +118,39 @@ public:
std::optional<nanobind::object>
lookupOperationClass(llvm::StringRef operationName);
+ class TracebackLoc {
+ public:
+ bool locTracebacksEnabled();
+
+ void setLocTracebacksEnabled(bool value);
+
+ size_t locTracebackFramesLimit();
+
+ void setLocTracebackFramesLimit(size_t value);
+
+ void registerTracebackFileInclusion(const std::string &file);
+
+ void registerTracebackFileExclusion(const std::string &file);
+
+ bool isUserTracebackFilename(llvm::StringRef file);
+
+ static constexpr size_t kMaxFrames = 512;
+
+ private:
+ nanobind::ft_mutex mutex;
+ bool locTracebackEnabled_ = false;
+ size_t locTracebackFramesLimit_ = 10;
+ std::unordered_set<std::string> userTracebackIncludeFiles;
+ std::unordered_set<std::string> userTracebackExcludeFiles;
+ std::regex userTracebackIncludeRegex;
+ bool rebuildUserTracebackIncludeRegex = false;
+ std::regex userTracebackExcludeRegex;
+ bool rebuildUserTracebackExcludeRegex = false;
+ llvm::StringMap<bool> isUserTracebackFilenameCache;
+ };
+
+ TracebackLoc &getTracebackLoc() { return tracebackLoc; }
+
private:
static PyGlobals *instance;
@@ -134,6 +171,8 @@ private:
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
llvm::StringSet<> loadedDialectModules;
+
+ TracebackLoc tracebackLoc;
};
} // namespace python
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 50f2a4f..a6499c9 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -17,9 +17,9 @@
#include "NanobindUtils.h"
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
+#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
#include "mlir-c/IntegerSet.h"
#include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/SmallVector.h"
@@ -64,7 +64,7 @@ static void pyListToVector(const nb::list &list,
}
template <typename PermutationTy>
-static bool isPermutation(std::vector<PermutationTy> permutation) {
+static bool isPermutation(const std::vector<PermutationTy> &permutation) {
llvm::SmallVector<bool, 8> seen(permutation.size(), false);
for (auto val : permutation) {
if (val < permutation.size()) {
@@ -366,7 +366,7 @@ nb::object PyAffineExpr::getCapsule() {
return nb::steal<nb::object>(mlirPythonAffineExprToCapsule(*this));
}
-PyAffineExpr PyAffineExpr::createFromCapsule(nb::object capsule) {
+PyAffineExpr PyAffineExpr::createFromCapsule(const nb::object &capsule) {
MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
if (mlirAffineExprIsNull(rawAffineExpr))
throw nb::python_error();
@@ -424,7 +424,7 @@ nb::object PyAffineMap::getCapsule() {
return nb::steal<nb::object>(mlirPythonAffineMapToCapsule(*this));
}
-PyAffineMap PyAffineMap::createFromCapsule(nb::object capsule) {
+PyAffineMap PyAffineMap::createFromCapsule(const nb::object &capsule) {
MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
if (mlirAffineMapIsNull(rawAffineMap))
throw nb::python_error();
@@ -500,7 +500,7 @@ nb::object PyIntegerSet::getCapsule() {
return nb::steal<nb::object>(mlirPythonIntegerSetToCapsule(*this));
}
-PyIntegerSet PyIntegerSet::createFromCapsule(nb::object capsule) {
+PyIntegerSet PyIntegerSet::createFromCapsule(const nb::object &capsule) {
MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
if (mlirIntegerSetIsNull(rawIntegerSet))
throw nb::python_error();
@@ -708,7 +708,8 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
return static_cast<size_t>(llvm::hash_value(self.get().ptr));
})
.def_static("compress_unused_symbols",
- [](nb::list affineMaps, DefaultingPyMlirContext context) {
+ [](const nb::list &affineMaps,
+ DefaultingPyMlirContext context) {
SmallVector<MlirAffineMap> maps;
pyListToVector<PyAffineMap, MlirAffineMap>(
affineMaps, maps, "attempting to create an AffineMap");
@@ -734,7 +735,7 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
kDumpDocstring)
.def_static(
"get",
- [](intptr_t dimCount, intptr_t symbolCount, nb::list exprs,
+ [](intptr_t dimCount, intptr_t symbolCount, const nb::list &exprs,
DefaultingPyMlirContext context) {
SmallVector<MlirAffineExpr> affineExprs;
pyListToVector<PyAffineExpr, MlirAffineExpr>(
@@ -869,7 +870,8 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
.def("__eq__", [](PyIntegerSet &self,
PyIntegerSet &other) { return self == other; })
- .def("__eq__", [](PyIntegerSet &self, nb::object other) { return false; })
+ .def("__eq__",
+ [](PyIntegerSet &self, const nb::object &other) { return false; })
.def("__str__",
[](PyIntegerSet &self) {
PyPrintAccumulator printAccum;
@@ -898,7 +900,7 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
kDumpDocstring)
.def_static(
"get",
- [](intptr_t numDims, intptr_t numSymbols, nb::list exprs,
+ [](intptr_t numDims, intptr_t numSymbols, const nb::list &exprs,
std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
if (exprs.size() != eqFlags.size())
throw nb::value_error(
@@ -934,8 +936,9 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
nb::arg("context").none() = nb::none())
.def(
"get_replaced",
- [](PyIntegerSet &self, nb::list dimExprs, nb::list symbolExprs,
- intptr_t numResultDims, intptr_t numResultSymbols) {
+ [](PyIntegerSet &self, const nb::list &dimExprs,
+ const nb::list &symbolExprs, intptr_t numResultDims,
+ intptr_t numResultSymbols) {
if (static_cast<intptr_t>(dimExprs.size()) !=
mlirIntegerSetGetNumDims(self))
throw nb::value_error(
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index db84ee1..f2eafa7 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -505,7 +505,7 @@ public:
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](nb::list attributes, DefaultingPyMlirContext context) {
+ [](const nb::list &attributes, DefaultingPyMlirContext context) {
SmallVector<MlirAttribute> mlirAttributes;
mlirAttributes.reserve(nb::len(attributes));
for (auto attribute : attributes) {
@@ -530,7 +530,7 @@ public:
.def("__iter__", [](const PyArrayAttribute &arr) {
return PyArrayAttributeIterator(arr);
});
- c.def("__add__", [](PyArrayAttribute arr, nb::list extras) {
+ c.def("__add__", [](PyArrayAttribute arr, const nb::list &extras) {
std::vector<MlirAttribute> attributes;
intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
attributes.reserve(numOldElements + nb::len(extras));
@@ -708,7 +708,7 @@ public:
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](std::string value, DefaultingPyMlirContext context) {
+ [](const std::string &value, DefaultingPyMlirContext context) {
MlirAttribute attr =
mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
return PyFlatSymbolRefAttribute(context->getRef(), attr);
@@ -736,8 +736,8 @@ public:
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](std::string dialectNamespace, nb_buffer buffer, PyType &type,
- DefaultingPyMlirContext context) {
+ [](const std::string &dialectNamespace, const nb_buffer &buffer,
+ PyType &type, DefaultingPyMlirContext context) {
const nb_buffer_info bufferInfo = buffer.request();
intptr_t bufferSize = bufferInfo.size;
MlirAttribute attr = mlirOpaqueAttrGet(
@@ -775,7 +775,7 @@ public:
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](std::string value, DefaultingPyMlirContext context) {
+ [](const std::string &value, DefaultingPyMlirContext context) {
MlirAttribute attr =
mlirStringAttrGet(context->get(), toMlirStringRef(value));
return PyStringAttribute(context->getRef(), attr);
@@ -784,7 +784,7 @@ public:
"Gets a uniqued string attribute");
c.def_static(
"get",
- [](nb::bytes value, DefaultingPyMlirContext context) {
+ [](const nb::bytes &value, DefaultingPyMlirContext context) {
MlirAttribute attr =
mlirStringAttrGet(context->get(), toMlirStringRef(value));
return PyStringAttribute(context->getRef(), attr);
@@ -793,7 +793,7 @@ public:
"Gets a uniqued string attribute");
c.def_static(
"get_typed",
- [](PyType &type, std::string value) {
+ [](PyType &type, const std::string &value) {
MlirAttribute attr =
mlirStringAttrTypedGet(type, toMlirStringRef(value));
return PyStringAttribute(type.getContext(), attr);
@@ -826,7 +826,7 @@ public:
using PyConcreteAttribute::PyConcreteAttribute;
static PyDenseElementsAttribute
- getFromList(nb::list attributes, std::optional<PyType> explicitType,
+ getFromList(const nb::list &attributes, std::optional<PyType> explicitType,
DefaultingPyMlirContext contextWrapper) {
const size_t numAttributes = nb::len(attributes);
if (numAttributes == 0)
@@ -878,8 +878,8 @@ public:
}
static PyDenseElementsAttribute
- getFromBuffer(nb_buffer array, bool signless,
- std::optional<PyType> explicitType,
+ getFromBuffer(const nb_buffer &array, bool signless,
+ const std::optional<PyType> &explicitType,
std::optional<std::vector<int64_t>> explicitShape,
DefaultingPyMlirContext contextWrapper) {
// Request a contiguous view. In exotic cases, this will cause a copy.
@@ -894,8 +894,8 @@ public:
auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
MlirContext context = contextWrapper->get();
- MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType,
- explicitShape, context);
+ MlirAttribute attr = getAttributeFromBuffer(
+ view, signless, explicitType, std::move(explicitShape), context);
if (mlirAttributeIsNull(attr)) {
throw std::invalid_argument(
"DenseElementsAttr could not be constructed from the given buffer. "
@@ -1092,16 +1092,16 @@ private:
"when the type is not a shaped type.");
}
return *bulkLoadElementType;
- } else {
- MlirAttribute encodingAttr = mlirAttributeGetNull();
- return mlirRankedTensorTypeGet(shape.size(), shape.data(),
- *bulkLoadElementType, encodingAttr);
}
+ MlirAttribute encodingAttr = mlirAttributeGetNull();
+ return mlirRankedTensorTypeGet(shape.size(), shape.data(),
+ *bulkLoadElementType, encodingAttr);
}
static MlirAttribute getAttributeFromBuffer(
Py_buffer &view, bool signless, std::optional<PyType> explicitType,
- std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) {
+ const std::optional<std::vector<int64_t>> &explicitShape,
+ MlirContext &context) {
// Detect format codes that are suitable for bulk loading. This includes
// all byte aligned integer and floating point types up to 8 bytes.
// Notably, this excludes exotics types which do not have a direct
@@ -1125,7 +1125,7 @@ private:
bulkLoadElementType = mlirF16TypeGet(context);
} else if (format == "?") {
// i1
- // The i1 type needs to be bit-packed, so we will handle it seperately
+ // The i1 type needs to be bit-packed, so we will handle it separately
return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
context);
} else if (isSignedIntegerFormat(format)) {
@@ -1205,8 +1205,8 @@ private:
packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little");
nb_buffer_info pythonBuffer = nb::cast<nb_buffer>(packedBooleans).request();
- MlirType bitpackedType =
- getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
+ MlirType bitpackedType = getShapedType(mlirIntegerTypeGet(context, 1),
+ std::move(explicitShape), view);
assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
// Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
// packedBooleans, hence the MlirAttribute will remain valid even when
@@ -1443,9 +1443,9 @@ public:
using PyConcreteAttribute::PyConcreteAttribute;
static PyDenseResourceElementsAttribute
- getFromBuffer(nb_buffer buffer, const std::string &name, const PyType &type,
- std::optional<size_t> alignment, bool isMutable,
- DefaultingPyMlirContext contextWrapper) {
+ getFromBuffer(const nb_buffer &buffer, const std::string &name,
+ const PyType &type, std::optional<size_t> alignment,
+ bool isMutable, DefaultingPyMlirContext contextWrapper) {
if (!mlirTypeIsAShaped(type)) {
throw std::invalid_argument(
"Constructing a DenseResourceElementsAttr requires a ShapedType.");
@@ -1534,7 +1534,7 @@ public:
c.def("__len__", &PyDictAttribute::dunderLen);
c.def_static(
"get",
- [](nb::dict attributes, DefaultingPyMlirContext context) {
+ [](const nb::dict &attributes, DefaultingPyMlirContext context) {
SmallVector<MlirNamedAttribute> mlirNamedAttributes;
mlirNamedAttributes.reserve(attributes.size());
for (std::pair<nb::handle, nb::handle> it : attributes) {
@@ -1618,7 +1618,7 @@ public:
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](PyType value, DefaultingPyMlirContext context) {
+ [](const PyType &value, DefaultingPyMlirContext context) {
MlirAttribute attr = mlirTypeAttrGet(value.get());
return PyTypeAttribute(context->getRef(), attr);
},
@@ -1663,7 +1663,7 @@ public:
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](int64_t offset, const std::vector<int64_t> strides,
+ [](int64_t offset, const std::vector<int64_t> &strides,
DefaultingPyMlirContext ctx) {
MlirAttribute attr = mlirStridedLayoutAttrGet(
ctx->get(), offset, strides.size(), strides.data());
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 5feed95..2df2a73 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -20,11 +20,8 @@
#include "nanobind/nanobind.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/raw_ostream.h"
#include <optional>
-#include <system_error>
-#include <utility>
namespace nb = nanobind;
using namespace nb::literals;
@@ -70,6 +67,12 @@ Returns a new MlirModule or raises an MLIRError if the parsing fails.
See also: https://mlir.llvm.org/docs/LangRef/
)";
+static const char kModuleCAPICreate[] =
+ R"(Creates a Module from a MlirModule wrapped by a capsule (i.e. module._CAPIPtr).
+Note this returns a new object BUT _clear_mlir_module(module) must be called to
+prevent double-frees (of the underlying mlir::Module).
+)";
+
static const char kOperationCreateDocstring[] =
R"(Creates a new operation.
@@ -199,7 +202,7 @@ operations.
/// Helper for creating an @classmethod.
template <class Func, typename... Args>
-nb::object classmethod(Func f, Args... args) {
+static nb::object classmethod(Func f, Args... args) {
nb::object cf = nb::cpp_function(f, args...);
return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr())));
}
@@ -705,84 +708,6 @@ size_t PyMlirContext::getLiveCount() {
return getLiveContexts().size();
}
-size_t PyMlirContext::getLiveOperationCount() {
- nb::ft_lock_guard lock(liveOperationsMutex);
- return liveOperations.size();
-}
-
-std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
- std::vector<PyOperation *> liveObjects;
- nb::ft_lock_guard lock(liveOperationsMutex);
- for (auto &entry : liveOperations)
- liveObjects.push_back(entry.second.second);
- return liveObjects;
-}
-
-size_t PyMlirContext::clearLiveOperations() {
-
- LiveOperationMap operations;
- {
- nb::ft_lock_guard lock(liveOperationsMutex);
- std::swap(operations, liveOperations);
- }
- for (auto &op : operations)
- op.second.second->setInvalid();
- size_t numInvalidated = operations.size();
- return numInvalidated;
-}
-
-void PyMlirContext::clearOperation(MlirOperation op) {
- PyOperation *py_op;
- {
- nb::ft_lock_guard lock(liveOperationsMutex);
- auto it = liveOperations.find(op.ptr);
- if (it == liveOperations.end()) {
- return;
- }
- py_op = it->second.second;
- liveOperations.erase(it);
- }
- py_op->setInvalid();
-}
-
-void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
- typedef struct {
- PyOperation &rootOp;
- bool rootSeen;
- } callBackData;
- callBackData data{op.getOperation(), false};
- // Mark all ops below the op that the passmanager will be rooted
- // at (but not op itself - note the preorder) as invalid.
- MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
- void *userData) {
- callBackData *data = static_cast<callBackData *>(userData);
- if (LLVM_LIKELY(data->rootSeen))
- data->rootOp.getOperation().getContext()->clearOperation(op);
- else
- data->rootSeen = true;
- return MlirWalkResult::MlirWalkResultAdvance;
- };
- mlirOperationWalk(op.getOperation(), invalidatingCallback,
- static_cast<void *>(&data), MlirWalkPreOrder);
-}
-void PyMlirContext::clearOperationsInside(MlirOperation op) {
- PyOperationRef opRef = PyOperation::forOperation(getRef(), op);
- clearOperationsInside(opRef->getOperation());
-}
-
-void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
- MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
- void *userData) {
- PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
- contextRef->clearOperation(op);
- return MlirWalkResult::MlirWalkResultAdvance;
- };
- mlirOperationWalk(op.getOperation(), invalidatingCallback,
- &op.getOperation().getContext(), MlirWalkPreOrder);
-}
-
-size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
-
nb::object PyMlirContext::contextEnter(nb::object context) {
return PyThreadContextEntry::pushContext(context);
}
@@ -1154,38 +1079,23 @@ PyLocation &DefaultingPyLocation::resolve() {
PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
: BaseContextObject(std::move(contextRef)), module(module) {}
-PyModule::~PyModule() {
- nb::gil_scoped_acquire acquire;
- auto &liveModules = getContext()->liveModules;
- assert(liveModules.count(module.ptr) == 1 &&
- "destroying module not in live map");
- liveModules.erase(module.ptr);
- mlirModuleDestroy(module);
-}
+PyModule::~PyModule() { mlirModuleDestroy(module); }
PyModuleRef PyModule::forModule(MlirModule module) {
MlirContext context = mlirModuleGetContext(module);
PyMlirContextRef contextRef = PyMlirContext::forContext(context);
- nb::gil_scoped_acquire acquire;
- auto &liveModules = contextRef->liveModules;
- auto it = liveModules.find(module.ptr);
- if (it == liveModules.end()) {
- // Create.
- PyModule *unownedModule = new PyModule(std::move(contextRef), module);
- // Note that the default return value policy on cast is automatic_reference,
- // which does not take ownership (delete will not be called).
- // Just be explicit.
- nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
- unownedModule->handle = pyRef;
- liveModules[module.ptr] =
- std::make_pair(unownedModule->handle, unownedModule);
- return PyModuleRef(unownedModule, std::move(pyRef));
- }
- // Use existing.
- PyModule *existing = it->second.second;
- nb::object pyRef = nb::borrow<nb::object>(it->second.first);
- return PyModuleRef(existing, std::move(pyRef));
+ // Create.
+ PyModule *unownedModule = new PyModule(std::move(contextRef), module);
+ // Note that the default return value policy on cast is `automatic_reference`,
+ // which means "does not take ownership, does not call delete/dtor".
+ // We use `take_ownership`, which means "Python will call the C++ destructor
+ // and delete operator when the Python wrapper is garbage collected", because
+ // MlirModule actually wraps OwningOpRef<ModuleOp> (see mlirModuleCreateParse
+ // etc).
+ nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
+ unownedModule->handle = pyRef;
+ return PyModuleRef(unownedModule, std::move(pyRef));
}
nb::object PyModule::createFromCapsule(nb::object capsule) {
@@ -1210,15 +1120,11 @@ PyOperation::~PyOperation() {
// If the operation has already been invalidated there is nothing to do.
if (!valid)
return;
-
- // Otherwise, invalidate the operation and remove it from live map when it is
- // attached.
- if (isAttached()) {
- getContext()->clearOperation(*this);
- } else {
- // And destroy it when it is detached, i.e. owned by Python, in which case
- // all nested operations must be invalidated at removed from the live map as
- // well.
+ // Otherwise, invalidate the operation when it is attached.
+ if (isAttached())
+ setInvalid();
+ else {
+ // And destroy it when it is detached, i.e. owned by Python.
erase();
}
}
@@ -1255,35 +1161,15 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
MlirOperation operation,
nb::object parentKeepAlive) {
- nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
- auto &liveOperations = contextRef->liveOperations;
- auto it = liveOperations.find(operation.ptr);
- if (it == liveOperations.end()) {
- // Create.
- PyOperationRef result = createInstance(std::move(contextRef), operation,
- std::move(parentKeepAlive));
- liveOperations[operation.ptr] =
- std::make_pair(result.getObject(), result.get());
- return result;
- }
- // Use existing.
- PyOperation *existing = it->second.second;
- nb::object pyRef = nb::borrow<nb::object>(it->second.first);
- return PyOperationRef(existing, std::move(pyRef));
+ return createInstance(std::move(contextRef), operation,
+ std::move(parentKeepAlive));
}
PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
MlirOperation operation,
nb::object parentKeepAlive) {
- nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
- auto &liveOperations = contextRef->liveOperations;
- assert(liveOperations.count(operation.ptr) == 0 &&
- "cannot create detached operation that already exists");
- (void)liveOperations;
PyOperationRef created = createInstance(std::move(contextRef), operation,
std::move(parentKeepAlive));
- liveOperations[operation.ptr] =
- std::make_pair(created.getObject(), created.get());
created->attached = false;
return created;
}
@@ -1523,7 +1409,7 @@ nb::object PyOperation::create(std::string_view name,
llvm::ArrayRef<MlirValue> operands,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
- int regions, DefaultingPyLocation location,
+ int regions, PyLocation &location,
const nb::object &maybeIp, bool inferType) {
llvm::SmallVector<MlirType, 4> mlirResults;
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
@@ -1627,7 +1513,7 @@ nb::object PyOperation::create(std::string_view name,
if (!operation.ptr)
throw nb::value_error("Operation creation failed");
PyOperationRef created =
- PyOperation::createDetached(location->getContext(), operation);
+ PyOperation::createDetached(location.getContext(), operation);
maybeInsertOperation(created, maybeIp);
return created.getObject();
@@ -1655,7 +1541,7 @@ nb::object PyOperation::createOpView() {
void PyOperation::erase() {
checkValid();
- getContext()->clearOperationAndInside(*this);
+ setInvalid();
mlirOperationDestroy(operation);
}
@@ -1937,9 +1823,9 @@ nb::object PyOpView::buildGeneric(
std::optional<nb::list> resultTypeList, nb::list operandList,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
- std::optional<int> regions, DefaultingPyLocation location,
+ std::optional<int> regions, PyLocation &location,
const nb::object &maybeIp) {
- PyMlirContextRef context = location->getContext();
+ PyMlirContextRef context = location.getContext();
// Class level operation construction metadata.
// Operand and result segment specs are either none, which does no
@@ -2108,7 +1994,7 @@ nb::object PyOpView::buildGeneric(
// Delegate to create.
return PyOperation::create(name,
/*results=*/std::move(resultTypes),
- /*operands=*/std::move(operands),
+ /*operands=*/operands,
/*attributes=*/std::move(attributes),
/*successors=*/std::move(successors),
/*regions=*/*regions, location, maybeIp,
@@ -2789,6 +2675,156 @@ private:
PyOperationRef operation;
};
+// see
+// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h
+
+#ifndef _Py_CAST
+#define _Py_CAST(type, expr) ((type)(expr))
+#endif
+
+// Static inline functions should use _Py_NULL rather than using directly NULL
+// to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer,
+// _Py_NULL is defined as nullptr.
+#ifndef _Py_NULL
+#if (defined(__STDC_VERSION__) && __STDC_VERSION__ > 201710L) || \
+ (defined(__cplusplus) && __cplusplus >= 201103)
+#define _Py_NULL nullptr
+#else
+#define _Py_NULL NULL
+#endif
+#endif
+
+// Python 3.10.0a3
+#if PY_VERSION_HEX < 0x030A00A3
+
+// bpo-42262 added Py_XNewRef()
+#if !defined(Py_XNewRef)
+[[maybe_unused]] PyObject *_Py_XNewRef(PyObject *obj) {
+ Py_XINCREF(obj);
+ return obj;
+}
+#define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj))
+#endif
+
+// bpo-42262 added Py_NewRef()
+#if !defined(Py_NewRef)
+[[maybe_unused]] PyObject *_Py_NewRef(PyObject *obj) {
+ Py_INCREF(obj);
+ return obj;
+}
+#define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj))
+#endif
+
+#endif // Python 3.10.0a3
+
+// Python 3.9.0b1
+#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION)
+
+// bpo-40429 added PyThreadState_GetFrame()
+PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) {
+ assert(tstate != _Py_NULL && "expected tstate != _Py_NULL");
+ return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame));
+}
+
+// bpo-40421 added PyFrame_GetBack()
+PyFrameObject *PyFrame_GetBack(PyFrameObject *frame) {
+ assert(frame != _Py_NULL && "expected frame != _Py_NULL");
+ return _Py_CAST(PyFrameObject *, Py_XNewRef(frame->f_back));
+}
+
+// bpo-40421 added PyFrame_GetCode()
+PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) {
+ assert(frame != _Py_NULL && "expected frame != _Py_NULL");
+ assert(frame->f_code != _Py_NULL && "expected frame->f_code != _Py_NULL");
+ return _Py_CAST(PyCodeObject *, Py_NewRef(frame->f_code));
+}
+
+#endif // Python 3.9.0b1
+
+MlirLocation tracebackToLocation(MlirContext ctx) {
+ size_t framesLimit =
+ PyGlobals::get().getTracebackLoc().locTracebackFramesLimit();
+ // Use a thread_local here to avoid requiring a large amount of space.
+ thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames>
+ frames;
+ size_t count = 0;
+
+ nb::gil_scoped_acquire acquire;
+ PyThreadState *tstate = PyThreadState_GET();
+ PyFrameObject *next;
+ PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
+ // In the increment expression:
+ // 1. get the next prev frame;
+ // 2. decrement the ref count on the current frame (in order that it can get
+ // gc'd, along with any objects in its closure and etc);
+ // 3. set current = next.
+ for (; pyFrame != nullptr && count < framesLimit;
+ next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
+ PyCodeObject *code = PyFrame_GetCode(pyFrame);
+ auto fileNameStr =
+ nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
+ llvm::StringRef fileName(fileNameStr);
+ if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
+ continue;
+
+ // co_qualname and PyCode_Addr2Location added in py3.11
+#if PY_VERSION_HEX < 0x030B00F0
+ std::string name =
+ nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
+ llvm::StringRef funcName(name);
+ int startLine = PyFrame_GetLineNumber(pyFrame);
+ MlirLocation loc =
+ mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0);
+#else
+ std::string name =
+ nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
+ llvm::StringRef funcName(name);
+ int startLine, startCol, endLine, endCol;
+ int lasti = PyFrame_GetLasti(pyFrame);
+ if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
+ &endCol)) {
+ throw nb::python_error();
+ }
+ MlirLocation loc = mlirLocationFileLineColRangeGet(
+ ctx, wrap(fileName), startLine, startCol, endLine, endCol);
+#endif
+
+ frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc);
+ ++count;
+ }
+ // When the loop breaks (after the last iter), current frame (if non-null)
+ // is leaked without this.
+ Py_XDECREF(pyFrame);
+
+ if (count == 0)
+ return mlirLocationUnknownGet(ctx);
+
+ MlirLocation callee = frames[0];
+ assert(!mlirLocationIsNull(callee) && "expected non-null callee location");
+ if (count == 1)
+ return callee;
+
+ MlirLocation caller = frames[count - 1];
+ assert(!mlirLocationIsNull(caller) && "expected non-null caller location");
+ for (int i = count - 2; i >= 1; i--)
+ caller = mlirLocationCallSiteGet(frames[i], caller);
+
+ return mlirLocationCallSiteGet(callee, caller);
+}
+
+PyLocation
+maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
+ if (location.has_value())
+ return location.value();
+ if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled())
+ return DefaultingPyLocation::resolve();
+
+ PyMlirContext &ctx = DefaultingPyMlirContext::resolve();
+ MlirLocation mlirLoc = tracebackToLocation(ctx.get());
+ PyMlirContextRef ref = PyMlirContext::forContext(ctx.get());
+ return {ref, mlirLoc};
+}
+
} // namespace
//------------------------------------------------------------------------------
@@ -2876,14 +2912,6 @@ void mlir::python::populateIRCore(nb::module_ &m) {
PyMlirContextRef ref = PyMlirContext::forContext(self.get());
return ref.releaseObject();
})
- .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
- .def("_get_live_operation_objects",
- &PyMlirContext::getLiveOperationObjects)
- .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
- .def("_clear_live_operations_inside",
- nb::overload_cast<MlirOperation>(
- &PyMlirContext::clearOperationsInside))
- .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
.def("__enter__", &PyMlirContext::contextEnter)
@@ -3052,10 +3080,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
.def("__eq__", [](PyLocation &self, nb::object other) { return false; })
.def_prop_ro_static(
"current",
- [](nb::object & /*class*/) {
+ [](nb::object & /*class*/) -> std::optional<PyLocation *> {
auto *loc = PyThreadContextEntry::getDefaultLocation();
if (!loc)
- throw nb::value_error("No current Location");
+ return std::nullopt;
return loc;
},
"Gets the Location bound to the current thread or raises ValueError")
@@ -3201,7 +3229,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
//----------------------------------------------------------------------------
nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
- .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
+ .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
+ kModuleCAPICreate)
+ .def("_clear_mlir_module", &PyModule::clearMlirModule)
.def_static(
"parse",
[](const std::string &moduleAsm, DefaultingPyMlirContext context) {
@@ -3240,8 +3270,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
kModuleParseDocstring)
.def_static(
"create",
- [](DefaultingPyLocation loc) {
- MlirModule module = mlirModuleCreateEmpty(loc);
+ [](const std::optional<PyLocation> &loc) {
+ PyLocation pyLoc = maybeGetTracebackLocation(loc);
+ MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
return PyModule::forModule(module).releaseObject();
},
nb::arg("loc").none() = nb::none(), "Creates an empty module")
@@ -3280,7 +3311,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
// Defer to the operation's __str__.
return self.attr("operation").attr("__str__")();
},
- kOperationStrDunderDocstring);
+ kOperationStrDunderDocstring)
+ .def(
+ "__eq__",
+ [](PyModule &self, PyModule &other) {
+ return mlirModuleEqual(self.get(), other.get());
+ },
+ "other"_a);
//----------------------------------------------------------------------------
// Mapping of Operation.
@@ -3292,7 +3329,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
})
.def("__eq__",
[](PyOperationBase &self, PyOperationBase &other) {
- return &self.getOperation() == &other.getOperation();
+ return mlirOperationEqual(self.getOperation().get(),
+ other.getOperation().get());
})
.def("__eq__",
[](PyOperationBase &self, nb::object other) { return false; })
@@ -3442,6 +3480,14 @@ void mlir::python::populateIRCore(nb::module_ &m) {
return operation.createOpView();
},
"Detaches the operation from its parent block.")
+ .def_prop_ro(
+ "attached",
+ [](PyOperationBase &self) {
+ PyOperation &operation = self.getOperation();
+ operation.checkValid();
+ return operation.isAttached();
+ },
+ "Reports if the operation is attached to its parent block.")
.def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
.def("walk", &PyOperationBase::walk, nb::arg("callback"),
nb::arg("walk_order") = MlirWalkPostOrder);
@@ -3454,8 +3500,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
std::optional<std::vector<PyValue *>> operands,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors, int regions,
- DefaultingPyLocation location, const nb::object &maybeIp,
- bool inferType) {
+ const std::optional<PyLocation> &location,
+ const nb::object &maybeIp, bool inferType) {
// Unpack/validate operands.
llvm::SmallVector<MlirValue, 4> mlirOperands;
if (operands) {
@@ -3467,8 +3513,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
}
}
+ PyLocation pyLoc = maybeGetTracebackLocation(location);
return PyOperation::create(name, results, mlirOperands, attributes,
- successors, regions, location, maybeIp,
+ successors, regions, pyLoc, maybeIp,
inferType);
},
nb::arg("name"), nb::arg("results").none() = nb::none(),
@@ -3498,7 +3545,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
[](PyOperationBase &self) {
return PyOpSuccessors(self.getOperation().getRef());
},
- "Returns the list of Operation successors.");
+ "Returns the list of Operation successors.")
+ .def("_set_invalid", &PyOperation::setInvalid,
+ "Invalidate the operation.");
auto opViewClass =
nb::class_<PyOpView, PyOperationBase>(m, "OpView")
@@ -3512,12 +3561,14 @@ void mlir::python::populateIRCore(nb::module_ &m) {
std::optional<nb::list> resultTypeList, nb::list operandList,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
- std::optional<int> regions, DefaultingPyLocation location,
+ std::optional<int> regions,
+ const std::optional<PyLocation> &location,
const nb::object &maybeIp) {
+ PyLocation pyLoc = maybeGetTracebackLocation(location);
new (self) PyOpView(PyOpView::buildGeneric(
name, opRegionSpec, operandSegmentSpecObj,
resultSegmentSpecObj, resultTypeList, operandList,
- attributes, successors, regions, location, maybeIp));
+ attributes, successors, regions, pyLoc, maybeIp));
},
nb::arg("name"), nb::arg("opRegionSpec"),
nb::arg("operandSegmentSpecObj").none() = nb::none(),
@@ -3540,7 +3591,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
[](PyOperationBase &self) {
return PyOpSuccessors(self.getOperation().getRef());
},
- "Returns the list of Operation successors.");
+ "Returns the list of Operation successors.")
+ .def(
+ "_set_invalid",
+ [](PyOpView &self) { self.getOperation().setInvalid(); },
+ "Invalidate the operation.");
opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
@@ -3551,17 +3606,18 @@ void mlir::python::populateIRCore(nb::module_ &m) {
[](nb::handle cls, std::optional<nb::list> resultTypeList,
nb::list operandList, std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
- std::optional<int> regions, DefaultingPyLocation location,
+ std::optional<int> regions, std::optional<PyLocation> location,
const nb::object &maybeIp) {
std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
std::tuple<int, bool> opRegionSpec =
nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
+ PyLocation pyLoc = maybeGetTracebackLocation(location);
return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
resultSegmentSpec, resultTypeList,
operandList, attributes, successors,
- regions, location, maybeIp);
+ regions, pyLoc, maybeIp);
},
nb::arg("cls"), nb::arg("results").none() = nb::none(),
nb::arg("operands").none() = nb::none(),
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index e600f1b..0de2f17 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -13,9 +13,9 @@
#include "Globals.h"
#include "NanobindUtils.h"
+#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
namespace nb = nanobind;
using namespace mlir;
@@ -197,3 +197,71 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
// Not found and loading did not yield a registration.
return std::nullopt;
}
+
+bool PyGlobals::TracebackLoc::locTracebacksEnabled() {
+ nanobind::ft_lock_guard lock(mutex);
+ return locTracebackEnabled_;
+}
+
+void PyGlobals::TracebackLoc::setLocTracebacksEnabled(bool value) {
+ nanobind::ft_lock_guard lock(mutex);
+ locTracebackEnabled_ = value;
+}
+
+size_t PyGlobals::TracebackLoc::locTracebackFramesLimit() {
+ nanobind::ft_lock_guard lock(mutex);
+ return locTracebackFramesLimit_;
+}
+
+void PyGlobals::TracebackLoc::setLocTracebackFramesLimit(size_t value) {
+ nanobind::ft_lock_guard lock(mutex);
+ locTracebackFramesLimit_ = std::min(value, kMaxFrames);
+}
+
+void PyGlobals::TracebackLoc::registerTracebackFileInclusion(
+ const std::string &file) {
+ nanobind::ft_lock_guard lock(mutex);
+ auto reg = "^" + llvm::Regex::escape(file);
+ if (userTracebackIncludeFiles.insert(reg).second)
+ rebuildUserTracebackIncludeRegex = true;
+ if (userTracebackExcludeFiles.count(reg)) {
+ if (userTracebackExcludeFiles.erase(reg))
+ rebuildUserTracebackExcludeRegex = true;
+ }
+}
+
+void PyGlobals::TracebackLoc::registerTracebackFileExclusion(
+ const std::string &file) {
+ nanobind::ft_lock_guard lock(mutex);
+ auto reg = "^" + llvm::Regex::escape(file);
+ if (userTracebackExcludeFiles.insert(reg).second)
+ rebuildUserTracebackExcludeRegex = true;
+ if (userTracebackIncludeFiles.count(reg)) {
+ if (userTracebackIncludeFiles.erase(reg))
+ rebuildUserTracebackIncludeRegex = true;
+ }
+}
+
+bool PyGlobals::TracebackLoc::isUserTracebackFilename(
+ const llvm::StringRef file) {
+ nanobind::ft_lock_guard lock(mutex);
+ if (rebuildUserTracebackIncludeRegex) {
+ userTracebackIncludeRegex.assign(
+ llvm::join(userTracebackIncludeFiles, "|"));
+ rebuildUserTracebackIncludeRegex = false;
+ isUserTracebackFilenameCache.clear();
+ }
+ if (rebuildUserTracebackExcludeRegex) {
+ userTracebackExcludeRegex.assign(
+ llvm::join(userTracebackExcludeFiles, "|"));
+ rebuildUserTracebackExcludeRegex = false;
+ isUserTracebackFilenameCache.clear();
+ }
+ if (!isUserTracebackFilenameCache.contains(file)) {
+ std::string fileStr = file.str();
+ bool include = std::regex_search(fileStr, userTracebackIncludeRegex);
+ bool exclude = std::regex_search(fileStr, userTracebackExcludeRegex);
+ isUserTracebackFilenameCache[file] = include || !exclude;
+ }
+ return isUserTracebackFilenameCache[file];
+}
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 9c22dea..0cc0459 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -192,16 +192,6 @@ public:
PyMlirContext(const PyMlirContext &) = delete;
PyMlirContext(PyMlirContext &&) = delete;
- /// For the case of a python __init__ (nanobind::init) method, pybind11 is
- /// quite strict about needing to return a pointer that is not yet associated
- /// to an nanobind::object. Since the forContext() method acts like a pool,
- /// possibly returning a recycled context, it does not satisfy this need. The
- /// usual way in python to accomplish such a thing is to override __new__, but
- /// that is also not supported by pybind11. Instead, we use this entry
- /// point which always constructs a fresh context (which cannot alias an
- /// existing one because it is fresh).
- static PyMlirContext *createNewContextForInit();
-
/// Returns a context reference for the singleton PyMlirContext wrapper for
/// the given context.
static PyMlirContextRef forContext(MlirContext context);
@@ -228,40 +218,6 @@ public:
/// Gets the count of live context objects. Used for testing.
static size_t getLiveCount();
- /// Get a list of Python objects which are still in the live context map.
- std::vector<PyOperation *> getLiveOperationObjects();
-
- /// Gets the count of live operations associated with this context.
- /// Used for testing.
- size_t getLiveOperationCount();
-
- /// Clears the live operations map, returning the number of entries which were
- /// invalidated. To be used as a safety mechanism so that API end-users can't
- /// corrupt by holding references they shouldn't have accessed in the first
- /// place.
- size_t clearLiveOperations();
-
- /// Removes an operation from the live operations map and sets it invalid.
- /// This is useful for when some non-bindings code destroys the operation and
- /// the bindings need to made aware. For example, in the case when pass
- /// manager is run.
- ///
- /// Note that this does *NOT* clear the nested operations.
- void clearOperation(MlirOperation op);
-
- /// Clears all operations nested inside the given op using
- /// `clearOperation(MlirOperation)`.
- void clearOperationsInside(PyOperationBase &op);
- void clearOperationsInside(MlirOperation op);
-
- /// Clears the operaiton _and_ all operations inside using
- /// `clearOperation(MlirOperation)`.
- void clearOperationAndInside(PyOperationBase &op);
-
- /// Gets the count of live modules associated with this context.
- /// Used for testing.
- size_t getLiveModuleCount();
-
/// Enter and exit the context manager.
static nanobind::object contextEnter(nanobind::object context);
void contextExit(const nanobind::object &excType,
@@ -288,25 +244,6 @@ private:
static nanobind::ft_mutex live_contexts_mutex;
static LiveContextMap &getLiveContexts();
- // Interns all live modules associated with this context. Modules tracked
- // in this map are valid. When a module is invalidated, it is removed
- // from this map, and while it still exists as an instance, any
- // attempt to access it will raise an error.
- using LiveModuleMap =
- llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>;
- LiveModuleMap liveModules;
-
- // Interns all live operations associated with this context. Operations
- // tracked in this map are valid. When an operation is invalidated, it is
- // removed from this map, and while it still exists as an instance, any
- // attempt to access it will raise an error.
- using LiveOperationMap =
- llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>;
- nanobind::ft_mutex liveOperationsMutex;
-
- // Guarded by liveOperationsMutex in free-threading mode.
- LiveOperationMap liveOperations;
-
bool emitErrorDiagnostics = false;
MlirContext context;
@@ -558,8 +495,8 @@ class PyModule;
using PyModuleRef = PyObjectRef<PyModule>;
class PyModule : public BaseContextObject {
public:
- /// Returns a PyModule reference for the given MlirModule. This may return
- /// a pre-existing or new object.
+ /// Returns a PyModule reference for the given MlirModule. This always returns
+ /// a new object.
static PyModuleRef forModule(MlirModule module);
PyModule(PyModule &) = delete;
PyModule(PyMlirContext &&) = delete;
@@ -580,11 +517,12 @@ public:
nanobind::object getCapsule();
/// Creates a PyModule from the MlirModule wrapped by a capsule.
- /// Note that PyModule instances are uniqued, so the returned object
- /// may be a pre-existing object. Ownership of the underlying MlirModule
- /// is taken by calling this function.
+ /// Note this returns a new object BUT clearMlirModule() must be called to
+ /// prevent double-frees (of the underlying mlir::Module).
static nanobind::object createFromCapsule(nanobind::object capsule);
+ void clearMlirModule() { module = {nullptr}; }
+
private:
PyModule(PyMlirContextRef contextRef, MlirModule module);
MlirModule module;
@@ -722,8 +660,7 @@ public:
llvm::ArrayRef<MlirValue> operands,
std::optional<nanobind::dict> attributes,
std::optional<std::vector<PyBlock *>> successors, int regions,
- DefaultingPyLocation location, const nanobind::object &ip,
- bool inferType);
+ PyLocation &location, const nanobind::object &ip, bool inferType);
/// Creates an OpView suitable for this operation.
nanobind::object createOpView();
@@ -781,7 +718,7 @@ public:
nanobind::list operandList,
std::optional<nanobind::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
- std::optional<int> regions, DefaultingPyLocation location,
+ std::optional<int> regions, PyLocation &location,
const nanobind::object &maybeIp);
/// Construct an instance of a class deriving from OpView, bypassing its
@@ -1227,7 +1164,7 @@ public:
/// Note that PyAffineExpr instances are uniqued, so the returned object
/// may be a pre-existing object. Ownership of the underlying MlirAffineExpr
/// is taken by calling this function.
- static PyAffineExpr createFromCapsule(nanobind::object capsule);
+ static PyAffineExpr createFromCapsule(const nanobind::object &capsule);
PyAffineExpr add(const PyAffineExpr &other) const;
PyAffineExpr mul(const PyAffineExpr &other) const;
@@ -1254,7 +1191,7 @@ public:
/// Note that PyAffineMap instances are uniqued, so the returned object
/// may be a pre-existing object. Ownership of the underlying MlirAffineMap
/// is taken by calling this function.
- static PyAffineMap createFromCapsule(nanobind::object capsule);
+ static PyAffineMap createFromCapsule(const nanobind::object &capsule);
private:
MlirAffineMap affineMap;
@@ -1274,7 +1211,7 @@ public:
/// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule.
/// Note that PyIntegerSet instances may be uniqued, so the returned object
/// may be a pre-existing object. Integer sets are owned by the context.
- static PyIntegerSet createFromCapsule(nanobind::object capsule);
+ static PyIntegerSet createFromCapsule(const nanobind::object &capsule);
private:
MlirIntegerSet integerSet;
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index b11e3f7..a9b1259 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -963,7 +963,7 @@ public:
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](std::string dialectNamespace, std::string typeData,
+ [](const std::string &dialectNamespace, const std::string &typeData,
DefaultingPyMlirContext context) {
MlirType type = mlirOpaqueTypeGet(context->get(),
toMlirStringRef(dialectNamespace),
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 6f49431..278847e 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -6,7 +6,6 @@
//
//===----------------------------------------------------------------------===//
-
#include "Globals.h"
#include "IRModule.h"
#include "NanobindUtils.h"
@@ -44,7 +43,27 @@ NB_MODULE(_mlir, m) {
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
"operation_name"_a, "operation_class"_a, nb::kw_only(),
"replace"_a = false,
- "Testing hook for directly registering an operation");
+ "Testing hook for directly registering an operation")
+ .def("loc_tracebacks_enabled",
+ [](PyGlobals &self) {
+ return self.getTracebackLoc().locTracebacksEnabled();
+ })
+ .def("set_loc_tracebacks_enabled",
+ [](PyGlobals &self, bool enabled) {
+ self.getTracebackLoc().setLocTracebacksEnabled(enabled);
+ })
+ .def("set_loc_tracebacks_frame_limit",
+ [](PyGlobals &self, int n) {
+ self.getTracebackLoc().setLocTracebackFramesLimit(n);
+ })
+ .def("register_traceback_file_inclusion",
+ [](PyGlobals &self, const std::string &filename) {
+ self.getTracebackLoc().registerTracebackFileInclusion(filename);
+ })
+ .def("register_traceback_file_exclusion",
+ [](PyGlobals &self, const std::string &filename) {
+ self.getTracebackLoc().registerTracebackFileExclusion(filename);
+ });
// Aside from making the globals accessible to python, having python manage
// it is necessary to make sure it is destroyed (and releases its python
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 20017e2..88e28dc 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -39,7 +39,7 @@ public:
return nb::steal<nb::object>(mlirPythonPassManagerToCapsule(get()));
}
- static nb::object createFromCapsule(nb::object capsule) {
+ static nb::object createFromCapsule(const nb::object &capsule) {
MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
if (mlirPassManagerIsNull(rawPm))
throw nb::python_error();
@@ -159,11 +159,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
"ValueError if the pipeline can't be parsed.")
.def(
"run",
- [](PyPassManager &passManager, PyOperationBase &op,
- bool invalidateOps) {
- if (invalidateOps) {
- op.getOperation().getContext()->clearOperationsInside(op);
- }
+ [](PyPassManager &passManager, PyOperationBase &op) {
// Actually run the pass manager.
PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
MlirLogicalResult status = mlirPassManagerRunOnOp(
@@ -172,7 +168,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
throw MLIRError("Failure while executing pass pipeline",
errors.take());
},
- "operation"_a, "invalidate_ops"_a = true,
+ "operation"_a,
"Run the pass manager on the provided operation, raising an "
"MLIRError on failure.")
.def(
diff --git a/mlir/lib/Bindings/Python/RegisterEverything.cpp b/mlir/lib/Bindings/Python/RegisterEverything.cpp
index 3ba42be..3edcb09 100644
--- a/mlir/lib/Bindings/Python/RegisterEverything.cpp
+++ b/mlir/lib/Bindings/Python/RegisterEverything.cpp
@@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir-c/RegisterEverything.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
NB_MODULE(_mlirRegisterEverything, m) {
m.doc() = "MLIR All Upstream Dialects, Translations and Passes Registration";
diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
index f9b0fed..920bca8 100644
--- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp
+++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
@@ -67,7 +67,6 @@ static void populateTransformInterpreterSubmodule(nb::module_ &m) {
// root. This is awkward, but we don't have access to PyMlirContext
// object here otherwise.
nb::object obj = nb::cast(payloadRoot);
- obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot);
MlirLogicalResult result = mlirTransformApplyNamedSequence(
payloadRoot, transformRoot, transformModule, options.options);
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 6ebeac5..eacb936 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -19,6 +19,7 @@
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
@@ -150,8 +151,7 @@ public:
/// Backpatch a byte in the result buffer at the given offset.
void patchByte(uint64_t offset, uint8_t value, StringLiteral desc) {
- LLVM_DEBUG(llvm::dbgs() << "patchByte(" << offset << ',' << uint64_t(value)
- << ")\t" << desc << '\n');
+ LDBG() << "patchByte(" << offset << ',' << uint64_t(value) << ")\t" << desc;
assert(offset < size() && offset >= prevResultSize &&
"cannot patch previously emitted data");
currentResult[offset - prevResultSize] = value;
@@ -160,8 +160,7 @@ public:
/// Emit the provided blob of data, which is owned by the caller and is
/// guaranteed to not die before the end of the bytecode process.
void emitOwnedBlob(ArrayRef<uint8_t> data, StringLiteral desc) {
- LLVM_DEBUG(llvm::dbgs()
- << "emitOwnedBlob(" << data.size() << "b)\t" << desc << '\n');
+ LDBG() << "emitOwnedBlob(" << data.size() << "b)\t" << desc;
// Push the current buffer before adding the provided data.
appendResult(std::move(currentResult));
appendOwnedResult(data);
@@ -209,15 +208,13 @@ public:
/// Emit a single byte.
template <typename T>
void emitByte(T byte, StringLiteral desc) {
- LLVM_DEBUG(llvm::dbgs()
- << "emitByte(" << uint64_t(byte) << ")\t" << desc << '\n');
+ LDBG() << "emitByte(" << uint64_t(byte) << ")\t" << desc;
currentResult.push_back(static_cast<uint8_t>(byte));
}
/// Emit a range of bytes.
void emitBytes(ArrayRef<uint8_t> bytes, StringLiteral desc) {
- LLVM_DEBUG(llvm::dbgs()
- << "emitBytes(" << bytes.size() << "b)\t" << desc << '\n');
+ LDBG() << "emitBytes(" << bytes.size() << "b)\t" << desc;
llvm::append_range(currentResult, bytes);
}
@@ -229,7 +226,7 @@ public:
/// additional bytes, provide the value of the integer encoded in
/// little-endian order.
void emitVarInt(uint64_t value, StringLiteral desc) {
- LLVM_DEBUG(llvm::dbgs() << "emitVarInt(" << value << ")\t" << desc << '\n');
+ LDBG() << "emitVarInt(" << value << ")\t" << desc;
// In the most common case, the value can be represented in a single byte.
// Given how hot this case is, explicitly handle that here.
diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
index 306cebd..2dbb993 100644
--- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
+++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
@@ -68,6 +68,10 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths,
return wrap(jitOrError->release());
}
+extern "C" void mlirExecutionEngineInitialize(MlirExecutionEngine jit) {
+ unwrap(jit)->initialize();
+}
+
extern "C" void mlirExecutionEngineDestroy(MlirExecutionEngine jit) {
delete (unwrap(jit));
}
@@ -106,9 +110,8 @@ extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit,
void *sym) {
unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
llvm::orc::SymbolMap symbolMap;
- symbolMap[interner(unwrap(name))] =
- { llvm::orc::ExecutorAddr::fromPtr(sym),
- llvm::JITSymbolFlags::Exported };
+ symbolMap[interner(unwrap(name))] = {llvm::orc::ExecutorAddr::fromPtr(sym),
+ llvm::JITSymbolFlags::Exported};
return symbolMap;
});
}
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 9d8554a..f5f4ed3 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -465,10 +465,6 @@ MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc,
return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType)));
}
-MlirType mlirUnrankedTensorTypeGetElementType(MlirType type) {
- return wrap(llvm::cast<UnrankedTensorType>(unwrap(type)).getElementType());
-}
-
//===----------------------------------------------------------------------===//
// Ranked / Unranked MemRef type.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 8491553..c7069f0 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -465,6 +465,10 @@ MlirModule mlirModuleFromOperation(MlirOperation op) {
return wrap(dyn_cast<ModuleOp>(unwrap(op)));
}
+bool mlirModuleEqual(MlirModule lhs, MlirModule rhs) {
+ return unwrap(lhs) == unwrap(rhs);
+}
+
//===----------------------------------------------------------------------===//
// Operation state API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt
index 191b5ab6..91ed05f 100644
--- a/mlir/lib/CMakeLists.txt
+++ b/mlir/lib/CMakeLists.txt
@@ -13,6 +13,7 @@ add_subdirectory(Parser)
add_subdirectory(Pass)
add_subdirectory(Query)
add_subdirectory(Reducer)
+add_subdirectory(Remark)
add_subdirectory(Rewrite)
add_subdirectory(Support)
add_subdirectory(TableGen)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 64720bf..203790e 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
@@ -1876,6 +1877,54 @@ struct AMDGPUSwizzleBitModeLowering
}
};
+struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (chipset < kGfx950)
+ return op->emitOpError("permlane_swap is only supported on gfx950+");
+
+ Location loc = op.getLoc();
+ Type i32 = rewriter.getI32Type();
+ Value src = adaptor.getSrc();
+ unsigned rowLength = op.getRowLength();
+ bool fi = op.getFetchInactive();
+ bool boundctrl = op.getBoundCtrl();
+
+ SmallVector<Value> decomposed =
+ LLVM::decomposeValue(rewriter, loc, src, i32);
+
+ SmallVector<Value> permuted;
+ for (Value v : decomposed) {
+ Value res;
+ Type i32pair = LLVM::LLVMStructType::getLiteral(
+ rewriter.getContext(), {v.getType(), v.getType()});
+
+ if (rowLength == 16)
+ res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
+ boundctrl);
+ else if (rowLength == 32)
+ res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
+ boundctrl);
+ else
+ llvm_unreachable("unsupported row length");
+
+ Value vdstNew = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
+ permuted.emplace_back(vdstNew);
+ }
+
+ Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
struct ConvertAMDGPUToROCDLPass
: public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
using Base::Base;
@@ -1944,6 +1993,6 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
- TransposeLoadOpLowering>(converter, chipset);
+ TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 515fe5c..b68933d 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -610,16 +610,19 @@ public:
? rewriter.getIntegerAttr(arithmeticType, 0)
: rewriter.getIndexAttr(0)));
- emitc::ExpressionOp ternary = emitc::ExpressionOp::create(
- rewriter, op.getLoc(), arithmeticType, /*do_not_inline=*/false);
- Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
+ emitc::ExpressionOp ternary =
+ emitc::ExpressionOp::create(rewriter, op.getLoc(), arithmeticType,
+ ValueRange({lhs, rhs, excessCheck, poison}),
+ /*do_not_inline=*/false);
+ Block &bodyBlock = ternary.createBody();
auto currentPoint = rewriter.getInsertionPoint();
rewriter.setInsertionPointToStart(&bodyBlock);
Value arithmeticResult =
- EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
- Value resultOrPoison =
- emitc::ConditionalOp::create(rewriter, op.getLoc(), arithmeticType,
- excessCheck, arithmeticResult, poison);
+ EmitCOp::create(rewriter, op.getLoc(), arithmeticType,
+ bodyBlock.getArgument(0), bodyBlock.getArgument(1));
+ Value resultOrPoison = emitc::ConditionalOp::create(
+ rewriter, op.getLoc(), arithmeticType, bodyBlock.getArgument(2),
+ arithmeticResult, bodyBlock.getArgument(3));
emitc::YieldOp::create(rewriter, op.getLoc(), resultOrPoison);
rewriter.setInsertionPoint(op->getBlock(), currentPoint);
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 18e857c..cb0c829 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -238,6 +238,16 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
ConversionPatternRewriter &rewriter) const override;
};
+struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ using Adaptor =
+ typename ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor;
+
+ LogicalResult
+ matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -480,6 +490,32 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
}
//===----------------------------------------------------------------------===//
+// SelectOpOneToNLowering
+//===----------------------------------------------------------------------===//
+
+/// Pattern for arith.select where the true/false values lower to multiple
+/// SSA values (1:N conversion). This pattern generates multiple arith.select
+/// than can be lowered by the 1:1 arith.select pattern.
+LogicalResult SelectOpOneToNLowering::matchAndRewrite(
+ arith::SelectOp op, Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // In case of a 1:1 conversion, the 1:1 pattern will match.
+ if (llvm::hasSingleElement(adaptor.getTrueValue()))
+ return rewriter.notifyMatchFailure(
+ op, "not a 1:N conversion, 1:1 pattern will match");
+ if (!op.getCondition().getType().isInteger(1))
+ return rewriter.notifyMatchFailure(op,
+ "non-i1 conditions are not supported");
+ SmallVector<Value> results;
+ for (auto [trueValue, falseValue] :
+ llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue()))
+ results.push_back(arith::SelectOp::create(
+ rewriter, op.getLoc(), op.getCondition(), trueValue, falseValue));
+ rewriter.replaceOpWithMultiple(op, {results});
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
@@ -587,6 +623,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
RemSIOpLowering,
RemUIOpLowering,
SelectOpLowering,
+ SelectOpOneToNLowering,
ShLIOpLowering,
ShRSIOpLowering,
ShRUIOpLowering,
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index e28d5122..c69ede9 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -333,7 +333,7 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType,
tileLoadOp.getBase(),
memrefIndices, maskOp1D,
- /*passthru=*/pad1DOp);
+ /*passthrough=*/pad1DOp);
// Create 'arm_sme.insert_tile_slice' to insert slice into tile.
auto insertSlice = arm_sme::InsertTileSliceOp::create(
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 785cb82..71986f8 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -50,6 +50,7 @@ add_subdirectory(NVVMToLLVM)
add_subdirectory(OpenACCToSCF)
add_subdirectory(OpenMPToLLVM)
add_subdirectory(PDLToPDLInterp)
+add_subdirectory(PtrToLLVM)
add_subdirectory(ReconcileUnrealizedCasts)
add_subdirectory(SCFToControlFlow)
add_subdirectory(SCFToEmitC)
@@ -68,6 +69,7 @@ add_subdirectory(TosaToSCF)
add_subdirectory(TosaToTensor)
add_subdirectory(UBToLLVM)
add_subdirectory(UBToSPIRV)
+add_subdirectory(VectorToAMX)
add_subdirectory(VectorToArmSME)
add_subdirectory(VectorToGPU)
add_subdirectory(VectorToLLVM)
@@ -75,3 +77,4 @@ add_subdirectory(VectorToSCF)
add_subdirectory(VectorToSPIRV)
add_subdirectory(VectorToXeGPU)
add_subdirectory(XeVMToLLVM)
+add_subdirectory(XeGPUToXeVM)
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 35ad99c..7a3a7fd 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -56,22 +56,30 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> {
private:
std::string funcName;
};
+
+// Rewrite complex.pow(z, w) -> complex.exp(w * complex.log(z))
+struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
+ using OpRewritePattern<complex::PowOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(complex::PowOp op,
+ PatternRewriter &rewriter) const final {
+ Location loc = op.getLoc();
+ Value logBase = rewriter.create<complex::LogOp>(loc, op.getLhs());
+ Value mul = rewriter.create<complex::MulOp>(loc, op.getRhs(), logBase);
+ Value exp = rewriter.create<complex::ExpOp>(loc, mul);
+ rewriter.replaceOp(op, exp);
+ return success();
+ }
+};
} // namespace
void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
RewritePatternSet &patterns) {
+ patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext());
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
patterns.getContext(), "__ocml_cabs_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
patterns.getContext(), "__ocml_cabs_f64");
- patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>(
- patterns.getContext(), "__ocml_carg_f32");
- patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>(
- patterns.getContext(), "__ocml_carg_f64");
- patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>(
- patterns.getContext(), "__ocml_conj_f32");
- patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>(
- patterns.getContext(), "__ocml_conj_f64");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>(
patterns.getContext(), "__ocml_ccos_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>(
@@ -84,10 +92,6 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
patterns.getContext(), "__ocml_clog_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>(
patterns.getContext(), "__ocml_clog_f64");
- patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>(
- patterns.getContext(), "__ocml_cpow_f32");
- patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>(
- patterns.getContext(), "__ocml_cpow_f64");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>(
patterns.getContext(), "__ocml_csin_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>(
@@ -122,10 +126,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<func::FuncDialect>();
- target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp,
- complex::CosOp, complex::ExpOp, complex::LogOp,
- complex::PowOp, complex::SinOp, complex::SqrtOp,
- complex::TanOp, complex::TanhOp>();
+ target.addLegalOp<complex::MulOp>();
+ target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
+ complex::LogOp, complex::PowOp, complex::SinOp,
+ complex::SqrtOp, complex::TanOp, complex::TanhOp>();
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index ff6d369..798d8b0 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -125,22 +125,33 @@ static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
return rewriter.applySignatureConversion(block, *conversion, converter);
}
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+ SmallVector<Value> result;
+ for (const ValueRange &vals : values)
+ llvm::append_range(result, vals);
+ return result;
+}
+
/// Convert the destination block signature (if necessary) and lower the branch
/// op to llvm.br.
struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
+ using Adaptor =
+ typename ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor;
LogicalResult
- matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
+ matchAndRewrite(cf::BranchOp op, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
FailureOr<Block *> convertedBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
- TypeRange(adaptor.getOperands()));
+ TypeRange(ValueRange(flattenedAdaptor)));
if (failed(convertedBlock))
return failure();
DictionaryAttr attrs = op->getAttrDictionary();
Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
- op, adaptor.getOperands(), *convertedBlock);
+ op, flattenedAdaptor, *convertedBlock);
// TODO: We should not just forward all attributes like that. But there are
// existing Flang tests that depend on this behavior.
newOp->setAttrs(attrs);
@@ -152,29 +163,37 @@ struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
/// branch op to llvm.cond_br.
struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
+ using Adaptor =
+ typename ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor;
LogicalResult
- matchAndRewrite(cf::CondBranchOp op,
- typename cf::CondBranchOp::Adaptor adaptor,
+ matchAndRewrite(cf::CondBranchOp op, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> flattenedAdaptorTrue =
+ flattenValues(adaptor.getTrueDestOperands());
+ SmallVector<Value> flattenedAdaptorFalse =
+ flattenValues(adaptor.getFalseDestOperands());
+ if (!llvm::hasSingleElement(adaptor.getCondition()))
+ return rewriter.notifyMatchFailure(op,
+ "expected single element condition");
FailureOr<Block *> convertedTrueBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
- TypeRange(adaptor.getTrueDestOperands()));
+ TypeRange(ValueRange(flattenedAdaptorTrue)));
if (failed(convertedTrueBlock))
return failure();
FailureOr<Block *> convertedFalseBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
- TypeRange(adaptor.getFalseDestOperands()));
+ TypeRange(ValueRange(flattenedAdaptorFalse)));
if (failed(convertedFalseBlock))
return failure();
- DictionaryAttr attrs = op->getAttrDictionary();
+ DictionaryAttr attrs = op->getDiscardableAttrDictionary();
auto newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
- op, adaptor.getCondition(), adaptor.getTrueDestOperands(),
- adaptor.getFalseDestOperands(), op.getBranchWeightsAttr(),
+ op, llvm::getSingleElement(adaptor.getCondition()),
+ flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
*convertedTrueBlock, *convertedFalseBlock);
// TODO: We should not just forward all attributes like that. But there are
// existing Flang tests that depend on this behavior.
- newOp->setAttrs(attrs);
+ newOp->setDiscardableAttrs(attrs);
return success();
}
};
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
index ed5d6d4..764ad2e 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
@@ -14,6 +14,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/DebugLog.h"
#include <memory>
#define DEBUG_TYPE "convert-to-llvm"
@@ -31,7 +32,8 @@ namespace {
class ConvertToLLVMPassInterface {
public:
ConvertToLLVMPassInterface(MLIRContext *context,
- ArrayRef<std::string> filterDialects);
+ ArrayRef<std::string> filterDialects,
+ bool allowPatternRollback = true);
virtual ~ConvertToLLVMPassInterface() = default;
/// Get the dependent dialects used by `convert-to-llvm`.
@@ -60,6 +62,9 @@ protected:
MLIRContext *context;
/// List of dialects names to use as filters.
ArrayRef<std::string> filterDialects;
+ /// An experimental flag to disallow pattern rollback. This is more efficient
+ /// but not supported by all lowering patterns.
+ bool allowPatternRollback;
};
/// This DialectExtension can be attached to the context, which will invoke the
@@ -75,13 +80,13 @@ public:
void apply(MLIRContext *context,
MutableArrayRef<Dialect *> dialects) const final {
- LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM extension load\n");
+ LDBG() << "Convert to LLVM extension load";
for (Dialect *dialect : dialects) {
auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
if (!iface)
continue;
- LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM found dialect interface for "
- << dialect->getNamespace() << "\n");
+ LDBG() << "Convert to LLVM found dialect interface for "
+ << dialect->getNamespace();
iface->loadDependentDialects(context);
}
}
@@ -128,7 +133,9 @@ struct StaticConvertToLLVM : public ConvertToLLVMPassInterface {
/// Apply the conversion driver.
LogicalResult transform(Operation *op, AnalysisManager manager) const final {
- if (failed(applyPartialConversion(op, *target, *patterns)))
+ ConversionConfig config;
+ config.allowPatternRollback = allowPatternRollback;
+ if (failed(applyPartialConversion(op, *target, *patterns, config)))
return failure();
return success();
}
@@ -179,7 +186,9 @@ struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface {
patterns);
// Apply the conversion.
- if (failed(applyPartialConversion(op, target, std::move(patterns))))
+ ConversionConfig config;
+ config.allowPatternRollback = allowPatternRollback;
+ if (failed(applyPartialConversion(op, target, std::move(patterns), config)))
return failure();
return success();
}
@@ -206,9 +215,11 @@ public:
std::shared_ptr<ConvertToLLVMPassInterface> impl;
// Choose the pass implementation.
if (useDynamic)
- impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects);
+ impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects,
+ allowPatternRollback);
else
- impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects);
+ impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects,
+ allowPatternRollback);
if (failed(impl->initialize()))
return failure();
this->impl = impl;
@@ -228,8 +239,10 @@ public:
//===----------------------------------------------------------------------===//
ConvertToLLVMPassInterface::ConvertToLLVMPassInterface(
- MLIRContext *context, ArrayRef<std::string> filterDialects)
- : context(context), filterDialects(filterDialects) {}
+ MLIRContext *context, ArrayRef<std::string> filterDialects,
+ bool allowPatternRollback)
+ : context(context), filterDialects(filterDialects),
+ allowPatternRollback(allowPatternRollback) {}
void ConvertToLLVMPassInterface::getDependentDialects(
DialectRegistry &registry) {
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 67bb1c1..42c76ed 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -527,19 +527,21 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
using Super = CallOpInterfaceLowering<CallOpType>;
using Base = ConvertOpToLLVMPattern<CallOpType>;
+ using Adaptor = typename ConvertOpToLLVMPattern<CallOpType>::OneToNOpAdaptor;
- LogicalResult matchAndRewriteImpl(CallOpType callOp,
- typename CallOpType::Adaptor adaptor,
+ LogicalResult matchAndRewriteImpl(CallOpType callOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter,
bool useBarePtrCallConv = false) const {
// Pack the result types into a struct.
Type packedResult = nullptr;
+ SmallVector<SmallVector<Type>> groupedResultTypes;
unsigned numResults = callOp.getNumResults();
auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
-
+ int64_t numConvertedTypes = 0;
if (numResults != 0) {
if (!(packedResult = this->getTypeConverter()->packFunctionResults(
- resultTypes, useBarePtrCallConv)))
+ resultTypes, useBarePtrCallConv, &groupedResultTypes,
+ &numConvertedTypes)))
return failure();
}
@@ -565,34 +567,64 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
static_cast<int32_t>(promoted.size()), 0};
newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
- SmallVector<Value, 4> results;
- if (numResults < 2) {
- // If < 2 results, packing did not do anything and we can just return.
- results.append(newOp.result_begin(), newOp.result_end());
- } else {
- // Otherwise, it had been converted to an operation producing a structure.
- // Extract individual results from the structure and return them as list.
- results.reserve(numResults);
- for (unsigned i = 0; i < numResults; ++i) {
- results.push_back(LLVM::ExtractValueOp::create(
- rewriter, callOp.getLoc(), newOp->getResult(0), i));
+ // Helper function that extracts an individual result from the return value
+ // of the new call op. llvm.call ops support only 0 or 1 result. In case of
+ // 2 or more results, the results are packed into a structure.
+ //
+ // The new call op may have more than 2 results because:
+ // a. The original call op has more than 2 results.
+ // b. An original op result type-converted to more than 1 result.
+ auto getUnpackedResult = [&](unsigned i) -> Value {
+ assert(numConvertedTypes > 0 && "convert op has no results");
+ if (numConvertedTypes == 1) {
+ assert(i == 0 && "out of bounds: converted op has only one result");
+ return newOp->getResult(0);
}
+ // Results have been converted to a structure. Extract individual results
+ // from the structure.
+ return LLVM::ExtractValueOp::create(rewriter, callOp.getLoc(),
+ newOp->getResult(0), i);
+ };
+
+ // Group the results into a vector of vectors, such that it is clear which
+ // original op result is replaced with which range of values. (In case of a
+ // 1:N conversion, there can be multiple replacements for a single result.)
+ SmallVector<SmallVector<Value>> results;
+ results.reserve(numResults);
+ unsigned counter = 0;
+ for (unsigned i = 0; i < numResults; ++i) {
+ SmallVector<Value> &group = results.emplace_back();
+ for (unsigned j = 0, e = groupedResultTypes[i].size(); j < e; ++j)
+ group.push_back(getUnpackedResult(counter++));
}
- if (useBarePtrCallConv) {
- // For the bare-ptr calling convention, promote memref results to
- // descriptors.
- assert(results.size() == resultTypes.size() &&
- "The number of arguments and types doesn't match");
- this->getTypeConverter()->promoteBarePtrsToDescriptors(
- rewriter, callOp.getLoc(), resultTypes, results);
- } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
- resultTypes, results,
- /*toDynamic=*/false))) {
- return failure();
+ // Special handling for MemRef types.
+ for (unsigned i = 0; i < numResults; ++i) {
+ Type origType = resultTypes[i];
+ auto memrefType = dyn_cast<MemRefType>(origType);
+ auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(origType);
+ if (useBarePtrCallConv && memrefType) {
+ // For the bare-ptr calling convention, promote memref results to
+ // descriptors.
+ assert(results[i].size() == 1 && "expected one converted result");
+ results[i].front() = MemRefDescriptor::fromStaticShape(
+ rewriter, callOp.getLoc(), *this->getTypeConverter(), memrefType,
+ results[i].front());
+ }
+ if (unrankedMemrefType) {
+ assert(!useBarePtrCallConv && "unranked memref is not supported in the "
+ "bare-ptr calling convention");
+ assert(results[i].size() == 1 && "expected one converted result");
+ Value desc = this->copyUnrankedDescriptor(
+ rewriter, callOp.getLoc(), unrankedMemrefType, results[i].front(),
+ /*toDynamic=*/false);
+ if (!desc)
+ return failure();
+ results[i].front() = desc;
+ }
}
- rewriter.replaceOp(callOp, results);
+ rewriter.replaceOpWithMultiple(callOp, results);
return success();
}
};
@@ -606,7 +638,7 @@ public:
symbolTables(symbolTables) {}
LogicalResult
- matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
+ matchAndRewrite(func::CallOp callOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
bool useBarePtrCallConv = false;
if (getTypeConverter()->getOptions().useBarePtrCallConv) {
@@ -636,7 +668,7 @@ struct CallIndirectOpLowering
using Super::Super;
LogicalResult
- matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
+ matchAndRewrite(func::CallIndirectOp callIndirectOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
}
@@ -679,41 +711,50 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
+ matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- unsigned numArguments = op.getNumOperands();
SmallVector<Value, 4> updatedOperands;
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
bool useBarePtrCallConv =
shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
- if (useBarePtrCallConv) {
- // For the bare-ptr calling convention, extract the aligned pointer to
- // be returned from the memref descriptor.
- for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
- Type oldTy = std::get<0>(it).getType();
- Value newOperand = std::get<1>(it);
- if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
- cast<BaseMemRefType>(oldTy))) {
- MemRefDescriptor memrefDesc(newOperand);
- newOperand = memrefDesc.allocatedPtr(rewriter, loc);
- } else if (isa<UnrankedMemRefType>(oldTy)) {
+
+ for (auto [oldOperand, newOperands] :
+ llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
+ Type oldTy = oldOperand.getType();
+ if (auto memRefType = dyn_cast<MemRefType>(oldTy)) {
+ assert(newOperands.size() == 1 && "expected one converted result");
+ if (useBarePtrCallConv &&
+ getTypeConverter()->canConvertToBarePtr(memRefType)) {
+ // For the bare-ptr calling convention, extract the aligned pointer to
+ // be returned from the memref descriptor.
+ MemRefDescriptor memrefDesc(newOperands.front());
+ updatedOperands.push_back(memrefDesc.allocatedPtr(rewriter, loc));
+ continue;
+ }
+ } else if (auto unrankedMemRefType =
+ dyn_cast<UnrankedMemRefType>(oldTy)) {
+ assert(newOperands.size() == 1 && "expected one converted result");
+ if (useBarePtrCallConv) {
// Unranked memref is not supported in the bare pointer calling
// convention.
return failure();
}
- updatedOperands.push_back(newOperand);
+ Value updatedDesc =
+ copyUnrankedDescriptor(rewriter, loc, unrankedMemRefType,
+ newOperands.front(), /*toDynamic=*/true);
+ if (!updatedDesc)
+ return failure();
+ updatedOperands.push_back(updatedDesc);
+ continue;
}
- } else {
- updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
- (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
- updatedOperands,
- /*toDynamic=*/true);
+
+ llvm::append_range(updatedOperands, newOperands);
}
// If ReturnOp has 0 or 1 operand, create it and return immediately.
- if (numArguments <= 1) {
+ if (updatedOperands.size() <= 1) {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
op, TypeRange(), updatedOperands, op->getAttrs());
return success();
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 3cfbd89..e516118 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -532,6 +532,9 @@ void GpuToLLVMConversionPass::runOnOperation() {
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
vector::populateVectorTransferLoweringPatterns(patterns,
/*maxTransferRank=*/1);
+ // Transform N-D vector.from_elements to 1-D vector.from_elements before
+ // conversion.
+ vector::populateVectorFromElementsLoweringPatterns(patterns);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 317bfc2..93e370d 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -27,6 +27,7 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -369,6 +370,9 @@ struct LowerGpuOpsToNVVMOpsPass final
{
RewritePatternSet patterns(m.getContext());
populateGpuRewritePatterns(patterns);
+ // Transform N-D vector.from_elements to 1-D vector.from_elements before
+ // conversion.
+ vector::populateVectorFromElementsLoweringPatterns(patterns);
if (failed(applyPatternsGreedily(m, std::move(patterns))))
return signalPassFailure();
}
@@ -394,7 +398,7 @@ struct LowerGpuOpsToNVVMOpsPass final
if (!allowedDialectsSet.empty() && !allowed)
continue;
- auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
+ auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
if (!iface) {
// Error out if dialect was explicily specified but doesn't implement
// conversion interface.
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index d22364e..8994905 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -79,17 +79,30 @@ static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
return canBeBare;
}
-static Value getLaneId(ConversionPatternRewriter &rewriter, Location loc,
- const unsigned indexBitwidth) {
+static Value getLaneId(RewriterBase &rewriter, Location loc) {
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
- Value mbcntLo = ROCDL::MbcntLoOp::create(rewriter, loc, int32Type,
- ValueRange{minus1, zero});
- Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, int32Type,
- ValueRange{minus1, mbcntLo});
+ NamedAttribute noundef = rewriter.getNamedAttr(
+ LLVM::LLVMDialect::getNoUndefAttrName(), rewriter.getUnitAttr());
+ NamedAttribute lowRange = rewriter.getNamedAttr(
+ LLVM::LLVMDialect::getRangeAttrName(),
+ LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
+ APInt(32, 32)));
+ NamedAttribute highRange = rewriter.getNamedAttr(
+ LLVM::LLVMDialect::getRangeAttrName(),
+ LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
+ APInt(32, 64)));
+ Value mbcntLo = ROCDL::MbcntLoOp::create(
+ rewriter, loc, int32Type, minus1, zero, /*arg_attrs=*/{},
+ /*res_attrs=*/
+ rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, lowRange})));
+ Value laneId = ROCDL::MbcntHiOp::create(
+ rewriter, loc, int32Type, minus1, mbcntLo, /*arg_attrs=*/{},
+ rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, highRange})));
return laneId;
}
+
static constexpr StringLiteral amdgcnDataLayout =
"e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32"
"-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:"
@@ -104,18 +117,16 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
LogicalResult
matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
+ Location loc = op.getLoc();
MLIRContext *context = rewriter.getContext();
- // convert to: %mlo = call @llvm.amdgcn.mbcnt.lo(-1, 0)
- // followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo)
-
- Type intTy = IntegerType::get(context, 32);
- Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
- Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
- Value mbcntLo = ROCDL::MbcntLoOp::create(rewriter, loc, intTy,
- ValueRange{minus1, zero});
- Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, intTy,
- ValueRange{minus1, mbcntLo});
+ // convert to:
+ // %mlo = call noundef range(i32 0, 32)
+ // @llvm.amdgcn.mbcnt.lo(-1, 0)
+ // followed by:
+ // %lid = call noundef range(i32 0, 64)
+ // @llvm.amdgcn.mbcnt.hi(-1, %mlo)
+
+ Value laneId = getLaneId(rewriter, loc);
// Truncate or extend the result depending on the index bitwidth specified
// by the LLVMTypeConverter options.
const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
@@ -160,6 +171,38 @@ struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> {
const amdgpu::Chipset chipset;
};
+static bool isSupportedReadLaneType(Type type) {
+ // read(first)lane also supports some vector types, but limit it for scalars
+ // for now.
+ return type.isInteger(16) || type.isInteger(32) || type.isInteger(64) ||
+ isa<Float16Type, BFloat16Type, Float32Type, Float64Type,
+ LLVM::LLVMPointerType>(type);
+}
+
+struct GPUSubgroupBroadcastOpToROCDL
+ : public ConvertOpToLLVMPattern<gpu::SubgroupBroadcastOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Value src = adaptor.getSrc();
+ if (!isSupportedReadLaneType(src.getType()))
+ return rewriter.notifyMatchFailure(op, "unsupported readlane type");
+
+ if (adaptor.getBroadcastType() == gpu::BroadcastType::specific_lane) {
+ rewriter.replaceOpWithNewOp<ROCDL::ReadlaneOp>(op, src.getType(), src,
+ adaptor.getLane());
+ } else { // first_active_lane or any_lane
+ // any_lane is lowered to readfirstlane too, to force value into scalar
+ // register.
+ rewriter.replaceOpWithNewOp<ROCDL::ReadfirstlaneOp>(op, src.getType(),
+ src);
+ }
+ return success();
+ }
+};
+
struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
@@ -185,8 +228,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
Location loc = op->getLoc();
Value initShflValue = adaptor.getValue();
- const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
- Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);
+ Value srcLaneId = getLaneId(rewriter, loc);
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
Value width = adaptor.getWidth();
@@ -317,7 +359,7 @@ struct LowerGpuOpsToROCDLOpsPass final
{
RewritePatternSet patterns(ctx);
populateGpuRewritePatterns(patterns);
- populateGpuPromoteShuffleToAMDGPUPatterns(patterns);
+ populateGpuPromoteShuffleToAMDGPUPatterns(patterns, maybeChipset);
(void)applyPatternsGreedily(m, std::move(patterns));
}
@@ -453,7 +495,8 @@ void mlir::populateGpuToROCDLConversionPatterns(
// TODO: Add alignment for workgroup memory
patterns.add<GPUDynamicSharedMemoryOpLowering>(converter);
- patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
+ patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL,
+ GPUSubgroupBroadcastOpToROCDL>(converter);
patterns.add<GPUSubgroupSizeOpToROCDL>(converter, chipset);
populateMathToROCDLConversionPatterns(converter, patterns);
diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
index fce7a3f..522e914 100644
--- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -353,14 +353,9 @@ void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
results.push_back(d.memRefDescPtr(builder, loc));
}
-void UnrankedMemRefDescriptor::computeSizes(
+Value UnrankedMemRefDescriptor::computeSize(
OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter,
- ArrayRef<UnrankedMemRefDescriptor> values, ArrayRef<unsigned> addressSpaces,
- SmallVectorImpl<Value> &sizes) {
- if (values.empty())
- return;
- assert(values.size() == addressSpaces.size() &&
- "must provide address space for each descriptor");
+ UnrankedMemRefDescriptor desc, unsigned addressSpace) {
// Cache the index type.
Type indexType = typeConverter.getIndexType();
@@ -371,34 +366,31 @@ void UnrankedMemRefDescriptor::computeSizes(
builder, loc, indexType,
llvm::divideCeil(typeConverter.getIndexTypeBitwidth(), 8));
- sizes.reserve(sizes.size() + values.size());
- for (auto [desc, addressSpace] : llvm::zip(values, addressSpaces)) {
- // Emit IR computing the memory necessary to store the descriptor. This
- // assumes the descriptor to be
- // { type*, type*, index, index[rank], index[rank] }
- // and densely packed, so the total size is
- // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
- // TODO: consider including the actual size (including eventual padding due
- // to data layout) into the unranked descriptor.
- Value pointerSize = createIndexAttrConstant(
- builder, loc, indexType,
- llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
- Value doublePointerSize =
- LLVM::MulOp::create(builder, loc, indexType, two, pointerSize);
-
- // (1 + 2 * rank) * sizeof(index)
- Value rank = desc.rank(builder, loc);
- Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank);
- Value doubleRankIncremented =
- LLVM::AddOp::create(builder, loc, indexType, doubleRank, one);
- Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType,
- doubleRankIncremented, indexSize);
-
- // Total allocation size.
- Value allocationSize = LLVM::AddOp::create(
- builder, loc, indexType, doublePointerSize, rankIndexSize);
- sizes.push_back(allocationSize);
- }
+ // Emit IR computing the memory necessary to store the descriptor. This
+ // assumes the descriptor to be
+ // { type*, type*, index, index[rank], index[rank] }
+ // and densely packed, so the total size is
+ // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
+ // TODO: consider including the actual size (including eventual padding due
+ // to data layout) into the unranked descriptor.
+ Value pointerSize = createIndexAttrConstant(
+ builder, loc, indexType,
+ llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
+ Value doublePointerSize =
+ LLVM::MulOp::create(builder, loc, indexType, two, pointerSize);
+
+ // (1 + 2 * rank) * sizeof(index)
+ Value rank = desc.rank(builder, loc);
+ Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank);
+ Value doubleRankIncremented =
+ LLVM::AddOp::create(builder, loc, indexType, doubleRank, one);
+ Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType,
+ doubleRankIncremented, indexSize);
+
+ // Total allocation size.
+ Value allocationSize = LLVM::AddOp::create(builder, loc, indexType,
+ doublePointerSize, rankIndexSize);
+ return allocationSize;
}
Value UnrankedMemRefDescriptor::allocatedPtr(
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 2568044..48a0319 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -216,34 +216,14 @@ MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
return memRefDescriptor;
}
-LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
- OpBuilder &builder, Location loc, TypeRange origTypes,
- SmallVectorImpl<Value> &operands, bool toDynamic) const {
- assert(origTypes.size() == operands.size() &&
- "expected as may original types as operands");
-
- // Find operands of unranked memref type and store them.
- SmallVector<UnrankedMemRefDescriptor> unrankedMemrefs;
- SmallVector<unsigned> unrankedAddressSpaces;
- for (unsigned i = 0, e = operands.size(); i < e; ++i) {
- if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
- unrankedMemrefs.emplace_back(operands[i]);
- FailureOr<unsigned> addressSpace =
- getTypeConverter()->getMemRefAddressSpace(memRefType);
- if (failed(addressSpace))
- return failure();
- unrankedAddressSpaces.emplace_back(*addressSpace);
- }
- }
-
- if (unrankedMemrefs.empty())
- return success();
-
- // Compute allocation sizes.
- SmallVector<Value> sizes;
- UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(),
- unrankedMemrefs, unrankedAddressSpaces,
- sizes);
+Value ConvertToLLVMPattern::copyUnrankedDescriptor(
+ OpBuilder &builder, Location loc, UnrankedMemRefType memRefType,
+ Value operand, bool toDynamic) const {
+ // Convert memory space.
+ FailureOr<unsigned> addressSpace =
+ getTypeConverter()->getMemRefAddressSpace(memRefType);
+ if (failed(addressSpace))
+ return {};
// Get frequently used types.
Type indexType = getTypeConverter()->getIndexType();
@@ -254,52 +234,61 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
if (toDynamic) {
mallocFunc = LLVM::lookupOrCreateMallocFn(builder, module, indexType);
if (failed(mallocFunc))
- return failure();
+ return {};
}
if (!toDynamic) {
freeFunc = LLVM::lookupOrCreateFreeFn(builder, module);
if (failed(freeFunc))
- return failure();
+ return {};
}
- unsigned unrankedMemrefPos = 0;
- for (unsigned i = 0, e = operands.size(); i < e; ++i) {
- Type type = origTypes[i];
- if (!isa<UnrankedMemRefType>(type))
- continue;
- Value allocationSize = sizes[unrankedMemrefPos++];
- UnrankedMemRefDescriptor desc(operands[i]);
-
- // Allocate memory, copy, and free the source if necessary.
- Value memory =
- toDynamic ? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
- allocationSize)
- .getResult()
- : LLVM::AllocaOp::create(builder, loc, getPtrType(),
- IntegerType::get(getContext(), 8),
- allocationSize,
- /*alignment=*/0);
- Value source = desc.memRefDescPtr(builder, loc);
- LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false);
- if (!toDynamic)
- LLVM::CallOp::create(builder, loc, freeFunc.value(), source);
-
- // Create a new descriptor. The same descriptor can be returned multiple
- // times, attempting to modify its pointer can lead to memory leaks
- // (allocated twice and overwritten) or double frees (the caller does not
- // know if the descriptor points to the same memory).
- Type descriptorType = getTypeConverter()->convertType(type);
- if (!descriptorType)
- return failure();
- auto updatedDesc =
- UnrankedMemRefDescriptor::poison(builder, loc, descriptorType);
- Value rank = desc.rank(builder, loc);
- updatedDesc.setRank(builder, loc, rank);
- updatedDesc.setMemRefDescPtr(builder, loc, memory);
+ UnrankedMemRefDescriptor desc(operand);
+ Value allocationSize = UnrankedMemRefDescriptor::computeSize(
+ builder, loc, *getTypeConverter(), desc, *addressSpace);
+
+ // Allocate memory, copy, and free the source if necessary.
+ Value memory = toDynamic
+ ? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
+ allocationSize)
+ .getResult()
+ : LLVM::AllocaOp::create(builder, loc, getPtrType(),
+ IntegerType::get(getContext(), 8),
+ allocationSize,
+ /*alignment=*/0);
+ Value source = desc.memRefDescPtr(builder, loc);
+ LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false);
+ if (!toDynamic)
+ LLVM::CallOp::create(builder, loc, freeFunc.value(), source);
+
+ // Create a new descriptor. The same descriptor can be returned multiple
+ // times, attempting to modify its pointer can lead to memory leaks
+ // (allocated twice and overwritten) or double frees (the caller does not
+ // know if the descriptor points to the same memory).
+ Type descriptorType = getTypeConverter()->convertType(memRefType);
+ if (!descriptorType)
+ return {};
+ auto updatedDesc =
+ UnrankedMemRefDescriptor::poison(builder, loc, descriptorType);
+ Value rank = desc.rank(builder, loc);
+ updatedDesc.setRank(builder, loc, rank);
+ updatedDesc.setMemRefDescPtr(builder, loc, memory);
+ return updatedDesc;
+}
- operands[i] = updatedDesc;
+LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
+ OpBuilder &builder, Location loc, TypeRange origTypes,
+ SmallVectorImpl<Value> &operands, bool toDynamic) const {
+ assert(origTypes.size() == operands.size() &&
+ "expected as may original types as operands");
+ for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+ if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
+ Value updatedDesc = copyUnrankedDescriptor(builder, loc, memRefType,
+ operands[i], toDynamic);
+ if (!updatedDesc)
+ return failure();
+ operands[i] = updatedDesc;
+ }
}
-
return success();
}
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 1a9bf56..cb9dea1 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -365,6 +365,7 @@ Type LLVMTypeConverter::convertFunctionSignatureImpl(
useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
: structFuncArgTypeConverter;
+
// Convert argument types one by one and check for errors.
for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
SmallVector<Type, 8> converted;
@@ -658,27 +659,19 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
/// UnrankedMemRefType, are converted following the specific rules for the
/// calling convention. Calling convention independent types are converted
/// following the default LLVM type conversions.
-Type LLVMTypeConverter::convertCallingConventionType(
- Type type, bool useBarePtrCallConv) const {
- if (useBarePtrCallConv)
- if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
- return convertMemRefToBarePtr(memrefTy);
-
- return convertType(type);
-}
+LogicalResult LLVMTypeConverter::convertCallingConventionType(
+ Type type, SmallVectorImpl<Type> &result, bool useBarePtrCallConv) const {
+ if (useBarePtrCallConv) {
+ if (auto memrefTy = dyn_cast<BaseMemRefType>(type)) {
+ Type converted = convertMemRefToBarePtr(memrefTy);
+ if (!converted)
+ return failure();
+ result.push_back(converted);
+ return success();
+ }
+ }
-/// Promote the bare pointers in 'values' that resulted from memrefs to
-/// descriptors. 'stdTypes' holds they types of 'values' before the conversion
-/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
-void LLVMTypeConverter::promoteBarePtrsToDescriptors(
- ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
- SmallVectorImpl<Value> &values) const {
- assert(stdTypes.size() == values.size() &&
- "The number of types and values doesn't match");
- for (unsigned i = 0, end = values.size(); i < end; ++i)
- if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
- values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
- memrefTy, values[i]);
+ return convertType(type, result);
}
/// Convert a non-empty list of types of values produced by an operation into an
@@ -706,23 +699,35 @@ Type LLVMTypeConverter::packOperationResults(TypeRange types) const {
/// LLVM-compatible type. In particular, if more than one value is returned,
/// create an LLVM dialect structure type with elements that correspond to each
/// of the types converted with `convertCallingConventionType`.
-Type LLVMTypeConverter::packFunctionResults(TypeRange types,
- bool useBarePtrCallConv) const {
+Type LLVMTypeConverter::packFunctionResults(
+ TypeRange types, bool useBarePtrCallConv,
+ SmallVector<SmallVector<Type>> *groupedTypes,
+ int64_t *numConvertedTypes) const {
assert(!types.empty() && "expected non-empty list of type");
+ assert((!groupedTypes || groupedTypes->empty()) &&
+ "expected groupedTypes to be empty");
useBarePtrCallConv |= options.useBarePtrCallConv;
- if (types.size() == 1)
- return convertCallingConventionType(types.front(), useBarePtrCallConv);
-
SmallVector<Type> resultTypes;
resultTypes.reserve(types.size());
+ size_t sizeBefore = 0;
for (auto t : types) {
- auto converted = convertCallingConventionType(t, useBarePtrCallConv);
- if (!converted || !LLVM::isCompatibleType(converted))
+ if (failed(
+ convertCallingConventionType(t, resultTypes, useBarePtrCallConv)))
return {};
- resultTypes.push_back(converted);
+ if (groupedTypes) {
+ SmallVector<Type> &group = groupedTypes->emplace_back();
+ llvm::append_range(group, ArrayRef(resultTypes).drop_front(sizeBefore));
+ }
+ sizeBefore = resultTypes.size();
}
+ if (numConvertedTypes)
+ *numConvertedTypes = resultTypes.size();
+ if (resultTypes.size() == 1)
+ return resultTypes.front();
+ if (resultTypes.empty())
+ return {};
return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
}
@@ -740,40 +745,50 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
return allocated;
}
-SmallVector<Value, 4>
-LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
- ValueRange operands, OpBuilder &builder,
- bool useBarePtrCallConv) const {
+SmallVector<Value, 4> LLVMTypeConverter::promoteOperands(
+ Location loc, ValueRange opOperands, ValueRange adaptorOperands,
+ OpBuilder &builder, bool useBarePtrCallConv) const {
+ SmallVector<ValueRange> ranges;
+ for (size_t i = 0, e = adaptorOperands.size(); i < e; i++)
+ ranges.push_back(adaptorOperands.slice(i, 1));
+ return promoteOperands(loc, opOperands, ranges, builder, useBarePtrCallConv);
+}
+
+SmallVector<Value, 4> LLVMTypeConverter::promoteOperands(
+ Location loc, ValueRange opOperands, ArrayRef<ValueRange> adaptorOperands,
+ OpBuilder &builder, bool useBarePtrCallConv) const {
SmallVector<Value, 4> promotedOperands;
- promotedOperands.reserve(operands.size());
+ promotedOperands.reserve(adaptorOperands.size());
useBarePtrCallConv |= options.useBarePtrCallConv;
- for (auto it : llvm::zip(opOperands, operands)) {
- auto operand = std::get<0>(it);
- auto llvmOperand = std::get<1>(it);
-
+ for (auto [operand, llvmOperand] :
+ llvm::zip_equal(opOperands, adaptorOperands)) {
if (useBarePtrCallConv) {
// For the bare-ptr calling convention, we only have to extract the
// aligned pointer of a memref.
if (isa<MemRefType>(operand.getType())) {
- MemRefDescriptor desc(llvmOperand);
- llvmOperand = desc.alignedPtr(builder, loc);
+ assert(llvmOperand.size() == 1 && "Expected a single operand");
+ MemRefDescriptor desc(llvmOperand.front());
+ promotedOperands.push_back(desc.alignedPtr(builder, loc));
+ continue;
} else if (isa<UnrankedMemRefType>(operand.getType())) {
llvm_unreachable("Unranked memrefs are not supported");
}
} else {
if (isa<UnrankedMemRefType>(operand.getType())) {
- UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
+ assert(llvmOperand.size() == 1 && "Expected a single operand");
+ UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand.front(),
promotedOperands);
continue;
}
if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
- MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType,
+ assert(llvmOperand.size() == 1 && "Expected a single operand");
+ MemRefDescriptor::unpack(builder, loc, llvmOperand.front(), memrefType,
promotedOperands);
continue;
}
}
- promotedOperands.push_back(llvmOperand);
+ llvm::append_range(promotedOperands, llvmOperand);
}
return promotedOperands;
}
@@ -802,11 +817,7 @@ mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
result.append(converted.begin(), converted.end());
return success();
}
- auto converted = converter.convertType(type);
- if (!converted)
- return failure();
- result.push_back(converted);
- return success();
+ return converter.convertType(type, result);
}
/// Callback to convert function argument types. It converts MemRef function
@@ -814,11 +825,7 @@ mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
LogicalResult
mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
SmallVectorImpl<Type> &result) {
- auto llvmTy = converter.convertCallingConventionType(
- type, /*useBarePointerCallConv=*/true);
- if (!llvmTy)
- return failure();
-
- result.push_back(llvmTy);
- return success();
+ return converter.convertCallingConventionType(
+ type, result,
+ /*useBarePointerCallConv=*/true);
}
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 6bd0e2d..2b7bdc9 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -17,11 +17,13 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cstdint>
+#include <numeric>
using namespace mlir;
@@ -97,6 +99,48 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
return resultTy;
}
+static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
+ OpBuilder &builder) {
+ assert(isMemRefTypeLegalForEmitC(memrefType) &&
+ "incompatible memref type for EmitC conversion");
+ emitc::CallOpaqueOp elementSize = emitc::CallOpaqueOp::create(
+ builder, loc, emitc::SizeTType::get(builder.getContext()),
+ builder.getStringAttr("sizeof"), ValueRange{},
+ ArrayAttr::get(builder.getContext(),
+ {TypeAttr::get(memrefType.getElementType())}));
+
+ IndexType indexType = builder.getIndexType();
+ int64_t numElements = std::accumulate(memrefType.getShape().begin(),
+ memrefType.getShape().end(), int64_t{1},
+ std::multiplies<int64_t>());
+ emitc::ConstantOp numElementsValue = emitc::ConstantOp::create(
+ builder, loc, indexType, builder.getIndexAttr(numElements));
+
+ Type sizeTType = emitc::SizeTType::get(builder.getContext());
+ emitc::MulOp totalSizeBytes = emitc::MulOp::create(
+ builder, loc, sizeTType, elementSize.getResult(0), numElementsValue);
+
+ return totalSizeBytes.getResult();
+}
+
+static emitc::ApplyOp
+createPointerFromEmitcArray(Location loc, OpBuilder &builder,
+ TypedValue<emitc::ArrayType> arrayValue) {
+
+ emitc::ConstantOp zeroIndex = emitc::ConstantOp::create(
+ builder, loc, builder.getIndexType(), builder.getIndexAttr(0));
+
+ emitc::ArrayType arrayType = arrayValue.getType();
+ llvm::SmallVector<mlir::Value> indices(arrayType.getRank(), zeroIndex);
+ emitc::SubscriptOp subPtr =
+ emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices));
+ emitc::ApplyOp ptr = emitc::ApplyOp::create(
+ builder, loc, emitc::PointerType::get(arrayType.getElementType()),
+ builder.getStringAttr("&"), subPtr);
+
+ return ptr;
+}
+
struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@@ -112,19 +156,21 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
Type elementType = memrefType.getElementType();
IndexType indexType = rewriter.getIndexType();
- emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>(
- loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{},
+ emitc::CallOpaqueOp sizeofElementOp = emitc::CallOpaqueOp::create(
+ rewriter, loc, sizeTType, rewriter.getStringAttr("sizeof"),
+ ValueRange{},
ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)}));
int64_t numElements = 1;
for (int64_t dimSize : memrefType.getShape()) {
numElements *= dimSize;
}
- Value numElementsValue = rewriter.create<emitc::ConstantOp>(
- loc, indexType, rewriter.getIndexAttr(numElements));
+ Value numElementsValue = emitc::ConstantOp::create(
+ rewriter, loc, indexType, rewriter.getIndexAttr(numElements));
- Value totalSizeBytes = rewriter.create<emitc::MulOp>(
- loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue);
+ Value totalSizeBytes =
+ emitc::MulOp::create(rewriter, loc, sizeTType,
+ sizeofElementOp.getResult(0), numElementsValue);
emitc::CallOpaqueOp allocCall;
StringAttr allocFunctionName;
@@ -132,8 +178,8 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
SmallVector<Value, 2> argsVec;
if (allocOp.getAlignment()) {
allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName);
- alignmentValue = rewriter.create<emitc::ConstantOp>(
- loc, sizeTType,
+ alignmentValue = emitc::ConstantOp::create(
+ rewriter, loc, sizeTType,
rewriter.getIntegerAttr(indexType,
allocOp.getAlignment().value_or(0)));
argsVec.push_back(alignmentValue);
@@ -144,21 +190,62 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
argsVec.push_back(totalSizeBytes);
ValueRange args(argsVec);
- allocCall = rewriter.create<emitc::CallOpaqueOp>(
- loc,
+ allocCall = emitc::CallOpaqueOp::create(
+ rewriter, loc,
emitc::PointerType::get(
emitc::OpaqueType::get(rewriter.getContext(), "void")),
allocFunctionName, args);
emitc::PointerType targetPointerType = emitc::PointerType::get(elementType);
- emitc::CastOp castOp = rewriter.create<emitc::CastOp>(
- loc, targetPointerType, allocCall.getResult(0));
+ emitc::CastOp castOp = emitc::CastOp::create(
+ rewriter, loc, targetPointerType, allocCall.getResult(0));
rewriter.replaceOp(allocOp, castOp);
return success();
}
};
+struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = copyOp.getLoc();
+ MemRefType srcMemrefType = cast<MemRefType>(copyOp.getSource().getType());
+ MemRefType targetMemrefType =
+ cast<MemRefType>(copyOp.getTarget().getType());
+
+ if (!isMemRefTypeLegalForEmitC(srcMemrefType))
+ return rewriter.notifyMatchFailure(
+ loc, "incompatible source memref type for EmitC conversion");
+
+ if (!isMemRefTypeLegalForEmitC(targetMemrefType))
+ return rewriter.notifyMatchFailure(
+ loc, "incompatible target memref type for EmitC conversion");
+
+ auto srcArrayValue =
+ cast<TypedValue<emitc::ArrayType>>(operands.getSource());
+ emitc::ApplyOp srcPtr =
+ createPointerFromEmitcArray(loc, rewriter, srcArrayValue);
+
+ auto targetArrayValue =
+ cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
+ emitc::ApplyOp targetPtr =
+ createPointerFromEmitcArray(loc, rewriter, targetArrayValue);
+
+ emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create(
+ rewriter, loc, TypeRange{}, "memcpy",
+ ValueRange{
+ targetPtr.getResult(), srcPtr.getResult(),
+ calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)});
+
+ rewriter.replaceOp(copyOp, memCpyCall.getResults());
+
+ return success();
+ }
+};
+
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
using OpConversionPattern::OpConversionPattern;
@@ -320,6 +407,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
void mlir::populateMemRefToEmitCConversionPatterns(
RewritePatternSet &patterns, const TypeConverter &converter) {
- patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
- ConvertLoad, ConvertStore>(converter, patterns.getContext());
+ patterns.add<ConvertAlloca, ConvertAlloc, ConvertCopy, ConvertGlobal,
+ ConvertGetGlobal, ConvertLoad, ConvertStore>(
+ converter, patterns.getContext());
}
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index e78dd76..a073a9a 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -18,6 +18,8 @@
#include "mlir/IR/Attributes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/StringRef.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
@@ -27,6 +29,15 @@ namespace mlir {
using namespace mlir;
namespace {
+
+emitc::IncludeOp addStandardHeader(OpBuilder &builder, ModuleOp module,
+ StringRef headerName) {
+ StringAttr includeAttr = builder.getStringAttr(headerName);
+ return emitc::IncludeOp::create(
+ builder, module.getLoc(), includeAttr,
+ /*is_standard_include=*/builder.getUnitAttr());
+}
+
struct ConvertMemRefToEmitCPass
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
using Base::Base;
@@ -55,34 +66,29 @@ struct ConvertMemRefToEmitCPass
return signalPassFailure();
mlir::ModuleOp module = getOperation();
+ llvm::SmallSet<StringRef, 4> existingHeaders;
+ mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
+ module.walk([&](mlir::emitc::IncludeOp includeOp) {
+ if (includeOp.getIsStandardInclude())
+ existingHeaders.insert(includeOp.getInclude());
+ });
+
module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
- if (callOp.getCallee() != alignedAllocFunctionName &&
- callOp.getCallee() != mallocFunctionName) {
+ StringRef expectedHeader;
+ if (callOp.getCallee() == alignedAllocFunctionName ||
+ callOp.getCallee() == mallocFunctionName)
+ expectedHeader = options.lowerToCpp ? cppStandardLibraryHeader
+ : cStandardLibraryHeader;
+ else if (callOp.getCallee() == memcpyFunctionName)
+ expectedHeader =
+ options.lowerToCpp ? cppStringLibraryHeader : cStringLibraryHeader;
+ else
return mlir::WalkResult::advance();
+ if (!existingHeaders.contains(expectedHeader)) {
+ addStandardHeader(builder, module, expectedHeader);
+ existingHeaders.insert(expectedHeader);
}
-
- for (auto &op : *module.getBody()) {
- emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
- if (!includeOp) {
- continue;
- }
- if (includeOp.getIsStandardInclude() &&
- ((options.lowerToCpp &&
- includeOp.getInclude() == cppStandardLibraryHeader) ||
- (!options.lowerToCpp &&
- includeOp.getInclude() == cStandardLibraryHeader))) {
- return mlir::WalkResult::interrupt();
- }
- }
-
- mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
- StringAttr includeAttr =
- builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader
- : cStandardLibraryHeader);
- builder.create<mlir::emitc::IncludeOp>(
- module.getLoc(), includeAttr,
- /*is_standard_include=*/builder.getUnitAttr());
- return mlir::WalkResult::interrupt();
+ return mlir::WalkResult::advance();
});
}
};
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index d6bdd34..262e0e7 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1246,10 +1246,8 @@ struct MemorySpaceCastOpLowering
auto result = UnrankedMemRefDescriptor::poison(
rewriter, loc, typeConverter->convertType(resultTypeU));
result.setRank(rewriter, loc, rank);
- SmallVector<Value, 1> sizes;
- UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
- result, resultAddrSpace, sizes);
- Value resultUnderlyingSize = sizes.front();
+ Value resultUnderlyingSize = UnrankedMemRefDescriptor::computeSize(
+ rewriter, loc, *getTypeConverter(), result, resultAddrSpace);
Value resultUnderlyingDesc =
LLVM::AllocaOp::create(rewriter, loc, getPtrType(),
rewriter.getI8Type(), resultUnderlyingSize);
@@ -1530,12 +1528,11 @@ private:
auto targetDesc = UnrankedMemRefDescriptor::poison(
rewriter, loc, typeConverter->convertType(targetType));
targetDesc.setRank(rewriter, loc, resultRank);
- SmallVector<Value, 4> sizes;
- UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
- targetDesc, addressSpace, sizes);
+ Value allocationSize = UnrankedMemRefDescriptor::computeSize(
+ rewriter, loc, *getTypeConverter(), targetDesc, addressSpace);
Value underlyingDescPtr = LLVM::AllocaOp::create(
rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8),
- sizes.front());
+ allocationSize);
targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
// Extract pointers and offset from the source memref.
@@ -1872,6 +1869,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
return LLVM::AtomicBinOp::umin;
case arith::AtomicRMWKind::ori:
return LLVM::AtomicBinOp::_or;
+ case arith::AtomicRMWKind::xori:
+ return LLVM::AtomicBinOp::_xor;
case arith::AtomicRMWKind::andi:
return LLVM::AtomicBinOp::_and;
default:
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 2549a9c..37d12ba 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -283,11 +283,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
Value srcPtr =
getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
adaptor.getSrcMemref(), adaptor.getIndices());
+ auto shape = NVVM::LdStMatrixShapeAttr::get(rewriter.getContext(), 8, 8);
Value ldMatrixResult = NVVM::LdMatrixOp::create(
b, ldMatrixResultType, srcPtr,
/*num=*/op.getNumTiles(),
/*layout=*/op.getTranspose() ? NVVM::MMALayout::col
- : NVVM::MMALayout::row);
+ : NVVM::MMALayout::row,
+ /*shape=*/shape, /*eltType=*/NVVM::LdStMatrixEltType::B16);
// The ldmatrix operation returns either a single i32 value or a struct of
// i32 values. Here we unpack those values and cast them back to their
@@ -394,11 +396,6 @@ struct ConvertNVGPUToNVVMPass
: public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
using Base::Base;
- void getDependentDialects(DialectRegistry &registry) const override {
- registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
- arith::ArithDialect>();
- }
-
void runOnOperation() override {
LowerToLLVMOptions options(&getContext());
RewritePatternSet patterns(&getContext());
@@ -1029,8 +1026,10 @@ struct NVGPUTmaAsyncStoreOpLowering
coords[index] = truncToI32(b, value);
}
+ // TODO: Enhance the NVGPU Op for other modes too
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
- op, adaptor.getTensorMapDescriptor(), dest, coords,
+ op, adaptor.getTensorMapDescriptor(), dest, coords, Value{},
+ NVVM::TMAStoreMode::TILE, // default is TILE mode
adaptor.getPredicate());
return success();
}
@@ -1104,12 +1103,10 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
// // [0,14) start_address
dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
- LDBG() << "Generating warpgroup.descriptor: "
- << "leading_off:" << leadDimVal << "\t"
- << "stride_off :" << strideDimVal << "\t"
- << "base_offset:" << offsetVal << "\t"
- << "layout_type:" << swizzle << " ("
- << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
+ LDBG() << "Generating warpgroup.descriptor: " << "leading_off:"
+ << leadDimVal << "\t" << "stride_off :" << strideDimVal << "\t"
+ << "base_offset:" << offsetVal << "\t" << "layout_type:" << swizzle
+ << " (" << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
<< ")\n start_addr : " << baseAddr;
rewriter.replaceOp(op, dsc);
@@ -1399,14 +1396,12 @@ struct NVGPUWarpgroupMmaOpLowering
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
/// descriptors and arranges them based on induction variables: i, j, and k.
Value generateWgmma(int i, int j, int k, Value matrixC) {
- LDBG() << "\t wgmma."
- << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK << "(A["
- << (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM
- << "][" << (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "] * "
- << " B[" << (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN
- << "])";
+ LDBG() << "\t wgmma." << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
+ << "(A[" << (iterationM * wgmmaM) << ":"
+ << (iterationM * wgmmaM) + wgmmaM << "][" << (iterationK * wgmmaK)
+ << ":" << (iterationK * wgmmaK + wgmmaK) << "] * " << " B["
+ << (iterationK * wgmmaK) << ":" << (iterationK * wgmmaK + wgmmaK)
+ << "][" << 0 << ":" << wgmmaN << "])";
Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
@@ -1700,8 +1695,10 @@ struct NVGPUTmaPrefetchOpLowering
LogicalResult
matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>(
- op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
+ rewriter.replaceOpWithNewOp<NVVM::PrefetchOp>(
+ op, /* CacheLevel */ nullptr, /* Cache Eviction Priority */ nullptr,
+ adaptor.getTensorMapDescriptor(), adaptor.getPredicate(),
+ /* Tensormap UnitAttr */ mlir::UnitAttr::get(op.getContext()));
return success();
}
};
diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index 91788f9..314cbed 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -26,6 +26,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "nvvm-to-llvm"
@@ -57,12 +58,13 @@ struct PtxLowering
SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
LDBG() << op.getPtx();
- PtxBuilder generator(op, rewriter);
- op.getAsmValues(rewriter, asmValues);
+ bool needsManualMapping = op.getAsmValues(rewriter, asmValues);
+ PtxBuilder generator(op, rewriter, needsManualMapping);
for (auto &[asmValue, modifier] : asmValues) {
- LDBG() << asmValue << "\t Modifier : " << &modifier;
- generator.insertValue(asmValue, modifier);
+ LDBG() << asmValue << "\t Modifier : " << modifier;
+ if (failed(generator.insertValue(asmValue, modifier)))
+ return failure();
}
generator.buildAndReplaceOp();
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
index 5bd1d49..d57926ec 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
@@ -15,6 +15,7 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include <queue>
#define DEBUG_TYPE "pdl-predicate-tree"
@@ -544,7 +545,7 @@ static void visitUpward(std::vector<PositionalPredicate> &predList,
Value value = opIndex.parent;
TypeSwitch<Operation *>(value.getDefiningOp())
.Case<pdl::OperationOp>([&](auto operationOp) {
- LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
+ LDBG() << " * Value: " << value;
// Get users and iterate over them.
Position *usersPos = builder.getUsers(pos, /*useRepresentative=*/true);
@@ -618,19 +619,15 @@ static Value buildPredicateList(pdl::PatternOp pattern,
RootOrderingGraph graph;
ParentMaps parentMaps;
buildCostGraph(roots, graph, parentMaps);
- LLVM_DEBUG({
- llvm::dbgs() << "Graph:\n";
- for (auto &target : graph) {
- llvm::dbgs() << " * " << target.first.getLoc() << " " << target.first
- << "\n";
- for (auto &source : target.second) {
- RootOrderingEntry &entry = source.second;
- llvm::dbgs() << " <- " << source.first << ": " << entry.cost.first
- << ":" << entry.cost.second << " via "
- << entry.connector.getLoc() << "\n";
- }
+ LDBG() << "Graph:";
+ for (auto &target : graph) {
+ LDBG() << " * " << target.first.getLoc() << " " << target.first;
+ for (auto &source : target.second) {
+ RootOrderingEntry &entry = source.second;
+ LDBG() << " <- " << source.first << ": " << entry.cost.first << ":"
+ << entry.cost.second << " via " << entry.connector.getLoc();
}
- });
+ }
// Solve the optimal branching problem for each candidate root, or use the
// provided one.
@@ -638,11 +635,11 @@ static Value buildPredicateList(pdl::PatternOp pattern,
OptimalBranching::EdgeList bestEdges;
if (!bestRoot) {
unsigned bestCost = 0;
- LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n");
+ LDBG() << "Candidate roots:";
for (Value root : roots) {
OptimalBranching solver(graph, root);
unsigned cost = solver.solve();
- LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n");
+ LDBG() << " * " << root << ": " << cost;
if (!bestRoot || bestCost > cost) {
bestCost = cost;
bestRoot = root;
@@ -656,18 +653,15 @@ static Value buildPredicateList(pdl::PatternOp pattern,
}
// Print the best solution.
- LLVM_DEBUG({
- llvm::dbgs() << "Best tree:\n";
- for (const std::pair<Value, Value> &edge : bestEdges) {
- llvm::dbgs() << " * " << edge.first;
- if (edge.second)
- llvm::dbgs() << " <- " << edge.second;
- llvm::dbgs() << "\n";
- }
- });
+ LDBG() << "Best tree:";
+ for (const std::pair<Value, Value> &edge : bestEdges) {
+ if (edge.second)
+ LDBG() << " * " << edge.first << " <- " << edge.second;
+ else
+ LDBG() << " * " << edge.first;
+ }
- LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n");
- LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n");
+ LDBG() << "Calling key getTreePredicates (Value: " << bestRoot << ")";
// The best root is the starting point for the traversal. Get the tree
// predicates for the DAG rooted at bestRoot.
@@ -691,7 +685,7 @@ static Value buildPredicateList(pdl::PatternOp pattern,
// Determine the connector.
Value connector = graph[target][source].connector;
assert(connector && "invalid edge");
- LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n");
+ LDBG() << " * Connector: " << connector.getLoc();
DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target);
Position *pos = valueToPosition.lookup(connector);
assert(pos && "connector has not been traversed yet");
@@ -806,9 +800,9 @@ static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
/// Get or insert a child matcher for the given parent switch node, given a
/// predicate and parent pattern.
-std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
- OrderedPredicate *predicate,
- pdl::PatternOp pattern) {
+static std::unique_ptr<MatcherNode> &
+getOrCreateChild(SwitchNode *node, OrderedPredicate *predicate,
+ pdl::PatternOp pattern) {
assert(isSamePredicate(node, predicate) &&
"expected matcher to equal the given predicate");
diff --git a/mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt b/mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt
new file mode 100644
index 0000000..2d416be
--- /dev/null
+++ b/mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_conversion_library(MLIRPtrToLLVM
+ PtrToLLVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/PtrToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRPtrDialect
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ )
diff --git a/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp b/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp
new file mode 100644
index 0000000..a0758aa
--- /dev/null
+++ b/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp
@@ -0,0 +1,440 @@
+//===- PtrToLLVM.cpp - Ptr to LLVM dialect conversion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/PtrToLLVM/PtrToLLVM.h"
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/Ptr/IR/PtrOps.h"
+#include "mlir/IR/TypeUtilities.h"
+#include <type_traits>
+
+using namespace mlir;
+
+namespace {
+//===----------------------------------------------------------------------===//
+// FromPtrOpConversion
+//===----------------------------------------------------------------------===//
+struct FromPtrOpConversion : public ConvertOpToLLVMPattern<ptr::FromPtrOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(ptr::FromPtrOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+//===----------------------------------------------------------------------===//
+// GetMetadataOpConversion
+//===----------------------------------------------------------------------===//
+struct GetMetadataOpConversion
+ : public ConvertOpToLLVMPattern<ptr::GetMetadataOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(ptr::GetMetadataOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+//===----------------------------------------------------------------------===//
+// PtrAddOpConversion
+//===----------------------------------------------------------------------===//
+struct PtrAddOpConversion : public ConvertOpToLLVMPattern<ptr::PtrAddOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(ptr::PtrAddOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+//===----------------------------------------------------------------------===//
+// ToPtrOpConversion
+//===----------------------------------------------------------------------===//
+struct ToPtrOpConversion : public ConvertOpToLLVMPattern<ptr::ToPtrOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(ptr::ToPtrOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+//===----------------------------------------------------------------------===//
+// TypeOffsetOpConversion
+//===----------------------------------------------------------------------===//
+struct TypeOffsetOpConversion
+ : public ConvertOpToLLVMPattern<ptr::TypeOffsetOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(ptr::TypeOffsetOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Internal functions
+//===----------------------------------------------------------------------===//
+
+// Function to create an LLVM struct type representing a memref metadata.
+static FailureOr<LLVM::LLVMStructType>
+createMemRefMetadataType(MemRefType type,
+ const LLVMTypeConverter &typeConverter) {
+ MLIRContext *context = type.getContext();
+ // Get the address space.
+ FailureOr<unsigned> addressSpace = typeConverter.getMemRefAddressSpace(type);
+ if (failed(addressSpace))
+ return failure();
+
+ // Get pointer type (using address space 0 by default)
+ auto ptrType = LLVM::LLVMPointerType::get(context, *addressSpace);
+
+ // Get the strides offsets and shape.
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ if (failed(type.getStridesAndOffset(strides, offset)))
+ return failure();
+ ArrayRef<int64_t> shape = type.getShape();
+
+ // Use index type from the type converter for the descriptor elements
+ Type indexType = typeConverter.getIndexType();
+
+ // For a ranked memref, the descriptor contains:
+ // 1. The pointer to the allocated data
+ // 2. The pointer to the aligned data
+ // 3. The dynamic offset?
+ // 4. The dynamic sizes?
+ // 5. The dynamic strides?
+ SmallVector<Type, 5> elements;
+
+ // Allocated pointer.
+ elements.push_back(ptrType);
+
+ // Potentially add the dynamic offset.
+ if (offset == ShapedType::kDynamic)
+ elements.push_back(indexType);
+
+ // Potentially add the dynamic sizes.
+ for (int64_t dim : shape) {
+ if (dim == ShapedType::kDynamic)
+ elements.push_back(indexType);
+ }
+
+ // Potentially add the dynamic strides.
+ for (int64_t stride : strides) {
+ if (stride == ShapedType::kDynamic)
+ elements.push_back(indexType);
+ }
+ return LLVM::LLVMStructType::getLiteral(context, elements);
+}
+
+//===----------------------------------------------------------------------===//
+// FromPtrOpConversion
+//===----------------------------------------------------------------------===//
+
+LogicalResult FromPtrOpConversion::matchAndRewrite(
+ ptr::FromPtrOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // Get the target memref type
+ auto mTy = dyn_cast<MemRefType>(op.getResult().getType());
+ if (!mTy)
+ return rewriter.notifyMatchFailure(op, "Expected memref result type");
+
+ if (!op.getMetadata() && op.getType().hasPtrMetadata()) {
+ return rewriter.notifyMatchFailure(
+ op, "Can convert only memrefs with metadata");
+ }
+
+ // Convert the result type
+ Type descriptorTy = getTypeConverter()->convertType(mTy);
+ if (!descriptorTy)
+ return rewriter.notifyMatchFailure(op, "Failed to convert result type");
+
+ // Get the strides, offsets and shape.
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ if (failed(mTy.getStridesAndOffset(strides, offset))) {
+ return rewriter.notifyMatchFailure(op,
+ "Failed to get the strides and offset");
+ }
+ ArrayRef<int64_t> shape = mTy.getShape();
+
+ // Create a new memref descriptor
+ Location loc = op.getLoc();
+ auto desc = MemRefDescriptor::poison(rewriter, loc, descriptorTy);
+
+ // Set the allocated and aligned pointers.
+ desc.setAllocatedPtr(
+ rewriter, loc,
+ rewriter.create<LLVM::ExtractValueOp>(loc, adaptor.getMetadata(), 0));
+ desc.setAlignedPtr(rewriter, loc, adaptor.getPtr());
+
+ // Extract metadata from the passed struct.
+ unsigned fieldIdx = 1;
+
+ // Set dynamic offset if needed.
+ if (offset == ShapedType::kDynamic) {
+ Value offsetValue = rewriter.create<LLVM::ExtractValueOp>(
+ loc, adaptor.getMetadata(), fieldIdx++);
+ desc.setOffset(rewriter, loc, offsetValue);
+ } else {
+ desc.setConstantOffset(rewriter, loc, offset);
+ }
+
+ // Set dynamic sizes if needed.
+ for (auto [i, dim] : llvm::enumerate(shape)) {
+ if (dim == ShapedType::kDynamic) {
+ Value sizeValue = rewriter.create<LLVM::ExtractValueOp>(
+ loc, adaptor.getMetadata(), fieldIdx++);
+ desc.setSize(rewriter, loc, i, sizeValue);
+ } else {
+ desc.setConstantSize(rewriter, loc, i, dim);
+ }
+ }
+
+ // Set dynamic strides if needed.
+ for (auto [i, stride] : llvm::enumerate(strides)) {
+ if (stride == ShapedType::kDynamic) {
+ Value strideValue = rewriter.create<LLVM::ExtractValueOp>(
+ loc, adaptor.getMetadata(), fieldIdx++);
+ desc.setStride(rewriter, loc, i, strideValue);
+ } else {
+ desc.setConstantStride(rewriter, loc, i, stride);
+ }
+ }
+
+ rewriter.replaceOp(op, static_cast<Value>(desc));
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// GetMetadataOpConversion
+//===----------------------------------------------------------------------===//
+
+LogicalResult GetMetadataOpConversion::matchAndRewrite(
+ ptr::GetMetadataOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ auto mTy = dyn_cast<MemRefType>(op.getPtr().getType());
+ if (!mTy)
+ return rewriter.notifyMatchFailure(op, "Only memref metadata is supported");
+
+ // Get the metadata type.
+ FailureOr<LLVM::LLVMStructType> mdTy =
+ createMemRefMetadataType(mTy, *getTypeConverter());
+ if (failed(mdTy)) {
+ return rewriter.notifyMatchFailure(op,
+ "Failed to create the metadata type");
+ }
+
+ // Get the memref descriptor.
+ MemRefDescriptor descriptor(adaptor.getPtr());
+
+ // Get the strides offsets and shape.
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ if (failed(mTy.getStridesAndOffset(strides, offset))) {
+ return rewriter.notifyMatchFailure(op,
+ "Failed to get the strides and offset");
+ }
+ ArrayRef<int64_t> shape = mTy.getShape();
+
+ // Create a new LLVM struct to hold the metadata
+ Location loc = op.getLoc();
+ Value sV = rewriter.create<LLVM::UndefOp>(loc, *mdTy);
+
+ // First element is the allocated pointer.
+ sV = rewriter.create<LLVM::InsertValueOp>(
+ loc, sV, descriptor.allocatedPtr(rewriter, loc), 0);
+
+ // Track the current field index.
+ unsigned fieldIdx = 1;
+
+ // Add dynamic offset if needed.
+ if (offset == ShapedType::kDynamic) {
+ sV = rewriter.create<LLVM::InsertValueOp>(
+ loc, sV, descriptor.offset(rewriter, loc), fieldIdx++);
+ }
+
+ // Add dynamic sizes if needed.
+ for (auto [i, dim] : llvm::enumerate(shape)) {
+ if (dim != ShapedType::kDynamic)
+ continue;
+ sV = rewriter.create<LLVM::InsertValueOp>(
+ loc, sV, descriptor.size(rewriter, loc, i), fieldIdx++);
+ }
+
+ // Add dynamic strides if needed
+ for (auto [i, stride] : llvm::enumerate(strides)) {
+ if (stride != ShapedType::kDynamic)
+ continue;
+ sV = rewriter.create<LLVM::InsertValueOp>(
+ loc, sV, descriptor.stride(rewriter, loc, i), fieldIdx++);
+ }
+ rewriter.replaceOp(op, sV);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// PtrAddOpConversion
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+PtrAddOpConversion::matchAndRewrite(ptr::PtrAddOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // Get and check the base.
+ Value base = adaptor.getBase();
+ if (!isa<LLVM::LLVMPointerType>(base.getType()))
+ return rewriter.notifyMatchFailure(op, "Incompatible pointer type");
+
+ // Get the offset.
+ Value offset = adaptor.getOffset();
+
+ // Ptr assumes the offset is in bytes.
+ Type elementType = IntegerType::get(rewriter.getContext(), 8);
+
+ // Convert the `ptradd` flags.
+ LLVM::GEPNoWrapFlags flags;
+ switch (op.getFlags()) {
+ case ptr::PtrAddFlags::none:
+ flags = LLVM::GEPNoWrapFlags::none;
+ break;
+ case ptr::PtrAddFlags::nusw:
+ flags = LLVM::GEPNoWrapFlags::nusw;
+ break;
+ case ptr::PtrAddFlags::nuw:
+ flags = LLVM::GEPNoWrapFlags::nuw;
+ break;
+ case ptr::PtrAddFlags::inbounds:
+ flags = LLVM::GEPNoWrapFlags::inbounds;
+ break;
+ }
+
+ // Create the GEP operation with appropriate arguments
+ rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, base.getType(), elementType,
+ base, ValueRange{offset}, flags);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ToPtrOpConversion
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ToPtrOpConversion::matchAndRewrite(ptr::ToPtrOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // Bail if it's not a memref.
+ if (!isa<MemRefType>(op.getPtr().getType()))
+ return rewriter.notifyMatchFailure(op, "Expected a memref input");
+
+ // Extract the aligned pointer from the memref descriptor.
+ rewriter.replaceOp(
+ op, MemRefDescriptor(adaptor.getPtr()).alignedPtr(rewriter, op.getLoc()));
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TypeOffsetOpConversion
+//===----------------------------------------------------------------------===//
+
+LogicalResult TypeOffsetOpConversion::matchAndRewrite(
+ ptr::TypeOffsetOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // Convert the type attribute.
+ Type type = getTypeConverter()->convertType(op.getElementType());
+ if (!type)
+ return rewriter.notifyMatchFailure(op, "Couldn't convert the type");
+
+ // Convert the result type.
+ Type rTy = getTypeConverter()->convertType(op.getResult().getType());
+ if (!rTy)
+ return rewriter.notifyMatchFailure(op, "Couldn't convert the result type");
+
+ // TODO: Use MLIR's data layout. We don't use it because overall support is
+ // still flaky.
+
+ // Create an LLVM pointer type for the GEP operation.
+ auto ptrTy = LLVM::LLVMPointerType::get(getContext());
+
+ // Create a GEP operation to compute the offset of the type.
+ auto offset =
+ LLVM::GEPOp::create(rewriter, op.getLoc(), ptrTy, type,
+ LLVM::ZeroOp::create(rewriter, op.getLoc(), ptrTy),
+ ArrayRef<LLVM::GEPArg>({LLVM::GEPArg(1)}));
+
+ // Replace the original op with a PtrToIntOp using the computed offset.
+ rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, rTy, offset.getRes());
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ConvertToLLVMPatternInterface implementation
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Implement the interface to convert Ptr to LLVM.
+struct PtrToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
+ using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
+ void loadDependentDialects(MLIRContext *context) const final {
+ context->loadDialect<LLVM::LLVMDialect>();
+ }
+
+ /// Hook for derived dialect interface to provide conversion patterns
+ /// and mark dialect legal for the conversion target.
+ void populateConvertToLLVMConversionPatterns(
+ ConversionTarget &target, LLVMTypeConverter &converter,
+ RewritePatternSet &patterns) const final {
+ ptr::populatePtrToLLVMConversionPatterns(converter, patterns);
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// API
+//===----------------------------------------------------------------------===//
+
+void mlir::ptr::populatePtrToLLVMConversionPatterns(
+ LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+ // Add address space conversions.
+ converter.addTypeAttributeConversion(
+ [&](PtrLikeTypeInterface type, ptr::GenericSpaceAttr memorySpace)
+ -> TypeConverter::AttributeConversionResult {
+ if (type.getMemorySpace() != memorySpace)
+ return TypeConverter::AttributeConversionResult::na();
+ return IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0);
+ });
+
+ // Add type conversions.
+ converter.addConversion([&](ptr::PtrType type) -> Type {
+ std::optional<Attribute> maybeAttr =
+ converter.convertTypeAttribute(type, type.getMemorySpace());
+ auto memSpace =
+ maybeAttr ? dyn_cast_or_null<IntegerAttr>(*maybeAttr) : IntegerAttr();
+ if (!memSpace)
+ return {};
+ return LLVM::LLVMPointerType::get(type.getContext(),
+ memSpace.getValue().getSExtValue());
+ });
+
+ // Convert ptr metadata of memref type.
+ converter.addConversion([&](ptr::PtrMetadataType type) -> Type {
+ auto mTy = dyn_cast<MemRefType>(type.getType());
+ if (!mTy)
+ return {};
+ FailureOr<LLVM::LLVMStructType> res =
+ createMemRefMetadataType(mTy, converter);
+ return failed(res) ? Type() : res.value();
+ });
+
+ // Add conversion patterns.
+ patterns.add<FromPtrOpConversion, GetMetadataOpConversion, PtrAddOpConversion,
+ ToPtrOpConversion, TypeOffsetOpConversion>(converter);
+}
+
+void mlir::ptr::registerConvertPtrToLLVMInterface(DialectRegistry &registry) {
+ registry.addExtension(+[](MLIRContext *ctx, ptr::PtrDialect *dialect) {
+ dialect->addInterfaces<PtrToLLVMDialectInterface>();
+ });
+}
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index ba448e4..37cfc9f 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -382,8 +382,11 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
// With the body block done, we can fill in the condition block.
rewriter.setInsertionPointToEnd(conditionBlock);
- auto comparison = arith::CmpIOp::create(
- rewriter, loc, arith::CmpIPredicate::slt, iv, upperBound);
+ arith::CmpIPredicate predicate = forOp.getUnsignedCmp()
+ ? arith::CmpIPredicate::ult
+ : arith::CmpIPredicate::slt;
+ auto comparison =
+ arith::CmpIOp::create(rewriter, loc, predicate, iv, upperBound);
cf::CondBranchOp::create(rewriter, loc, comparison, firstBodyBlock,
ArrayRef<Value>(), endBlock, ArrayRef<Value>());
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
index 84cbd86..1f239aa 100644
--- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
+++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
@@ -154,6 +154,10 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = forOp.getLoc();
+ if (forOp.getUnsignedCmp())
+ return rewriter.notifyMatchFailure(forOp,
+ "unsigned loops are not supported");
+
// Create an emitc::variable op for each result. These variables will be
// assigned to by emitc::assign ops within the loop body.
SmallVector<Value> resultVariables;
diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
index badd2f6..7d0a236 100644
--- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
+++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
@@ -27,7 +27,7 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/RegionUtils.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include <optional>
#define DEBUG_TYPE "loops-to-gpu"
@@ -134,7 +134,7 @@ static LogicalResult checkAffineLoopNestMappable(AffineForOp forOp,
unsigned numBlockDims,
unsigned numThreadDims) {
if (numBlockDims < 1 || numThreadDims < 1) {
- LLVM_DEBUG(llvm::dbgs() << "nothing to map");
+ LDBG() << "nothing to map";
return success();
}
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 34f372a..c4a9fc2 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -22,7 +22,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
@@ -538,15 +538,18 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/// Applies the conversion patterns in the given function.
static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
- ConversionTarget target(*module.getContext());
- target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
- target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
- memref::MemRefDialect>();
-
RewritePatternSet patterns(module.getContext());
patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
FrozenRewritePatternSet frozen(std::move(patterns));
- return applyPartialConversion(module, target, frozen);
+ walkAndApplyPatterns(module, frozen);
+ auto status = module.walk([](Operation *op) {
+ if (isa<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(op)) {
+ op->emitError("unconverted operation found");
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ return failure(status.wasInterrupted());
}
/// A pass converting SCF operations to OpenMP operations.
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index dc92367f..55ed31e 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -178,8 +178,14 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
// Generate the rest of the loop header.
rewriter.setInsertionPointToEnd(header);
auto *mergeBlock = loopOp.getMergeBlock();
- auto cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(),
- newIndVar, adaptor.getUpperBound());
+ Value cmpOp;
+ if (forOp.getUnsignedCmp()) {
+ cmpOp = spirv::ULessThanOp::create(rewriter, loc, rewriter.getI1Type(),
+ newIndVar, adaptor.getUpperBound());
+ } else {
+ cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(),
+ newIndVar, adaptor.getUpperBound());
+ }
spirv::BranchConditionalOp::create(rewriter, loc, cmpOp, body,
ArrayRef<Value>(), mergeBlock,
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index fa9e544..398ab88 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -301,7 +301,7 @@ public:
ConversionPatternRewriter &rewriter) const override {
// Create mpi::CommRankOp
Location loc = op.getLoc();
- auto ctx = op.getContext();
+ auto *ctx = op.getContext();
Value commWorld =
mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx));
auto rank = mpi::CommRankOp::create(
@@ -520,7 +520,7 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
};
static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
- auto ctx = kind.getContext();
+ auto *ctx = kind.getContext();
auto getReductionOp = [ctx](mpi::MPI_ReductionOpEnum redOp) {
return mpi::MPI_ReductionOpEnumAttr::get(ctx, redOp);
};
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index 044b725..e568660 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -64,8 +64,9 @@ public:
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
PatternRewriter &rewriter) const final {
- StringRef roundingMode = op.getRoundingMode();
- if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+ RoundingMode roundingMode = op.getRoundingMode();
+ if (roundingMode != RoundingMode::DOUBLE_ROUND &&
+ roundingMode != RoundingMode::SINGLE_ROUND) {
return failure();
}
@@ -100,7 +101,7 @@ public:
multiply64 = arith::AddIOp::create(rewriter, loc, multiply64, round);
// Apply double rounding if necessary.
- if (op.getRoundingMode() == "DOUBLE_ROUND") {
+ if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND) {
int64_t roundInt = 1 << 30;
Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
@@ -129,8 +130,9 @@ public:
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
PatternRewriter &rewriter) const final {
- StringRef roundingMode = op.getRoundingMode();
- if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+ RoundingMode roundingMode = op.getRoundingMode();
+ if (roundingMode != RoundingMode::DOUBLE_ROUND &&
+ roundingMode != RoundingMode::SINGLE_ROUND) {
return failure();
}
@@ -179,7 +181,7 @@ public:
arith::SelectOp::create(rewriter, loc, shiftOver32, shiftHighR, zero32);
// Conditionally perform our double round.
- if (op.getRoundingMode() == "DOUBLE_ROUND") {
+ if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND) {
Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
Value valuePositive = arith::CmpIOp::create(
rewriter, loc, arith::CmpIPredicate::sge, value32, zero32);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 0e3de06..d0a431b 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -65,7 +65,7 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
return result;
auto nanMode = op.getNanMode();
- if (nanMode == "PROPAGATE")
+ if (nanMode == NanPropagationMode::PROPAGATE)
return result;
// Unordered comparison of NaN against itself will always return true.
@@ -126,12 +126,12 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::MulOp>(op)) {
auto shiftVal = cast<tosa::MulOp>(op).getShift();
DenseElementsAttr shiftElem;
- if (!matchPattern(shiftVal, m_Constant(&shiftElem))) {
- (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
- return nullptr;
- }
-
- int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
+ bool shiftIsConstant = true;
+ int32_t shift = 0;
+ if (matchPattern(shiftVal, m_Constant(&shiftElem)))
+ shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
+ else
+ shiftIsConstant = false;
if (isa<FloatType>(elementTy)) {
if (shift != 0) {
@@ -147,23 +147,26 @@ static Value createLinalgBodyCalculationForElementwiseOp(
Value a = args[0];
Value b = args[1];
- if (shift > 0) {
- auto shiftConst =
- arith::ConstantIntOp::create(rewriter, loc, shift, /*bitwidth=*/8);
+ if (shift > 0 || !shiftIsConstant) {
+ Value shiftConst;
+ if (shiftIsConstant)
+ shiftConst =
+ rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
+
if (!a.getType().isInteger(32))
a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a);
if (!b.getType().isInteger(32))
b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);
- auto result = tosa::ApplyScaleOp::create(
- rewriter, loc, rewriter.getI32Type(), a, b, shiftConst,
- rewriter.getStringAttr("SINGLE_ROUND"));
-
- if (elementTy.isInteger(32))
- return result;
+ auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
+ auto roundingAttr = RoundingModeAttr::get(rewriter.getContext(),
+ RoundingMode::SINGLE_ROUND);
+ auto result =
+ tosa::ApplyScaleOp::create(rewriter, loc, rewriter.getI32Type(), a,
+ b, shiftAmount, roundingAttr);
- return arith::TruncIOp::create(rewriter, loc, elementTy, result);
+ return result;
}
int aWidth = a.getType().getIntOrFloatBitWidth();
@@ -464,7 +467,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// In the case of "PROPAGATE" semantics no compare and selection is
// required.
- if (nanMode == "PROPAGATE")
+ if (nanMode == NanPropagationMode::PROPAGATE)
return result;
// In the case of "IGNORE" semantics materialize a comparison
@@ -918,6 +921,18 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
if (operands.size() == 1)
return operands;
+ // No need to broadcast for static shape
+ bool hasDynamic = false;
+ for (auto op : operands) {
+ const auto tType = dyn_cast<RankedTensorType>(op.getType());
+ if (tType && !tType.hasStaticShape()) {
+ hasDynamic = true;
+ break;
+ }
+ }
+ if (!hasDynamic)
+ return operands;
+
// Broadcast dynamic dimensions operand by operand
return llvm::map_to_vector(operands, [&](Value operand) {
return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
@@ -990,8 +1005,14 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
static ValueRange getBroadcastableOperands(Operation *operation,
ValueRange operands) {
// Shift cannot broadcast
- if (isa<tosa::MulOp>(operation))
- return operands.take_front(2);
+ if (isa<tosa::MulOp>(operation)) {
+ DenseElementsAttr shiftElems;
+ // Shift cannot broadcast when it is constant
+ if (matchPattern(operation->getOperand(2), m_Constant(&shiftElems)))
+ return operands.take_front(2);
+ else
+ return operands.take_front(3);
+ }
// Input1_zp and output_zp cannot broadcast
if (isa<tosa::NegateOp>(operation))
return operands.take_front(1);
@@ -1173,7 +1194,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
// NaN propagation has no meaning for non floating point types.
- if (isa<FloatType>(elementTy) && op.getNanMode() == "IGNORE") {
+ if (isa<FloatType>(elementTy) &&
+ op.getNanMode() == NanPropagationMode::IGNORE) {
isNanIgnoreMode = true;
// Because the TOSA spec requires the result be NaN iff all elements in
// the reduction are NaN we can't simply perform a compare and select.
@@ -1336,11 +1358,11 @@ public:
unsigned rank = inputTy.getRank();
// This is an illegal configuration. terminate and log an error
- if (op.getRoundingMode() == "INEXACT_ROUND")
+ if (op.getRoundingMode() == RoundingMode::INEXACT_ROUND)
return rewriter.notifyMatchFailure(
op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
"currently supported");
- if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())
+ if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && !op.getScale32())
return rewriter.notifyMatchFailure(
op, "tosa.rescale requires scale32 for double_round to be true");
@@ -1386,11 +1408,10 @@ public:
// is ever true.
bool doubleRound =
- op.getRoundingMode() == "DOUBLE_ROUND" &&
+ op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
- StringAttr roundingMode = doubleRound
- ? rewriter.getStringAttr("DOUBLE_ROUND")
- : rewriter.getStringAttr("SINGLE_ROUND");
+ RoundingMode roundingMode =
+ doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(rank)};
@@ -1573,7 +1594,7 @@ public:
auto input = op.getInput();
auto inputTy = cast<RankedTensorType>(input.getType());
auto resultTy = cast<RankedTensorType>(op.getType());
- const bool isBilinear = op.getMode() == "BILINEAR";
+ const bool isBilinear = op.getMode() == ResizeMode::BILINEAR;
auto inputH = inputTy.getDimSize(1);
auto inputW = inputTy.getDimSize(2);
@@ -1584,8 +1605,8 @@ public:
return rewriter.notifyMatchFailure(
op, "tosa.resize is not a pure 1x1->1x1 image operation");
- // TODO(suderman): These string values should be declared the TOSA dialect.
- if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
+ if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
+ op.getMode() != ResizeMode::BILINEAR)
return rewriter.notifyMatchFailure(
op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
@@ -1785,7 +1806,8 @@ public:
return rewriter.notifyMatchFailure(
op, "unable to get dynamic dimensions of tosa.resize");
- if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
+ if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
+ op.getMode() != ResizeMode::BILINEAR)
return rewriter.notifyMatchFailure(
op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
@@ -1890,7 +1912,7 @@ public:
getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
}
- if (op.getMode() == "NEAREST_NEIGHBOR") {
+ if (op.getMode() == ResizeMode::NEAREST_NEIGHBOR) {
auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
@@ -1926,7 +1948,7 @@ public:
linalg::YieldOp::create(b, result);
} else {
// The mode here must be BILINEAR.
- assert(op.getMode() == "BILINEAR");
+ assert(op.getMode() == ResizeMode::BILINEAR);
auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
@@ -2291,7 +2313,7 @@ public:
Value predicate;
if (isa<FloatType>(inElementTy)) {
- if (argmaxOp.getNanMode() == "IGNORE") {
+ if (argmaxOp.getNanMode() == NanPropagationMode::IGNORE) {
// Only update index & max value for non NaN values. If all
// values are NaNs, the initial index will be return which is 0.
predicate = arith::CmpFOp::create(rewriter, nestedLoc,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 12d85ca..6f28849 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -803,7 +803,7 @@ public:
dilationAttr);
rewriter.setInsertionPointAfter(op);
- StringRef nanMode = op.getNanMode();
+ NanPropagationMode nanMode = op.getNanMode();
rewriter.replaceOp(op, resultOp);
// NaN propagation has no meaning for non floating point types.
@@ -817,7 +817,7 @@ public:
// we've already produced a named op we will just take its body and modify
// it to include the appropriate checks. If the current value is NaN the
// old value of pool will be taken otherwise we use the result.
- if (nanMode == "IGNORE") {
+ if (nanMode == NanPropagationMode::IGNORE) {
auto genericOp = linalg::GenericOp::create(
rewriter, loc, resultOp.getType(0), resultOp.getInputs(),
resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
@@ -1040,11 +1040,13 @@ public:
rewriter, loc, rewriter.getI8IntegerAttr(30));
Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
- auto scaled =
- tosa::ApplyScaleOp::create(
- rewriter, loc, rewriter.getI32Type(), poolVal, multiplier,
- shift, rewriter.getStringAttr("SINGLE_ROUND"))
- .getResult();
+ auto roundingAttr = RoundingModeAttr::get(
+ rewriter.getContext(), RoundingMode::SINGLE_ROUND);
+
+ auto scaled = tosa::ApplyScaleOp::create(
+ rewriter, loc, rewriter.getI32Type(), poolVal,
+ multiplier, shift, roundingAttr)
+ .getResult();
// If we have quantization information we need to apply output
// zeropoint.
diff --git a/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt b/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt
new file mode 100644
index 0000000..2d4b2b6
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRVectorToAMX
+ VectorToAMX.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToAMX
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRAMXDialect
+ MLIRAffineUtils
+ MLIRArithDialect
+ MLIRLinalgUtils
+ MLIRMemRefDialect
+ MLIRSCFDialect
+ MLIRTransforms
+ MLIRVectorDialect
+ )
diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
new file mode 100644
index 0000000..7b9ed1d
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
@@ -0,0 +1,429 @@
+//===- VectorToAMX.cpp - Convert vector to AMX dialect ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/VectorToAMX/VectorToAMX.h"
+
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "llvm/Support/DebugLog.h"
+
+#include <numeric>
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTVECTORTOAMX
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "vector-to-amx"
+
+namespace {
+
+/// Return true if vector shape is compatible with AMX tiles.
+/// The validation accounts for VNNI packing.
+static bool verifyAmxShape(VectorType vec) {
+ // Check overall shape:
+ // - 2D for plain layout input or output
+ // - 3D for VNNI packed input
+ if (vec.getRank() != 2 && vec.getRank() != 3)
+ return false;
+
+ ArrayRef<int64_t> shape = vec.getShape();
+ int64_t rows = shape[0];
+ int64_t cols = shape[1];
+ unsigned elemBitWidth = vec.getElementType().getIntOrFloatBitWidth();
+
+ // 3D shape indicates VNNI packed layout.
+ if (vec.getRank() == 3) {
+ int64_t vnniFactor = 32 / elemBitWidth;
+ if (shape.back() != vnniFactor) {
+ LDBG() << "invalid VNNI packing factor";
+ return false;
+ }
+ cols *= vnniFactor;
+ }
+
+ // AMX tile supports up to 16 rows of 64 bytes each.
+ constexpr unsigned maxRows = 16;
+ constexpr unsigned maxBitsPerRow = 64 * 8;
+ return rows <= maxRows && (cols * elemBitWidth) <= maxBitsPerRow;
+}
+
+/// Check if contraction operands are in AMX-compatible packed VNNI layout.
+static LogicalResult isAmxVnniLayout(PatternRewriter &rewriter,
+ vector::ContractionOp contractOp) {
+ VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
+ if (!accType || accType.getRank() != 2)
+ return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
+
+ // Expect 3D inputs for VNNI packed data.
+ VectorType lhsType = contractOp.getLhs().getType();
+ VectorType rhsType = contractOp.getRhs().getType();
+ if (lhsType.getRank() != 3 || rhsType.getRank() != 3)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects lhs and rhs 3D vectors");
+
+ // Check if shapes are compatible with AMX tile.
+ if (!verifyAmxShape(lhsType) || !verifyAmxShape(rhsType) ||
+ !verifyAmxShape(accType))
+ return rewriter.notifyMatchFailure(contractOp, "Invalid operand shape");
+
+ // Validate affine maps.
+ //
+ // Iterators can be ordered arbitrarily. Indexing map positions are based on
+ // operands' target shapes.
+ // The matrix layouts must match the following:
+ // - matrix A - [M]x[K/vnniFactor]x[vnniFactor]
+ // - matrix B - [K/vnniFactor]x[N]x[vnniFactor]
+ // - matrix C - [M]x[N]
+ SmallVector<AffineMap, 4> indexingMaps = contractOp.getIndexingMapsArray();
+ AffineMap mapA = indexingMaps[0];
+ AffineMap mapB = indexingMaps[1];
+ if (mapA.getNumInputs() != 4 || mapA.getNumResults() != 3 ||
+ mapB.getNumResults() != 3)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Invalid input indexing maps");
+ FailureOr<linalg::ContractionDimensions> dims =
+ linalg::inferContractionDims(indexingMaps);
+ if (failed(dims))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Failed to infer contraction dims");
+ // Two reduction dimensions are expected:
+ // - one for the K dimension
+ // - one for the VNNI factor
+ if (dims->k.size() != 2)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expected two reduction dims");
+ assert(dims->m.size() == 1 && dims->n.size() == 1 &&
+ "Invalid parallel contraction dims");
+
+ SmallVector<vector::IteratorType> iteratorTypes =
+ contractOp.getIteratorTypesArray();
+ // Check VNNI dim maps - the innermost dim for A and B inputs.
+ auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(2));
+ auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(2));
+ if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
+ iteratorTypes[vnniDimA.getPosition()] != vector::IteratorType::reduction)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI dim map");
+ // Check K dim maps - non-transposed row-major layout.
+ auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(1));
+ auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(0));
+ if (!redDimA || !redDimB || redDimA != redDimB ||
+ iteratorTypes[redDimA.getPosition()] != vector::IteratorType::reduction)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid K dim map");
+ // Check M and N dim maps - map to non-transposed output.
+ AffineMap mapC = indexingMaps[2];
+ auto mDimC = dyn_cast<AffineDimExpr>(mapC.getResult(0));
+ auto nDimC = dyn_cast<AffineDimExpr>(mapC.getResult(1));
+ if (!mDimC || !nDimC)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid acc maps");
+ auto parallelDimA = dyn_cast<AffineDimExpr>(mapA.getResult(0));
+ if (!parallelDimA ||
+ iteratorTypes[parallelDimA.getPosition()] !=
+ vector::IteratorType::parallel ||
+ parallelDimA != mDimC)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid M dim map");
+ auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(1));
+ if (!parallelDimB ||
+ iteratorTypes[parallelDimB.getPosition()] !=
+ vector::IteratorType::parallel ||
+ parallelDimB != nDimC)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid N dim map");
+
+ return success();
+}
+
+/// Validate contraction operands for AMX lowering.
+static LogicalResult validateOperands(PatternRewriter &rewriter,
+ vector::ContractionOp contractOp) {
+ VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
+ if (!accType)
+ return rewriter.notifyMatchFailure(contractOp, "Expects vector acc");
+
+ // Check if operand types are compatible with AMX compute ops.
+ bool validElemTypes = false;
+ Type lhsElemType = contractOp.getLhs().getType().getElementType();
+ Type rhsElemType = contractOp.getRhs().getType().getElementType();
+ Type accElemType = accType.getElementType();
+ if (accElemType.isInteger(32)) {
+ validElemTypes = lhsElemType.isInteger(8) && rhsElemType.isInteger(8);
+ } else if (accElemType.isF32()) {
+ validElemTypes = (lhsElemType.isF16() && rhsElemType.isF16()) ||
+ (lhsElemType.isBF16() && rhsElemType.isBF16());
+ }
+ if (!validElemTypes)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Invalid combination of operand types");
+
+ if (failed(isAmxVnniLayout(rewriter, contractOp)))
+ return failure();
+
+ return success();
+}
+
+/// Collapse the two innermost dimensions together.
+static TypedValue<MemRefType> collapseLastDim(PatternRewriter &rewriter,
+ TypedValue<MemRefType> memref) {
+ int64_t rank = memref.getType().getRank();
+ SmallVector<ReassociationIndices> reassocIndices;
+ for (auto i : llvm::seq<int64_t>(0, rank - 2))
+ reassocIndices.push_back({i});
+ reassocIndices.push_back({rank - 2, rank - 1});
+ return memref::CollapseShapeOp::create(rewriter, memref.getLoc(), memref,
+ reassocIndices);
+}
+
+/// Attempt to create an AMX tile load/store operation equivalent to the given
+/// vector transfer `xfer` op.
+/// This approach allows to skip longer route through registers and a temporary
+/// buffer otherwise required to move data to/from an AMX tile.
+static Operation *
+loadStoreFromTransfer(PatternRewriter &rewriter,
+ VectorTransferOpInterface xferOp, bool isPacked,
+ TypedValue<amx::TileType> tileToStore = nullptr) {
+ if (!xferOp || !isa<vector::TransferReadOp, vector::TransferWriteOp>(xferOp))
+ return nullptr;
+ if (xferOp.hasOutOfBoundsDim() ||
+ !xferOp.getPermutationMap().isMinorIdentity())
+ return nullptr;
+
+ // Extra checks in case of a write op.
+ // Stores must not be packed.
+ if (isa<vector::TransferWriteOp>(xferOp) &&
+ (!tileToStore || isPacked ||
+ tileToStore.getType().getShape() != xferOp.getVectorType().getShape()))
+ return nullptr;
+
+ // Check for a memref source buffer.
+ // AMX data transfer requires at least 2D shape to correctly
+ // infer stride between rows.
+ Value base = xferOp.getBase();
+ auto memTy = dyn_cast<MemRefType>(base.getType());
+ int64_t memRank = memTy.getRank();
+ if (!memTy || memRank < 2)
+ return nullptr;
+
+ // Check that the source buffer has enough contiguous elements to load whole
+ // AMX tile row.
+ //
+ // To ensure correctness, the validation is conservative and expects the
+ // buffer's innermost dimensions to be statically known, equal to or larger
+ // than the vector row length, and equal to the VNNI dimension if applicable.
+ //
+ // This check could be relaxed to accept more arbitrarily shaped buffers as
+ // long as there are enough contiguous elements to load a whole row.
+ if (!memTy.areTrailingDimsContiguous(isPacked ? 2 : 1))
+ return nullptr;
+ VectorType vecTy = xferOp.getVectorType();
+ ArrayRef<int64_t> vecShape = vecTy.getShape();
+ ArrayRef<int64_t> memShape = memTy.getShape();
+ if (memShape.back() == ShapedType::kDynamic ||
+ memShape.back() < vecShape.back())
+ return nullptr;
+ if (isPacked &&
+ (memShape.back() != vecShape.back() ||
+ memShape[memShape.size() - 2] == ShapedType::kDynamic ||
+ memShape[memShape.size() - 2] < vecShape[vecShape.size() - 2]))
+ return nullptr;
+
+ // Load values directly from the buffer to an AMX tile.
+ PatternRewriter::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(xferOp);
+ Location loc = xferOp.getLoc();
+
+ // Create a subview of the source buffer based on the transfer op to resolve
+ // offsets.
+ SmallVector<OpFoldResult> strides(memRank, rewriter.getIndexAttr(1));
+ int64_t vecRank = vecTy.getRank();
+ assert(memRank >= vecRank &&
+ "Expects buffer to be the same or greater rank than vector");
+ SmallVector<int64_t> shape(memRank - vecRank, 1);
+ shape.append(vecShape.begin(), vecShape.end());
+ TypedValue<MemRefType> src =
+ memref::SubViewOp::create(
+ rewriter, loc, base, getAsOpFoldResult(xferOp.getIndices()),
+ getAsOpFoldResult(rewriter.getI64ArrayAttr(shape)), strides)
+ .getResult();
+
+ // Collapse the VNNI dimension in case of packing.
+ if (isPacked)
+ src = collapseLastDim(rewriter, src);
+ int64_t rows = vecShape[0];
+ int64_t cols = std::accumulate(vecShape.begin() + 1, vecShape.end(), 1,
+ std::multiplies<int64_t>());
+ auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
+
+ Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+ SmallVector<Value> tileIndicides(src.getType().getRank(), zeroIndex);
+
+ Operation *amxTileOp = nullptr;
+ if (isa<vector::TransferReadOp>(xferOp)) {
+ amxTileOp =
+ amx::TileLoadOp::create(rewriter, loc, tileType, src, tileIndicides);
+ } else if (isa<vector::TransferWriteOp>(xferOp)) {
+ amxTileOp = amx::TileStoreOp::create(rewriter, loc, src, tileIndicides,
+ tileToStore);
+ } else {
+ llvm_unreachable("unsupported vector transfer op");
+ }
+
+ return amxTileOp;
+}
+
+/// Attempt to create an AMX tile load operation equivalent to the given
+/// vector transfer `readOp`.
+/// Returns loaded AMX tile if successful.
+static FailureOr<TypedValue<amx::TileType>>
+loadFromTransfer(PatternRewriter &rewriter, vector::TransferReadOp readOp,
+ bool isPacked) {
+ amx::TileLoadOp loadOp = dyn_cast_if_present<amx::TileLoadOp>(
+ loadStoreFromTransfer(rewriter, readOp, isPacked));
+ if (!loadOp)
+ return failure();
+ return loadOp.getRes();
+}
+
+/// Attempt to create an AMX tile store operation equivalent to the given
+/// vector transfer `writeOp`.
+static LogicalResult storeFromTransfer(PatternRewriter &rewriter,
+ vector::TransferWriteOp writeOp,
+ TypedValue<amx::TileType> tileToStore) {
+ return success(loadStoreFromTransfer(rewriter, writeOp, /*isPacked=*/false,
+ tileToStore));
+}
+
+/// Load vector values to an AMX tile.
+static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
+ TypedValue<VectorType> vec) {
+ Location loc = vec.getLoc();
+
+ VectorType vecTy = vec.getType();
+ bool isPacked = vecTy.getRank() == 3;
+
+ // Try to load tile directly from vector producer's buffer.
+ auto readOp = vec.getDefiningOp<vector::TransferReadOp>();
+ FailureOr<TypedValue<amx::TileType>> tile =
+ loadFromTransfer(rewriter, readOp, isPacked);
+ if (succeeded(tile))
+ return *tile;
+
+ // Transfer the vector to a tile through an intermediate buffer.
+ Value buf = memref::AllocaOp::create(
+ rewriter, loc, MemRefType::get(vecTy.getShape(), vecTy.getElementType()));
+ Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+ SmallVector<Value> indices(vecTy.getRank(), zeroIndex);
+ vector::TransferWriteOp::create(rewriter, loc, vec, buf, indices);
+
+ // Collapse the VNNI dimension in case of packing.
+ if (isPacked)
+ buf = collapseLastDim(rewriter, cast<TypedValue<MemRefType>>(buf));
+
+ ArrayRef<int64_t> shape = vecTy.getShape();
+ int64_t rows = shape[0];
+ int64_t cols = std::accumulate(shape.begin() + 1, shape.end(), 1,
+ std::multiplies<int64_t>());
+ auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
+
+ return amx::TileLoadOp::create(rewriter, loc, tileType, buf,
+ {zeroIndex, zeroIndex});
+}
+
+/// Store an AMX tile in a vector.
+static TypedValue<VectorType> storeTile(PatternRewriter &rewriter,
+ TypedValue<amx::TileType> tile) {
+ Location loc = tile.getLoc();
+
+ // Transfer the tile to a vector through an intermediate buffer.
+ amx::TileType tileTy = tile.getType();
+ Value buf = memref::AllocaOp::create(
+ rewriter, loc,
+ MemRefType::get(tileTy.getShape(), tileTy.getElementType()));
+ Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+ SmallVector<Value> indices(2, zeroIndex);
+ amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);
+
+ auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType());
+ return vector::TransferReadOp::create(rewriter, loc, vecTy, buf, indices, {});
+}
+
+struct ContractionToAMX : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = contractOp.getLoc();
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind");
+ if (failed(validateOperands(rewriter, contractOp)))
+ return failure();
+
+ TypedValue<amx::TileType> lhsTile = loadTile(rewriter, contractOp.getLhs());
+ TypedValue<amx::TileType> rhsTile = loadTile(rewriter, contractOp.getRhs());
+ auto acc = dyn_cast<TypedValue<VectorType>>(contractOp.getAcc());
+ assert(acc && "Invalid accumulator type");
+ TypedValue<amx::TileType> accTile = loadTile(rewriter, acc);
+
+ TypedValue<amx::TileType> tileMul;
+ if (acc.getType().getElementType().isFloat()) {
+ tileMul = amx::TileMulFOp::create(rewriter, loc, accTile.getType(),
+ lhsTile, rhsTile, accTile);
+ } else {
+ tileMul = amx::TileMulIOp::create(rewriter, loc, accTile.getType(),
+ lhsTile, rhsTile, accTile);
+ }
+
+ // If the contraction result is only written back to memory, try to replace
+ // the vector op with an AMX store directly.
+ Value res = contractOp.getResult();
+ if (res.hasOneUse()) {
+ auto writeOp = dyn_cast<vector::TransferWriteOp>(*res.getUsers().begin());
+ LogicalResult storeRes = storeFromTransfer(rewriter, writeOp, tileMul);
+ if (succeeded(storeRes)) {
+ rewriter.eraseOp(writeOp);
+ rewriter.eraseOp(contractOp);
+ return success();
+ }
+ }
+
+ // Load the result back into a vector.
+ Value newResult = storeTile(rewriter, tileMul);
+ rewriter.replaceOp(contractOp, newResult);
+
+ return success();
+ }
+};
+
+struct ConvertVectorToAMXPass
+ : public impl::ConvertVectorToAMXBase<ConvertVectorToAMXPass> {
+ void runOnOperation() override {
+ MLIRContext &ctx = getContext();
+ RewritePatternSet patterns(&ctx);
+ populateVectorToAMXConversionPatterns(patterns);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+
+} // namespace
+
+void mlir::populateVectorToAMXConversionPatterns(RewritePatternSet &patterns) {
+ patterns.add<ContractionToAMX>(patterns.getContext());
+}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index f9e2a01..1ff7d5d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -306,11 +306,11 @@ public:
// Resolve address.
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
- adaptor.getBase(), adaptor.getIndices());
+ adaptor.getBase(), adaptor.getOffsets());
Value base = adaptor.getBase();
Value ptrs =
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
- base, ptr, adaptor.getIndexVec(), vType);
+ base, ptr, adaptor.getIndices(), vType);
// Replace with the gather intrinsic.
rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
@@ -362,10 +362,10 @@ public:
// Resolve address.
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
- adaptor.getBase(), adaptor.getIndices());
+ adaptor.getBase(), adaptor.getOffsets());
Value ptrs =
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
- adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
+ adaptor.getBase(), ptr, adaptor.getIndices(), vType);
// Replace with the scatter intrinsic.
rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
@@ -1891,15 +1891,21 @@ struct VectorFromElementsLowering
ConversionPatternRewriter &rewriter) const override {
Location loc = fromElementsOp.getLoc();
VectorType vectorType = fromElementsOp.getType();
- // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
- // Such ops should be handled in the same way as vector.insert.
+ // Only support 1-D vectors. Multi-dimensional vectors should have been
+ // transformed to 1-D vectors by the vector-to-vector transformations before
+ // this.
if (vectorType.getRank() > 1)
return rewriter.notifyMatchFailure(fromElementsOp,
"rank > 1 vectors are not supported");
Type llvmType = typeConverter->convertType(vectorType);
+ Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
- for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
- result = vector::InsertOp::create(rewriter, loc, val, result, idx);
+ for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) {
+ auto constIdx =
+ LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, idx);
+ result = LLVM::InsertElementOp::create(rewriter, loc, llvmType, result,
+ val, constIdx);
+ }
rewriter.replaceOp(fromElementsOp, result);
return success();
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index cf10869..9852df6 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -94,6 +94,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
populateVectorGatherLoweringPatterns(patterns);
+ populateVectorFromElementsLoweringPatterns(patterns);
if (armI8MM) {
if (armNeon)
arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index a4be7d4..036cbad 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -743,6 +743,22 @@ struct VectorLoadOpConverter final
auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
+ std::optional<uint64_t> alignment = loadOp.getAlignment();
+ if (alignment > std::numeric_limits<uint32_t>::max()) {
+ return rewriter.notifyMatchFailure(loadOp,
+ "invalid alignment requirement");
+ }
+
+ auto memoryAccess = spirv::MemoryAccess::None;
+ spirv::MemoryAccessAttr memoryAccessAttr;
+ IntegerAttr alignmentAttr;
+ if (alignment.has_value()) {
+ memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
+ memoryAccessAttr =
+ spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
+ alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
+ }
+
// For single element vectors, we don't need to bitcast the access chain to
// the original vector type. Both is going to be the same, a pointer
// to a scalar.
@@ -753,7 +769,8 @@ struct VectorLoadOpConverter final
accessChain);
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
- castedAccessChain);
+ castedAccessChain,
+ memoryAccessAttr, alignmentAttr);
return success();
}
@@ -782,6 +799,12 @@ struct VectorStoreOpConverter final
return rewriter.notifyMatchFailure(
storeOp, "failed to get memref element pointer");
+ std::optional<uint64_t> alignment = storeOp.getAlignment();
+ if (alignment > std::numeric_limits<uint32_t>::max()) {
+ return rewriter.notifyMatchFailure(storeOp,
+ "invalid alignment requirement");
+ }
+
spirv::StorageClass storageClass = attr.getValue();
auto vectorType = storeOp.getVectorType();
auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
@@ -795,8 +818,19 @@ struct VectorStoreOpConverter final
: spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
accessChain);
- rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
- adaptor.getValueToStore());
+ auto memoryAccess = spirv::MemoryAccess::None;
+ spirv::MemoryAccessAttr memoryAccessAttr;
+ IntegerAttr alignmentAttr;
+ if (alignment.has_value()) {
+ memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
+ memoryAccessAttr =
+ spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
+ alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
+ }
+
+ rewriter.replaceOpWithNewOp<spirv::StoreOp>(
+ storeOp, castedAccessChain, adaptor.getValueToStore(), memoryAccessAttr,
+ alignmentAttr);
return success();
}
diff --git a/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt
index 567083d..e9ad67c5 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt
@@ -13,4 +13,5 @@ add_mlir_conversion_library(MLIRVectorToXeGPU
MLIRTransforms
MLIRVectorDialect
MLIRXeGPUDialect
+ MLIRXeGPUUtils
)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 8010755..819c2e5 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -14,9 +14,11 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -68,11 +70,6 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
if (!srcTy)
return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
- // Perform common data transfer checks.
- VectorType vecTy = xferOp.getVectorType();
- if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy)))
- return failure();
-
// Validate further transfer op semantics.
SmallVector<int64_t> strides;
int64_t offset;
@@ -80,6 +77,7 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
return rewriter.notifyMatchFailure(
xferOp, "Buffer must be contiguous in the innermost dimension");
+ VectorType vecTy = xferOp.getVectorType();
unsigned vecRank = vecTy.getRank();
if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
return rewriter.notifyMatchFailure(
@@ -155,6 +153,277 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
return ndDesc;
}
+// Adjusts the strides of a memref according to a given permutation map for
+// vector operations.
+//
+// This function updates the innermost strides in the `strides` array to
+// reflect the permutation specified by `permMap`. The permutation is computed
+// using the inverse and broadcasting-aware version of the permutation map,
+// and is applied to the relevant strides. This ensures that memory accesses
+// are consistent with the logical permutation of vector elements.
+//
+// Example:
+// Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]`.
+// If the permutation map swaps the last two dimensions (e.g., [0, 1] -> [1,
+// 0]), then after calling this function, the last two strides will be
+// swapped:
+// Original strides: [s0, s1, s2, s3]
+// After permutation: [s0, s1, s3, s2]
+//
+static void adjustStridesForPermutation(AffineMap permMap,
+ SmallVectorImpl<Value> &strides) {
+
+ AffineMap invMap = inverseAndBroadcastProjectedPermutation(permMap);
+ SmallVector<unsigned> perms;
+ invMap.isPermutationOfMinorIdentityWithBroadcasting(perms);
+ SmallVector<int64_t> perms64(perms.begin(), perms.end());
+ strides = applyPermutation(strides, perms64);
+}
+
+// Computes memory strides for vector transfer operations, handling both
+// static and dynamic memrefs while applying permutation transformations
+// for XeGPU lowering.
+static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
+ PatternRewriter &rewriter) {
+ SmallVector<Value> strides;
+ Value baseMemref = xferOp.getBase();
+ AffineMap permMap = xferOp.getPermutationMap();
+ MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
+
+ Location loc = xferOp.getLoc();
+ if (memrefType.hasStaticShape()) {
+ int64_t offset;
+ SmallVector<int64_t> intStrides;
+ if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
+ return {};
+ // Wrap static strides as MLIR values
+ for (int64_t s : intStrides)
+ strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
+ } else {
+ // For dynamic shape memref, use memref.extract_strided_metadata to get
+ // stride values
+ unsigned rank = memrefType.getRank();
+ Type indexType = rewriter.getIndexType();
+
+ // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
+ // size0, size1, ..., sizeN-1]
+ SmallVector<Type> resultTypes;
+ resultTypes.push_back(MemRefType::get(
+ {}, memrefType.getElementType())); // base memref (unranked)
+ resultTypes.push_back(indexType); // offset
+
+ for (unsigned i = 0; i < rank; ++i)
+ resultTypes.push_back(indexType); // strides
+
+ for (unsigned i = 0; i < rank; ++i)
+ resultTypes.push_back(indexType); // sizes
+
+ auto meta = memref::ExtractStridedMetadataOp::create(
+ rewriter, loc, resultTypes, baseMemref);
+ strides.append(meta.getStrides().begin(), meta.getStrides().end());
+ }
+ // Adjust strides according to the permutation map (e.g., for transpose)
+ adjustStridesForPermutation(permMap, strides);
+ return strides;
+}
+
+// This function compute the vectors of localOffsets for scattered load/stores.
+// It is used in the lowering of vector.transfer_read/write to
+// load_gather/store_scatter Example:
+// %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0],
+// %cst {in_bounds = [true, true, true, true]}>} :
+// memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
+//
+// %6 = vector.step: vector<4xindex>
+// %7 = vector.step: vector<2xindex>
+// %8 = vector.step: vector<6xindex>
+// %9 = vector.step: vector<32xindex>
+// %10 = arith.mul %6, 384
+// %11 = arith.mul %7, 192
+// %12 = arith.mul %8, 32
+// %13 = arith.mul %9, 1
+// %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16>
+// %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16>
+// %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16>
+// %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16>
+// %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex>
+// %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex>
+// %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex>
+// %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex>
+// %22 = arith.add %18, %19
+// %23 = arith.add %20, %21
+// %local_offsets = arith.add %22, %23
+// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
+// %offsets = orig_offset + local_offsets
+static Value computeOffsets(VectorTransferOpInterface xferOp,
+ PatternRewriter &rewriter,
+ ArrayRef<Value> strides) {
+ Location loc = xferOp.getLoc();
+ VectorType vectorType = xferOp.getVectorType();
+ SmallVector<Value> indices(xferOp.getIndices().begin(),
+ xferOp.getIndices().end());
+ ArrayRef<int64_t> vectorShape = vectorType.getShape();
+
+ // Create vector.step operations for each dimension
+ SmallVector<Value> stepVectors;
+ llvm::map_to_vector(vectorShape, [&](int64_t dim) {
+ auto stepType = VectorType::get({dim}, rewriter.getIndexType());
+ auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
+ stepVectors.push_back(stepOp);
+ return stepOp;
+ });
+
+ // Multiply step vectors by corresponding strides
+ size_t memrefRank = strides.size();
+ size_t vectorRank = vectorShape.size();
+ SmallVector<Value> strideMultiplied;
+ for (size_t i = 0; i < vectorRank; ++i) {
+ size_t memrefDim = memrefRank - vectorRank + i;
+ Value strideValue = strides[memrefDim];
+ auto mulType = dyn_cast<VectorType>(stepVectors[i].getType());
+ auto bcastOp =
+ vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
+ auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
+ strideMultiplied.push_back(mulOp);
+ }
+
+ // Shape cast each multiplied vector to add singleton dimensions
+ SmallVector<Value> shapeCasted;
+ for (size_t i = 0; i < vectorRank; ++i) {
+ SmallVector<int64_t> newShape(vectorRank, 1);
+ newShape[i] = vectorShape[i];
+ auto newType = VectorType::get(newShape, rewriter.getIndexType());
+ auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
+ strideMultiplied[i]);
+ shapeCasted.push_back(castOp);
+ }
+
+ // Broadcast each shape-casted vector to full vector shape
+ SmallVector<Value> broadcasted;
+ auto fullIndexVectorType =
+ VectorType::get(vectorShape, rewriter.getIndexType());
+ for (Value shapeCastVal : shapeCasted) {
+ auto broadcastOp = vector::BroadcastOp::create(
+ rewriter, loc, fullIndexVectorType, shapeCastVal);
+ broadcasted.push_back(broadcastOp);
+ }
+
+ // Add all broadcasted vectors together to compute local offsets
+ Value localOffsets = broadcasted[0];
+ for (size_t i = 1; i < broadcasted.size(); ++i)
+ localOffsets =
+ arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
+
+ // Compute base offset from transfer read indices
+ Value baseOffset = nullptr;
+ if (!indices.empty()) {
+ baseOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ for (size_t i = 0; i < indices.size(); ++i) {
+ Value strideVal = strides[i];
+ Value offsetContrib =
+ arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
+ baseOffset =
+ arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
+ }
+ // Broadcast base offset to match vector shape
+ Value bcastBase = vector::BroadcastOp::create(
+ rewriter, loc, fullIndexVectorType, baseOffset);
+ localOffsets =
+ arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
+ }
+ return localOffsets;
+}
+
+// Collapse memref shape to 1D
+static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp,
+ PatternRewriter &rewriter) {
+ Location loc = xferOp.getLoc();
+
+ Value baseMemref = xferOp.getBase();
+ MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
+ Type elementType = memrefType.getElementType();
+
+ // Compute the total number of elements in the memref
+ MemRefType flatMemrefType;
+ if (memrefType.hasStaticShape()) {
+ auto totalElements = memrefType.getNumElements();
+ flatMemrefType = MemRefType::get({totalElements}, elementType);
+ } else {
+ flatMemrefType = MemRefType::get({ShapedType::kDynamic}, elementType);
+ }
+
+ SmallVector<ReassociationIndices> reassociation;
+ ReassociationIndices allDims =
+ llvm::to_vector(llvm::seq<int64_t>(0, memrefType.getRank()));
+ reassociation.push_back(allDims);
+
+ auto collapseOp = memref::CollapseShapeOp::create(
+ rewriter, loc, flatMemrefType, baseMemref, reassociation);
+ return collapseOp;
+}
+
+static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
+ PatternRewriter &rewriter) {
+
+ Location loc = readOp.getLoc();
+ VectorType vectorType = readOp.getVectorType();
+ ArrayRef<int64_t> vectorShape = vectorType.getShape();
+ auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
+ if (!memrefType)
+ return rewriter.notifyMatchFailure(readOp, "Expected memref source");
+
+ SmallVector<Value> strides = computeStrides(readOp, rewriter);
+ if (strides.empty())
+ return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
+
+ Value localOffsets = computeOffsets(readOp, rewriter, strides);
+
+ Value flatMemref = collapseMemrefTo1D(readOp, rewriter);
+
+ Value mask = vector::ConstantMaskOp::create(
+ rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
+ vectorShape);
+ auto gatherOp = xegpu::LoadGatherOp::create(
+ rewriter, loc, vectorType, flatMemref, localOffsets, mask,
+ /*chunk_size=*/IntegerAttr{},
+ /*l1_hint=*/xegpu::CachePolicyAttr{},
+ /*l2_hint=*/xegpu::CachePolicyAttr{},
+ /*l3_hint=*/xegpu::CachePolicyAttr{});
+
+ rewriter.replaceOp(readOp, gatherOp.getResult());
+ return success();
+}
+
+static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
+ PatternRewriter &rewriter) {
+
+ Location loc = writeOp.getLoc();
+ VectorType vectorType = writeOp.getVectorType();
+ ArrayRef<int64_t> vectorShape = vectorType.getShape();
+
+ auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
+ if (!memrefType)
+ return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
+
+ SmallVector<Value> strides = computeStrides(writeOp, rewriter);
+
+ Value localOffsets = computeOffsets(writeOp, rewriter, strides);
+
+ Value flatMemref = collapseMemrefTo1D(writeOp, rewriter);
+
+ Value mask = vector::ConstantMaskOp::create(
+ rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
+ vectorShape);
+ xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
+ localOffsets, mask,
+ /*chunk_size=*/IntegerAttr{},
+ /*l1_hint=*/xegpu::CachePolicyAttr{},
+ /*l2_hint=*/xegpu::CachePolicyAttr{},
+ /*l3_hint=*/xegpu::CachePolicyAttr{});
+ rewriter.eraseOp(writeOp);
+ return success();
+}
+
struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
@@ -165,6 +434,22 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
if (failed(transferPreconditions(rewriter, readOp)))
return failure();
+ // TODO:This check needs to be replaced with proper uArch capability check
+ auto chip = xegpu::getChipStr(readOp);
+ if (chip != "pvc" && chip != "bmg") {
+ // lower to scattered load Op if the target HW doesn't have 2d block load
+ // support
+ // TODO: add support for OutOfBound access
+ if (readOp.hasOutOfBoundsDim())
+ return failure();
+ return lowerToScatteredLoadOp(readOp, rewriter);
+ }
+
+ // Perform common data transfer checks.
+ VectorType vecTy = readOp.getVectorType();
+ if (failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
+ return failure();
+
bool isOutOfBounds = readOp.hasOutOfBoundsDim();
if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
return rewriter.notifyMatchFailure(
@@ -173,7 +458,6 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
AffineMap readMap = readOp.getPermutationMap();
bool isTransposeLoad = !readMap.isMinorIdentity();
- VectorType vecTy = readOp.getVectorType();
Type elementType = vecTy.getElementType();
unsigned minTransposeBitWidth = 32;
if (isTransposeLoad &&
@@ -221,11 +505,26 @@ struct TransferWriteLowering
if (failed(transferPreconditions(rewriter, writeOp)))
return failure();
+ // TODO:This check needs to be replaced with proper uArch capability check
+ auto chip = xegpu::getChipStr(writeOp);
+ if (chip != "pvc" && chip != "bmg") {
+ // lower to scattered store Op if the target HW doesn't have 2d block
+ // store support
+ // TODO: add support for OutOfBound access
+ if (writeOp.hasOutOfBoundsDim())
+ return failure();
+ return lowerToScatteredStoreOp(writeOp, rewriter);
+ }
+
+ // Perform common data transfer checks.
+ VectorType vecTy = writeOp.getVectorType();
+ if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy)))
+ return failure();
+
AffineMap map = writeOp.getPermutationMap();
if (!map.isMinorIdentity())
return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
- VectorType vecTy = writeOp.getVectorType();
auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(),
/*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
new file mode 100644
index 0000000..84b2580
--- /dev/null
+++ b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
@@ -0,0 +1,27 @@
+add_mlir_conversion_library(MLIRXeGPUToXeVM
+ XeGPUToXeVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/XeGPUToXeVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRFuncDialect
+ MLIRGPUDialect
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRXeVMDialect
+ MLIRVectorDialect
+ MLIRArithDialect
+ MLIRIndexDialect
+ MLIRSCFDialect
+ MLIRXeGPUDialect
+ MLIRPass
+ MLIRTransforms
+ MLIRSCFTransforms
+)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
new file mode 100644
index 0000000..a7f2dc2
--- /dev/null
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -0,0 +1,1026 @@
+//===-- XeGPUToXeVM.cpp - XeGPU to XeVM dialect conversion ------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+
+#include <numeric>
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+// TODO: Below are uArch dependent values, should move away from hardcoding
+static constexpr int32_t systolicDepth{8};
+static constexpr int32_t executionSize{16};
+
+// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
+enum class NdTdescOffset : uint32_t {
+ BasePtr = 0, // Base pointer (i64)
+ BaseShapeW = 2, // Base shape width (i32)
+ BaseShapeH = 3, // Base shape height (i32)
+ TensorOffsetW = 4, // Tensor offset W (i32)
+ TensorOffsetH = 5 // Tensor offset H (i32)
+};
+
+static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
+ switch (xeGpuMemspace) {
+ case xegpu::MemorySpace::Global:
+ return static_cast<int>(xevm::AddrSpace::GLOBAL);
+ case xegpu::MemorySpace::SLM:
+ return static_cast<int>(xevm::AddrSpace::SHARED);
+ }
+}
+
+// Get same bitwidth flat vector type of new element type.
+static VectorType encodeVectorTypeTo(VectorType currentVecType,
+ Type toElemType) {
+ auto elemType = currentVecType.getElementType();
+ auto currentBitWidth = elemType.getIntOrFloatBitWidth();
+ auto newBitWidth = toElemType.getIntOrFloatBitWidth();
+ const int size =
+ currentVecType.getNumElements() * currentBitWidth / newBitWidth;
+ return VectorType::get(size, toElemType);
+}
+
+static xevm::LoadCacheControl
+translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
+ std::optional<xegpu::CachePolicy> L3hint) {
+ auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
+ auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
+ switch (L1hintVal) {
+ case xegpu::CachePolicy::CACHED:
+ if (L3hintVal == xegpu::CachePolicy::CACHED)
+ return xevm::LoadCacheControl::L1C_L2UC_L3C;
+ else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::LoadCacheControl::L1C_L2UC_L3UC;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::UNCACHED:
+ if (L3hintVal == xegpu::CachePolicy::CACHED)
+ return xevm::LoadCacheControl::L1UC_L2UC_L3C;
+ else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::STREAMING:
+ if (L3hintVal == xegpu::CachePolicy::CACHED)
+ return xevm::LoadCacheControl::L1S_L2UC_L3C;
+ else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::LoadCacheControl::L1S_L2UC_L3UC;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::READ_INVALIDATE:
+ return xevm::LoadCacheControl::INVALIDATE_READ;
+ default:
+ llvm_unreachable("Unsupported cache control.");
+ }
+}
+
+static xevm::StoreCacheControl
+translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
+ std::optional<xegpu::CachePolicy> L3hint) {
+ auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
+ auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
+ switch (L1hintVal) {
+ case xegpu::CachePolicy::UNCACHED:
+ if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
+ else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+ return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::STREAMING:
+ if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::StoreCacheControl::L1S_L2UC_L3UC;
+ else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+ return xevm::StoreCacheControl::L1S_L2UC_L3WB;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::WRITE_BACK:
+ if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
+ else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+ return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::WRITE_THROUGH:
+ if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
+ else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+ return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ default:
+ llvm_unreachable("Unsupported cache control.");
+ }
+}
+
+class CreateNdDescToXeVMPattern
+ : public OpConversionPattern<xegpu::CreateNdDescOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::CreateNdDescOp op,
+ xegpu::CreateNdDescOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto source = op.getSource();
+ // Op is lowered to a code sequence that populates payload.
+ // Payload is a 8xi32 vector. Offset to individual fields are defined in
+ // NdTdescOffset enum.
+ Type payloadElemTy = rewriter.getI32Type();
+ VectorType payloadTy = VectorType::get(8, payloadElemTy);
+ Type i64Ty = rewriter.getI64Type();
+ // 4xi64 view is used for inserting the base pointer.
+ VectorType payloadI64Ty = VectorType::get(4, i64Ty);
+ // Initialize payload to zero.
+ Value payload = arith::ConstantOp::create(
+ rewriter, loc,
+ DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
+
+ Value baseAddr;
+ Value baseShapeW;
+ Value baseShapeH;
+ Value offsetW;
+ Value offsetH;
+
+ // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
+ SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
+ SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
+ // Descriptor shape is expected to be 2D.
+ int64_t rank = mixedSizes.size();
+ if (rank != 2)
+ return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
+ auto sourceTy = source.getType();
+ auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
+ // If source is a memref, we need to extract the aligned pointer as index.
+ // Pointer type is passed as i32 or i64 by type converter.
+ if (sourceMemrefTy) {
+ if (!sourceMemrefTy.hasStaticShape()) {
+ return rewriter.notifyMatchFailure(op, "Expected static memref shape.");
+ }
+ baseAddr =
+ memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
+ } else {
+ baseAddr = adaptor.getSource();
+ }
+ // Utility for creating offset values from op fold result.
+ auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
+ unsigned idx) -> Value {
+ Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]);
+ val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
+ return val;
+ };
+ // Offsets can be either 2D or not provided (0 is used).
+ if (mixedOffsets.size() == 2) {
+ offsetW = createOffset(mixedOffsets, 1);
+ offsetH = createOffset(mixedOffsets, 0);
+ } else if (mixedOffsets.size() == 0) {
+ offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
+ offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
+ } else {
+ return rewriter.notifyMatchFailure(op,
+ "Expected 2D offsets or no offsets.");
+ }
+ // Get shape values from op fold results.
+ baseShapeW = createOffset(mixedSizes, 1);
+ baseShapeH = createOffset(mixedSizes, 0);
+ if (sourceMemrefTy) {
+ // Cast index to i64.
+ baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
+ } else if (baseAddr.getType() != i64Ty) {
+ // Pointer type may be i32. Cast to i64 if needed.
+ baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
+ }
+ // Populate payload.
+ Value payLoadAsI64 =
+ vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
+ payLoadAsI64 =
+ vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
+ static_cast<int>(NdTdescOffset::BasePtr));
+ payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
+ payload =
+ vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
+ static_cast<int>(NdTdescOffset::BaseShapeW));
+ payload =
+ vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
+ static_cast<int>(NdTdescOffset::BaseShapeH));
+ payload = vector::InsertOp::create(
+ rewriter, loc, offsetW, payload,
+ static_cast<int>(NdTdescOffset::TensorOffsetW));
+ payload = vector::InsertOp::create(
+ rewriter, loc, offsetH, payload,
+ static_cast<int>(NdTdescOffset::TensorOffsetH));
+ rewriter.replaceOp(op, payload);
+ return success();
+ }
+};
+
+class UpdateNdOffsetToXeVMPattern
+ : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::UpdateNdOffsetOp op,
+ xegpu::UpdateNdOffsetOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto mixedOffsets = op.getMixedOffsets();
+ // Only 2D offsets are supported for now.
+ if (mixedOffsets.size() != 2)
+ return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
+ auto payload = adaptor.getTensorDesc();
+ // Utility for updating payload offset values from op fold result.
+ auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
+ Value offset =
+ getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]);
+ offset = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI32Type(), offset);
+ Value oldOffset =
+ vector::ExtractOp::create(rewriter, loc, payload, payloadPos);
+ Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
+ return vector::InsertOp::create(rewriter, loc, newOffset, payload,
+ payloadPos);
+ };
+ // Update offsets in the payload.
+ payload = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
+ payload = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
+ rewriter.replaceOp(op, payload);
+ return success();
+ }
+};
+
+template <
+ typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
+class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto ctxt = rewriter.getContext();
+
+ auto tdesc = adaptor.getTensorDesc();
+ auto tdescTy = op.getTensorDescType();
+ if (tdescTy.getRank() != 2)
+ return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
+ auto elemType = tdescTy.getElementType();
+ auto elemBitSize = elemType.getIntOrFloatBitWidth();
+ if (elemBitSize % 8 != 0)
+ return rewriter.notifyMatchFailure(
+ op, "Expected element type bit width to be multiple of 8.");
+
+ VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
+ Value payLoadAsI64 =
+ vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
+ Value basePtr = vector::ExtractOp::create(
+ rewriter, loc, payLoadAsI64, static_cast<int>(NdTdescOffset::BasePtr));
+ Value baseShapeW = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
+ Value baseShapeH = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
+ // Offsets provided in two ways:
+ // 1. Offsets are extracted from the tensor descriptor.
+ // 2. (Mixed) offsets which are provided by the op.
+ Value offsetW;
+ Value offsetH;
+ auto mixedOffsets = op.getMixedOffsets();
+ int64_t opOffsetsSize = mixedOffsets.size();
+ if (opOffsetsSize != 0 && opOffsetsSize != 2)
+ return rewriter.notifyMatchFailure(op,
+ "Expected 2D offsets or no offsets.");
+ if (opOffsetsSize) {
+ // If mixed offsets are provided by the op convert them to i32.
+ offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
+ offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI32Type(), offsetW);
+ offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
+ offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI32Type(), offsetH);
+ } else {
+ // If offsets are not available, we need to extract them from the tensor
+ // descriptor.
+ offsetW = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetW));
+ offsetH = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetH));
+ }
+ // Get address space from tensor descriptor memory space.
+ auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+ // Convert base pointer (i64) to LLVM pointer type.
+ Value basePtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
+ // Compute element byte size and surface width in bytes.
+ Value elemByteSize = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
+ Value surfaceW =
+ arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
+
+ // Get tile sizes and vblocks from the tensor descriptor type.
+ auto tileW = tdescTy.getDimSize(1);
+ auto tileH = tdescTy.getDimSize(0);
+ int32_t vblocks = tdescTy.getArrayLength();
+ if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
+ Value src = adaptor.getValue();
+ // If store value is a scalar, get value from op instead of adaptor.
+ // Adaptor might have optimized away single element vector
+ if (src.getType().isIntOrFloat()) {
+ src = op.getValue();
+ }
+ VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
+ if (!srcVecTy)
+ return rewriter.notifyMatchFailure(
+ op, "Expected store value to be a vector type.");
+ // Get flat vector type of integer type with matching element bit size.
+ VectorType newSrcVecTy =
+ encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
+ if (srcVecTy != newSrcVecTy)
+ src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
+ auto storeCacheControl =
+ translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ xevm::BlockStore2dOp::create(
+ rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
+ offsetH, elemBitSize, tileW, tileH, src,
+ xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
+ rewriter.eraseOp(op);
+ } else {
+ auto loadCacheControl =
+ translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
+ xevm::BlockPrefetch2dOp::create(
+ rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
+ offsetH, elemBitSize, tileW, tileH, vblocks,
+ xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+ rewriter.eraseOp(op);
+ } else {
+ VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
+ const bool vnni = op.getPacked().value_or(false);
+ auto transposeValue = op.getTranspose();
+ bool transpose =
+ transposeValue.has_value() && transposeValue.value()[0] == 1;
+ VectorType loadedTy = encodeVectorTypeTo(
+ dstVecTy, vnni ? rewriter.getI32Type()
+ : rewriter.getIntegerType(elemBitSize));
+
+ Value resultFlatVec = xevm::BlockLoad2dOp::create(
+ rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
+ surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
+ transpose, vnni,
+ xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+ resultFlatVec = vector::BitCastOp::create(
+ rewriter, loc,
+ encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
+ resultFlatVec);
+ rewriter.replaceOp(op, resultFlatVec);
+ }
+ }
+ return success();
+ }
+};
+
+// Add a builder that creates
+// offset * elemByteSize + baseAddr
+static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
+ Value baseAddr, Value offset, int64_t elemByteSize) {
+ Value byteSize = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI64Type(), elemByteSize);
+ Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
+ Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
+ return newAddr;
+}
+
+class CreateDescToXeVMPattern
+ : public OpConversionPattern<xegpu::CreateDescOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto eTy = op.getTensorDescType().getElementType();
+ auto eBw = eTy.getIntOrFloatBitWidth();
+ if (eBw % 8 != 0)
+ return rewriter.notifyMatchFailure(
+ op, "Expected element type bit width to be multiple of 8.");
+ auto loc = op.getLoc();
+ // Offsets are provided as scalar i64 by type converter.
+ auto offsets = adaptor.getOffsets();
+ // Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32).
+ // But type converter will convert them to integer types.
+ Value addr = adaptor.getSource();
+ // ui32 or i32 are passed as i32 so they need to be casted to i64.
+ if (addr.getType() != rewriter.getI64Type())
+ addr = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), addr);
+ auto laneAddr = addOffset(rewriter, loc, addr, offsets, eBw / 8);
+ rewriter.replaceOp(op, laneAddr);
+ return success();
+ }
+};
+
+class UpdateOffsetToXeVMPattern
+ : public OpConversionPattern<xegpu::UpdateOffsetOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::UpdateOffsetOp op,
+ xegpu::UpdateOffsetOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto eTy = op.getTensorDescType().getElementType();
+ auto eBw = eTy.getIntOrFloatBitWidth();
+ if (eBw % 8 != 0)
+ return rewriter.notifyMatchFailure(
+ op, "Expected element type bit width to be multiple of 8.");
+ auto loc = op.getLoc();
+ // Scatter descriptor is provided as scalar i64 by type converter.
+ // Offsets are provided as scalar i64 by type converter.
+ Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(),
+ adaptor.getOffsets(), eBw / 8);
+ rewriter.replaceOp(op, newOffset);
+ return success();
+ }
+};
+
+template <typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
+class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto ctxt = rewriter.getContext();
+ auto tdescTy = op.getTensorDescType();
+ Value basePtrI64;
+ // Load result or Store valye Type can be vector or scalar.
+ Type valOrResTy;
+ if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
+ valOrResTy = op.getResult().getType();
+ else
+ valOrResTy = adaptor.getValue().getType();
+ VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
+ bool hasScalarVal = !valOrResVecTy;
+ int64_t elemBitWidth =
+ hasScalarVal ? valOrResTy.getIntOrFloatBitWidth()
+ : valOrResVecTy.getElementType().getIntOrFloatBitWidth();
+ // Element type must be multiple of 8 bits.
+ if (elemBitWidth % 8 != 0)
+ return rewriter.notifyMatchFailure(
+ op, "Expected element type bit width to be multiple of 8.");
+ int64_t elemByteSize = elemBitWidth / 8;
+ // Default memory space is global.
+ LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
+ // If tensor descriptor is available, we use its memory space.
+ if (tdescTy)
+ ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+ // Base pointer can come from source (load) or dest (store).
+ // If they are memrefs, we use their memory space.
+ if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
+ basePtrI64 = adaptor.getSource();
+ if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
+ auto addrSpace = memRefTy.getMemorySpaceAsInt();
+ if (addrSpace != 0)
+ ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
+ }
+ } else {
+ basePtrI64 = adaptor.getDest();
+ if (auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
+ auto addrSpace = memRefTy.getMemorySpaceAsInt();
+ if (addrSpace != 0)
+ ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
+ }
+ }
+ // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
+ if (basePtrI64.getType() != rewriter.getI64Type()) {
+ basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
+ basePtrI64);
+ }
+ Value offsets = adaptor.getOffsets();
+ Value mask = adaptor.getMask();
+ if (offsets) {
+ if (dyn_cast<VectorType>(offsets.getType())) {
+ // Offset needs be scalar. Single element vector is converted to scalar
+ // by type converter.
+ return rewriter.notifyMatchFailure(op,
+ "Expected offsets to be a scalar.");
+ } else {
+ // If offsets are provided, we add them to the base pointer.
+ // Offsets are in number of elements, we need to multiply by
+ // element byte size.
+ basePtrI64 =
+ addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
+ }
+ }
+ // Convert base pointer (i64) to LLVM pointer type.
+ Value basePtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+
+ Value maskForLane;
+ VectorType maskVecTy = dyn_cast<VectorType>(mask.getType());
+ if (maskVecTy) {
+ // Mask needs be scalar. Single element vector is converted to scalar by
+ // type converter.
+ return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar.");
+ } else
+ maskForLane = mask;
+ if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
+ scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
+ maskForLane, true, true);
+ // If mask is true,- then clause - load from memory and yield.
+ rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ if (!hasScalarVal)
+ valOrResTy = VectorType::get({valOrResVecTy.getNumElements()},
+ valOrResVecTy.getElementType());
+ Value loaded =
+ LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
+ // Set cache control attribute on the load operation.
+ loaded.getDefiningOp()->setAttr(
+ "cache_control", xevm::LoadCacheControlAttr::get(
+ ctxt, translateLoadXeGPUCacheHint(
+ op.getL1Hint(), op.getL3Hint())));
+ scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
+ rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ // If mask is false - else clause -yield a vector of zeros.
+ auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
+ TypedAttr eVal;
+ if (eTy.isFloat())
+ eVal = FloatAttr::get(eTy, 0.0);
+ else
+ eVal = IntegerAttr::get(eTy, 0);
+ if (hasScalarVal)
+ loaded = arith::ConstantOp::create(rewriter, loc, eVal);
+ else
+ loaded = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(valOrResVecTy, eVal));
+ scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
+ rewriter.replaceOp(op, ifOp.getResult(0));
+ } else {
+ // If mask is true, perform the store.
+ scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false);
+ auto body = ifOp.getBody();
+ rewriter.setInsertionPointToStart(body);
+ auto storeOp =
+ LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
+ // Set cache control attribute on the store operation.
+ storeOp.getOperation()->setAttr(
+ "cache_control", xevm::StoreCacheControlAttr::get(
+ ctxt, translateStoreXeGPUCacheHint(
+ op.getL1Hint(), op.getL3Hint())));
+ rewriter.eraseOp(op);
+ }
+ return success();
+ }
+};
+
+class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto ctxt = rewriter.getContext();
+ auto tdescTy = op.getTensorDescType();
+ Value basePtrI64 = adaptor.getSource();
+ // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
+ if (basePtrI64.getType() != rewriter.getI64Type())
+ basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
+ basePtrI64);
+ Value offsets = adaptor.getOffsets();
+ if (offsets) {
+ VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
+ if (offsetsVecTy) {
+ // Offset needs be scalar.
+ return rewriter.notifyMatchFailure(op,
+ "Expected offsets to be a scalar.");
+ } else {
+ int64_t elemBitWidth{0};
+ int64_t elemByteSize;
+ // Element byte size can come from three sources:
+ if (tdescTy) {
+ // If tensor descriptor is available, we use its element type to
+ // determine element byte size.
+ elemBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
+ } else if (auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
+ // If memref is available, we use its element type to
+ // determine element byte size.
+ elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
+ } else {
+ // Otherwise, we use the provided offset byte alignment.
+ elemByteSize = *op.getOffsetAlignByte();
+ }
+ if (elemBitWidth != 0) {
+ if (elemBitWidth % 8 != 0)
+ return rewriter.notifyMatchFailure(
+ op, "Expected element type bit width to be multiple of 8.");
+ elemByteSize = elemBitWidth / 8;
+ }
+ basePtrI64 =
+ addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
+ }
+ }
+ // Default memory space is global.
+ LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
+ // If tensor descriptor is available, we use its memory space.
+ if (tdescTy)
+ ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+ // If source is a memref, we use its memory space.
+ if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
+ auto addrSpace = memRefTy.getMemorySpaceAsInt();
+ if (addrSpace != 0)
+ ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
+ }
+ // Convert base pointer (i64) to LLVM pointer type.
+ Value ptrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+ // Create the prefetch op with cache control attribute.
+ xevm::PrefetchOp::create(
+ rewriter, loc, ptrLLVM,
+ xevm::LoadCacheControlAttr::get(
+ ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+class FenceToXeVMPattern : public OpConversionPattern<xegpu::FenceOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
+ switch (op.getFenceScope()) {
+ case xegpu::FenceScope::Workgroup:
+ memScope = xevm::MemScope::WORKGROUP;
+ break;
+ case xegpu::FenceScope::GPU:
+ memScope = xevm::MemScope::DEVICE;
+ break;
+ }
+ xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
+ switch (op.getMemoryKind()) {
+ case xegpu::MemorySpace::Global:
+ addrSpace = xevm::AddrSpace::GLOBAL;
+ break;
+ case xegpu::MemorySpace::SLM:
+ addrSpace = xevm::AddrSpace::SHARED;
+ break;
+ }
+ xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto ctxt = rewriter.getContext();
+ auto aTy = cast<VectorType>(op.getLhs().getType());
+ auto bTy = cast<VectorType>(op.getRhs().getType());
+ auto resultType = cast<VectorType>(op.getResultType());
+
+ auto encodePrecision = [&](Type type) -> xevm::ElemType {
+ if (type == rewriter.getBF16Type())
+ return xevm::ElemType::BF16;
+ else if (type == rewriter.getF16Type())
+ return xevm::ElemType::F16;
+ else if (type == rewriter.getTF32Type())
+ return xevm::ElemType::TF32;
+ else if (type.isInteger(8)) {
+ if (type.isUnsignedInteger())
+ return xevm::ElemType::U8;
+ return xevm::ElemType::S8;
+ } else if (type == rewriter.getF32Type())
+ return xevm::ElemType::F32;
+ else if (type.isInteger(32))
+ return xevm::ElemType::S32;
+ llvm_unreachable("add more support for ElemType");
+ };
+ xevm::ElemType precATy = encodePrecision(aTy.getElementType());
+ xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
+ Value c = op.getAcc();
+ if (!c) {
+ auto elementTy = resultType.getElementType();
+ Attribute initValueAttr;
+ if (isa<FloatType>(elementTy))
+ initValueAttr = FloatAttr::get(elementTy, 0.0);
+ else
+ initValueAttr = IntegerAttr::get(elementTy, 0);
+ c = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(resultType, initValueAttr));
+ }
+
+ Value aVec = op.getLhs();
+ Value bVec = op.getRhs();
+ auto cvecty = cast<VectorType>(c.getType());
+ xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
+ xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
+ VectorType cNty =
+ VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
+ if (cvecty != cNty)
+ c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
+ Value dpasRes = xevm::MMAOp::create(
+ rewriter, loc, cNty, aVec, bVec, c,
+ xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize,
+ systolicDepth *
+ getNumOperandsPerDword(precATy)),
+ xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
+ if (cvecty != cNty)
+ dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
+ rewriter.replaceOp(op, dpasRes);
+ return success();
+ }
+
+private:
+ static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
+ switch (pTy) {
+ case xevm::ElemType::TF32:
+ return 1;
+ case xevm::ElemType::BF16:
+ case xevm::ElemType::F16:
+ return 2;
+ case xevm::ElemType::U8:
+ case xevm::ElemType::S8:
+ return 4;
+ default:
+ llvm_unreachable("unsupported xevm::ElemType");
+ }
+ }
+};
+
+static std::optional<LLVM::AtomicBinOp>
+matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
+ switch (arithKind) {
+ case arith::AtomicRMWKind::addf:
+ return LLVM::AtomicBinOp::fadd;
+ case arith::AtomicRMWKind::addi:
+ return LLVM::AtomicBinOp::add;
+ case arith::AtomicRMWKind::assign:
+ return LLVM::AtomicBinOp::xchg;
+ case arith::AtomicRMWKind::maximumf:
+ return LLVM::AtomicBinOp::fmax;
+ case arith::AtomicRMWKind::maxs:
+ return LLVM::AtomicBinOp::max;
+ case arith::AtomicRMWKind::maxu:
+ return LLVM::AtomicBinOp::umax;
+ case arith::AtomicRMWKind::minimumf:
+ return LLVM::AtomicBinOp::fmin;
+ case arith::AtomicRMWKind::mins:
+ return LLVM::AtomicBinOp::min;
+ case arith::AtomicRMWKind::minu:
+ return LLVM::AtomicBinOp::umin;
+ case arith::AtomicRMWKind::ori:
+ return LLVM::AtomicBinOp::_or;
+ case arith::AtomicRMWKind::andi:
+ return LLVM::AtomicBinOp::_and;
+ default:
+ return std::nullopt;
+ }
+}
+
+class AtomicRMWToXeVMPattern : public OpConversionPattern<xegpu::AtomicRMWOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto ctxt = rewriter.getContext();
+ auto tdesc = op.getTensorDesc().getType();
+ auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
+ Value basePtrI64 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc());
+ Value basePtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+ VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
+ VectorType srcOrDstFlatVecTy = VectorType::get(
+ srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
+ Value srcFlatVec = vector::ShapeCastOp::create(
+ rewriter, loc, srcOrDstFlatVecTy, op.getValue());
+ auto atomicKind = matchSimpleAtomicOp(op.getKind());
+ assert(atomicKind.has_value());
+ Value resVec = srcFlatVec;
+ for (int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
+ auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
+ Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
+ rewriter.getIndexAttr(i));
+ Value currPtr =
+ LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
+ srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
+ Value newVal =
+ LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
+ val, LLVM::AtomicOrdering::seq_cst);
+ resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
+ }
+ rewriter.replaceOp(op, resVec);
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
+struct ConvertXeGPUToXeVMPass
+ : public impl::ConvertXeGPUToXeVMPassBase<ConvertXeGPUToXeVMPass> {
+ using Base::Base;
+
+ void runOnOperation() override {
+ LLVMTypeConverter typeConverter(&getContext());
+ typeConverter.addConversion([&](VectorType type) -> Type {
+ unsigned rank = type.getRank();
+ auto elemType = type.getElementType();
+ // If the element type is index, convert it to i64.
+ if (llvm::isa<IndexType>(elemType))
+ elemType = IntegerType::get(&getContext(), 64);
+ // If the vector is a scalar or has a single element, return the element
+ if (rank < 1 || type.getNumElements() == 1)
+ return elemType;
+ // Otherwise, convert the vector to a flat vector type.
+ int64_t sum =
+ std::accumulate(type.getShape().begin(), type.getShape().end(),
+ int64_t{1}, std::multiplies<int64_t>());
+ return VectorType::get(sum, elemType);
+ });
+ typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
+ if (type.isScattered())
+ return IntegerType::get(&getContext(), 64);
+ auto i32Type = IntegerType::get(&getContext(), 32);
+ return VectorType::get(8, i32Type);
+ });
+ typeConverter.addConversion([&](MemRefType type) -> Type {
+ // Convert MemRefType to i64 type.
+ return IntegerType::get(&getContext(), 64);
+ });
+
+ // LLVM type converter puts unrealized casts for the following cases:
+ // add materialization casts to handle them.
+
+ // Materialization to convert memref to i64
+ auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1)
+ return {};
+ auto input = inputs.front();
+ if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
+
+ Value addr =
+ memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input);
+ return arith::IndexCastUIOp::create(builder, loc, type, addr)
+ .getResult();
+ }
+ return {};
+ };
+
+ // Materialization to convert ui64 to i64
+ auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1)
+ return {};
+ auto input = inputs.front();
+ if (input.getType() == builder.getIntegerType(64, false)) {
+ Value cast =
+ index::CastUOp::create(builder, loc, builder.getIndexType(), input)
+ .getResult();
+ return arith::IndexCastUIOp::create(builder, loc, type, cast)
+ .getResult();
+ }
+ return {};
+ };
+
+ // Materialization to convert ui32 to i32
+ auto ui32MaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1)
+ return {};
+ auto input = inputs.front();
+ if (input.getType() == builder.getIntegerType(32, false)) {
+ Value cast =
+ index::CastUOp::create(builder, loc, builder.getIndexType(), input)
+ .getResult();
+ return arith::IndexCastUIOp::create(builder, loc, type, cast)
+ .getResult();
+ }
+ return {};
+ };
+
+ // Materialization to convert
+ // - single element 1D vector to scalar
+ // - bitcast vector of same rank
+ // - shape vector of different rank but same element type
+ auto vectorMaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1)
+ return {};
+ auto input = inputs.front();
+ if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
+ if (vecTy.getNumElements() == 1) {
+ // If the vector has a single element, return the element type.
+ Value cast =
+ vector::ExtractOp::create(builder, loc, input, 0).getResult();
+ if (vecTy.getElementType() == builder.getIndexType())
+ cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
+ .getResult();
+ return cast;
+ } else if (auto targetVecTy = dyn_cast<VectorType>(type)) {
+ // If the target type is a vector of same rank,
+ // bitcast to the target type.
+ if (targetVecTy.getRank() == vecTy.getRank())
+ return vector::BitCastOp::create(builder, loc, targetVecTy, input)
+ .getResult();
+ else if (targetVecTy.getElementType() == vecTy.getElementType()) {
+ // If the target type is a vector of different rank but same element
+ // type, reshape to the target type.
+ return vector::ShapeCastOp::create(builder, loc, targetVecTy, input)
+ .getResult();
+ }
+ }
+ }
+ return {};
+ };
+ typeConverter.addSourceMaterialization(memrefMaterializationCast);
+ typeConverter.addSourceMaterialization(ui64MaterializationCast);
+ typeConverter.addSourceMaterialization(ui32MaterializationCast);
+ typeConverter.addSourceMaterialization(vectorMaterializationCast);
+ typeConverter.addTargetMaterialization(memrefMaterializationCast);
+ typeConverter.addTargetMaterialization(ui32MaterializationCast);
+ typeConverter.addTargetMaterialization(ui64MaterializationCast);
+ typeConverter.addTargetMaterialization(vectorMaterializationCast);
+ ConversionTarget target(getContext());
+ target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
+ vector::VectorDialect, arith::ArithDialect,
+ memref::MemRefDialect, gpu::GPUDialect,
+ index::IndexDialect>();
+ target.addIllegalDialect<xegpu::XeGPUDialect>();
+
+ RewritePatternSet patterns(&getContext());
+ populateXeGPUToXeVMConversionPatterns(typeConverter, patterns);
+ scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter,
+ patterns, target);
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern Population
+//===----------------------------------------------------------------------===//
+void mlir::populateXeGPUToXeVMConversionPatterns(
+ const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
+ patterns.add<CreateNdDescToXeVMPattern, UpdateNdOffsetToXeVMPattern,
+ LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
+ LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
+ LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
+ typeConverter, patterns.getContext());
+ patterns.add<CreateDescToXeVMPattern, UpdateOffsetToXeVMPattern,
+ AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
+ LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
+ LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
+ typeConverter, patterns.getContext());
+ patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
+ patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index d7ffdcb..11a40d6 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -511,6 +511,18 @@ LogicalResult DPPOp::verify() {
}
//===----------------------------------------------------------------------===//
+// PermlaneSwapOp
+//===----------------------------------------------------------------------===//
+LogicalResult PermlaneSwapOp::verify() {
+ unsigned rowLength = getRowLength();
+
+ if (rowLength != 16 && rowLength != 32)
+ return emitOpError("row_length attribute must either be 16 or 32.");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// GatherToLDSOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index 729e3da..d35853b 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -5,7 +5,7 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms
ResolveStridedMetadata.cpp
ADDITIONAL_HEADER_DIRS
- {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
DEPENDS
MLIRAMDGPUTransformsIncGen
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index 6f3110c..68990ef 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -271,7 +271,9 @@ Type amx::TileType::parse(AsmParser &parser) {
if (parser.parseGreater())
return nullptr;
- return TileType::get(shape, elementType);
+ return TileType::getChecked(
+ [&] { return parser.emitError(parser.getNameLoc()); }, shape,
+ elementType);
}
void amx::TileType::print(AsmPrinter &os) const {
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
index 86edc2b..b405ec2 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
@@ -93,13 +93,13 @@ FlatAffineValueConstraints::addAffineForOpDomain(AffineForOp forOp) {
int64_t lb = forOp.getConstantLowerBound();
dividend[pos] = 1;
dividend.back() -= lb;
- addLocalFloorDiv(dividend, step);
+ unsigned qPos = addLocalFloorDiv(dividend, step);
// Second constraint: (iv - lb) - step * q = 0.
SmallVector<int64_t, 8> eq(getNumCols(), 0);
eq[pos] = 1;
eq.back() -= lb;
// For the local var just added above.
- eq[getNumCols() - 2] = -step;
+ eq[qPos] = -step;
addEquality(eq);
}
}
diff --git a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
index 2f85e0b..166d39e 100644
--- a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
@@ -21,6 +21,7 @@
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include <numeric>
#include <optional>
@@ -548,19 +549,19 @@ bool mlir::affine::isTilingValid(ArrayRef<AffineForOp> loops) {
// Check whether there is any negative direction vector in the
// dependence components found above, which means that dependence is
// violated by the default hyper-rect tiling method.
- LLVM_DEBUG(llvm::dbgs() << "Checking whether tiling legality violated "
- "for dependence at depth: "
- << Twine(d) << " between:\n";);
- LLVM_DEBUG(srcAccess.opInst->dump());
- LLVM_DEBUG(dstAccess.opInst->dump());
+ LDBG() << "Checking whether tiling legality violated "
+ << "for dependence at depth: " << Twine(d) << " between:"
+ << OpWithFlags(srcAccess.opInst, OpPrintingFlags().skipRegions())
+ << "\nand:\n"
+ << OpWithFlags(dstAccess.opInst,
+ OpPrintingFlags().skipRegions());
for (const DependenceComponent &depComp : depComps) {
if (depComp.lb.has_value() && depComp.ub.has_value() &&
*depComp.lb < *depComp.ub && *depComp.ub < 0) {
- LLVM_DEBUG(llvm::dbgs()
- << "Dependence component lb = " << Twine(*depComp.lb)
- << " ub = " << Twine(*depComp.ub)
- << " is negative at depth: " << Twine(d)
- << " and thus violates the legality rule.\n");
+ LDBG() << "Dependence component lb = " << Twine(*depComp.lb)
+ << " ub = " << Twine(*depComp.ub)
+ << " is negative at depth: " << Twine(d)
+ << " and thus violates the legality rule.";
return false;
}
}
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index a89c1ae..99ea20b 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/IntegerSet.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
@@ -241,7 +242,7 @@ addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
}
bool MemRefDependenceGraph::init() {
- LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
+ LDBG() << "--- Initializing MDG ---";
// Map from a memref to the set of ids of the nodes that have ops accessing
// the memref.
DenseMap<Value, SetVector<unsigned>> memrefAccesses;
@@ -288,8 +289,7 @@ bool MemRefDependenceGraph::init() {
// Return false if non-handled/unknown region-holding ops are found. We
// won't know what such ops do or what its regions mean; for e.g., it may
// not be an imperative op.
- LLVM_DEBUG(llvm::dbgs()
- << "MDG init failed; unknown region-holding op found!\n");
+ LDBG() << "MDG init failed; unknown region-holding op found!";
return false;
}
// We aren't creating nodes for memory-effect free ops either with no
@@ -297,7 +297,7 @@ bool MemRefDependenceGraph::init() {
// interface.
}
- LLVM_DEBUG(llvm::dbgs() << "Created " << nodes.size() << " nodes\n");
+ LDBG() << "Created " << nodes.size() << " nodes";
// Add dependence edges between nodes which produce SSA values and their
// users. Load ops can be considered as the ones producing SSA values.
@@ -556,9 +556,8 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
gatherDefiningNodes(dstId, definingNodes);
if (llvm::any_of(definingNodes,
[&](unsigned id) { return hasDependencePath(srcId, id); })) {
- LLVM_DEBUG(llvm::dbgs()
- << "Can't fuse: a defining op with a user in the dst "
- "loop has dependence from the src loop\n");
+ LDBG() << "Can't fuse: a defining op with a user in the dst "
+ << "loop has dependence from the src loop";
return nullptr;
}
@@ -957,20 +956,20 @@ std::optional<bool> ComputationSliceState::isSliceValid() const {
FlatAffineValueConstraints srcConstraints;
// TODO: Store the source's domain to avoid computation at each depth.
if (failed(getSourceAsConstraints(srcConstraints))) {
- LLVM_DEBUG(llvm::dbgs() << "Unable to compute source's domain\n");
+ LDBG() << "Unable to compute source's domain";
return std::nullopt;
}
// As the set difference utility currently cannot handle symbols in its
// operands, validity of the slice cannot be determined.
if (srcConstraints.getNumSymbolVars() > 0) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot handle symbols in source domain\n");
+ LDBG() << "Cannot handle symbols in source domain";
return std::nullopt;
}
// TODO: Handle local vars in the source domains while using the 'projectOut'
// utility below. Currently, aligning is not done assuming that there will be
// no local vars in the source domain.
if (srcConstraints.getNumLocalVars() != 0) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot handle locals in source domain\n");
+ LDBG() << "Cannot handle locals in source domain";
return std::nullopt;
}
@@ -978,7 +977,7 @@ std::optional<bool> ComputationSliceState::isSliceValid() const {
// fusion succeeds.
FlatAffineValueConstraints sliceConstraints;
if (failed(getAsConstraints(&sliceConstraints))) {
- LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice's domain\n");
+ LDBG() << "Unable to compute slice's domain";
return std::nullopt;
}
@@ -987,11 +986,11 @@ std::optional<bool> ComputationSliceState::isSliceValid() const {
sliceConstraints.projectOut(ivs.size(),
sliceConstraints.getNumVars() - ivs.size());
- LLVM_DEBUG(llvm::dbgs() << "Domain of the source of the slice:\n");
- LLVM_DEBUG(srcConstraints.dump());
- LLVM_DEBUG(llvm::dbgs() << "Domain of the slice if this fusion succeeds "
- "(expressed in terms of its source's IVs):\n");
- LLVM_DEBUG(sliceConstraints.dump());
+ LDBG() << "Domain of the source of the slice:\n"
+ << "Source constraints:" << srcConstraints
+ << "\nDomain of the slice if this fusion succeeds "
+ << "(expressed in terms of its source's IVs):\n"
+ << "Slice constraints:" << sliceConstraints;
// TODO: Store 'srcSet' to avoid recalculating for each depth.
PresburgerSet srcSet(srcConstraints);
@@ -999,7 +998,7 @@ std::optional<bool> ComputationSliceState::isSliceValid() const {
PresburgerSet diffSet = sliceSet.subtract(srcSet);
if (!diffSet.isIntegerEmpty()) {
- LLVM_DEBUG(llvm::dbgs() << "Incorrect slice\n");
+ LDBG() << "Incorrect slice";
return false;
}
return true;
@@ -1172,8 +1171,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
unsigned rank = access.getRank();
- LLVM_DEBUG(llvm::dbgs() << "MemRefRegion::compute: " << *op
- << "\ndepth: " << loopDepth << "\n";);
+ LDBG() << "MemRefRegion::compute: " << *op << " depth: " << loopDepth;
// 0-d memrefs.
if (rank == 0) {
@@ -1236,7 +1234,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
if (auto constVal = getConstantIntValue(symbol))
cst.addBound(BoundType::EQ, symbol, constVal.value());
} else {
- LLVM_DEBUG(llvm::dbgs() << "unknown affine dimensional value");
+ LDBG() << "unknown affine dimensional value";
return failure();
}
}
@@ -1260,7 +1258,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
// Add access function equalities to connect loop IVs to data dimensions.
if (failed(cst.composeMap(&accessValueMap))) {
op->emitError("getMemRefRegion: compose affine map failed");
- LLVM_DEBUG(accessValueMap.getAffineMap().dump());
+ LDBG() << "Access map: " << accessValueMap.getAffineMap();
return failure();
}
@@ -1317,8 +1315,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
}
cst.removeTrivialRedundancy();
- LLVM_DEBUG(llvm::dbgs() << "Memory region:\n");
- LLVM_DEBUG(cst.dump());
+ LDBG() << "Memory region: " << cst;
return success();
}
@@ -1346,14 +1343,14 @@ std::optional<int64_t> MemRefRegion::getRegionSize() {
auto memRefType = cast<MemRefType>(memref.getType());
if (!memRefType.getLayout().isIdentity()) {
- LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
+ LDBG() << "Non-identity layout map not yet supported";
return false;
}
// Compute the extents of the buffer.
std::optional<int64_t> numElements = getConstantBoundingSizeAndShape();
if (!numElements) {
- LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n");
+ LDBG() << "Dynamic shapes not yet supported";
return std::nullopt;
}
auto eltSize = getMemRefIntOrFloatEltSizeInBytes(memRefType);
@@ -1397,8 +1394,7 @@ LogicalResult mlir::affine::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,
/*addMemRefDimBounds=*/false)))
return success();
- LLVM_DEBUG(llvm::dbgs() << "Memory region");
- LLVM_DEBUG(region.getConstraints()->dump());
+ LDBG() << "Memory region: " << region.getConstraints();
bool outOfBounds = false;
unsigned rank = loadOrStoreOp.getMemRefType().getRank();
@@ -1558,7 +1554,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
// Check if 'loopDepth' exceeds nesting depth of src/dst ops.
if ((!isBackwardSlice && loopDepth > getNestingDepth(a)) ||
(isBackwardSlice && loopDepth > getNestingDepth(b))) {
- LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n");
+ LDBG() << "Invalid loop depth";
return SliceComputationResult::GenericFailure;
}
@@ -1571,7 +1567,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
&dependenceConstraints, /*dependenceComponents=*/nullptr,
/*allowRAR=*/readReadAccesses);
if (result.value == DependenceResult::Failure) {
- LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n");
+ LDBG() << "Dependence check failed";
return SliceComputationResult::GenericFailure;
}
if (result.value == DependenceResult::NoDependence)
@@ -1586,8 +1582,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
if (sliceUnionCst.getNumDimAndSymbolVars() == 0) {
// Initialize 'sliceUnionCst' with the bounds computed in previous step.
if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Unable to compute slice bound constraints\n");
+ LDBG() << "Unable to compute slice bound constraints";
return SliceComputationResult::GenericFailure;
}
assert(sliceUnionCst.getNumDimAndSymbolVars() > 0);
@@ -1597,8 +1592,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
// Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
FlatAffineValueConstraints tmpSliceCst;
if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Unable to compute slice bound constraints\n");
+ LDBG() << "Unable to compute slice bound constraints";
return SliceComputationResult::GenericFailure;
}
@@ -1630,8 +1624,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
if (sliceUnionCst.getNumLocalVars() > 0 ||
tmpSliceCst.getNumLocalVars() > 0 ||
failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Unable to compute union bounding box of slice bounds\n");
+ LDBG() << "Unable to compute union bounding box of slice bounds";
return SliceComputationResult::GenericFailure;
}
}
@@ -1639,7 +1632,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
// Empty union.
if (sliceUnionCst.getNumDimAndSymbolVars() == 0) {
- LLVM_DEBUG(llvm::dbgs() << "empty slice union - unexpected\n");
+ LDBG() << "empty slice union - unexpected";
return SliceComputationResult::GenericFailure;
}
@@ -1652,7 +1645,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
unsigned innermostCommonLoopDepth =
getInnermostCommonLoopDepth(ops, &surroundingLoops);
if (loopDepth > innermostCommonLoopDepth) {
- LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n");
+ LDBG() << "Exceeds max loop depth";
return SliceComputationResult::GenericFailure;
}
@@ -1696,7 +1689,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
// that the slice is valid, otherwise return appropriate failure status.
std::optional<bool> isSliceValid = sliceUnion->isSliceValid();
if (!isSliceValid) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot determine if the slice is valid\n");
+ LDBG() << "Cannot determine if the slice is valid";
return SliceComputationResult::GenericFailure;
}
if (!*isSliceValid)
@@ -2050,7 +2043,8 @@ static std::optional<int64_t> getMemoryFootprintBytes(Block &block,
if (failed(
region->compute(opInst,
/*loopDepth=*/getNestingDepth(&*block.begin())))) {
- LLVM_DEBUG(opInst->emitError("error obtaining memory region"));
+ LDBG() << "Error obtaining memory region";
+ opInst->emitError("error obtaining memory region");
return failure();
}
@@ -2058,9 +2052,11 @@ static std::optional<int64_t> getMemoryFootprintBytes(Block &block,
if (inserted) {
it->second = std::move(region);
} else if (failed(it->second->unionBoundingBox(*region))) {
- LLVM_DEBUG(opInst->emitWarning(
+ LDBG() << "getMemoryFootprintBytes: unable to perform a union on a "
+ "memory region";
+ opInst->emitWarning(
"getMemoryFootprintBytes: unable to perform a union on a memory "
- "region"));
+ "region");
return failure();
}
return WalkResult::advance();
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 22608a1..7e5ce26 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -427,6 +427,21 @@ bool mlir::affine::isValidSymbol(Value value) {
return false;
}
+/// A utility function to check if a value is defined at the top level of
+/// `region` or is an argument of `region` or is defined above the region.
+static bool isTopLevelValueOrAbove(Value value, Region *region) {
+ Region *parentRegion = value.getParentRegion();
+ do {
+ if (parentRegion == region)
+ return true;
+ Operation *regionOp = region->getParentOp();
+ if (regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
+ break;
+ region = region->getParentOp()->getParentRegion();
+ } while (region);
+ return false;
+}
+
/// A value can be used as a symbol for `region` iff it meets one of the
/// following conditions:
/// *) It is a constant.
@@ -445,19 +460,12 @@ bool mlir::affine::isValidSymbol(Value value, Region *region) {
return false;
// A top-level value is a valid symbol.
- if (region && ::isTopLevelValue(value, region))
+ if (region && isTopLevelValueOrAbove(value, region))
return true;
auto *defOp = value.getDefiningOp();
- if (!defOp) {
- // A block argument that is not a top-level value is a valid symbol if it
- // dominates region's parent op.
- Operation *regionOp = region ? region->getParentOp() : nullptr;
- if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
- if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
- return isValidSymbol(value, parentOpRegion);
+ if (!defOp)
return false;
- }
// Constant operation is ok.
Attribute operandCst;
@@ -475,12 +483,6 @@ bool mlir::affine::isValidSymbol(Value value, Region *region) {
if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
return isDimOpValidSymbol(dimOp, region);
- // Check for values dominating `region`'s parent op.
- Operation *regionOp = region ? region->getParentOp() : nullptr;
- if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
- if (auto *parentRegion = region->getParentOp()->getParentRegion())
- return isValidSymbol(value, parentRegion);
-
return false;
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 6c9adff..ff0157e 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -26,6 +26,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#include <iomanip>
#include <optional>
@@ -95,8 +96,8 @@ static bool canRemoveSrcNodeAfterFusion(
// Otherwise, the src loop can't be removed.
if (fusedLoopInsPoint != depNodeOp &&
!fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) {
- LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't "
- "dominate dependence\n");
+ LDBG() << "Src loop can't be removed: dst loop doesn't "
+ << "dominate dependence";
return false;
}
@@ -109,14 +110,13 @@ static bool canRemoveSrcNodeAfterFusion(
if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
std::optional<bool> isMaximal = fusionSlice.isMaximal();
if (!isMaximal) {
- LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine "
- "if fusion is maximal\n");
+ LDBG() << "Src loop can't be removed: can't determine "
+ << "if fusion is maximal";
return false;
}
if (!*isMaximal) {
- LLVM_DEBUG(llvm::dbgs()
- << "Src loop can't be removed: fusion is not maximal\n");
+ LDBG() << "Src loop can't be removed: fusion is not maximal";
return false;
}
}
@@ -190,7 +190,8 @@ static bool isEscapingMemref(Value memref, Block *block) {
// Check if this is defined to be an alias of another memref.
if (auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp))
- if (isEscapingMemref(viewOp.getViewSource(), block))
+ if (memref == viewOp.getViewDest() &&
+ isEscapingMemref(viewOp.getViewSource(), block))
return true;
// Any op besides allocating ops wouldn't guarantee alias freedom
@@ -279,19 +280,19 @@ static std::optional<double> getAdditionalComputeFraction(
AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth,
ArrayRef<ComputationSliceState> depthSliceUnions, int64_t &sliceCost,
int64_t &fusedLoopNestComputeCost) {
- LLVM_DEBUG(llvm::dbgs() << "Determining additional compute fraction...\n";);
+ LDBG() << "Determining additional compute fraction...";
// Compute cost of sliced and unsliced src loop nest.
// Walk src loop nest and collect stats.
LoopNestStats srcLoopNestStats;
if (!getLoopNestStats(srcForOp, &srcLoopNestStats)) {
- LLVM_DEBUG(llvm::dbgs() << "Failed to get source loop nest stats.\n");
+ LDBG() << "Failed to get source loop nest stats.";
return std::nullopt;
}
// Compute cost of dst loop nest.
LoopNestStats dstLoopNestStats;
if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) {
- LLVM_DEBUG(llvm::dbgs() << "Failed to get destination loop nest stats.\n");
+ LDBG() << "Failed to get destination loop nest stats.";
return std::nullopt;
}
@@ -304,14 +305,14 @@ static std::optional<double> getAdditionalComputeFraction(
const ComputationSliceState &slice = depthSliceUnions[depth - 1];
// Skip slice union if it wasn't computed for this depth.
if (slice.isEmpty()) {
- LLVM_DEBUG(llvm::dbgs() << "Slice wasn't computed.\n");
+ LDBG() << "Slice wasn't computed.";
return std::nullopt;
}
if (!getFusionComputeCost(srcForOp, srcLoopNestStats, dstForOp,
dstLoopNestStats, slice,
&fusedLoopNestComputeCost)) {
- LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n");
+ LDBG() << "Unable to compute fusion compute cost";
return std::nullopt;
}
@@ -348,9 +349,8 @@ static Value createPrivateMemRef(AffineForOp forOp,
MemRefAccess bM(cast<AffineWriteOpInterface>(b));
return aM == bM;
})) {
- LLVM_DEBUG(llvm::dbgs()
- << "Private memref creation unsupported for multiple producer "
- "stores with different access functions.\n");
+ LDBG() << "Private memref creation unsupported for multiple producer "
+ << "stores with different access functions.";
return nullptr;
}
@@ -455,8 +455,7 @@ static Value createPrivateMemRef(AffineForOp forOp,
assert(succeeded(res) &&
"replaceAllMemrefUsesWith should always succeed here");
(void)res;
- LLVM_DEBUG(llvm::dbgs() << "Created private memref of type: " << newMemRefType
- << '\n');
+ LDBG() << "Created private memref of type: " << newMemRefType;
return newMemRef;
}
@@ -505,15 +504,12 @@ static bool isFusionProfitable(AffineForOp srcForOp,
unsigned maxLegalFusionDepth,
unsigned *dstLoopDepth,
double computeToleranceThreshold) {
- LLVM_DEBUG({
- llvm::dbgs()
- << "Checking whether fusion is profitable between source nest:\n";
- llvm::dbgs() << ' ' << srcForOp << " and destination nest:\n";
- llvm::dbgs() << dstForOp << "\n";
- });
+ LDBG() << "Checking whether fusion is profitable between source nest:";
+ LDBG() << ' ' << srcForOp << " and destination nest:";
+ LDBG() << dstForOp;
if (maxLegalFusionDepth == 0) {
- LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth is 0\n");
+ LDBG() << "Can't fuse: maxLegalFusionDepth is 0";
return false;
}
@@ -537,8 +533,8 @@ static bool isFusionProfitable(AffineForOp srcForOp,
// TODO: Suppport multiple producer stores in profitability
// analysis.
if (producerStores.size() > 1) {
- LLVM_DEBUG(llvm::dbgs() << "Limited profitability analysis. Not "
- "supported for multiple producer store case.\n");
+ LDBG() << "Limited profitability analysis. Not "
+ << "supported for multiple producer store case.";
int64_t sliceCost;
int64_t fusedLoopNestComputeCost;
// We will still fuse if fusion obeys the specified compute
@@ -547,12 +543,11 @@ static bool isFusionProfitable(AffineForOp srcForOp,
srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost,
fusedLoopNestComputeCost);
if (!fraction || fraction > computeToleranceThreshold) {
- LLVM_DEBUG(llvm::dbgs() << "Additional computation exceeds "
- "compute tolerance. Not fusing.\n");
+ LDBG() << "Additional computation exceeds "
+ << "compute tolerance. Not fusing.";
return false;
}
- LLVM_DEBUG(llvm::dbgs()
- << "Considering fusion profitable at max legal depth.\n");
+ LDBG() << "Considering fusion profitable at max legal depth.";
return true;
}
@@ -574,8 +569,7 @@ static bool isFusionProfitable(AffineForOp srcForOp,
// Compute src loop nest write region size.
MemRefRegion srcWriteRegion(srcStoreOp->getLoc());
if (failed(srcWriteRegion.compute(srcStoreOp, /*loopDepth=*/0))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Unable to compute MemRefRegion for source operation\n");
+ LDBG() << "Unable to compute MemRefRegion for source operation";
return false;
}
@@ -609,8 +603,7 @@ static bool isFusionProfitable(AffineForOp srcForOp,
getAdditionalComputeFraction(srcForOp, dstForOp, i, depthSliceUnions,
sliceCost, fusedLoopNestComputeCost);
if (!mayAdditionalComputeFraction) {
- LLVM_DEBUG(llvm::dbgs()
- << "Can't determine additional compute fraction.\n");
+ LDBG() << "Can't determine additional compute fraction.";
continue;
}
double additionalComputeFraction = *mayAdditionalComputeFraction;
@@ -620,9 +613,7 @@ static bool isFusionProfitable(AffineForOp srcForOp,
// depth 'i'.
MemRefRegion sliceWriteRegion(srcStoreOp->getLoc());
if (failed(sliceWriteRegion.compute(srcStoreOp, /*loopDepth=*/0, &slice))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Failed to compute slice write region at loopDepth: " << i
- << "\n");
+ LDBG() << "Failed to compute slice write region at loopDepth: " << i;
continue;
}
@@ -630,9 +621,7 @@ static bool isFusionProfitable(AffineForOp srcForOp,
sliceWriteRegion.getRegionSize();
if (!maybeSliceWriteRegionSizeBytes.has_value() ||
*maybeSliceWriteRegionSizeBytes == 0) {
- LLVM_DEBUG(llvm::dbgs()
- << "Failed to get slice write region size at loopDepth: " << i
- << "\n");
+ LDBG() << "Failed to get slice write region size at loopDepth: " << i;
continue;
}
int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
@@ -649,9 +638,8 @@ static bool isFusionProfitable(AffineForOp srcForOp,
<< " storage reduction factor: " << storageReduction << "x\n"
<< " fused nest cost: " << fusedLoopNestComputeCost << "\n"
<< " src write region size: " << srcWriteRegionSizeBytes << "\n"
- << " slice write region size: " << sliceWriteRegionSizeBytes
- << "\n";
- llvm::dbgs() << msg.str();
+ << " slice write region size: " << sliceWriteRegionSizeBytes;
+ LDBG() << msg.str();
});
// TODO: This is a placeholder cost model.
@@ -670,28 +658,24 @@ static bool isFusionProfitable(AffineForOp srcForOp,
// A simple cost model: fuse if it reduces the memory footprint.
if (!bestDstLoopDepth) {
- LLVM_DEBUG(
- llvm::dbgs()
- << "All fusion choices involve more than the threshold amount of "
- "redundant computation; NOT fusing.\n");
+ LDBG() << "All fusion choices involve more than the threshold amount of "
+ << "redundant computation; NOT fusing.";
return false;
}
if (!bestDstLoopDepth) {
- LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n");
+ LDBG() << "no fusion depth could be evaluated.";
return false;
}
// Set dstLoopDepth based on best values from search.
*dstLoopDepth = *bestDstLoopDepth;
- LLVM_DEBUG(
- llvm::dbgs() << " LoopFusion fusion stats:"
- << "\n best loop depth: " << bestDstLoopDepth
- << "\n src loop nest compute cost: " << srcLoopNestCost
- << "\n dst loop nest compute cost: " << dstLoopNestCost
- << "\n fused loop nest compute cost: "
- << minFusedLoopNestComputeCost << "\n");
+ LDBG() << " LoopFusion fusion stats:";
+ LDBG() << " best loop depth: " << bestDstLoopDepth;
+ LDBG() << " src loop nest compute cost: " << srcLoopNestCost;
+ LDBG() << " dst loop nest compute cost: " << dstLoopNestCost;
+ LDBG() << " fused loop nest compute cost: " << minFusedLoopNestComputeCost;
auto dstMemSize = getMemoryFootprintBytes(dstForOp);
auto srcMemSize = getMemoryFootprintBytes(srcForOp);
@@ -699,8 +683,7 @@ static bool isFusionProfitable(AffineForOp srcForOp,
std::optional<double> storageReduction;
if (!dstMemSize || !srcMemSize) {
- LLVM_DEBUG(llvm::dbgs()
- << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
+ LDBG() << " fusion memory benefit cannot be evaluated; NOT fusing.";
return false;
}
@@ -710,13 +693,13 @@ static bool isFusionProfitable(AffineForOp srcForOp,
assert(sliceMemEstimate && "expected value");
auto fusedMem = dstMemSizeVal + *sliceMemEstimate;
- LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
- << " dst mem: " << dstMemSizeVal << "\n"
- << " fused mem: " << fusedMem << "\n"
- << " slice mem: " << sliceMemEstimate << "\n");
+ LDBG() << " src mem: " << srcMemSizeVal;
+ LDBG() << " dst mem: " << dstMemSizeVal;
+ LDBG() << " fused mem: " << fusedMem;
+ LDBG() << " slice mem: " << sliceMemEstimate;
if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
- LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
+ LDBG() << "Fusion is not profitable; NOT fusing.";
return false;
}
storageReduction =
@@ -734,8 +717,8 @@ static bool isFusionProfitable(AffineForOp srcForOp,
<< std::setprecision(2) << additionalComputeFraction
<< "% redundant computation and a ";
msg << (storageReduction ? std::to_string(*storageReduction) : "<unknown>");
- msg << "% storage reduction.\n";
- llvm::dbgs() << msg.str();
+ msg << "% storage reduction.";
+ LDBG() << msg.str();
});
return true;
@@ -895,7 +878,7 @@ public:
/// No fusion is performed when producers with a user count greater than
/// `maxSrcUserCount` for any of the memrefs involved.
void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) {
- LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
+ LDBG() << "Evaluating dst loop " << dstId;
// Skip if this node was removed (fused into another node).
if (mdg->nodes.count(dstId) == 0)
return;
@@ -909,7 +892,7 @@ public:
if (dstNode->op->getNumResults() > 0)
return;
- LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
+ LDBG() << "Evaluating dst loop " << dstId;
// Sink sequential loops in 'dstNode' (and thus raise parallel loops)
// while preserving relative order. This can increase the maximum loop
@@ -936,18 +919,14 @@ public:
auto *srcNode = mdg->getNode(srcId);
auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
- LLVM_DEBUG(llvm::dbgs()
- << "Trying to fuse producer loop nest " << srcId
- << " with consumer loop nest " << dstId << "\n");
- LLVM_DEBUG(llvm::dbgs() << "Compute tolerance threshold: "
- << computeToleranceThreshold << '\n');
- LLVM_DEBUG(llvm::dbgs()
- << "Producer loop nest:\n"
- << *srcNode->op << "\n and consumer loop nest:\n"
- << *dstNode->op << '\n');
+ LDBG() << "Trying to fuse producer loop nest " << srcId
+ << " with consumer loop nest " << dstId;
+ LDBG() << "Compute tolerance threshold: " << computeToleranceThreshold;
+ LDBG() << "Producer loop nest:";
+ LDBG() << *srcNode->op << " and consumer loop nest:";
+ LDBG() << *dstNode->op;
- LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId
- << " for dst loop " << dstId << "\n");
+ LDBG() << "Evaluating src loop " << srcId << " for dst loop " << dstId;
// Skip if 'srcNode' is a loop nest returning values.
// TODO: support loop nests that return values.
@@ -1018,19 +997,16 @@ public:
&depthSliceUnions[i - 1], strategy);
if (result.value == FusionResult::Success) {
maxLegalFusionDepth = i;
- LLVM_DEBUG(llvm::dbgs()
- << "Found valid slice for depth: " << i << '\n');
+ LDBG() << "Found valid slice for depth: " << i;
}
}
if (maxLegalFusionDepth == 0) {
- LLVM_DEBUG(llvm::dbgs()
- << "Can't fuse: fusion is not legal at any depth\n");
+ LDBG() << "Can't fuse: fusion is not legal at any depth";
continue;
}
- LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
- << maxLegalFusionDepth << '\n');
+ LDBG() << "Max legal depth for fusion: " << maxLegalFusionDepth;
double computeToleranceThresholdToUse = computeToleranceThreshold;
@@ -1040,7 +1016,7 @@ public:
// producer-consumer memref access for example). Check this and allow
// fusion accordingly.
if (hasCyclicDependence(srcAffineForOp)) {
- LLVM_DEBUG(llvm::dbgs() << "Source nest has a cyclic dependence.\n");
+ LDBG() << "Source nest has a cyclic dependence.";
// Maximal fusion does not check for compute tolerance threshold; so
// perform the maximal fusion only when the redundanation computation
// is zero.
@@ -1053,18 +1029,15 @@ public:
srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
sliceCost, fusedLoopNestComputeCost);
if (!fraction || fraction > 0) {
- LLVM_DEBUG(
- llvm::dbgs()
- << "Can't perform maximal fusion with a cyclic dependence "
- "and non-zero additional compute.\n");
+ LDBG() << "Can't perform maximal fusion with a cyclic dependence "
+ << "and non-zero additional compute.";
return;
}
} else {
// Set redundant computation tolerance to zero regardless of what
// the user specified. Without this, fusion would be invalid.
- LLVM_DEBUG(llvm::dbgs()
- << "Setting compute tolerance to zero since "
- "source has a cylic dependence.\n");
+ LDBG() << "Setting compute tolerance to zero since "
+ << "source has a cylic dependence.";
computeToleranceThresholdToUse = 0;
}
}
@@ -1107,8 +1080,7 @@ public:
if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
removeSrcNode)) {
// Create a private version of this memref.
- LLVM_DEBUG(llvm::dbgs()
- << "Creating private memref for " << memref << '\n');
+ LDBG() << "Creating private memref for " << memref;
// Create a private version of this memref.
privateMemrefs.insert(memref);
}
@@ -1118,10 +1090,9 @@ public:
fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
dstNodeChanged = true;
- LLVM_DEBUG(llvm::dbgs()
- << "Fused src loop " << srcId << " into dst loop " << dstId
- << " at depth " << bestDstLoopDepth << ":\n"
- << dstAffineForOp << "\n");
+ LDBG() << "Fused src loop " << srcId << " into dst loop " << dstId
+ << " at depth " << bestDstLoopDepth << ":";
+ LDBG() << dstAffineForOp;
// Move 'dstAffineForOp' before 'insertPointInst' if needed.
if (fusedLoopInsPoint != dstAffineForOp)
@@ -1179,8 +1150,7 @@ public:
dstLoopCollector.memrefFrees);
if (removeSrcNode) {
- LLVM_DEBUG(llvm::dbgs()
- << "Removing src loop " << srcId << " after fusion\n");
+ LDBG() << "Removing src loop " << srcId << " after fusion";
// srcNode is no longer valid after it is removed from mdg.
srcAffineForOp.erase();
mdg->removeNode(srcId);
@@ -1195,7 +1165,7 @@ public:
/// user count greater than `maxSrcUserCount` for any of the memrefs involved
/// are encountered.
void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
- LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
+ LDBG() << "--- Producer/Consumer Fusion ---";
init();
while (!worklist.empty()) {
unsigned dstId = worklist.back();
@@ -1207,7 +1177,7 @@ public:
// Visits each node in the graph, and for each node, attempts to fuse it with
// its sibling nodes (nodes which share a parent, but no dependence edges).
void fuseSiblingNodes() {
- LLVM_DEBUG(llvm::dbgs() << "--- Sibling Fusion ---\n");
+ LDBG() << "--- Sibling Fusion ---";
init();
while (!worklist.empty()) {
unsigned dstId = worklist.back();
@@ -1289,8 +1259,7 @@ public:
maxLegalFusionDepth = i;
}
- LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
- << maxLegalFusionDepth << '\n');
+ LDBG() << "Max legal depth for fusion: " << maxLegalFusionDepth;
// Skip if fusion is not feasible at any loop depths.
if (maxLegalFusionDepth == 0)
@@ -1304,7 +1273,7 @@ public:
// producer-consumer memref access for example). Check this and allow
// fusion accordingly.
if (hasCyclicDependence(sibAffineForOp)) {
- LLVM_DEBUG(llvm::dbgs() << "Source nest has a cyclic dependence.\n");
+ LDBG() << "Source nest has a cyclic dependence.";
// Maximal fusion does not check for compute tolerance threshold; so
// perform the maximal fusion only when the redundanation computation is
// zero.
@@ -1316,17 +1285,15 @@ public:
sibAffineForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
sliceCost, fusedLoopNestComputeCost);
if (!fraction || fraction > 0) {
- LLVM_DEBUG(
- llvm::dbgs()
- << "Can't perform maximal fusion with a cyclic dependence "
- "and non-zero additional compute.\n");
+ LDBG() << "Can't perform maximal fusion with a cyclic dependence "
+ << "and non-zero additional compute.";
return;
}
} else {
// Set redundant computation tolerance to zero regardless of what the
// user specified. Without this, fusion would be invalid.
- LLVM_DEBUG(llvm::dbgs() << "Setting compute tolerance to zero since "
- "source has a cyclic dependence.\n");
+ LDBG() << "Setting compute tolerance to zero since "
+ << "source has a cyclic dependence.";
computeToleranceThresholdToUse = 0.0;
}
}
@@ -1356,8 +1323,7 @@ public:
// slice is used in the destination.
auto isMaximal = bestSlice.isMaximal();
if (!isMaximal.value_or(false)) {
- LLVM_DEBUG(llvm::dbgs()
- << "Slice isn't maximal; not performing sibling fusion.\n");
+ LDBG() << "Slice isn't maximal; not performing sibling fusion.";
continue;
}
@@ -1374,10 +1340,9 @@ public:
if (insertPointInst != dstForInst)
dstForInst->moveBefore(insertPointInst);
- LLVM_DEBUG(llvm::dbgs()
- << "Fused sibling nest " << sibId << " into destination nest "
- << dstNode->id << " at depth " << bestDstLoopDepth << ":\n"
- << dstAffineForOp << "\n");
+ LDBG() << "Fused sibling nest " << sibId << " into destination nest "
+ << dstNode->id << " at depth " << bestDstLoopDepth << ":";
+ LDBG() << dstAffineForOp;
// Update data dependence graph state post fusion.
updateStateAfterSiblingFusion(sibNode, dstNode);
@@ -1555,7 +1520,7 @@ public:
void LoopFusion::runOnBlock(Block *block) {
MemRefDependenceGraph g(*block);
if (!g.init()) {
- LLVM_DEBUG(llvm::dbgs() << "MDG init failed\n");
+ LDBG() << "MDG init failed";
return;
}
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
index 41cd739..c6abb0d 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
@@ -251,20 +252,20 @@ FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp,
FusionStrategy fusionStrategy) {
// Return 'failure' if 'dstLoopDepth == 0'.
if (dstLoopDepth == 0) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n");
+ LDBG() << "Cannot fuse loop nests at depth 0";
return FusionResult::FailPrecondition;
}
// Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block.
auto *block = srcForOp->getBlock();
if (block != dstForOp->getBlock()) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n");
+ LDBG() << "Cannot fuse loop nests in different blocks";
return FusionResult::FailPrecondition;
}
// Return 'failure' if no valid insertion point for fused loop nest in 'block'
// exists which would preserve dependences.
if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) {
- LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n");
+ LDBG() << "Fusion would violate dependences in block";
return FusionResult::FailBlockDependence;
}
@@ -277,14 +278,14 @@ FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp,
// Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'.
SmallVector<Operation *, 4> opsA;
if (!gatherLoadsAndStores(forOpA, opsA)) {
- LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
+ LDBG() << "Fusing loops with affine.if unsupported";
return FusionResult::FailPrecondition;
}
// Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'.
SmallVector<Operation *, 4> opsB;
if (!gatherLoadsAndStores(forOpB, opsB)) {
- LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
+ LDBG() << "Fusing loops with affine.if unsupported";
return FusionResult::FailPrecondition;
}
@@ -296,7 +297,7 @@ FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp,
// TODO: 'getMaxLoopDepth' does not support forward slice fusion.
assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion");
if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) {
- LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n");
+ LDBG() << "Fusion would violate loop dependences";
return FusionResult::FailFusionDependence;
}
}
@@ -339,12 +340,12 @@ FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp,
strategyOpsA, opsB, dstLoopDepth, numCommonLoops,
isSrcForOpBeforeDstForOp, srcSlice);
if (sliceComputationResult.value == SliceComputationResult::GenericFailure) {
- LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
+ LDBG() << "computeSliceUnion failed";
return FusionResult::FailPrecondition;
}
if (sliceComputationResult.value ==
SliceComputationResult::IncorrectSliceFailure) {
- LLVM_DEBUG(llvm::dbgs() << "Incorrect slice computation\n");
+ LDBG() << "Incorrect slice computation";
return FusionResult::FailIncorrectSlice;
}
@@ -477,7 +478,7 @@ bool mlir::affine::getLoopNestStats(AffineForOp forOpRoot,
auto *parentForOp = forOp->getParentOp();
if (forOp != forOpRoot) {
if (!isa<AffineForOp>(parentForOp)) {
- LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n");
+ LDBG() << "Expected parent AffineForOp";
return WalkResult::interrupt();
}
// Add mapping to 'forOp' from its parent AffineForOp.
@@ -498,7 +499,7 @@ bool mlir::affine::getLoopNestStats(AffineForOp forOpRoot,
std::optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
if (!maybeConstTripCount) {
// Currently only constant trip count loop nests are supported.
- LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n");
+ LDBG() << "Non-constant trip count unsupported";
return WalkResult::interrupt();
}
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 2de057d..cd216ef 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -21,9 +21,11 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
@@ -365,12 +367,11 @@ checkIfHyperRectangular(MutableArrayRef<AffineForOp> input) {
if (input.size() <= 1)
return success();
if (failed(getIndexSet(ops, &cst))) {
- LLVM_DEBUG(llvm::dbgs() << "Index set computation failed!\n");
+ LDBG() << "Index set computation failed!";
return failure();
}
if (!cst.isHyperRectangular(0, input.size())) {
- LLVM_DEBUG(llvm::dbgs()
- << "Non-hyperrectangular nests not supported for tiling!\n");
+ LDBG() << "Non-hyperrectangular nests not supported for tiling!";
return failure();
}
return success();
@@ -385,14 +386,13 @@ static LogicalResult performPreTilingChecks(MutableArrayRef<AffineForOp> input,
if (llvm::any_of(input,
[](AffineForOp op) { return op.getNumResults() > 0; })) {
- LLVM_DEBUG(llvm::dbgs()
- << "Cannot tile nest where a loop has yield values\n");
+ LDBG() << "Cannot tile nest where a loop has yield values";
return failure();
}
// Check if the supplied `for` ops are all successively nested.
if (!isPerfectlyNested(input)) {
- LLVM_DEBUG(llvm::dbgs() << "input loops not perfectly nested");
+ LDBG() << "input loops not perfectly nested";
return failure();
}
@@ -1098,7 +1098,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
// If the trip count is lower than the unroll jam factor, no unroll jam.
if (mayBeConstantTripCount && *mayBeConstantTripCount < unrollJamFactor) {
- LLVM_DEBUG(llvm::dbgs() << "[failed] trip count < unroll-jam factor\n");
+ LDBG() << "[failed] trip count < unroll-jam factor";
return failure();
}
@@ -1339,6 +1339,15 @@ bool mlir::affine::isValidLoopInterchangePermutation(
unsigned maxLoopDepth = loops.size();
if (maxLoopDepth == 1)
return true;
+
+ // We cannot guarantee the validity of the interchange if the loops have
+ // iter_args, since the dependence analysis does not take them into account.
+ // Conservatively return false in such cases.
+ if (llvm::any_of(loops, [](AffineForOp loop) {
+ return loop.getNumIterOperands() > 0;
+ }))
+ return false;
+
// Gather dependence components for dependences between all ops in loop nest
// rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth].
std::vector<SmallVector<DependenceComponent, 2>> depCompsVec;
@@ -1766,9 +1775,7 @@ findHighestBlockForPlacement(const MemRefRegion &region, Block &block,
// We can't hoist past the definition of the memref being copied.
Value memref = region.memref;
if (!memref.getParentRegion()->isAncestor(enclosingOp->getParentRegion())) {
- LLVM_DEBUG(
- llvm::dbgs()
- << "memref definition will end up not dominating hoist location\n");
+ LDBG() << "memref definition will end up not dominating hoist location";
break;
}
@@ -1977,7 +1984,7 @@ static LogicalResult generateCopy(
auto memRefType = cast<MemRefType>(memref.getType());
if (!memRefType.getLayout().isIdentity()) {
- LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
+ LDBG() << "Non-identity layout map not yet supported";
return failure();
}
@@ -1989,7 +1996,7 @@ static LogicalResult generateCopy(
unsigned rank = memRefType.getRank();
if (rank == 0) {
- LLVM_DEBUG(llvm::dbgs() << "Non-zero ranked memrefs supported\n");
+ LDBG() << "Non-zero ranked memrefs supported";
return failure();
}
@@ -2001,19 +2008,18 @@ static LogicalResult generateCopy(
std::optional<int64_t> numElements =
region.getConstantBoundingSizeAndShape(&fastBufferShape, &lbs);
if (!numElements) {
- LLVM_DEBUG(llvm::dbgs() << "Non-constant region size not supported\n");
+ LDBG() << "Non-constant region size not supported";
return failure();
}
if (llvm::any_of(lbs, [](AffineMap lb) { return lb.getNumResults() > 1; })) {
// This can be supported in the future if needed.
- LLVM_DEBUG(llvm::dbgs()
- << "Max lower bound for memref region start not supported\n");
+ LDBG() << "Max lower bound for memref region start not supported";
return failure();
}
if (*numElements == 0) {
- LLVM_DEBUG(llvm::dbgs() << "Nothing to copy\n");
+ LDBG() << "Nothing to copy";
return success();
}
@@ -2021,9 +2027,8 @@ static LogicalResult generateCopy(
for (unsigned i = 0; i < rank; ++i) {
region.getLowerAndUpperBound(i, lbMaps[i], ubMaps[i]);
if (lbMaps[i].getNumResults() == 0 || ubMaps[i].getNumResults() == 0) {
- LLVM_DEBUG(llvm::dbgs()
- << "Missing lower or upper bound for region along dimension: "
- << i << '\n');
+ LDBG() << "Missing lower or upper bound for region along dimension: "
+ << i;
return failure();
}
}
@@ -2122,7 +2127,7 @@ static LogicalResult generateCopy(
// TODO: use all stride levels once DmaStartOp is extended for
// multi-level strides.
if (dmaStrideInfos.size() > 1) {
- LLVM_DEBUG(llvm::dbgs() << "Only up to one level of stride supported\n");
+ LDBG() << "Only up to one level of stride supported";
return failure();
}
@@ -2309,10 +2314,11 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
// surrounding the this block range.
unsigned copyDepth = getNestingDepth(&*begin);
- LLVM_DEBUG(llvm::dbgs() << "Generating copies at depth " << copyDepth
- << "\n");
- LLVM_DEBUG(llvm::dbgs() << "from begin: " << *begin << "\n");
- LLVM_DEBUG(llvm::dbgs() << "to inclusive end: " << *std::prev(end) << "\n");
+ LDBG() << "Generating copies at depth " << copyDepth;
+ LDBG() << "from begin: "
+ << OpWithFlags(&*begin, OpPrintingFlags().skipRegions());
+ LDBG() << "to inclusive end: "
+ << OpWithFlags(&*std::prev(end), OpPrintingFlags().skipRegions());
// List of memory regions to copy for. We need a map vector to have a
// guaranteed iteration order to write test cases. CHECK-DAG doesn't help here
@@ -2349,8 +2355,8 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
return;
if (!memref.getParentRegion()->isAncestor(block->getParent())) {
- LLVM_DEBUG(llvm::dbgs() << "memref definition is inside of the depth at "
- "which copy-in/copy-out would happen\n");
+ LDBG() << "memref definition is inside of the depth at "
+ << "which copy-in/copy-out would happen";
return;
}
@@ -2358,12 +2364,10 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
auto region = std::make_unique<MemRefRegion>(opInst->getLoc());
if (failed(region->compute(opInst, copyDepth, /*sliceState=*/nullptr,
/*addMemRefDimBounds=*/false))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Error obtaining memory region: semi-affine maps?\n");
- LLVM_DEBUG(llvm::dbgs() << "over-approximating to the entire memref\n");
+ LDBG() << "Error obtaining memory region: semi-affine maps?";
+ LDBG() << "over-approximating to the entire memref";
if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) {
- LLVM_DEBUG(
- opInst->emitError("non-constant memref sizes not yet supported"));
+ LDBG() << "non-constant memref sizes not yet supported";
error = true;
return;
}
@@ -2392,13 +2396,11 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
// Perform a union with the existing region.
if (failed(it->second->unionBoundingBox(*region))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Memory region bounding box failed; "
- "over-approximating to the entire memref\n");
+ LDBG() << "Memory region bounding box failed; "
+ << "over-approximating to the entire memref";
// If the union fails, we will overapproximate.
if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) {
- LLVM_DEBUG(opInst->emitError(
- "non-constant memref sizes not yet supported"));
+ LDBG() << "non-constant memref sizes not yet supported";
error = true;
return true;
}
@@ -2428,8 +2430,7 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
});
if (error) {
- LLVM_DEBUG(begin->emitError(
- "copy generation failed for one or more memref's in this block\n"));
+ LDBG() << "copy generation failed for one or more memref's in this block";
return failure();
}
@@ -2466,8 +2467,7 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
processRegions(writeRegions);
if (!ret) {
- LLVM_DEBUG(begin->emitError(
- "copy generation failed for one or more memref's in this block\n"));
+ LDBG() << "copy generation failed for one or more memref's in this block";
return failure();
}
@@ -2608,7 +2608,7 @@ static AffineIfOp createSeparationCondition(MutableArrayRef<AffineForOp> loops,
/*boundFloorDivisor=*/nullptr,
/*ub=*/nullptr, &fullTileLbPos,
&fullTileUbPos)) {
- LLVM_DEBUG(llvm::dbgs() << "Can't get constant diff pair for a loop\n");
+ LDBG() << "Can't get constant diff pair for a loop";
return nullptr;
}
@@ -2667,8 +2667,7 @@ createFullTiles(MutableArrayRef<AffineForOp> inputNest,
for (auto loop : inputNest) {
// TODO: straightforward to generalize to a non-unit stride.
if (loop.getStepAsInt() != 1) {
- LLVM_DEBUG(llvm::dbgs()
- << "[tile separation] non-unit stride not implemented\n");
+ LDBG() << "[tile separation] non-unit stride not implemented";
return failure();
}
SmallVector<Operation *, 1> loopOp{loop.getOperation()};
@@ -2682,8 +2681,8 @@ createFullTiles(MutableArrayRef<AffineForOp> inputNest,
/*boundFloorDivisor=*/nullptr,
/*ub=*/nullptr, &lbPos, &ubPos) ||
lbPos == ubPos) {
- LLVM_DEBUG(llvm::dbgs() << "[tile separation] Can't get constant diff / "
- "equalities not yet handled\n");
+ LDBG() << "[tile separation] Can't get constant diff / "
+ << "equalities not yet handled";
return failure();
}
@@ -2741,8 +2740,8 @@ mlir::affine::separateFullTiles(MutableArrayRef<AffineForOp> inputNest,
AffineIfOp ifOp = createSeparationCondition(inputNest, b);
if (!ifOp) {
fullTileLoops.front().erase();
- LLVM_DEBUG(llvm::dbgs() << "All tiles are full tiles, or failure creating "
- "separation condition\n");
+ LDBG() << "All tiles are full tiles, or failure creating "
+ << "separation condition";
return failure();
}
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 488c3c3..7d4d818 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2678,6 +2678,7 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
case AtomicRMWKind::addi:
case AtomicRMWKind::maxu:
case AtomicRMWKind::ori:
+ case AtomicRMWKind::xori:
return builder.getZeroAttr(resultType);
case AtomicRMWKind::andi:
return builder.getIntegerAttr(
@@ -2736,7 +2737,7 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
// Integer operations.
.Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
.Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
- .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
+ .Case([](arith::XOrIOp op) { return AtomicRMWKind::xori; })
.Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
.Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
.Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
@@ -2806,6 +2807,8 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
return arith::OrIOp::create(builder, loc, lhs, rhs);
case AtomicRMWKind::andi:
return arith::AndIOp::create(builder, loc, lhs, rhs);
+ case AtomicRMWKind::xori:
+ return arith::XOrIOp::create(builder, loc, lhs, rhs);
// TODO: Add remaining reduction operations.
default:
(void)emitOptionalError(loc, "Reduction operation type not supported");
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 93682a9..4780dbb 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -12,7 +12,7 @@ add_mlir_dialect_library(MLIRArithTransforms
UnsignedWhenEquivalent.cpp
ADDITIONAL_HEADER_DIRS
- {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arith/Transforms
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arith/Transforms
DEPENDS
MLIRArithTransformsIncGen
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
index 1aa8064..35365f2 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
@@ -158,13 +158,11 @@ protected:
PatternRewriter &rewriter) {
// Check iterator types for matrix multiplication.
SmallVector<vector::IteratorType> itTypes = op.getIteratorTypesArray();
- if (!((itTypes.size() == 3 &&
- (itTypes[0] == vector::IteratorType::parallel &&
- itTypes[1] == vector::IteratorType::parallel &&
- itTypes[2] == vector::IteratorType::reduction)) ||
- (itTypes.size() == 2 &&
- (itTypes[0] == vector::IteratorType::parallel &&
- itTypes[1] == vector::IteratorType::reduction))))
+ if ((itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+ itTypes[1] != vector::IteratorType::parallel ||
+ itTypes[2] != vector::IteratorType::reduction) &&
+ (itTypes.size() != 2 || itTypes[0] != vector::IteratorType::parallel ||
+ itTypes[1] != vector::IteratorType::reduction))
return rewriter.notifyMatchFailure(
op, "iterator types do not correspond to matrix multiplication");
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
index 35b0bd1..6cb2a56 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
@@ -183,9 +183,9 @@ protected:
Value acc;
// Conventional names for matrix dimensions.
- int64_t M = 0;
- int64_t N = 0;
- int64_t K = 0;
+ int64_t m = 0;
+ int64_t n = 0;
+ int64_t k = 0;
// Create the matrix mulitply and accumulate operation according to
// `mmlaOp`.
@@ -286,41 +286,41 @@ Value VectorContractRewriter::lower(vector::ContractionOp op,
// Single-dimension vector type for the entire RHS tile.
- auto flatRhsTileType = VectorType::get(/*shape=*/K * N, operandEltType,
+ auto flatRhsTileType = VectorType::get(/*shape=*/k * n, operandEltType,
/*scalableDims=*/{true});
// Vector type having the same number of elements as a row in the
// accumulator/output tile and the same element type.
- auto accRowTy = VectorType::get(/*shape=*/N, resultEltType,
+ auto accRowTy = VectorType::get(/*shape=*/n, resultEltType,
/*scalableDims=*/{true});
// Vector type having twice the number of elements as a row in the
// accumulator/output tile the same element type.
- auto accRowX2Ty = VectorType::get(/*shape=*/2 * N, resultEltType,
+ auto accRowX2Ty = VectorType::get(/*shape=*/2 * n, resultEltType,
/*scalableDims=*/{true});
// Vector type having half the number of elements as a row in the
// accumulator/output tile and an integer element type with twice the bit
// width.
- auto accRow64Ty = VectorType::get(/*shape=*/N / 2, rewriter.getI64Type(),
+ auto accRow64Ty = VectorType::get(/*shape=*/n / 2, rewriter.getI64Type(),
/*scalableDims=*/{true});
// Vector type having the same the number of elements as a row in the
// accumulator/output tile and an integer element type with twice the bit
// width.
- auto accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(),
+ auto accRowX264Ty = VectorType::get(/*shape=*/n, rewriter.getI64Type(),
/*scalableDims=*/{true});
Location loc = op.getLoc();
// Extract LHS sub-tiles with logical shape <2xK>.
SmallVector<Value> lhsTile;
- for (int64_t i = 0; i < M; i += 2) {
+ for (int64_t i = 0; i < m; i += 2) {
// Extract two consecutive rows of the LHS tile.
auto r0 =
vector::ExtractOp::create(rewriter, loc, lhs, ArrayRef<int64_t>{i});
auto r1 =
vector::ExtractOp::create(rewriter, loc, lhs, ArrayRef<int64_t>{i + 1});
// Concatenate to obtain a 2 x K x <input-type> flattened sub-tile.
- SmallVector<int64_t> shuffleIdx(2 * K);
+ SmallVector<int64_t> shuffleIdx(2 * k);
std::iota(shuffleIdx.begin(), shuffleIdx.end(), 0);
auto t = vector::ShuffleOp::create(rewriter, loc, r0, r1, shuffleIdx);
// Turn it into a scalable vector.
@@ -337,13 +337,13 @@ Value VectorContractRewriter::lower(vector::ContractionOp op,
// Extract the RHS sub-tiles with logical shape <Kx[2]>.
SmallVector<Value> rhsTile;
- for (int64_t j = 0; j < N; j += 2)
+ for (int64_t j = 0; j < n; j += 2)
rhsTile.push_back(vector::ScalableExtractOp::create(
- rewriter, loc, flatRhsType, rhs, j * K));
+ rewriter, loc, flatRhsType, rhs, j * k));
// Extract and pack the ACC sub-tiles.
SmallVector<Value> accTile;
- for (int64_t i = 0; i < M; i += 2) {
+ for (int64_t i = 0; i < m; i += 2) {
// Extract two consecutive rows of the accumulator tile.
auto r0 = vector::ExtractOp::create(rewriter, loc, op.getAcc(),
ArrayRef<int64_t>{i});
@@ -370,28 +370,28 @@ Value VectorContractRewriter::lower(vector::ContractionOp op,
vector::BitCastOp::create(rewriter, loc, accRowX2Ty, intrI64);
}
// Extract ACC sub-tiles.
- for (int64_t j = 0; j < N; j += 2)
+ for (int64_t j = 0; j < n; j += 2)
accTile.push_back(vector::ScalableExtractOp::create(
rewriter, loc, flatAccType, accTileVec, j * 2));
}
// Emit sub-tile matrix multiplications.
SmallVector<Value> outTile;
- for (int64_t i = 0; i < M / 2; ++i)
- for (int64_t j = 0; j < N / 2; ++j) {
- Value mmla = createMMLA(rewriter, loc, accTile[i * N / 2 + j], lhsTile[i],
+ for (int64_t i = 0; i < m / 2; ++i)
+ for (int64_t j = 0; j < n / 2; ++j) {
+ Value mmla = createMMLA(rewriter, loc, accTile[i * n / 2 + j], lhsTile[i],
rhsTile[j]);
outTile.push_back(mmla);
}
// Unpack the OUT sub-tiles and insert into the result.
Value result = ub::PoisonOp::create(rewriter, loc, op.getResultType());
- for (int64_t i = 0; i < M / 2; ++i) {
+ for (int64_t i = 0; i < m / 2; ++i) {
// Collect a number of sub-tiles in a row.
Value row = ub::PoisonOp::create(rewriter, loc, accRowX2Ty);
- for (int64_t j = 0; j < N / 2; ++j)
+ for (int64_t j = 0; j < n / 2; ++j)
row = vector::ScalableInsertOp::create(
- rewriter, loc, outTile[i * N / 2 + j], row, j * 4);
+ rewriter, loc, outTile[i * n / 2 + j], row, j * 4);
// Unpack the row to obtain two rows of the output. If we have the out
// sub-tiles transposed we obtain two consecutive output rows by
@@ -432,9 +432,9 @@ public:
VectorType lhsType = op.getLhsType();
VectorType rhsType = op.getRhsType();
- M = lhsType.getDimSize(0);
- N = rhsType.getDimSize(0);
- K = rhsType.getDimSize(1);
+ m = lhsType.getDimSize(0);
+ n = rhsType.getDimSize(0);
+ k = rhsType.getDimSize(1);
// Check the operands have the expected shape:
// * for LHS: fixed vector MxK
@@ -442,8 +442,8 @@ public:
// * K == 8
// * M and N even and at least 2
if (lhsType.isScalable() || !rhsType.getScalableDims()[0] ||
- rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != K || K != 8 ||
- M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 ||
+ rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != k || k != 8 ||
+ m < 2 || m % 2 != 0 || n < 2 || n % 2 != 0 ||
!rhsType.getScalableDims()[0])
return rewriter.notifyMatchFailure(op, "non-matching operand shape");
@@ -504,9 +504,9 @@ public:
VectorType lhsType = op.getLhsType();
VectorType rhsType = op.getRhsType();
- M = lhsType.getDimSize(0);
- N = rhsType.getDimSize(0);
- K = rhsType.getDimSize(1);
+ m = lhsType.getDimSize(0);
+ n = rhsType.getDimSize(0);
+ k = rhsType.getDimSize(1);
// Check the operands have the expected shape:
// * for LHS: fixed vector MxK
@@ -514,8 +514,8 @@ public:
// * K == 4
// * M and N even and at least 2
if (lhsType.isScalable() || !rhsType.getScalableDims()[0] ||
- rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != K || K != 4 ||
- M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 ||
+ rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != k || k != 4 ||
+ m < 2 || m % 2 != 0 || n < 2 || n % 2 != 0 ||
!rhsType.getScalableDims()[0])
return rewriter.notifyMatchFailure(op, "non-matching operand shape");
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
index ddc64ea..91e37dd 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
@@ -248,7 +248,7 @@ LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) {
Region *definingRegion = value.getParentRegion();
// Last users of the `value` inside all blocks where the value dies.
- llvm::SmallSet<Operation *, 4> lastUsers;
+ llvm::SmallPtrSet<Operation *, 4> lastUsers;
// Find blocks in the `definingRegion` that have users of the `value` (if
// there are multiple users in the block, which one will be selected is
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index f1f12f4..56ff212 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -463,8 +463,12 @@ struct SimplifyClones : public OpRewritePattern<CloneOp> {
// which otherwise could prevent removal of unnecessary allocs.
Value canonicalSource = source;
while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
- canonicalSource.getDefiningOp()))
+ canonicalSource.getDefiningOp())) {
+ if (canonicalSource != iface.getViewDest()) {
+ break;
+ }
canonicalSource = iface.getViewSource();
+ }
std::optional<Operation *> maybeCloneDeallocOp =
memref::findDealloc(cloneOp.getOutput());
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 8916526..a465c95 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -37,8 +37,12 @@ using namespace mlir::bufferization;
/// Given a memref value, return the "base" value by skipping over all
/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
static Value getViewBase(Value value) {
- while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
+ while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>()) {
+ if (value != viewLikeOp.getViewDest()) {
+ break;
+ }
value = viewLikeOp.getViewSource();
+ }
return value;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 8f983ab..0b2e080 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -121,7 +121,7 @@ void BufferViewFlowAnalysis::build(Operation *op) {
// Add additional dependencies created by view changes to the alias list.
if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
registerDependencies(viewInterface.getViewSource(),
- viewInterface->getResult(0));
+ viewInterface.getViewDest());
return WalkResult::advance();
}
@@ -231,8 +231,12 @@ static bool isFunctionArgument(Value v) {
/// Given a memref value, return the "base" value by skipping over all
/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
static Value getViewBase(Value value) {
- while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
+ while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>()) {
+ if (value != viewLikeOp.getViewDest()) {
+ break;
+ }
value = viewLikeOp.getViewSource();
+ }
return value;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 91f6f25..68ef519 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -20,6 +20,7 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/PassManager.h"
+#include "llvm/Support/DebugLog.h"
#include <optional>
namespace mlir {
@@ -328,20 +329,16 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
"blocks");
// Bufferize the op.
- LLVM_DEBUG(llvm::dbgs()
- << "//===-------------------------------------------===//\n"
- << "IR after bufferizing: " << nextOp->getName() << "\n");
+ LDBG(3) << "//===-------------------------------------------===//\n"
+ << "IR after bufferizing: " << nextOp->getName();
rewriter.setInsertionPoint(nextOp);
if (failed(
bufferizableOp.bufferize(rewriter, options, bufferizationState))) {
- LLVM_DEBUG(llvm::dbgs()
- << "failed to bufferize\n"
- << "//===-------------------------------------------===//\n");
+ LDBG(2) << "failed to bufferize\n"
+ << "//===-------------------------------------------===//";
return nextOp->emitError("failed to bufferize op");
}
- LLVM_DEBUG(llvm::dbgs()
- << *op
- << "\n//===-------------------------------------------===//\n");
+ LDBG(3) << *op << "\n//===-------------------------------------------===//";
}
// Return early if the top-level op is entirely gone.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index a8e8353..fb7f2bb 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -56,6 +56,7 @@
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/DebugLog.h"
MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::OneShotAnalysisState)
@@ -616,13 +617,10 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
if (getParallelRegion(def.getParentRegion(), options) !=
getParallelRegion(uConflictingWrite->getOwner()->getParentRegion(),
options)) {
- LLVM_DEBUG(
- llvm::dbgs()
- << "\n- bufferizes out-of-place due to parallel region:\n");
- LLVM_DEBUG(llvm::dbgs()
- << " unConflictingWrite = operand "
- << uConflictingWrite->getOperandNumber() << " of "
- << *uConflictingWrite->getOwner() << "\n");
+ LDBG() << "\n- bufferizes out-of-place due to parallel region:\n"
+ << " unConflictingWrite = operand "
+ << uConflictingWrite->getOperandNumber() << " of "
+ << *uConflictingWrite->getOwner();
return true;
}
}
@@ -631,9 +629,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
for (OpOperand *uRead : usesRead) {
Operation *readingOp = uRead->getOwner();
- LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n");
- LLVM_DEBUG(llvm::dbgs() << " uRead = operand " << uRead->getOperandNumber()
- << " of " << *readingOp << "\n");
+ LDBG() << "\n- check conflict:\n"
+ << " uRead = operand " << uRead->getOperandNumber() << " of "
+ << *readingOp;
// Find the definition of uRead by following the SSA use-def chain.
// E.g.:
@@ -648,23 +646,22 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
const SetVector<Value> &definitions = state.findDefinitionsCached(uRead);
if (definitions.empty()) {
// Fast path: No conflict if there are no definitions.
- LLVM_DEBUG(llvm::dbgs()
- << " no conflict: read value has no definitions\n");
+ LDBG() << " no conflict: read value has no definitions";
continue;
}
// Look for conflicting memory writes. Potential conflicts are writes to an
// alias that have been decided to bufferize inplace.
for (OpOperand *uConflictingWrite : usesWrite) {
- LLVM_DEBUG(llvm::dbgs() << " unConflictingWrite = operand "
- << uConflictingWrite->getOperandNumber() << " of "
- << *uConflictingWrite->getOwner() << "\n");
+ LDBG() << " unConflictingWrite = operand "
+ << uConflictingWrite->getOperandNumber() << " of "
+ << *uConflictingWrite->getOwner();
// Check if op dominance can be used to rule out read-after-write
// conflicts.
bool useDominance =
canUseOpDominance(uRead, uConflictingWrite, definitions, state);
- LLVM_DEBUG(llvm::dbgs() << "\n- useDominance = " << useDominance << "\n");
+ LDBG() << "\n- useDominance = " << useDominance;
// Throughout this loop, check for multiple requirements that have to be
// met for uConflictingWrite to be an actual conflict.
@@ -680,8 +677,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// inside a loop), there may be no meaningful `happensBefore`
// relationship.
if (happensBefore(readingOp, conflictingWritingOp, domInfo)) {
- LLVM_DEBUG(llvm::dbgs()
- << " no conflict: read happens before write\n");
+ LDBG() << " no conflict: read happens before write";
continue;
}
@@ -693,8 +689,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// Note: If the op is executed multiple times (e.g., because it is
// inside a loop), it may be conflicting with itself.
if (uConflictingWrite == uRead) {
- LLVM_DEBUG(llvm::dbgs()
- << " no conflict: read and write are same use\n");
+ LDBG() << " no conflict: read and write are same use";
continue;
}
@@ -705,8 +700,8 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// multiple times.
if (state.insideMutuallyExclusiveRegions(readingOp,
conflictingWritingOp)) {
- LLVM_DEBUG(llvm::dbgs() << " no conflict: read and write are in "
- "mutually exclusive regions\n");
+ LDBG() << " no conflict: read and write are in "
+ "mutually exclusive regions";
continue;
}
@@ -721,9 +716,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
state, uRead, uConflictingWrite->get()) ||
hasEquivalentValueInReverseUseDefChain(
state, uConflictingWrite, uRead->get())) {
- LLVM_DEBUG(
- llvm::dbgs()
- << " no conflict: op bufferizes to element-wise access\n");
+ LDBG() << " no conflict: op bufferizes to element-wise access";
continue;
}
}
@@ -733,15 +726,14 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// No conflict if the operands are non-conflicting subsets.
if (areNonConflictingSubsets(uRead, uConflictingWrite, state)) {
- LLVM_DEBUG(llvm::dbgs() << " no conflict: non-conflicting subsets\n");
+ LDBG() << " no conflict: non-conflicting subsets";
continue;
}
// No conflict if the op interface says so.
if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) {
- LLVM_DEBUG(llvm::dbgs()
- << " no conflict: op interace of reading op says 'no'\n");
+ LDBG() << " no conflict: op interace of reading op says 'no'";
continue;
}
}
@@ -751,9 +743,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
options.dynCastBufferizableOp(conflictingWritingOp)) {
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
state)) {
- LLVM_DEBUG(
- llvm::dbgs()
- << " no conflict: op interace of writing op says 'no'\n");
+ LDBG() << " no conflict: op interace of writing op says 'no'";
continue;
}
}
@@ -761,29 +751,26 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// Check all possible definitions.
for (Value definition : definitions) {
- LLVM_DEBUG(llvm::dbgs() << " * definition = " << definition << "\n");
+ LDBG() << " * definition = " << definition;
// No conflict if the conflicting write happens before the definition.
if (Operation *defOp = definition.getDefiningOp()) {
if (happensBefore(conflictingWritingOp, defOp, domInfo)) {
// conflictingWritingOp happens before defOp. No conflict.
- LLVM_DEBUG(llvm::dbgs()
- << " no conflict: write happens before definition\n");
+ LDBG() << " no conflict: write happens before definition";
continue;
}
// No conflict if conflictingWritingOp is contained in defOp.
if (defOp->isProperAncestor(conflictingWritingOp)) {
- LLVM_DEBUG(
- llvm::dbgs()
- << " no conflict: write is contained in definition\n");
+ LDBG() << " no conflict: write is contained in definition";
continue;
}
} else {
auto bbArg = cast<BlockArgument>(definition);
Block *block = bbArg.getOwner();
if (!block->findAncestorOpInBlock(*conflictingWritingOp)) {
- LLVM_DEBUG(llvm::dbgs() << " no conflict: definition is bbArg "
- "and write happens outside of block\n");
+ LDBG() << " no conflict: definition is bbArg "
+ "and write happens outside of block";
// conflictingWritingOp happens outside of the block. No
// conflict.
continue;
@@ -795,8 +782,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
AliasingValueList aliases = state.getAliasingValues(*uConflictingWrite);
if (aliases.getNumAliases() == 1 &&
aliases.getAliases()[0].value == definition) {
- LLVM_DEBUG(llvm::dbgs()
- << " no conflict: definition and write are same\n");
+ LDBG() << " no conflict: definition and write are same";
continue;
}
@@ -804,7 +790,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
if (options.printConflicts)
annotateConflict(uRead, uConflictingWrite, definition);
- LLVM_DEBUG(llvm::dbgs() << " => RaW CONFLICT FOUND\n");
+ LDBG() << " => RaW CONFLICT FOUND";
return true;
}
}
@@ -958,7 +944,7 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
for (AliasingValue alias : state.getAliasingValues(operand))
state.applyOnAliases(alias.value, checkReadOnly);
if (foundReadOnly) {
- LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n");
+ LDBG() << "=> NOT WRITABLE";
return true;
}
@@ -987,10 +973,9 @@ void OneShotAnalysisState::resetCache() {
static LogicalResult
bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state,
const DominanceInfo &domInfo) {
- LLVM_DEBUG(
- llvm::dbgs() << "//===-------------------------------------------===//\n"
- << "Analyzing operand #" << operand.getOperandNumber()
- << " of " << *operand.getOwner() << "\n");
+ LDBG() << "//===-------------------------------------------===//\n"
+ << "Analyzing operand #" << operand.getOperandNumber() << " of "
+ << *operand.getOwner();
bool foundInterference =
wouldCreateWriteToNonWritableBuffer(operand, state) ||
@@ -1001,8 +986,7 @@ bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state,
else
state.bufferizeInPlace(operand);
- LLVM_DEBUG(llvm::dbgs()
- << "//===-------------------------------------------===//\n");
+ LDBG() << "//===-------------------------------------------===//";
return success();
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index 725fa24..b593cca 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -51,14 +51,8 @@ static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
/// Return "true" if the given op is guaranteed to have neither "Allocate" nor
/// "Free" side effects.
static bool hasNeitherAllocateNorFreeSideEffect(Operation *op) {
- if (isa<MemoryEffectOpInterface>(op))
- return !hasEffect<MemoryEffects::Allocate>(op) &&
- !hasEffect<MemoryEffects::Free>(op);
- // If the op does not implement the MemoryEffectOpInterface but has has
- // recursive memory effects, then this op in isolation (without its body) does
- // not have any side effects. All the ops inside the regions of this op will
- // be processed separately.
- return op->hasTrait<OpTrait::HasRecursiveMemoryEffects>();
+ return !mightHaveEffect<MemoryEffects::Allocate>(op) &&
+ !mightHaveEffect<MemoryEffects::Free>(op);
}
/// Return "true" if the given op has buffer semantics. I.e., it has buffer
@@ -517,9 +511,7 @@ LogicalResult BufferDeallocation::verifyOperationPreconditions(Operation *op) {
// MemoryEffectOpInterface. They usually do not have side effects apart
// from the callee, which will be analyzed separately. (This is similar to
// "recursive memory effects".)
- if (!isa<MemoryEffectOpInterface>(op) &&
- !op->hasTrait<OpTrait::HasRecursiveMemoryEffects>() &&
- !isa<CallOpInterface>(op))
+ if (hasUnknownEffects(op) && !isa<CallOpInterface>(op))
return op->emitError(
"ops with unknown memory side effects are not supported");
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
index 37b4cfc..47740d3 100644
--- a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
@@ -3,7 +3,7 @@ add_mlir_dialect_library(MLIRControlFlowTransforms
BufferizableOpInterfaceImpl.cpp
ADDITIONAL_HEADER_DIRS
- {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms
LINK_LIBS PUBLIC
MLIRBufferizationDialect
diff --git a/mlir/lib/Dialect/DLTI/Traits.cpp b/mlir/lib/Dialect/DLTI/Traits.cpp
index 34f2dd5..3f6dd29 100644
--- a/mlir/lib/Dialect/DLTI/Traits.cpp
+++ b/mlir/lib/Dialect/DLTI/Traits.cpp
@@ -24,7 +24,7 @@ LogicalResult mlir::impl::verifyHasDefaultDLTIDataLayoutTrait(Operation *op) {
}
DataLayoutSpecInterface mlir::impl::getDataLayoutSpec(Operation *op) {
- return op->getAttrOfType<DataLayoutSpecAttr>(
+ return op->getAttrOfType<DataLayoutSpecInterface>(
DLTIDialect::kDataLayoutAttrName);
}
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index e79da92..00ce3b5 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -131,6 +131,12 @@ bool mlir::emitc::isPointerWideType(Type type) {
type);
}
+bool mlir::emitc::isFundamentalType(Type type) {
+ return llvm::isa<IndexType>(type) || isPointerWideType(type) ||
+ isSupportedIntegerType(type) || isSupportedFloatType(type) ||
+ isa<emitc::PointerType>(type);
+}
+
/// Check that the type of the initial value is compatible with the operations
/// result type.
static LogicalResult verifyInitializationAttribute(Operation *op,
@@ -375,6 +381,52 @@ OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
// ExpressionOp
//===----------------------------------------------------------------------===//
+ParseResult ExpressionOp::parse(OpAsmParser &parser, OperationState &result) {
+ SmallVector<OpAsmParser::UnresolvedOperand> operands;
+ if (parser.parseOperandList(operands))
+ return parser.emitError(parser.getCurrentLocation()) << "expected operands";
+ if (succeeded(parser.parseOptionalKeyword("noinline")))
+ result.addAttribute(ExpressionOp::getDoNotInlineAttrName(result.name),
+ parser.getBuilder().getUnitAttr());
+ Type type;
+ if (parser.parseColonType(type))
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected function type");
+ auto fnType = llvm::dyn_cast<FunctionType>(type);
+ if (!fnType)
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected function type");
+ if (parser.resolveOperands(operands, fnType.getInputs(),
+ parser.getCurrentLocation(), result.operands))
+ return failure();
+ if (fnType.getNumResults() != 1)
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected single return type");
+ result.addTypes(fnType.getResults());
+ Region *body = result.addRegion();
+ SmallVector<OpAsmParser::Argument> argsInfo;
+ for (auto [unresolvedOperand, operandType] :
+ llvm::zip(operands, fnType.getInputs())) {
+ OpAsmParser::Argument argInfo;
+ argInfo.ssaName = unresolvedOperand;
+ argInfo.type = operandType;
+ argsInfo.push_back(argInfo);
+ }
+ if (parser.parseRegion(*body, argsInfo, /*enableNameShadowing=*/true))
+ return failure();
+ return success();
+}
+
+void emitc::ExpressionOp::print(OpAsmPrinter &p) {
+ p << ' ';
+ p.printOperands(getDefs());
+ p << " : ";
+ p.printFunctionalType(getOperation());
+ p.shadowRegionArgs(getRegion(), getDefs());
+ p << ' ';
+ p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
+}
+
Operation *ExpressionOp::getRootOp() {
auto yieldOp = cast<YieldOp>(getBody()->getTerminator());
Value yieldedValue = yieldOp.getResult();
@@ -1395,6 +1447,7 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
//===----------------------------------------------------------------------===//
// FieldOp
//===----------------------------------------------------------------------===//
+
static void printEmitCFieldOpTypeAndInitialValue(OpAsmPrinter &p, FieldOp op,
TypeAttr type,
Attribute initialValue) {
@@ -1452,6 +1505,15 @@ LogicalResult FieldOp::verify() {
//===----------------------------------------------------------------------===//
// GetFieldOp
//===----------------------------------------------------------------------===//
+
+LogicalResult GetFieldOp::verify() {
+ auto parentClassOp = getOperation()->getParentOfType<emitc::ClassOp>();
+ if (!parentClassOp.getOperation())
+ return emitOpError(" must be nested within an emitc.class operation");
+
+ return success();
+}
+
LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
mlir::FlatSymbolRefAttr fieldNameAttr = getFieldNameAttr();
FieldOp fieldOp =
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index 3f0690c..f8469b8 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -9,7 +9,9 @@
#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/STLExtras.h"
namespace mlir {
namespace emitc {
@@ -24,20 +26,24 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
Location loc = op->getLoc();
builder.setInsertionPointAfter(op);
- auto expressionOp = emitc::ExpressionOp::create(builder, loc, resultType);
+ auto expressionOp =
+ emitc::ExpressionOp::create(builder, loc, resultType, op->getOperands());
// Replace all op's uses with the new expression's result.
result.replaceAllUsesWith(expressionOp.getResult());
- // Create an op to yield op's value.
- Region &region = expressionOp.getRegion();
- Block &block = region.emplaceBlock();
+ Block &block = expressionOp.createBody();
+ IRMapping mapper;
+ for (auto [operand, arg] :
+ llvm::zip(expressionOp.getOperands(), block.getArguments()))
+ mapper.map(operand, arg);
builder.setInsertionPointToEnd(&block);
- auto yieldOp = emitc::YieldOp::create(builder, loc, result);
- // Move op into the new expression.
- op->moveBefore(yieldOp);
+ Operation *rootOp = builder.clone(*op, mapper);
+ op->erase();
+ // Create an op to yield op's value.
+ emitc::YieldOp::create(builder, loc, rootOp->getResults()[0]);
return expressionOp;
}
@@ -53,51 +59,93 @@ struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
using OpRewritePattern<ExpressionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExpressionOp expressionOp,
PatternRewriter &rewriter) const override {
- bool anythingFolded = false;
- for (Operation &op : llvm::make_early_inc_range(
- expressionOp.getBody()->without_terminator())) {
- // Don't fold expressions whose result value has its address taken.
- auto applyOp = dyn_cast<emitc::ApplyOp>(op);
- if (applyOp && applyOp.getApplicableOperator() == "&")
- continue;
-
- for (Value operand : op.getOperands()) {
- auto usedExpression = operand.getDefiningOp<ExpressionOp>();
- if (!usedExpression)
- continue;
-
- // Don't fold expressions with multiple users: assume any
- // re-materialization was done separately.
- if (!usedExpression.getResult().hasOneUse())
- continue;
-
- // Don't fold expressions with side effects.
- if (usedExpression.hasSideEffects())
- continue;
-
- // Fold the used expression into this expression by cloning all
- // instructions in the used expression just before the operation using
- // its value.
- rewriter.setInsertionPoint(&op);
- IRMapping mapper;
- for (Operation &opToClone :
- usedExpression.getBody()->without_terminator()) {
- Operation *clone = rewriter.clone(opToClone, mapper);
- mapper.map(&opToClone, clone);
- }
-
- Operation *expressionRoot = usedExpression.getRootOp();
- Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot);
- assert(clonedExpressionRootOp &&
- "Expected cloned expression root to be in mapper");
- assert(clonedExpressionRootOp->getNumResults() == 1 &&
- "Expected cloned root to have a single result");
-
- rewriter.replaceOp(usedExpression, clonedExpressionRootOp);
- anythingFolded = true;
- }
+ Block *expressionBody = expressionOp.getBody();
+ ExpressionOp usedExpression;
+ SetVector<Value> foldedOperands;
+
+ auto takesItsOperandsAddress = [](Operation *user) {
+ auto applyOp = dyn_cast<emitc::ApplyOp>(user);
+ return applyOp && applyOp.getApplicableOperator() == "&";
+ };
+
+ // Select as expression to fold the first operand expression that
+ // - doesn't have its result value's address taken,
+ // - has a single user: assume any re-materialization was done separately,
+ // - has no side effects,
+ // and save all other operands to be used later as operands in the folded
+ // expression.
+ for (auto [operand, arg] : llvm::zip(expressionOp.getOperands(),
+ expressionBody->getArguments())) {
+ ExpressionOp operandExpression = operand.getDefiningOp<ExpressionOp>();
+ if (usedExpression || !operandExpression ||
+ llvm::any_of(arg.getUsers(), takesItsOperandsAddress) ||
+ !operandExpression.getResult().hasOneUse() ||
+ operandExpression.hasSideEffects())
+ foldedOperands.insert(operand);
+ else
+ usedExpression = operandExpression;
}
- return anythingFolded ? success() : failure();
+
+ // If no operand expression was selected, bail out.
+ if (!usedExpression)
+ return failure();
+
+ // Collect additional operands from the folded expression.
+ for (Value operand : usedExpression.getOperands())
+ foldedOperands.insert(operand);
+
+ // Create a new expression to hold the folding result.
+ rewriter.setInsertionPointAfter(expressionOp);
+ auto foldedExpression = emitc::ExpressionOp::create(
+ rewriter, expressionOp.getLoc(), expressionOp.getResult().getType(),
+ foldedOperands.getArrayRef(), expressionOp.getDoNotInline());
+ Block &foldedExpressionBody = foldedExpression.createBody();
+
+ // Map each operand of the new expression to its matching block argument.
+ IRMapping mapper;
+ for (auto [operand, arg] : llvm::zip(foldedExpression.getOperands(),
+ foldedExpressionBody.getArguments()))
+ mapper.map(operand, arg);
+
+ // Prepare to fold the used expression and the matched expression into the
+ // newly created folded expression.
+ auto foldExpression = [&rewriter, &mapper](ExpressionOp expressionToFold,
+ bool withTerminator) {
+ Block *expressionToFoldBody = expressionToFold.getBody();
+ for (auto [operand, arg] :
+ llvm::zip(expressionToFold.getOperands(),
+ expressionToFoldBody->getArguments())) {
+ mapper.map(arg, mapper.lookup(operand));
+ }
+
+ for (Operation &opToClone : expressionToFoldBody->without_terminator())
+ rewriter.clone(opToClone, mapper);
+
+ if (withTerminator)
+ rewriter.clone(*expressionToFoldBody->getTerminator(), mapper);
+ };
+ rewriter.setInsertionPointToStart(&foldedExpressionBody);
+
+ // First, fold the used expression into the new expression and map its
+ // result to the clone of its root operation within the new expression.
+ foldExpression(usedExpression, /*withTerminator=*/false);
+ Operation *expressionRoot = usedExpression.getRootOp();
+ Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot);
+ assert(clonedExpressionRootOp &&
+ "Expected cloned expression root to be in mapper");
+ assert(clonedExpressionRootOp->getNumResults() == 1 &&
+ "Expected cloned root to have a single result");
+ mapper.map(usedExpression.getResult(),
+ clonedExpressionRootOp->getResults()[0]);
+
+ // Now fold the matched expression into the new expression.
+ foldExpression(expressionOp, /*withTerminator=*/true);
+
+ // Complete the rewrite.
+ rewriter.replaceOp(expressionOp, foldedExpression);
+ rewriter.eraseOp(usedExpression);
+
+ return success();
}
};
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index c55e26e..06d7e07 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -64,8 +64,8 @@ public:
TypeAttr typeAttr = TypeAttr::get(val.getType());
fields.push_back({fieldName, typeAttr});
- FieldOp fieldop = rewriter.create<emitc::FieldOp>(
- funcOp->getLoc(), fieldName, typeAttr, nullptr);
+ FieldOp fieldop = emitc::FieldOp::create(rewriter, funcOp->getLoc(),
+ fieldName, typeAttr, nullptr);
if (argAttrs && idx < argAttrs->size()) {
fieldop->setDiscardableAttrs(funcOp.getArgAttrDict(idx));
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 2503ccb..b87b4f4 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2486,8 +2486,7 @@ LogicalResult WarpExecuteOnLane0Op::verify() {
if (getArgs().size() != getWarpRegion().getNumArguments())
return emitOpError(
"expected same number op arguments and block arguments.");
- auto yield =
- cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = getTerminator();
if (yield.getNumOperands() != getNumResults())
return emitOpError(
"expected same number of yield operands and return values.");
@@ -2511,6 +2510,50 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
}
+gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
+ return cast<gpu::YieldOp>(getBody()->getTerminator());
+}
+
+//===----------------------------------------------------------------------===//
+// GPU_SubgroupBroadcastOp
+//===----------------------------------------------------------------------===//
+
+void gpu::SubgroupBroadcastOp::inferResultRanges(
+ ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
+ setResultRange(getResult(), argRanges.front());
+}
+
+Speculation::Speculatability gpu::SubgroupBroadcastOp::getSpeculatability() {
+ switch (getBroadcastType()) {
+ case BroadcastType::first_active_lane:
+ // Cannot speculate first_lane broadcast, because speculating it across
+ // control flow can change the active lanes.
+ return Speculation::NotSpeculatable;
+ case BroadcastType::any_lane:
+ LLVM_FALLTHROUGH;
+ case BroadcastType::specific_lane:
+ // Speculation should be safe as long as we inside structured control flow.
+ return Speculation::Speculatable;
+ }
+}
+
+LogicalResult gpu::SubgroupBroadcastOp::verify() {
+ switch (getBroadcastType()) {
+ case BroadcastType::first_active_lane:
+ LLVM_FALLTHROUGH;
+ case BroadcastType::any_lane:
+ if (getLane())
+ return emitOpError()
+ << "lane can only be specified for `specific_lane` broadcast";
+ return success();
+ case BroadcastType::specific_lane:
+ if (!getLane())
+ return emitOpError()
+ << "lane must be specified for `specific_lane` broadcast";
+ return success();
+ }
+}
+
//===----------------------------------------------------------------------===//
// GPU KernelMetadataAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 21cb2f6..c766539 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -13,6 +13,7 @@
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/TransformOps/Utils.h"
@@ -43,6 +44,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/InterleavedRange.h"
#include "llvm/Support/LogicalResult.h"
+#include <optional>
#include <type_traits>
using namespace mlir;
@@ -170,7 +172,16 @@ void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) {
void transform::ApplyGPUPromoteShuffleToAMDGPUPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- populateGpuPromoteShuffleToAMDGPUPatterns(patterns);
+ std::optional<StringRef> chipsetName = getChipset();
+ std::optional<amdgpu::Chipset> maybeChipset;
+ if (chipsetName) {
+ FailureOr<amdgpu::Chipset> parsedChipset =
+ amdgpu::Chipset::parse(*chipsetName);
+ assert(llvm::succeeded(parsedChipset) && "expected valid chipset");
+ maybeChipset = parsedChipset;
+ }
+
+ populateGpuPromoteShuffleToAMDGPUPatterns(patterns, maybeChipset);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
index 9bf11c7..d2c2138 100644
--- a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
@@ -25,6 +25,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
namespace mlir {
#define GEN_PASS_DEF_GPUELIMINATEBARRIERS
@@ -37,9 +38,6 @@ using namespace mlir::gpu;
#define DEBUG_TYPE "gpu-erase-barriers"
#define DEBUG_TYPE_ALIAS "gpu-erase-barries-alias"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
-
// The functions below provide interface-like verification, but are too specific
// to barrier elimination to become interfaces.
@@ -424,27 +422,18 @@ static bool maybeCaptured(Value v) {
/// everything. This seems sufficient to achieve barrier removal in structured
/// control flow, more complex cases would require a proper dataflow analysis.
static bool mayAlias(Value first, Value second) {
- DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, {
- DBGS_ALIAS() << "checking aliasing between ";
- DBGS_ALIAS() << first << "\n";
- DBGS_ALIAS() << " and ";
- DBGS_ALIAS() << second << "\n";
- });
+ LDBG(DEBUG_TYPE_ALIAS, 1)
+ << "checking aliasing between " << first << " and " << second;
first = getBase(first);
second = getBase(second);
- DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, {
- DBGS_ALIAS() << "base ";
- DBGS_ALIAS() << first << "\n";
- DBGS_ALIAS() << " and ";
- DBGS_ALIAS() << second << "\n";
- });
+ LDBG(DEBUG_TYPE_ALIAS, 1) << "base " << first << " and " << second;
// Values derived from the same base memref do alias (unless we do a more
// advanced analysis to prove non-overlapping accesses).
if (first == second) {
- DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, DBGS_ALIAS() << "-> do alias!\n");
+ LDBG(DEBUG_TYPE_ALIAS, 1) << "-> do alias!";
return true;
}
@@ -493,7 +482,7 @@ static bool mayAlias(Value first, Value second) {
return false;
// Otherwise, conservatively assume aliasing.
- DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, DBGS_ALIAS() << "-> may alias!\n");
+ LDBG(DEBUG_TYPE_ALIAS, 1) << "-> may alias!";
return true;
}
@@ -567,20 +556,16 @@ haveConflictingEffects(ArrayRef<MemoryEffects::EffectInstance> beforeEffects,
continue;
// Other kinds of effects create a conflict, e.g. read-after-write.
- LLVM_DEBUG(
- DBGS() << "found a conflict between (before): " << before.getValue()
- << " read:" << isa<MemoryEffects::Read>(before.getEffect())
- << " write:" << isa<MemoryEffects::Write>(before.getEffect())
- << " alloc:"
- << isa<MemoryEffects::Allocate>(before.getEffect()) << " free:"
- << isa<MemoryEffects::Free>(before.getEffect()) << "\n");
- LLVM_DEBUG(
- DBGS() << "and (after): " << after.getValue()
- << " read:" << isa<MemoryEffects::Read>(after.getEffect())
- << " write:" << isa<MemoryEffects::Write>(after.getEffect())
- << " alloc:" << isa<MemoryEffects::Allocate>(after.getEffect())
- << " free:" << isa<MemoryEffects::Free>(after.getEffect())
- << "\n");
+ LDBG() << "found a conflict between (before): " << before.getValue()
+ << " read:" << isa<MemoryEffects::Read>(before.getEffect())
+ << " write:" << isa<MemoryEffects::Write>(before.getEffect())
+ << " alloc:" << isa<MemoryEffects::Allocate>(before.getEffect())
+ << " free:" << isa<MemoryEffects::Free>(before.getEffect());
+ LDBG() << "and (after): " << after.getValue()
+ << " read:" << isa<MemoryEffects::Read>(after.getEffect())
+ << " write:" << isa<MemoryEffects::Write>(after.getEffect())
+ << " alloc:" << isa<MemoryEffects::Allocate>(after.getEffect())
+ << " free:" << isa<MemoryEffects::Free>(after.getEffect());
return true;
}
}
@@ -595,8 +580,8 @@ public:
LogicalResult matchAndRewrite(BarrierOp barrier,
PatternRewriter &rewriter) const override {
- LLVM_DEBUG(DBGS() << "checking the necessity of: " << barrier << " "
- << barrier.getLoc() << "\n");
+ LDBG() << "checking the necessity of: " << barrier << " "
+ << barrier.getLoc();
SmallVector<MemoryEffects::EffectInstance> beforeEffects;
getEffectsBefore(barrier, beforeEffects, /*stopAtBarrier=*/true);
@@ -605,14 +590,12 @@ public:
getEffectsAfter(barrier, afterEffects, /*stopAtBarrier=*/true);
if (!haveConflictingEffects(beforeEffects, afterEffects)) {
- LLVM_DEBUG(DBGS() << "the surrounding barriers are sufficient, removing "
- << barrier << "\n");
+ LDBG() << "the surrounding barriers are sufficient, removing " << barrier;
rewriter.eraseOp(barrier);
return success();
}
- LLVM_DEBUG(DBGS() << "barrier is necessary: " << barrier << " "
- << barrier.getLoc() << "\n");
+ LDBG() << "barrier is necessary: " << barrier << " " << barrier.getLoc();
return failure();
}
};
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index d4978ca..97adad6 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -431,8 +431,7 @@ private:
if (std::optional<SymbolTable::UseRange> symbolUses =
SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) {
for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
- StringRef symbolName =
- cast<FlatSymbolRefAttr>(symbolUse.getSymbolRef()).getValue();
+ StringAttr symbolName = symbolUse.getSymbolRef().getLeafReference();
if (symbolTable.lookup(symbolName))
continue;
diff --git a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
index 18c69f5..67cef8a 100644
--- a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
@@ -11,16 +11,21 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/PatternMatch.h"
+#include <optional>
using namespace mlir;
namespace {
+
+constexpr amdgpu::Chipset kGfx950 = amdgpu::Chipset(9, 5, 0);
+
/// Try to promote `gpu.shuffle` to `amdgpu.swizzle_bitmode`, width must be 64
/// and offset must be a constant integer in the range [0, 31].
struct PromoteShuffleToSwizzlePattern
@@ -56,9 +61,48 @@ struct PromoteShuffleToSwizzlePattern
return success();
}
};
+
+/// Try to promote `gpu.shuffle` to `amdgpu.permlane_swap`, width must be 64
+/// and offset must be a constant integer in the set {16, 32}.
+struct PromoteShuffleToPermlanePattern
+ : public OpRewritePattern<gpu::ShuffleOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(gpu::ShuffleOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getMode() != gpu::ShuffleMode::XOR)
+ return rewriter.notifyMatchFailure(op,
+ "only xor shuffle mode is supported");
+
+ if (!isConstantIntValue(op.getWidth(), 64))
+ return rewriter.notifyMatchFailure(op,
+ "only 64 width shuffle is supported");
+
+ std::optional<int64_t> offset = getConstantIntValue(op.getOffset());
+ if (!offset)
+ return rewriter.notifyMatchFailure(op,
+ "offset must be a constant integer");
+
+ int64_t offsetValue = *offset;
+ if (offsetValue != 16 && offsetValue != 32)
+ return rewriter.notifyMatchFailure(op, "offset must be either 15 or 31");
+
+ Location loc = op.getLoc();
+ Value res = amdgpu::PermlaneSwapOp::create(
+ rewriter, loc, op.getResult(0).getType(), op.getValue(), offsetValue);
+ Value valid = arith::ConstantIntOp::create(rewriter, loc, 1, /*width*/ 1);
+ rewriter.replaceOp(op, {res, valid});
+ return success();
+ }
+};
+
} // namespace
void mlir::populateGpuPromoteShuffleToAMDGPUPatterns(
- RewritePatternSet &patterns) {
- patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext());
+ RewritePatternSet &patterns, std::optional<amdgpu::Chipset> maybeChipset) {
+ patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext(),
+ /*benefit*/ 1);
+ if (maybeChipset && *maybeChipset >= kGfx950)
+ patterns.add<PromoteShuffleToPermlanePattern>(patterns.getContext(),
+ /*benefit*/ 2);
}
diff --git a/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp b/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp
index e9cf493..6da76e9 100644
--- a/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Target/LLVM/XeVM/Target.h"
#include "llvm/Support/Regex.h"
namespace mlir {
diff --git a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
index 384d1a0..88f531f 100644
--- a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
+++ b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Value.h"
+#include "llvm/ADT/DenseMap.h"
#include <numeric>
@@ -55,28 +56,30 @@ WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns(
SmallVector<size_t> &indices) const {
SmallVector<Type> types(warpOp.getResultTypes().begin(),
warpOp.getResultTypes().end());
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
- llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
- yield.getOperands().end());
+ gpu::YieldOp yield = warpOp.getTerminator();
+ SmallVector<Value> yieldValues(yield.getOperands().begin(),
+ yield.getOperands().end());
+ llvm::SmallDenseMap<Value, unsigned> indexLookup;
+ // Record the value -> first index mapping for faster lookup.
+ for (auto [i, v] : llvm::enumerate(yieldValues)) {
+ if (!indexLookup.count(v))
+ indexLookup[v] = i;
+ }
+
for (auto [value, type] : llvm::zip_equal(newYieldedValues, newReturnTypes)) {
- if (yieldValues.insert(value)) {
+ // If the value already exists in the yield, don't create a new output.
+ if (indexLookup.count(value)) {
+ indices.push_back(indexLookup[value]);
+ } else {
+ // If the value is new, add it to the yield and to the types.
+ yieldValues.push_back(value);
types.push_back(type);
indices.push_back(yieldValues.size() - 1);
- } else {
- // If the value already exit the region don't create a new output.
- for (auto [idx, yieldOperand] :
- llvm::enumerate(yieldValues.getArrayRef())) {
- if (yieldOperand == value) {
- indices.push_back(idx);
- break;
- }
- }
}
}
- yieldValues.insert_range(newYieldedValues);
+
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
- rewriter, warpOp, yieldValues.getArrayRef(), types);
+ rewriter, warpOp, yieldValues, types);
rewriter.replaceOp(warpOp,
newWarpOp.getResults().take_front(warpOp.getNumResults()));
return newWarpOp;
@@ -85,8 +88,7 @@ WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns(
OpOperand *WarpDistributionPattern::getWarpResult(
WarpExecuteOnLane0Op warpOp,
llvm::function_ref<bool(Operation *)> fn) const {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
for (OpOperand &yieldOperand : yield->getOpOperands()) {
Value yieldValues = yieldOperand.get();
Operation *definedOp = yieldValues.getDefiningOp();
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index ff55f17..ec581ac 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -32,6 +32,7 @@ add_mlir_dialect_library(MLIRLLVMDialect
MLIRInferTypeOpInterface
MLIRIR
MLIRMemorySlotInterfaces
+ MLIRPtrMemorySpaceInterfaces
MLIRSideEffectInterfaces
MLIRSupport
)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index 894de44..7220e10 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -12,10 +12,20 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/LogicalResult.h"
+#include "llvm/Support/Regex.h"
#define DEBUG_TYPE "ptx-builder"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
//===----------------------------------------------------------------------===//
// BasicPtxBuilderInterface
@@ -28,50 +38,122 @@ using namespace NVVM;
static constexpr int64_t kSharedMemorySpace = 3;
-static char getRegisterType(Type type) {
- if (type.isInteger(1))
- return 'b';
- if (type.isInteger(16))
- return 'h';
- if (type.isInteger(32))
- return 'r';
- if (type.isInteger(64))
- return 'l';
- if (type.isF32())
- return 'f';
- if (type.isF64())
- return 'd';
- if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
- // Shared address spaces is addressed with 32-bit pointers.
- if (ptr.getAddressSpace() == kSharedMemorySpace) {
+static FailureOr<char> getRegisterType(Type type, Location loc) {
+ MLIRContext *ctx = type.getContext();
+ auto i16 = IntegerType::get(ctx, 16);
+ auto i32 = IntegerType::get(ctx, 32);
+ auto f32 = Float32Type::get(ctx);
+
+ auto getRegisterTypeForScalar = [&](Type type) -> FailureOr<char> {
+ if (type.isInteger(1))
+ return 'b';
+ if (type.isInteger(16))
+ return 'h';
+ if (type.isInteger(32))
return 'r';
+ if (type.isInteger(64))
+ return 'l';
+ if (type.isF32())
+ return 'f';
+ if (type.isF64())
+ return 'd';
+ if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
+ // Shared address spaces is addressed with 32-bit pointers.
+ if (ptr.getAddressSpace() == kSharedMemorySpace) {
+ return 'r';
+ }
+ return 'l';
}
- return 'l';
+ // register type for struct is not supported.
+ mlir::emitError(
+ loc, "The register type could not be deduced from MLIR type. The ")
+ << type
+ << " is not supported. Supported types are:"
+ "i1, i16, i32, i64, f32, f64,"
+ "pointers.\nPlease use llvm.bitcast if you have different type. "
+ "\nSee the constraints from here: "
+ "https://docs.nvidia.com/cuda/inline-ptx-assembly/"
+ "index.html#constraints";
+ return failure();
+ };
+
+ // Packed registers
+ if (auto v = dyn_cast<VectorType>(type)) {
+ assert(v.getNumDynamicDims() == 0 && "Dynamic vectors are not supported");
+
+ int64_t lanes = v.getNumElements();
+ Type elem = v.getElementType();
+
+ // Case 1. Single vector
+ if (lanes <= 1)
+ return getRegisterTypeForScalar(elem);
+
+ // Case 2. Packed registers
+ Type widened = elem;
+ switch (lanes) {
+
+ case 2:
+ if (elem.isF16() || elem.isBF16()) // vector<2xf16>
+ widened = f32;
+ else if (elem.isFloat(8)) // vector<2xf8>
+ widened = i16;
+ break;
+ case 4:
+ if (elem.isInteger(8)) // vector<i8x4>
+ widened = i32;
+ else if (elem.isFloat(8)) // vector<f8x4>
+ widened = f32;
+ else if (elem.isFloat(4)) // vector<f4x4>
+ widened = i16;
+ break;
+ // Other packing is not supported
+ default:
+ break;
+ }
+ return getRegisterTypeForScalar(widened);
}
- // register type for struct is not supported.
- llvm_unreachable("The register type could not deduced from MLIR type");
- return '?';
+
+ return getRegisterTypeForScalar(type);
}
-static char getRegisterType(Value v) {
+static FailureOr<char> getRegisterType(Value v, Location loc) {
if (v.getDefiningOp<LLVM::ConstantOp>())
return 'n';
- return getRegisterType(v.getType());
+ return getRegisterType(v.getType(), loc);
}
-void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
- LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n");
+/// Extract every element of a struct value.
+static SmallVector<Value> extractStructElements(PatternRewriter &rewriter,
+ Location loc, Value structVal) {
+ auto structTy = dyn_cast<LLVM::LLVMStructType>(structVal.getType());
+ assert(structTy && "expected LLVM struct");
+
+ SmallVector<Value> elems;
+ for (unsigned i : llvm::seq<unsigned>(0, structTy.getBody().size()))
+ elems.push_back(rewriter.create<LLVM::ExtractValueOp>(loc, structVal, i));
+
+ return elems;
+}
+
+LogicalResult PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
+ LDBG() << v << "\t Modifier : " << itype << "\n";
+ registerModifiers.push_back(itype);
+
+ Location loc = interfaceOp->getLoc();
auto getModifier = [&]() -> const char * {
- if (itype == PTXRegisterMod::ReadWrite) {
- assert(false && "Read-Write modifier is not supported. Try setting the "
- "same value as Write and Read separately.");
- return "+";
- }
- if (itype == PTXRegisterMod::Write) {
+ switch (itype) {
+ case PTXRegisterMod::Read:
+ return "";
+ case PTXRegisterMod::Write:
return "=";
+ case PTXRegisterMod::ReadWrite:
+ // "Read-Write modifier is not actually supported
+ // Interface will change it to "=" later and add integer mapping
+ return "+";
}
- return "";
+ llvm_unreachable("Unknown PTX register modifier");
};
+
auto addValue = [&](Value v) {
if (itype == PTXRegisterMod::Read) {
ptxOperands.push_back(v);
@@ -90,35 +172,273 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
}
for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
if (itype != PTXRegisterMod::Write) {
- Value extractValue = LLVM::ExtractValueOp::create(
- rewriter, interfaceOp->getLoc(), v, idx);
+ Value extractValue =
+ LLVM::ExtractValueOp::create(rewriter, loc, v, idx);
addValue(extractValue);
}
if (itype == PTXRegisterMod::ReadWrite) {
ss << idx << ",";
} else {
- ss << getModifier() << getRegisterType(t) << ",";
+ FailureOr<char> regType = getRegisterType(t, loc);
+ if (failed(regType))
+ return rewriter.notifyMatchFailure(loc,
+ "failed to get register type");
+ ss << getModifier() << regType.value() << ",";
}
}
- return;
+ return success();
}
// Handle Scalars
addValue(v);
- ss << getModifier() << getRegisterType(v) << ",";
+ FailureOr<char> regType = getRegisterType(v, loc);
+ if (failed(regType))
+ return rewriter.notifyMatchFailure(loc, "failed to get register type");
+ ss << getModifier() << regType.value() << ",";
+ return success();
+}
+
+/// Check if the operation needs to pack and unpack results.
+static bool
+needsPackUnpack(BasicPtxBuilderInterface interfaceOp,
+ bool needsManualRegisterMapping,
+ SmallVectorImpl<PTXRegisterMod> &registerModifiers) {
+ if (needsManualRegisterMapping)
+ return false;
+ const unsigned writeOnlyVals = interfaceOp->getNumResults();
+ const unsigned readWriteVals =
+ llvm::count_if(registerModifiers, [](PTXRegisterMod m) {
+ return m == PTXRegisterMod::ReadWrite;
+ });
+ return (writeOnlyVals + readWriteVals) > 1;
+}
+
+/// Pack the result types of the interface operation.
+/// If the operation has multiple results, it packs them into a struct
+/// type. Otherwise, it returns the original result types.
+static SmallVector<Type>
+packResultTypes(BasicPtxBuilderInterface interfaceOp,
+ bool needsManualRegisterMapping,
+ SmallVectorImpl<PTXRegisterMod> &registerModifiers,
+ SmallVectorImpl<Value> &ptxOperands) {
+ MLIRContext *ctx = interfaceOp->getContext();
+ TypeRange resultRange = interfaceOp->getResultTypes();
+
+ if (!needsPackUnpack(interfaceOp, needsManualRegisterMapping,
+ registerModifiers)) {
+ // Single value path:
+ if (interfaceOp->getResults().size() == 1)
+ return SmallVector<Type>{resultRange.front()};
+
+ // No declared results: if there is an RW, forward its type.
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+ if (m == PTXRegisterMod::ReadWrite)
+ return SmallVector<Type>{v.getType()};
+ }
+
+ SmallVector<Type> packed;
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+ if (m == PTXRegisterMod::ReadWrite)
+ packed.push_back(v.getType());
+ for (Type t : resultRange)
+ packed.push_back(t);
+
+ if (packed.empty())
+ return {};
+
+ auto sTy = LLVM::LLVMStructType::getLiteral(ctx, packed, /*isPacked=*/false);
+ return SmallVector<Type>{sTy};
+}
+
+/// Canonicalize the register constraints:
+/// - Turn every "+X" into "=X"
+/// - Append (at the very end) the 0-based indices of tokens that were "+X"
+/// Examples:
+/// "+f,+f,+r,=r,=r,r,r" -> "=f,=f,=r,=r,=r,r,r,0,1,2"
+/// "+f,+f,+r,=r,=r" -> "=f,=f,=r,=r,=r,0,1,2"
+static std::string canonicalizeRegisterConstraints(llvm::StringRef csv) {
+ SmallVector<llvm::StringRef> toks;
+ SmallVector<std::string> out;
+ SmallVector<unsigned> plusIdx;
+
+ csv.split(toks, ',');
+ out.reserve(toks.size() + 8);
+
+ for (unsigned i = 0, e = toks.size(); i < e; ++i) {
+ StringRef t = toks[i].trim();
+ if (t.consume_front("+")) {
+ plusIdx.push_back(i);
+ out.push_back(("=" + t).str());
+ } else {
+ out.push_back(t.str());
+ }
+ }
+
+ // Append indices of original "+X" tokens.
+ for (unsigned idx : plusIdx)
+ out.push_back(std::to_string(idx));
+
+ // Join back to CSV.
+ std::string result;
+ result.reserve(csv.size() + plusIdx.size() * 2);
+ llvm::raw_string_ostream os(result);
+ for (size_t i = 0; i < out.size(); ++i) {
+ if (i)
+ os << ',';
+ os << out[i];
+ }
+ return os.str();
+}
+
+constexpr llvm::StringLiteral kReadWritePrefix{"rw"};
+constexpr llvm::StringLiteral kWriteOnlyPrefix{"w"};
+constexpr llvm::StringLiteral kReadOnlyPrefix{"r"};
+
+/// Returns a regex that matches {$rwN}, {$wN}, {$rN}
+static llvm::Regex getPredicateMappingRegex() {
+ llvm::Regex rx(llvm::formatv(R"(\{\$({0}|{1}|{2})([0-9]+)\})",
+ kReadWritePrefix, kWriteOnlyPrefix,
+ kReadOnlyPrefix)
+ .str());
+ return rx;
+}
+
+void mlir::NVVM::countPlaceholderNumbers(
+ StringRef ptxCode, llvm::SmallDenseSet<unsigned int> &seenRW,
+ llvm::SmallDenseSet<unsigned int> &seenW,
+ llvm::SmallDenseSet<unsigned int> &seenR,
+ llvm::SmallVectorImpl<unsigned int> &rwNums,
+ llvm::SmallVectorImpl<unsigned int> &wNums,
+ llvm::SmallVectorImpl<unsigned int> &rNums) {
+
+ llvm::Regex rx = getPredicateMappingRegex();
+ StringRef rest = ptxCode;
+
+ SmallVector<StringRef, 3> m; // 0: full, 1: kind, 2: number
+ while (!rest.empty() && rx.match(rest, &m)) {
+ unsigned num = 0;
+ (void)m[2].getAsInteger(10, num);
+ // Insert it into the vector only the first time we see this number
+ if (m[1].equals_insensitive(kReadWritePrefix)) {
+ if (seenRW.insert(num).second)
+ rwNums.push_back(num);
+ } else if (m[1].equals_insensitive(kWriteOnlyPrefix)) {
+ if (seenW.insert(num).second)
+ wNums.push_back(num);
+ } else {
+ if (seenR.insert(num).second)
+ rNums.push_back(num);
+ }
+
+ const size_t advance = (size_t)(m[0].data() - rest.data()) + m[0].size();
+ rest = rest.drop_front(advance);
+ }
+}
+
+/// Rewrites `{$rwN}`, `{$wN}`, and `{$rN}` placeholders in `ptxCode` into
+/// compact `$K` indices:
+/// - All `rw*` first (sorted by N),
+/// - Then `w*`,
+/// - Then `r*`.
+/// If there a predicate, it comes always in the end.
+/// Each number is assigned once; duplicates are ignored.
+///
+/// Example Input:
+/// "{
+/// reg .pred p;
+/// setp.ge.s32 p, {$r0}, {$r1};"
+/// selp.s32 {$rw0}, {$r0}, {$r1}, p;
+/// selp.s32 {$rw1}, {$r0}, {$r1}, p;
+/// selp.s32 {$w0}, {$r0}, {$r1}, p;
+/// selp.s32 {$w1}, {$r0}, {$r1}, p;
+/// }\n"
+/// Example Output:
+/// "{
+/// reg .pred p;
+/// setp.ge.s32 p, $4, $5;"
+/// selp.s32 $0, $4, $5, p;
+/// selp.s32 $1, $4, $5, p;
+/// selp.s32 $2, $4, $5, p;
+/// selp.s32 $3, $4, $5, p;
+/// }\n"
+static std::string rewriteAsmPlaceholders(llvm::StringRef ptxCode) {
+ llvm::SmallDenseSet<unsigned> seenRW, seenW, seenR;
+ llvm::SmallVector<unsigned> rwNums, wNums, rNums;
+
+ // Step 1. Count Register Placeholder numbers
+ countPlaceholderNumbers(ptxCode, seenRW, seenW, seenR, rwNums, wNums, rNums);
+
+ // Step 2. Sort the Register Placeholder numbers
+ llvm::sort(rwNums);
+ llvm::sort(wNums);
+ llvm::sort(rNums);
+
+ // Step 3. Create mapping from original to new IDs
+ llvm::DenseMap<unsigned, unsigned> rwMap, wMap, rMap;
+ unsigned nextId = 0;
+ for (unsigned n : rwNums)
+ rwMap[n] = nextId++;
+ for (unsigned n : wNums)
+ wMap[n] = nextId++;
+ for (unsigned n : rNums)
+ rMap[n] = nextId++;
+
+ // Step 4. Rewrite the PTX code with new IDs
+ std::string out;
+ out.reserve(ptxCode.size());
+ size_t prev = 0;
+ StringRef rest = ptxCode;
+ SmallVector<StringRef, 3> matches;
+ llvm::Regex rx = getPredicateMappingRegex();
+ while (!rest.empty() && rx.match(rest, &matches)) {
+ // Compute absolute match bounds in the original buffer.
+ size_t absStart = (size_t)(matches[0].data() - ptxCode.data());
+ size_t absEnd = absStart + matches[0].size();
+
+ // Emit text before the match.
+ out.append(ptxCode.data() + prev, ptxCode.data() + absStart);
+
+ // Emit compact $K
+ unsigned num = 0;
+ (void)matches[2].getAsInteger(10, num);
+ unsigned id = 0;
+ if (matches[1].equals_insensitive(kReadWritePrefix))
+ id = rwMap.lookup(num);
+ else if (matches[1].equals_insensitive(kWriteOnlyPrefix))
+ id = wMap.lookup(num);
+ else
+ id = rMap.lookup(num);
+
+ out.push_back('$');
+ out += std::to_string(id);
+
+ prev = absEnd;
+
+ const size_t advance =
+ (size_t)(matches[0].data() - rest.data()) + matches[0].size();
+ rest = rest.drop_front(advance);
+ }
+
+ // Step 5. Tail.
+ out.append(ptxCode.data() + prev, ptxCode.data() + ptxCode.size());
+ return out;
}
LLVM::InlineAsmOp PtxBuilder::build() {
auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
LLVM::AsmDialect::AD_ATT);
- auto resultTypes = interfaceOp->getResultTypes();
+ SmallVector<Type> resultTypes = packResultTypes(
+ interfaceOp, needsManualRegisterMapping, registerModifiers, ptxOperands);
// Remove the last comma from the constraints string.
if (!registerConstraints.empty() &&
registerConstraints[registerConstraints.size() - 1] == ',')
registerConstraints.pop_back();
+ registerConstraints = canonicalizeRegisterConstraints(registerConstraints);
std::string ptxInstruction = interfaceOp.getPtx();
+ if (!needsManualRegisterMapping)
+ ptxInstruction = rewriteAsmPlaceholders(ptxInstruction);
// Add the predicate to the asm string.
if (interfaceOp.getPredicate().has_value() &&
@@ -136,7 +456,7 @@ LLVM::InlineAsmOp PtxBuilder::build() {
rewriter, interfaceOp->getLoc(),
/*result types=*/resultTypes,
/*operands=*/ptxOperands,
- /*asm_string=*/llvm::StringRef(ptxInstruction),
+ /*asm_string=*/ptxInstruction,
/*constraints=*/registerConstraints.data(),
/*has_side_effects=*/interfaceOp.hasSideEffect(),
/*is_align_stack=*/false, LLVM::TailCallKind::None,
@@ -146,10 +466,89 @@ LLVM::InlineAsmOp PtxBuilder::build() {
void PtxBuilder::buildAndReplaceOp() {
LLVM::InlineAsmOp inlineAsmOp = build();
- LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
- if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) {
- rewriter.replaceOp(interfaceOp, inlineAsmOp);
- } else {
+ LDBG() << "\n Generated PTX \n\t" << inlineAsmOp;
+
+ // Case 0: no result at all → just erase wrapper op.
+ if (!hasResult) {
rewriter.eraseOp(interfaceOp);
+ return;
+ }
+
+ if (needsManualRegisterMapping) {
+ rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
+ return;
+ }
+
+ // Case 1: Simple path, return single scalar
+ if (!needsPackUnpack(interfaceOp, needsManualRegisterMapping,
+ registerModifiers)) {
+ if (inlineAsmOp->getNumResults() > 0) {
+ rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
+ } else {
+ // RW-only case with no declared results: forward the RW value.
+ SmallVector<Value> results;
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+ if (m == PTXRegisterMod::ReadWrite) {
+ results.push_back(v);
+ break;
+ }
+ rewriter.replaceOp(interfaceOp, results);
+ }
+ return;
+ }
+
+ const bool hasRW = llvm::any_of(registerModifiers, [](PTXRegisterMod m) {
+ return m == PTXRegisterMod::ReadWrite;
+ });
+
+ // All multi-value paths produce a single struct result we need to unpack.
+ assert(LLVM::LLVMStructType::classof(inlineAsmOp.getResultTypes().front()) &&
+ "expected struct return for multi-result inline asm");
+ Value structVal = inlineAsmOp.getResult(0);
+ SmallVector<Value> unpacked =
+ extractStructElements(rewriter, interfaceOp->getLoc(), structVal);
+
+ // Case 2: only declared results (no RW): replace the op with all unpacked.
+ if (!hasRW && interfaceOp->getResults().size() > 0) {
+ rewriter.replaceOp(interfaceOp, unpacked);
+ return;
+ }
+
+ // Case 3: RW-only (no declared results): update RW uses and erase wrapper.
+ if (hasRW && interfaceOp->getResults().size() == 0) {
+ unsigned idx = 0;
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) {
+ if (m != PTXRegisterMod::ReadWrite)
+ continue;
+ Value repl = unpacked[idx++];
+ v.replaceUsesWithIf(repl, [&](OpOperand &use) {
+ Operation *owner = use.getOwner();
+ return owner != interfaceOp && owner != inlineAsmOp;
+ });
+ }
+ rewriter.eraseOp(interfaceOp);
+ return;
+ }
+
+ // Case 4: mixed (RW + declared results).
+ {
+ // First rewrite RW operands in place.
+ unsigned idx = 0;
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) {
+ if (m != PTXRegisterMod::ReadWrite)
+ continue;
+ Value repl = unpacked[idx++];
+ v.replaceUsesWithIf(repl, [&](OpOperand &use) {
+ Operation *owner = use.getOwner();
+ return owner != interfaceOp && owner != inlineAsmOp;
+ });
+ }
+ // The remaining unpacked values correspond to the declared results.
+ SmallVector<Value> tail;
+ tail.reserve(unpacked.size() - idx);
+ for (unsigned i = idx, e = unpacked.size(); i < e; ++i)
+ tail.push_back(unpacked[i]);
+
+ rewriter.replaceOp(interfaceOp, tail);
}
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index 1e02bfe..e268e8f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -12,6 +12,8 @@
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/Ptr/IR/PtrEnums.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
@@ -51,6 +53,87 @@ void LLVMDialect::registerAttributes() {
}
//===----------------------------------------------------------------------===//
+// AddressSpaceAttr
+//===----------------------------------------------------------------------===//
+
+/// Checks whether the given type is an LLVM type that can be loaded or stored.
+static bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering,
+ std::optional<int64_t> alignment,
+ const ::mlir::DataLayout *dataLayout,
+ function_ref<InFlightDiagnostic()> emitError) {
+ if (!isLoadableType(type)) {
+ if (emitError)
+ emitError() << "type must be LLVM type with size, but got " << type;
+ return false;
+ }
+ if (ordering == ptr::AtomicOrdering::not_atomic)
+ return true;
+
+ // To check atomic validity we need a datalayout.
+ if (!dataLayout) {
+ if (emitError)
+ emitError() << "expected a valid data layout";
+ return false;
+ }
+ if (!isTypeCompatibleWithAtomicOp(type, *dataLayout)) {
+ if (emitError)
+ emitError() << "unsupported type " << type << " for atomic access";
+ return false;
+ }
+ return true;
+}
+
+bool AddressSpaceAttr::isValidLoad(
+ Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
+ const ::mlir::DataLayout *dataLayout,
+ function_ref<InFlightDiagnostic()> emitError) const {
+ return isValidLoadStoreImpl(type, ordering, alignment, dataLayout, emitError);
+}
+
+bool AddressSpaceAttr::isValidStore(
+ Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
+ const ::mlir::DataLayout *dataLayout,
+ function_ref<InFlightDiagnostic()> emitError) const {
+ return isValidLoadStoreImpl(type, ordering, alignment, dataLayout, emitError);
+}
+
+bool AddressSpaceAttr::isValidAtomicOp(
+ ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering,
+ std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
+ function_ref<InFlightDiagnostic()> emitError) const {
+ // TODO: update this method once `ptr.atomic_rmw` is implemented.
+ assert(false && "unimplemented, see TODO in the source.");
+ return false;
+}
+
+bool AddressSpaceAttr::isValidAtomicXchg(
+ Type type, ptr::AtomicOrdering successOrdering,
+ ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
+ const ::mlir::DataLayout *dataLayout,
+ function_ref<InFlightDiagnostic()> emitError) const {
+ // TODO: update this method once `ptr.atomic_cmpxchg` is implemented.
+ assert(false && "unimplemented, see TODO in the source.");
+ return false;
+}
+
+bool AddressSpaceAttr::isValidAddrSpaceCast(
+ Type tgt, Type src, function_ref<InFlightDiagnostic()> emitError) const {
+ // TODO: update this method once the `ptr.addrspace_cast` op is added to the
+ // dialect.
+ assert(false && "unimplemented, see TODO in the source.");
+ return false;
+}
+
+bool AddressSpaceAttr::isValidPtrIntCast(
+ Type intLikeTy, Type ptrLikeTy,
+ function_ref<InFlightDiagnostic()> emitError) const {
+ // TODO: update this method once the int-cast ops are added to the `ptr`
+ // dialect.
+ assert(false && "unimplemented, see TODO in the source.");
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
// AliasScopeAttr
//===----------------------------------------------------------------------===//
@@ -374,6 +457,43 @@ TargetFeaturesAttr TargetFeaturesAttr::featuresAt(Operation *op) {
getAttributeName());
}
+FailureOr<Attribute> TargetFeaturesAttr::query(DataLayoutEntryKey key) {
+ auto stringKey = dyn_cast<StringAttr>(key);
+ if (!stringKey)
+ return failure();
+
+ if (contains(stringKey))
+ return UnitAttr::get(getContext());
+
+ if (contains((std::string("+") + stringKey.strref()).str()))
+ return BoolAttr::get(getContext(), true);
+
+ if (contains((std::string("-") + stringKey.strref()).str()))
+ return BoolAttr::get(getContext(), false);
+
+ return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// TargetAttr
+//===----------------------------------------------------------------------===//
+
+FailureOr<::mlir::Attribute> TargetAttr::query(DataLayoutEntryKey key) {
+ if (auto stringAttrKey = dyn_cast<StringAttr>(key)) {
+ if (stringAttrKey.getValue() == "triple")
+ return getTriple();
+ if (stringAttrKey.getValue() == "chip")
+ return getChip();
+ if (stringAttrKey.getValue() == "features" && getFeatures())
+ return getFeatures();
+ }
+ return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// ModuleFlagAttr
+//===----------------------------------------------------------------------===//
+
LogicalResult
ModuleFlagAttr::verify(function_ref<InFlightDiagnostic()> emitError,
LLVM::ModFlagBehavior flagBehavior, StringAttr key,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 422039f..ef27070 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -141,6 +141,38 @@ static ParseResult parseLLVMLinkage(OpAsmParser &p, LinkageAttr &val) {
return success();
}
+static ArrayAttr getLLVMAlignParamForCompressExpand(OpBuilder &builder,
+ bool isExpandLoad,
+ uint64_t alignment = 1) {
+ // From
+ // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
+ // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
+ //
+ // The pointer alignment defaults to 1.
+ if (alignment == 1) {
+ return nullptr;
+ }
+
+ auto emptyDictAttr = builder.getDictionaryAttr({});
+ auto alignmentAttr = builder.getI64IntegerAttr(alignment);
+ auto namedAttr =
+ builder.getNamedAttr(LLVMDialect::getAlignAttrName(), alignmentAttr);
+ SmallVector<mlir::NamedAttribute> attrs = {namedAttr};
+ auto alignDictAttr = builder.getDictionaryAttr(attrs);
+ // From
+ // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
+ // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
+ //
+ // The align parameter attribute can be provided for [expandload]'s first
+ // argument. The align parameter attribute can be provided for
+ // [compressstore]'s second argument.
+ int pos = isExpandLoad ? 0 : 1;
+ return pos == 0 ? builder.getArrayAttr(
+ {alignDictAttr, emptyDictAttr, emptyDictAttr})
+ : builder.getArrayAttr(
+ {emptyDictAttr, alignDictAttr, emptyDictAttr});
+}
+
//===----------------------------------------------------------------------===//
// Operand bundle helpers.
//===----------------------------------------------------------------------===//
@@ -821,8 +853,8 @@ void LoadOp::getEffects(
/// Returns true if the given type is supported by atomic operations. All
/// integer, float, and pointer types with a power-of-two bitsize and a minimal
/// size of 8 bits are supported.
-static bool isTypeCompatibleWithAtomicOp(Type type,
- const DataLayout &dataLayout) {
+bool LLVM::isTypeCompatibleWithAtomicOp(Type type,
+ const DataLayout &dataLayout) {
if (!isa<IntegerType, LLVMPointerType>(type))
if (!isCompatibleFloatingPointType(type))
return false;
@@ -836,8 +868,9 @@ static bool isTypeCompatibleWithAtomicOp(Type type,
/// Verifies the attributes and the type of atomic memory access operations.
template <typename OpTy>
-LogicalResult verifyAtomicMemOp(OpTy memOp, Type valueType,
- ArrayRef<AtomicOrdering> unsupportedOrderings) {
+static LogicalResult
+verifyAtomicMemOp(OpTy memOp, Type valueType,
+ ArrayRef<AtomicOrdering> unsupportedOrderings) {
if (memOp.getOrdering() != AtomicOrdering::not_atomic) {
DataLayout dataLayout = DataLayout::closest(memOp);
if (!isTypeCompatibleWithAtomicOp(valueType, dataLayout))
@@ -1087,7 +1120,7 @@ static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee) {
/// Verify that the parameter and return types of the variadic callee type match
/// the `callOp` argument and result types.
template <typename OpTy>
-LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
+static LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
std::optional<LLVMFunctionType> varCalleeType = callOp.getVarCalleeType();
if (!varCalleeType)
return success();
@@ -2500,7 +2533,7 @@ LogicalResult GlobalOp::verifyRegions() {
// LLVM::GlobalCtorsOp
//===----------------------------------------------------------------------===//
-LogicalResult checkGlobalXtorData(Operation *op, ArrayAttr data) {
+static LogicalResult checkGlobalXtorData(Operation *op, ArrayAttr data) {
if (data.empty())
return success();
@@ -4117,6 +4150,32 @@ LogicalResult LLVM::masked_scatter::verify() {
}
//===----------------------------------------------------------------------===//
+// masked_expandload (intrinsic)
+//===----------------------------------------------------------------------===//
+
+void LLVM::masked_expandload::build(OpBuilder &builder, OperationState &state,
+ mlir::TypeRange resTys, Value ptr,
+ Value mask, Value passthru,
+ uint64_t align) {
+ ArrayAttr argAttrs = getLLVMAlignParamForCompressExpand(builder, true, align);
+ build(builder, state, resTys, ptr, mask, passthru, /*arg_attrs=*/argAttrs,
+ /*res_attrs=*/nullptr);
+}
+
+//===----------------------------------------------------------------------===//
+// masked_compressstore (intrinsic)
+//===----------------------------------------------------------------------===//
+
+void LLVM::masked_compressstore::build(OpBuilder &builder,
+ OperationState &state, Value value,
+ Value ptr, Value mask, uint64_t align) {
+ ArrayAttr argAttrs =
+ getLLVMAlignParamForCompressExpand(builder, false, align);
+ build(builder, state, value, ptr, mask, /*arg_attrs=*/argAttrs,
+ /*res_attrs=*/nullptr);
+}
+
+//===----------------------------------------------------------------------===//
// InlineAsmOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index e7d5dad..ef38027 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -19,6 +19,7 @@
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "sroa"
@@ -734,9 +735,8 @@ static std::optional<uint64_t> gepToByteOffset(const DataLayout &dataLayout,
return false;
})
.Default([&](Type type) {
- LLVM_DEBUG(llvm::dbgs()
- << "[sroa] Unsupported type for offset computations"
- << type << "\n");
+ LDBG() << "[sroa] Unsupported type for offset computations"
+ << type;
return true;
});
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index 78b4411..297640c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -24,7 +24,9 @@ using namespace mlir::LLVM;
/// prints it as usual.
static void dispatchPrint(AsmPrinter &printer, Type type) {
if (isCompatibleType(type) &&
- !llvm::isa<IntegerType, FloatType, VectorType>(type))
+ !(llvm::isa<IntegerType, FloatType, VectorType>(type) ||
+ (llvm::isa<PtrLikeTypeInterface>(type) &&
+ !llvm::isa<LLVMPointerType>(type))))
return mlir::LLVM::detail::printType(type, printer);
printer.printType(type);
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index fee2d3e..2dd0132 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -13,6 +13,7 @@
#include "TypeDetail.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -701,6 +702,17 @@ const llvm::fltSemantics &LLVMPPCFP128Type::getFloatSemantics() const {
// Utility functions.
//===----------------------------------------------------------------------===//
+/// Check whether type is a compatible ptr type. These are pointer-like types
+/// with no element type, no metadata, and using the LLVM AddressSpaceAttr
+/// memory space.
+static bool isCompatiblePtrType(Type type) {
+ auto ptrTy = dyn_cast<PtrLikeTypeInterface>(type);
+ if (!ptrTy)
+ return false;
+ return !ptrTy.hasPtrMetadata() && ptrTy.getElementType() == nullptr &&
+ isa<AddressSpaceAttr>(ptrTy.getMemorySpace());
+}
+
bool mlir::LLVM::isCompatibleOuterType(Type type) {
// clang-format off
if (llvm::isa<
@@ -734,7 +746,7 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) {
if (auto vecType = llvm::dyn_cast<VectorType>(type))
return vecType.getRank() == 1;
- return false;
+ return isCompatiblePtrType(type);
}
static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
@@ -784,6 +796,8 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
LLVMX86AMXType
>([](Type) { return true; })
// clang-format on
+ .Case<PtrLikeTypeInterface>(
+ [](Type type) { return isCompatiblePtrType(type); })
.Default([](Type) { return false; });
if (!result)
@@ -805,6 +819,18 @@ bool mlir::LLVM::isCompatibleType(Type type) {
return LLVMDialect::isCompatibleType(type);
}
+bool mlir::LLVM::isLoadableType(Type type) {
+ return /*LLVM_PrimitiveType*/ (
+ LLVM::isCompatibleOuterType(type) &&
+ !isa<LLVM::LLVMVoidType, LLVM::LLVMFunctionType>(type)) &&
+ /*LLVM_OpaqueStruct*/
+ !(isa<LLVM::LLVMStructType>(type) &&
+ cast<LLVM::LLVMStructType>(type).isOpaque()) &&
+ /*LLVM_AnyTargetExt*/
+ !(isa<LLVM::LLVMTargetExtType>(type) &&
+ !cast<LLVM::LLVMTargetExtType>(type).supportsMemOps());
+}
+
bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {
return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
Float80Type, Float128Type, LLVMPPCFP128Type>(type);
@@ -818,7 +844,8 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) {
if (auto intType = llvm::dyn_cast<IntegerType>(elementType))
return intType.isSignless();
return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
- Float80Type, Float128Type, LLVMPointerType>(elementType);
+ Float80Type, Float128Type, LLVMPointerType>(elementType) ||
+ isCompatiblePtrType(elementType);
}
return false;
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7ad429e..77ec1eb 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -33,6 +33,7 @@
#include "llvm/IR/IRBuilder.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/NVPTXAddrSpace.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <optional>
@@ -50,7 +51,6 @@ using namespace NVVM;
// This verifier is shared among the following Ops:
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
-// CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
bool isIm2Col,
@@ -82,8 +82,27 @@ LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
}
LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
- if (getCoordinates().size() > 5)
- return emitError("Maximum 5 coordinates and dimension is supported.");
+ TMAStoreMode mode = getMode();
+ // We lower through inline-ptx when getPredicate() is true.
+ // a) Only TILE mode is supported
+ // b) Cache-hint is not supported
+ if (getPredicate()) {
+ if (mode != TMAStoreMode::TILE)
+ return emitError("Inline-ptx lowering supported only for Tile mode.");
+ if (getL2CacheHint())
+ return emitError("Inline-ptx lowering unsupported with L2 cache-hint.");
+ }
+
+ size_t dims = getCoordinates().size();
+ switch (mode) {
+ case TMAStoreMode::TILE:
+ return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
+ case TMAStoreMode::IM2COL:
+ return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
+ case TMAStoreMode::TILE_SCATTER4:
+ if (dims != 5)
+ return emitError("Scatter4 mode expects 5 coordinates");
+ }
return success();
}
@@ -98,17 +117,59 @@ LogicalResult CpAsyncOp::verify() {
return success();
}
+// This verify params can be shared across TMA Load and Prefetch Ops.
+static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff,
+ TMALoadMode mode, Location loc) {
+ if (tensorDims < 1 || tensorDims > 5)
+ return emitError(loc, "expects coordinates between 1 to 5 dimension");
+
+ auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col,
+ size_t expectedIm2colOff) -> LogicalResult {
+ if (isIm2col && (tensorDims < 3))
+ return emitError(loc)
+ << "to use " << stringifyEnum(mode)
+ << " mode, the tensor has to be at least 3-dimensional";
+
+ if (numIm2colOff != expectedIm2colOff)
+ return emitError(loc) << " im2col offsets expected " << expectedIm2colOff
+ << " (provided " << numIm2colOff << ")";
+
+ return success();
+ };
+
+ switch (mode) {
+ case TMALoadMode::TILE:
+ return checkTMALoadParams(mode, false, 0);
+ case TMALoadMode::IM2COL:
+ return checkTMALoadParams(mode, true, tensorDims - 2);
+ case TMALoadMode::IM2COL_W:
+ case TMALoadMode::IM2COL_W_128:
+ return checkTMALoadParams(mode, true, 2);
+ case TMALoadMode::TILE_GATHER4:
+ return (tensorDims == 5)
+ ? checkTMALoadParams(mode, false, 0)
+ : emitError(loc, "Gather4 mode expects 5 coordinates");
+ }
+ return success();
+}
+
LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
- size_t numIm2ColOffsets = getIm2colOffsets().size();
- bool isIm2Col = numIm2ColOffsets > 0;
- return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
- numIm2ColOffsets, getLoc());
+ return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
+ getMode(), getLoc());
}
LogicalResult CpAsyncBulkTensorReduceOp::verify() {
- bool isIm2Col = (getMode() == TMAStoreMode::IM2COL);
- return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 0,
- getLoc());
+ TMAStoreMode mode = getMode();
+ size_t dims = getCoordinates().size();
+ switch (mode) {
+ case TMAStoreMode::TILE:
+ return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
+ case TMAStoreMode::IM2COL:
+ return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
+ case TMAStoreMode::TILE_SCATTER4:
+ return emitError("Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
+ }
+ return success();
}
LogicalResult ConvertFloatToTF32Op::verify() {
@@ -811,24 +872,58 @@ LogicalResult NVVM::WMMAMmaOp::verify() {
}
LogicalResult NVVM::LdMatrixOp::verify() {
- unsigned addressSpace =
- llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
- if (addressSpace != NVVM::kSharedMemorySpace)
- return emitOpError("expected source pointer in memory space 3");
-
- if (getNum() != 1 && getNum() != 2 && getNum() != 4)
- return emitOpError("expected num attribute to be 1, 2 or 4");
+ uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN();
+ if (m == 8 && n == 8) {
+ if (num != 1 && num != 2 && num != 4) {
+ return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 "
+ "matrix");
+ }
+ if (getEltType() != LdStMatrixEltType::B16) {
+ return emitOpError("expected element type to be b16 for 8x8 matrix");
+ }
+ } else if (m == 8 && n == 16) {
+ if (num != 1 && num != 2 && num != 4) {
+ return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 "
+ "matrix");
+ }
+ if (getLayout() != MMALayout::row) {
+ return emitOpError("expected layout to be row for 8x16 matrix");
+ }
+ if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
+ getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
+ return emitOpError("expected element type to be b8x16.b4x16_p64 or "
+ "b8x16.b6x16_p32 for 8x16 matrix");
+ }
+ } else if (m == 16 && n == 16) {
+ if (num != 1 && num != 2) {
+ return emitOpError("expected num attribute to be 1 or 2 for 16x16 "
+ "matrix");
+ }
+ if (getLayout() != MMALayout::col) {
+ return emitOpError("expected layout to be col for 16x16 matrix");
+ }
+ if (getEltType() != LdStMatrixEltType::B8 &&
+ getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
+ getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
+ return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or "
+ "b8x16.b6x16_p32 for 16x16 matrix");
+ }
+ } else {
+ return emitOpError("expected shape to be 8x8, 8x16 or 16x16");
+ }
Type i32 = IntegerType::get(getContext(), 32);
- if (getNum() == 1 && getType() != i32)
+ uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
+ if (numElements == 1 && getType() != i32)
return emitOpError("expected destination type is i32");
- if (getNum() == 2 || getNum() == 4) {
+ if (numElements == 2 || numElements == 4) {
Type dstType = LLVM::LLVMStructType::getLiteral(
- getContext(), SmallVector<Type>(getNum(), i32));
+ getContext(), SmallVector<Type>(numElements, i32));
if (getType() != dstType)
return emitOpError("expected destination type is a structure of ")
- << getNum() << " elements of type i32";
+ << numElements << " elements of type i32";
}
+
return success();
}
@@ -1089,7 +1184,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
return ptx;
}
-void NVVM::WgmmaMmaAsyncOp::getAsmValues(
+bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
RewriterBase &rewriter,
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
&asmValues) {
@@ -1120,7 +1215,9 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
{makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
mlir::NVVM::PTXRegisterMod::Read});
}
+ return true; // Has manual mapping
}
+
LogicalResult NVVM::FenceProxyOp::verify() {
if (getKind() == NVVM::ProxyKind::TENSORMAP)
return emitOpError() << "tensormap proxy is not a supported proxy kind";
@@ -1236,30 +1333,70 @@ LogicalResult NVVM::PrefetchOp::verify() {
unsigned addressSpace =
llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
+ std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
- if (getUniform()) {
- if (getCacheLevel() != CacheLevel::L1)
- return emitOpError("unsupported cache level, the only supported uniform "
- "cache level is L1");
+ if (getTensormap() && cacheLevel)
+ return emitOpError("cannot specify both tensormap and cache level");
- if (addressSpace != MemSpace::kGenericMemorySpace)
+ if (getTensormap()) {
+ if (addressSpace != MemSpace::kGenericMemorySpace &&
+ addressSpace != MemSpace::kConstantMemorySpace) {
return emitOpError(
- "prefetch to uniform cache requires a generic pointer");
- }
+ "prefetch tensormap requires a generic or constant pointer");
+ }
- if (evictPriority) {
- if (getCacheLevel() != CacheLevel::L2)
+ if (evictPriority) {
return emitOpError(
- "cache eviction priority supported only for cache level L2");
-
- if (addressSpace != MemSpace::kGlobalMemorySpace)
- return emitOpError("cache eviction priority requires a global pointer");
+ "prefetch tensormap does not support eviction priority");
+ }
- if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
- *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
+ if (getInParamSpace() && addressSpace != MemSpace::kGenericMemorySpace) {
return emitOpError(
- "unsupported cache eviction priority, only evict_last and "
- "evict_normal are supported");
+ "in_param_space can only be specified for a generic pointer");
+ }
+
+ } else if (cacheLevel) {
+ if (addressSpace != MemSpace::kGenericMemorySpace &&
+ addressSpace != MemSpace::kGlobalMemorySpace &&
+ addressSpace != MemSpace::kLocalMemorySpace) {
+ return emitOpError("prefetch to cache level requires a generic, global, "
+ "or local pointer");
+ }
+
+ if (getUniform()) {
+ if (*cacheLevel != CacheLevel::L1) {
+ return emitOpError(
+ "unsupported cache level, the only supported uniform "
+ "cache level is L1");
+ }
+
+ if (addressSpace != MemSpace::kGenericMemorySpace) {
+ return emitOpError(
+ "prefetch to uniform cache requires a generic pointer");
+ }
+ }
+
+ if (evictPriority) {
+ if (*cacheLevel != CacheLevel::L2)
+ return emitOpError(
+ "cache eviction priority supported only for cache level L2");
+
+ if (addressSpace != MemSpace::kGlobalMemorySpace)
+ return emitOpError("cache eviction priority requires a global pointer");
+
+ if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
+ *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
+ return emitOpError(
+ "unsupported cache eviction priority, only evict_last and "
+ "evict_normal are supported");
+ }
+
+ if (getPredicate())
+ return emitOpError("predicate supported only on prefetch tensormap");
+
+ } else {
+ return emitOpError(
+ "requires specification of either cache level or tensormap");
}
return success();
@@ -1399,28 +1536,102 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
return {id, std::move(args)};
}
-llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
- bool isIm2Col) {
- switch (tensorDims) {
- case 1:
- return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
- case 2:
- return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
- case 3:
- return isIm2Col
- ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
- : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
- case 4:
- return isIm2Col
- ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
- : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
- case 5:
- return isIm2Col
- ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
- : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
- default:
- llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
- }
+mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ // Fill the Intrinsic Args
+ args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
+
+ for (auto v : thisOp.getCoordinates())
+ args.push_back(mt.lookupValue(v));
+ for (auto v : thisOp.getIm2colOffsets())
+ args.push_back(mt.lookupValue(v));
+
+ mlir::Value cacheHint = thisOp.getL2CacheHint();
+ const bool hasCacheHint = static_cast<bool>(cacheHint);
+ llvm::Value *i64Unused =
+ llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
+ args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
+ args.push_back(builder.getInt1(hasCacheHint));
+
+ const unsigned NI = llvm::Intrinsic::not_intrinsic;
+ static constexpr llvm::Intrinsic::ID IDTable[][6] = {
+ {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
+ {NI, NI, NI,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
+ {NI, NI, NI,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
+ {NI, NI, NI,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
+ {NI, NI, NI, NI, NI,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
+
+ static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
+ "TMALoadModes must match number of rows in IDTable");
+ size_t mode = static_cast<size_t>(thisOp.getMode());
+ size_t dim = thisOp.getCoordinates().size();
+ llvm::Intrinsic::ID id = IDTable[mode][dim];
+ if (id == llvm::Intrinsic::not_intrinsic)
+ llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
+
+ return {id, std::move(args)};
+}
+
+mlir::NVVM::IDArgPair
+CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ // Fill the Intrinsic Args
+ args.push_back(mt.lookupValue(thisOp.getSrcMem()));
+ args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
+
+ for (auto v : thisOp.getCoordinates())
+ args.push_back(mt.lookupValue(v));
+
+ mlir::Value cacheHint = thisOp.getL2CacheHint();
+ const bool hasCacheHint = static_cast<bool>(cacheHint);
+ llvm::Value *i64Unused =
+ llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
+ args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
+ args.push_back(builder.getInt1(hasCacheHint));
+
+ const unsigned NI = llvm::Intrinsic::not_intrinsic;
+ static constexpr llvm::Intrinsic::ID IDTable[][6] = {
+ {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
+ {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
+ {NI, NI, NI, NI, NI,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
+
+ static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
+ "TMAStoreModes must match number of rows in IDTable");
+ size_t mode = static_cast<size_t>(thisOp.getMode());
+ size_t dim = thisOp.getCoordinates().size();
+ llvm::Intrinsic::ID id = IDTable[mode][dim];
+ if (id == llvm::Intrinsic::not_intrinsic)
+ llvm_unreachable(
+ "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
+
+ return {id, std::move(args)};
}
#define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
@@ -1794,26 +2005,47 @@ NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
return {ids[type], args};
}
-llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) {
+static llvm::Value *getParamCastedAddr(llvm::Value *addr,
+ llvm::IRBuilderBase &builder) {
+ return builder.CreateAddrSpaceCast(
+ addr,
+ llvm::PointerType::get(builder.getContext(),
+ llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM));
+}
+
+NVVM::IDArgPair
+PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
using MemSpace = NVVM::NVVMMemorySpace;
using CacheLevel = NVVM::PrefetchCacheLevel;
- NVVM::PrefetchCacheLevel cacheLevel = op.getCacheLevel();
+ std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
std::optional<NVVM::CacheEvictionPriority> evictPriority =
op.getEvictPriority();
unsigned addressSpace =
llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
.getAddressSpace();
- if (op.getUniform() && cacheLevel == CacheLevel::L1)
- return llvm::Intrinsic::nvvm_prefetchu_L1;
+ llvm::SmallVector<llvm::Value *> args;
+ llvm::Value *addr = mt.lookupValue(op.getAddr());
+ args.push_back(op.getInParamSpace() ? getParamCastedAddr(addr, builder)
+ : addr);
- if (evictPriority && cacheLevel == CacheLevel::L2) {
+ if (op.getTensormap())
+ return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
+
+ assert(cacheLevel && "expected cache level for non-tensormap prefetch");
+
+ if (op.getUniform() && *cacheLevel == CacheLevel::L1)
+ return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
+
+ if (evictPriority && *cacheLevel == CacheLevel::L2) {
switch (*evictPriority) {
case NVVM::CacheEvictionPriority::EvictLast:
- return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last;
+ return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
case NVVM::CacheEvictionPriority::EvictNormal:
- return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal;
+ return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
default:
llvm_unreachable("Invalid cache eviction priority");
}
@@ -1821,21 +2053,41 @@ llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) {
switch (addressSpace) {
case MemSpace::kGenericMemorySpace:
- return cacheLevel == CacheLevel::L1 ? llvm::Intrinsic::nvvm_prefetch_L1
- : llvm::Intrinsic::nvvm_prefetch_L2;
+ return *cacheLevel == CacheLevel::L1
+ ? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args})
+ : NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
case MemSpace::kGlobalMemorySpace:
- return cacheLevel == CacheLevel::L1
- ? llvm::Intrinsic::nvvm_prefetch_global_L1
- : llvm::Intrinsic::nvvm_prefetch_global_L2;
+ return *cacheLevel == CacheLevel::L1
+ ? NVVM::IDArgPair(
+ {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
+ : NVVM::IDArgPair(
+ {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
case MemSpace::kLocalMemorySpace:
- return cacheLevel == CacheLevel::L1
- ? llvm::Intrinsic::nvvm_prefetch_local_L1
- : llvm::Intrinsic::nvvm_prefetch_local_L2;
+ return *cacheLevel == CacheLevel::L1
+ ? NVVM::IDArgPair(
+ {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
+ : NVVM::IDArgPair(
+ {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
default:
llvm_unreachable("Invalid pointer address space");
}
}
+bool NVVM::InlinePtxOp::getAsmValues(
+ RewriterBase &rewriter,
+ llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
+ &asmValues) {
+ for (auto arg : getReadWriteArgs())
+ asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::ReadWrite});
+ for (auto arg : getResults())
+ asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Write});
+ for (auto arg : getReadOnlyArgs())
+ asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Read});
+ if (getPredicate())
+ asmValues.push_back({getPredicate(), mlir::NVVM::PTXRegisterMod::Read});
+ return false; // No manual mapping needed
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
@@ -1874,19 +2126,31 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
attrName == NVVMDialect::getReqntidAttrName() ||
attrName == NVVMDialect::getClusterDimAttrName()) {
auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
- if (!values || values.empty() || values.size() > 3)
+ if (!values || values.empty() || values.size() > 3) {
return op->emitError()
<< "'" << attrName
<< "' attribute must be integer array with maximum 3 index";
+ }
}
// If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
// attribute
if (attrName == NVVMDialect::getMinctasmAttrName() ||
attrName == NVVMDialect::getMaxnregAttrName() ||
attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
- if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
+ if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) {
return op->emitError()
<< "'" << attrName << "' attribute must be integer constant";
+ }
+ }
+ // blocksareclusters must be used along with reqntid and cluster_dim
+ if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
+ if (!op->hasAttr(NVVMDialect::getReqntidAttrName()) ||
+ !op->hasAttr(NVVMDialect::getClusterDimAttrName())) {
+ return op->emitError()
+ << "'" << attrName << "' attribute must be used along with "
+ << "'" << NVVMDialect::getReqntidAttrName() << "' and "
+ << "'" << NVVMDialect::getClusterDimAttrName() << "'";
+ }
}
return success();
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
index 8317b67..23b4130 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
using namespace mlir;
using namespace LLVM;
@@ -63,9 +63,8 @@ DIExpressionRewriter::simplify(DIExpressionAttr expr,
}
if (maxNumRewrites && numRewrites >= *maxNumRewrites) {
- LLVM_DEBUG(llvm::dbgs()
- << "LLVMDIExpressionSimplifier exceeded max num rewrites ("
- << maxNumRewrites << ")\n");
+ LDBG() << "LLVMDIExpressionSimplifier exceeded max num rewrites ("
+ << maxNumRewrites << ")";
// Skip rewriting the rest.
result.append(inputs.begin(), inputs.end());
}
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
index 18f85b6..4ea2ac9 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
@@ -235,8 +235,10 @@ getUnderlyingObjectSet(Value pointerValue) {
WalkContinuation walkResult = walkSlice(pointerValue, [&](Value val) {
// Attempt to advance to the source of the underlying view-like operation.
// Examples of view-like operations include GEPOp and AddrSpaceCastOp.
- if (auto viewOp = val.getDefiningOp<ViewLikeOpInterface>())
- return WalkContinuation::advanceTo(viewOp.getViewSource());
+ if (auto viewOp = val.getDefiningOp<ViewLikeOpInterface>()) {
+ if (val == viewOp.getViewDest())
+ return WalkContinuation::advanceTo(viewOp.getViewSource());
+ }
// Attempt to advance to control flow predecessors.
std::optional<SmallVector<Value>> controlFlowPredecessors =
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 34c63d3..578931e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -194,9 +194,10 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
ArrayRef<AffineMap> indexingMaps) {
// Initialize indexingMaps attribute, for MatmulOp.
SmallVector<Attribute, 3> indexingMapsAttrVal;
- indexingMapsAttrVal = llvm::map_to_vector(
- MatmulOp::getDefaultIndexingMaps(b.getContext()),
- [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ indexingMapsAttrVal =
+ llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
+ return AffineMapAttr::get(map);
+ });
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
attributes, regionBuilder);
@@ -1569,40 +1570,50 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}
-// Retrieve the operation from the body, if it is the only one (except
-// yield) and if it gets the same amount of arguments as the body does.
-// If initFirst flag is enabled, we check that init takes the first position in
-// operands of payload.
-static Operation *findPayloadOp(Block *body, bool initFirst = false) {
+static bool canUseShortForm(Block *body, bool initFirst = false) {
+ // Check if the body can be printed in short form. The following 4 conditions
+ // must be satisfied:
+
+ // 1) The body must contain exactly 2 operations: the payload op and a yield.
if (body->getOperations().size() != 2)
- return nullptr;
+ return false;
Operation &payload = body->getOperations().front();
- assert(isa<YieldOp>(body->getOperations().back()));
+ // 2) The payload op must have the same number of operands as the number of
+ // block arguments.
if (payload.getNumOperands() == 0 ||
payload.getNumOperands() != body->getNumArguments())
- return nullptr;
+ return false;
+
+ // 3) If `initFirst` is true (e.g., for reduction ops), the init block
+ // must be the first operand of the payload op, otherwise, the operands
+ // must match the block arguments in order.
if (initFirst) {
// check init
if (payload.getOperands().back() != body->getArgument(0))
- return nullptr;
+ return false;
// check rest
for (const auto &[operand, bbArg] :
llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
if (bbArg != operand)
- return nullptr;
+ return false;
}
} else {
for (const auto &[operand, bbArg] :
llvm::zip(payload.getOperands(), body->getArguments())) {
if (bbArg != operand)
- return nullptr;
+ return false;
}
}
- return &payload;
+
+ // 4) The `yield` operand must be the result of the payload op.
+ auto yieldOp = cast<YieldOp>(body->getTerminator());
+ return yieldOp.getNumOperands() == 1 &&
+ yieldOp.getOperand(0).getDefiningOp() &&
+ yieldOp.getOperand(0).getDefiningOp() == &payload;
}
-void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
+static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
SmallVector<StringRef> elidedAttrs;
std::string attrToElide;
p << " { " << payloadOp->getName().getStringRef();
@@ -1621,15 +1632,15 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
void MapOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
- Operation *payloadOp = findPayloadOp(mapper);
- if (payloadOp) {
- printShortForm(p, payloadOp);
+ bool useShortForm = canUseShortForm(mapper);
+ if (useShortForm) {
+ printShortForm(p, &mapper->getOperations().front());
}
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
p.printOptionalAttrDict((*this)->getAttrs());
- if (!payloadOp) {
+ if (!useShortForm) {
// Print region if the payload op was not detected.
p.increaseIndent();
p.printNewline();
@@ -1828,15 +1839,15 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
void ReduceOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
- Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
- if (payloadOp) {
- printShortForm(p, payloadOp);
+ bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
+ if (useShortForm) {
+ printShortForm(p, &mapper->getOperations().front());
}
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
- if (!payloadOp) {
+ if (!useShortForm) {
// Print region if the payload op was not detected.
p.increaseIndent();
p.printNewline();
@@ -3749,6 +3760,25 @@ std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
// MatMulOp
//===----------------------------------------------------------------------===//
+static FailureOr<SmallVector<SmallVector<int64_t>>>
+getAffineResultPositions(ArrayAttr maps) {
+ SmallVector<SmallVector<int64_t>> positions;
+ for (auto map : maps) {
+ AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
+ if (!attr)
+ return failure();
+ SmallVector<int64_t> pos;
+ for (auto result : attr.getAffineMap().getResults()) {
+ auto dim = dyn_cast<AffineDimExpr>(result);
+ if (!dim)
+ return failure();
+ pos.push_back(dim.getPosition());
+ }
+ positions.push_back(pos);
+ }
+ return positions;
+}
+
/// Returns a list of AffineMap with the typical matmul indexing charactristic.
SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
AffineExpr d0, d1, d2;
@@ -3760,6 +3790,20 @@ SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
return indexingMaps;
}
+bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
+ (*positions)[1] == SmallVector<int64_t>{2, 1} &&
+ (*positions)[2] == SmallVector<int64_t>{0, 1};
+}
+
SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
utils::IteratorType::parallel,
@@ -3836,7 +3880,7 @@ bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
return expr.isFunctionOfDim(bcastMap.getNumDims() - 1);
}
-FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
+static FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
if (parser.parseOptionalKeyword("indexing_maps"))
return ArrayAttr{
nullptr}; // Success in case indexing_maps was not provided.
@@ -3912,6 +3956,380 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+SmallVector<AffineMap>
+MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
+ AffineExpr d0, d1, d2;
+ MLIRContext *context = builder.getContext();
+ bindDims(context, d0, d1, d2);
+ AffineMap mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
+ AffineMap mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
+ AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
+ return {mapLHS, mapRHS, mapOut};
+}
+
+bool MatmulTransposeAOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{2, 0} &&
+ (*positions)[1] == SmallVector<int64_t>{2, 1} &&
+ (*positions)[2] == SmallVector<int64_t>{0, 1};
+}
+
+void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
+ OperationState &result,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
+}
+
+MatmulTransposeAOp
+MatmulTransposeAOp::create(OpBuilder &builder, Location location,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, inputs, outputs, attributes);
+ auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
+ OperationState &result,
+ TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
+}
+
+MatmulTransposeAOp
+MatmulTransposeAOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, attributes);
+ auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
+ OperationState &result,
+ TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ result.addAttribute("cast", cast);
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
+}
+
+MatmulTransposeAOp
+MatmulTransposeAOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
+ auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+bool MatmulTransposeAOp::classof(Operation *op) {
+ return dyn_cast_or_null<linalg::MatmulOp>(op) &&
+ MatmulTransposeAOp::isDefaultIndexingMaps(
+ op->getAttr("indexing_maps"));
+}
+
+SmallVector<AffineMap>
+MatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
+ AffineExpr d0, d1, d2;
+ MLIRContext *context = builder.getContext();
+ bindDims(context, d0, d1, d2);
+ AffineMap mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
+ AffineMap mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
+ AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
+ return {mapLHS, mapRHS, mapOut};
+}
+
+bool MatmulTransposeBOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
+ (*positions)[1] == SmallVector<int64_t>{1, 2} &&
+ (*positions)[2] == SmallVector<int64_t>{0, 1};
+}
+
+void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
+ OperationState &result,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
+}
+
+MatmulTransposeBOp
+MatmulTransposeBOp::create(OpBuilder &builder, Location location,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, inputs, outputs, attributes);
+ auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
+ OperationState &result,
+ TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
+}
+
+MatmulTransposeBOp
+MatmulTransposeBOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, attributes);
+ auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
+ OperationState &result,
+ TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ result.addAttribute("cast", cast);
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
+}
+
+MatmulTransposeBOp
+MatmulTransposeBOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
+ auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+bool MatmulTransposeBOp::classof(Operation *op) {
+ return dyn_cast_or_null<linalg::MatmulOp>(op) &&
+ MatmulTransposeBOp::isDefaultIndexingMaps(
+ op->getAttr("indexing_maps"));
+}
+
+SmallVector<AffineMap>
+BatchMatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
+ AffineExpr d0, d1, d2, d3;
+ MLIRContext *context = builder.getContext();
+ bindDims(context, d0, d1, d2, d3);
+ AffineMap mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
+ AffineMap mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
+ AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
+ return {mapLHS, mapRHS, mapOut};
+}
+
+bool BatchMatmulTransposeAOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{0, 3, 1} &&
+ (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
+ (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
+}
+
+void linalg::BatchMatmulTransposeAOp::build(
+ OpBuilder &builder, OperationState &result, ValueRange inputs,
+ ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(),
+ getDefaultIndexingMaps(builder));
+}
+
+BatchMatmulTransposeAOp
+BatchMatmulTransposeAOp::create(OpBuilder &builder, Location location,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, inputs, outputs, attributes);
+ auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::BatchMatmulTransposeAOp::build(
+ OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(),
+ getDefaultIndexingMaps(builder));
+}
+
+BatchMatmulTransposeAOp
+BatchMatmulTransposeAOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, attributes);
+ auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::BatchMatmulTransposeAOp::build(
+ OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ result.addAttribute("cast", cast);
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(),
+ getDefaultIndexingMaps(builder));
+}
+
+BatchMatmulTransposeAOp
+BatchMatmulTransposeAOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
+ auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+bool BatchMatmulTransposeAOp::classof(Operation *op) {
+ return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
+ BatchMatmulTransposeAOp::isDefaultIndexingMaps(
+ op->getAttr("indexing_maps"));
+}
+
+SmallVector<AffineMap>
+BatchMatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
+ AffineExpr d0, d1, d2, d3;
+ MLIRContext *context = builder.getContext();
+ bindDims(context, d0, d1, d2, d3);
+ AffineMap mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
+ AffineMap mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
+ AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
+ return {mapLHS, mapRHS, mapOut};
+}
+
+bool BatchMatmulTransposeBOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
+ (*positions)[1] == SmallVector<int64_t>{0, 2, 3} &&
+ (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
+}
+
+void linalg::BatchMatmulTransposeBOp::build(
+ OpBuilder &builder, OperationState &result, ValueRange inputs,
+ ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(),
+ getDefaultIndexingMaps(builder));
+}
+
+BatchMatmulTransposeBOp
+BatchMatmulTransposeBOp::create(OpBuilder &builder, Location location,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, inputs, outputs, attributes);
+ auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::BatchMatmulTransposeBOp::build(
+ OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(),
+ getDefaultIndexingMaps(builder));
+}
+
+BatchMatmulTransposeBOp
+BatchMatmulTransposeBOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, attributes);
+ auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::BatchMatmulTransposeBOp::build(
+ OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ result.addAttribute("cast", cast);
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(),
+ getDefaultIndexingMaps(builder));
+}
+
+BatchMatmulTransposeBOp
+BatchMatmulTransposeBOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
+ auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+bool BatchMatmulTransposeBOp::classof(Operation *op) {
+ return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
+ BatchMatmulTransposeBOp::isDefaultIndexingMaps(
+ op->getAttr("indexing_maps"));
+}
+
//===----------------------------------------------------------------------===//
// ContractOp
//===----------------------------------------------------------------------===//
@@ -4120,6 +4538,20 @@ BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
return indexingMaps;
}
+bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
+ (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
+ (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
+}
+
SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
return SmallVector<utils::IteratorType>{
utils::IteratorType::parallel, utils::IteratorType::parallel,
@@ -5042,7 +5474,7 @@ PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
/// Returns true if the tiles and the tiled dims are constant.
template <typename OpTy>
-bool areTilesAndTiledDimsAllConstant(OpTy op) {
+static bool areTilesAndTiledDimsAllConstant(OpTy op) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
"applies to only pack or unpack operations");
ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
@@ -5345,11 +5777,18 @@ ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
auto innerDimsPos = getInnerDimsPos();
- auto packedShape = getSourceType().getShape();
+ SmallVector<int64_t> outerDims(getAllOuterDims());
SmallVector<int64_t> res;
+ // Recover the original order of the outer dims.
+ SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
+ invertPermutationVector(outerDimPermInv);
+ if (!outerDimPermInv.empty())
+ applyPermutationToVector(outerDims, outerDimPermInv);
+
+ // Collect the outer dims corresponding to the tilled inner dims.
for (auto index : innerDimsPos)
- res.push_back(packedShape[index]);
+ res.push_back(outerDims[index]);
return res;
}
@@ -5646,6 +6085,19 @@ BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
return indexingMaps;
}
+bool BatchReduceMatmulOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
+ (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
+ (*positions)[2] == SmallVector<int64_t>{1, 2};
+}
unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
std::string BatchReduceMatmulOp::getLibraryCallName() {
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 639e0fe..f0c1f44 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -70,12 +70,7 @@ static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
// We want to discourage direct use of PatternRewriter in APIs but In this
// very specific case, an IRRewriter is not enough.
- struct TrivialPatternRewriter : public PatternRewriter {
- public:
- explicit TrivialPatternRewriter(MLIRContext *context)
- : PatternRewriter(context) {}
- };
- TrivialPatternRewriter rewriter(operation->getContext());
+ PatternRewriter rewriter(operation->getContext());
rewriter.setInsertionPoint(operation);
auto result = pattern.returningMatchAndRewrite(op, rewriter);
if (failed(result))
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 3908d73..6912da3f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -55,8 +55,8 @@ static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp,
// Skip the batch dimension if present.
// Offset all dimensions accordingly.
SmallVector<int64_t, 3> offsetDims(dims);
- for (size_t i = 0; i < offsetDims.size(); i++)
- offsetDims[i] += batchDimsOffset;
+ for (int64_t &offsetDim : offsetDims)
+ offsetDim += batchDimsOffset;
auto tileOp = cast<TilingInterface>(linalgOp.getOperation());
OpBuilder builder(tileOp);
@@ -320,10 +320,6 @@ void linalg::populateBlockPackMatmulPatterns(
RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
patterns.add<BlockPackMatmul<linalg::GenericOp>,
BlockPackMatmul<linalg::MatmulOp>,
- BlockPackMatmul<linalg::BatchMatmulOp>,
- BlockPackMatmul<linalg::MatmulTransposeAOp>,
- BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
- BlockPackMatmul<linalg::MatmulTransposeBOp>,
- BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>(
- patterns.getContext(), controlFn);
+ BlockPackMatmul<linalg::BatchMatmulOp>>(patterns.getContext(),
+ controlFn);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 6ec2e9fd..fb39e186 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -26,7 +26,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MorphOps.cpp
TransposeMatmul.cpp
ShardingInterfaceImpl.cpp
- NamedOpConversions.cpp
+ SimplifyDepthwiseConv.cpp
NamedToElementwise.cpp
BlockPackMatmul.cpp
PackAndUnpackPatterns.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index d1eb270..108abe8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
@@ -50,28 +51,71 @@ static Value createMul(Location loc, Value x, Value y, Type accType,
return arith::MulFOp::create(builder, loc, xConvert, yConvert);
}
-// Delinearizes the given composite `index` by the basis specified in `factors`.
-static SmallVector<Value> unrollIndex(OpBuilder &b, Location loc, Value index,
- ArrayRef<int64_t> factors) {
- assert(!factors.empty() && "empty factor list");
- SmallVector<Value> basis;
- for (int64_t f : factors)
- basis.push_back(arith::ConstantOp::create(b, loc, b.getIndexAttr(f)));
- FailureOr<SmallVector<Value>> multiIndex =
- affine::delinearizeIndex(b, loc, index, basis);
- assert(!failed(multiIndex) && "Failed to linearize img2col index");
- return *multiIndex;
+// Generate the affine expression to compute the convolved index
+// for the input as `oIndex * stride + fIndex`,
+// where oIndex: output iterator; fIndex: filter iterator.
+static AffineExpr getConvolvedExpr(OpBuilder &b, int64_t stride,
+ bool useSymbols = true) {
+ AffineExpr oExpr, fExpr;
+ if (useSymbols)
+ bindSymbols(b.getContext(), oExpr, fExpr);
+ else
+ bindDims(b.getContext(), oExpr, fExpr);
+ return AffineExpr(stride * oExpr + fExpr);
}
-// Given indices corresponding to iterators in the output (oIndex) and filter
-// (fIndex) for a convolution, compute the convolved index for the
-// input as `oIndex * stride + fIndex`.
-static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex,
- Value fIndex, int64_t stride) {
- AffineExpr oExpr, fExpr;
- bindSymbols(b.getContext(), oExpr, fExpr);
- AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr);
- return affine::makeComposedAffineApply(b, loc, convMap, {oIndex, fIndex});
+// Stores the affine expressions to map the iteration space of the im2col matrix
+// to the corresponding indices of the output and filter matrices
+struct Im2ColToOperandsExprs {
+ AffineExpr fhIndex;
+ AffineExpr fwIndex;
+ AffineExpr icIndex;
+ AffineExpr ohIndex;
+ AffineExpr owIndex;
+};
+
+// Stores the affine expressions to map the iteration space of the im2col matrix
+// to the input matrix indices
+struct Im2ColToInputDimsExprs {
+ AffineExpr bIndex;
+ AffineExpr hIndex;
+ AffineExpr wIndex;
+ AffineExpr cIndex;
+};
+
+/// Construct the affine expressions that map the indices of the im2col matrix
+/// to the corresponding input tensor indices for a 2D convolution with the the
+/// provided strides.
+///
+/// @param exprs Affine expressions for output and filter indices.
+/// @param strides [height, width] stride values for the convolution.
+/// @param rewriter Pattern rewriter.
+/// @return Affine expressions mapping im2col matrix indices to input
+/// offsets.
+static Im2ColToInputDimsExprs
+getIm2ColInputExpressions(Im2ColToOperandsExprs exprs,
+ ArrayRef<int64_t> strides, RewriterBase &rewriter) {
+ // maps the iteration space of the im2col matrix to (output_y, filter_y)
+ auto hIndicesMap = AffineMap::inferFromExprList(
+ {ArrayRef{exprs.ohIndex, exprs.fhIndex}}, rewriter.getContext())[0];
+ // maps the iteration space of the im2col matrix to (output_x, filter_x)
+ auto wIndicesMap = AffineMap::inferFromExprList(
+ {ArrayRef{exprs.owIndex, exprs.fwIndex}}, rewriter.getContext())[0];
+ // Compute the input indexing map, to map the indices of the im2col matrix to
+ // the original input offsets. Each element of the im2col matrix corresponds
+ // to a pair of (out_element, filter_element). First, we build the expressions
+ // to compute the input (ix, iy) indices from [out_x/y, filter_x/y] pairs;
+ // then we compose them with the maps that map the im2col matrix elements to
+ // the (out_element, filter_element) pairs.
+ auto bIndexExpr = rewriter.getAffineDimExpr(0U);
+ auto hIndexExpr = getConvolvedExpr(rewriter, strides[0],
+ /*useSymbols*/ false);
+ hIndexExpr = hIndexExpr.compose(hIndicesMap);
+ auto wIndexExpr = getConvolvedExpr(rewriter, strides[1],
+ /*useSymbols*/ false);
+ wIndexExpr = wIndexExpr.compose(wIndicesMap);
+ auto cIndexExpr = exprs.icIndex;
+ return {bIndexExpr, hIndexExpr, wIndexExpr, cIndexExpr};
}
FailureOr<std::pair<Operation *, Operation *>>
@@ -135,44 +179,37 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
auto reduction = utils::IteratorType::reduction;
SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
+ // Given an index of the im2col matrix, retrieve the corresponding indices of
+ // the output and filter matrices
+ auto mIndicesExprs =
+ delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
+ auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
+ ArrayRef<int64_t>{fw * ic, ic, 1});
+ Im2ColToOperandsExprs i2cToOperExprs;
+ i2cToOperExprs.fhIndex = kIndicesExprs[0];
+ i2cToOperExprs.fwIndex = kIndicesExprs[1];
+ i2cToOperExprs.icIndex = kIndicesExprs[2];
+ i2cToOperExprs.ohIndex = mIndicesExprs[0];
+ i2cToOperExprs.owIndex = mIndicesExprs[1];
+
+ // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
+ Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions(
+ i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
+ rewriter);
+ auto inMap =
+ AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.hIndex,
+ inExprs.wIndex, inExprs.cIndex}},
+ rewriter.getContext())[0];
+
SmallVector<AffineMap> img2colIndexingMaps = {
- AffineMap::getMultiDimIdentityMap(nloops, context)};
+ inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
auto img2ColTensor = linalg::GenericOp::create(
rewriter, loc, colTensor.getType(),
- /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
+ /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
img2colIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- // Get the iterators named based on the matmul (batch, m, k).
- Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0);
- Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1);
- Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2);
-
- // Recover the original iteration indices from the problem/input sizes.
- SmallVector<Value> mIndices = unrollIndex(
- nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
- auto ohIndex = mIndices[0];
- auto owIndex = mIndices[1];
-
- SmallVector<Value> kIndices = unrollIndex(
- nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
- auto fhIndex = kIndices[0];
- auto fwIndex = kIndices[1];
- auto icIndex = kIndices[2];
-
- // Extract the input element corresponding to the expanded indices.
- Value hIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
- convOp.getStrides().getValues<int64_t>()[0]);
- Value wIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
- convOp.getStrides().getValues<int64_t>()[1]);
-
- // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
- SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
- Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input,
- extractionIndices);
- linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
});
// Because the filter does not share the same batch dimension,
@@ -421,44 +458,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
auto reduction = utils::IteratorType::reduction;
SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);
- SmallVector<AffineMap, 4> img2colIndexingMaps = {
- AffineMap::getMultiDimIdentityMap(nloops, context)};
+ // Recover the original iteration indices from the problem/input sizes:
+ // given an index of the im2col matrix, retrieve the corresponding indices of
+ // the output and filter matrices
+ auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(1U),
+ ArrayRef<int64_t>{fh * fw, fw, 1});
+ auto mIndicesExprs =
+ delinearize(rewriter.getAffineDimExpr(2U), ArrayRef<int64_t>{ow, 1});
+ Im2ColToOperandsExprs i2cToOperExprs;
+ i2cToOperExprs.icIndex = kIndicesExprs[0];
+ i2cToOperExprs.fhIndex = kIndicesExprs[1];
+ i2cToOperExprs.fwIndex = kIndicesExprs[2];
+ i2cToOperExprs.ohIndex = mIndicesExprs[0];
+ i2cToOperExprs.owIndex = mIndicesExprs[1];
+ Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions(
+ i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
+ rewriter);
+ auto inMap =
+ AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.cIndex,
+ inExprs.hIndex, inExprs.wIndex}},
+ rewriter.getContext())[0];
+ // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
+ SmallVector<AffineMap> img2colIndexingMaps = {
+ inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
auto img2ColTensor = linalg::GenericOp::create(
rewriter, loc, colTensor.getType(),
- /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
+ /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
img2colIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- // Get the iterators named based on the matmul (batch, m, k).
- Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0);
- Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 1);
- Value nIndex = linalg::IndexOp::create(nestedBuilder, loc, 2);
-
- // Recover the original iteration indices from the problem/input sizes.
- SmallVector<Value> kIndices = unrollIndex(
- nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw});
- auto icIndex = kIndices[0];
- auto fhIndex = kIndices[1];
- auto fwIndex = kIndices[2];
-
- SmallVector<Value> nIndices = unrollIndex(
- nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t>{oh, ow});
- auto ohIndex = nIndices[0];
- auto owIndex = nIndices[1];
-
- // Extract the input element corresponding to the expanded indices.
- Value hIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
- convOp.getStrides().getValues<int64_t>()[0]);
- Value wIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
- convOp.getStrides().getValues<int64_t>()[1]);
-
- // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
- SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex};
- Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input,
- extractionIndices);
- linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
});
// Because the filter does not share the same batch dimension,
@@ -545,6 +574,7 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
Value reshapedOutput = tensor::CollapseShapeOp::create(
rewriter, loc, reshapedOutputType, output, outputReassocIndices);
+ // Shape of the Toeplitz matrix produced by Im2col.
SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
inputType.getElementType());
@@ -556,44 +586,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
auto reduction = utils::IteratorType::reduction;
SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
+ // Given an index of the im2col matrix, retrieve the corresponding indices of
+ // the output and filter matrices
+ auto mIndicesExprs =
+ delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
+ auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
+ ArrayRef<int64_t>{fw * ic, ic, 1});
+ Im2ColToOperandsExprs i2cToOperExprs;
+ i2cToOperExprs.fhIndex = kIndicesExprs[0];
+ i2cToOperExprs.fwIndex = kIndicesExprs[1];
+ i2cToOperExprs.icIndex = kIndicesExprs[2];
+ i2cToOperExprs.ohIndex = mIndicesExprs[0];
+ i2cToOperExprs.owIndex = mIndicesExprs[1];
+
+ // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
+ Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions(
+ i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
+ rewriter);
+ auto inMap =
+ AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.hIndex,
+ inExprs.wIndex, inExprs.cIndex}},
+ rewriter.getContext())[0];
SmallVector<AffineMap> img2colIndexingMaps = {
- AffineMap::getMultiDimIdentityMap(nloops, context)};
+ inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
auto img2ColTensor = linalg::GenericOp::create(
rewriter, loc, colTensor.getType(),
- /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
+ /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
img2colIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- // Get the iterators named based on the matmul (batch, m, k).
- Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0);
- Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1);
- Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2);
-
- // Recover the original iteration indices from the problem/input sizes.
- SmallVector<Value> mIndices = unrollIndex(
- nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
- auto ohIndex = mIndices[0];
- auto owIndex = mIndices[1];
-
- SmallVector<Value> kIndices = unrollIndex(
- nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
- auto fhIndex = kIndices[0];
- auto fwIndex = kIndices[1];
- auto icIndex = kIndices[2];
-
- // Extract the input element corresponding to the expanded indices.
- Value hIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
- convOp.getStrides().getValues<int64_t>()[0]);
- Value wIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
- convOp.getStrides().getValues<int64_t>()[1]);
-
- // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
- SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
- Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input,
- extractionIndices);
- linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
});
// Because we didn't transpose the filters we don't actually have a batched
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 76ddee4..2ff7f46 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -75,7 +75,7 @@ static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource,
// layout for best compatibility.
Value toBuffer = bufferization::ToBufferOp::create(
b, loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType),
- tensorSource, /*readOnly=*/true);
+ tensorSource, /*read_only=*/true);
memref::CopyOp::create(b, loc, toBuffer, memrefDest);
} break;
case linalg::BufferizeToAllocationOptions::MemcpyOp::LinalgCopy: {
@@ -84,7 +84,7 @@ static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource,
// layout for best compatibility.
Value toBuffer = bufferization::ToBufferOp::create(
b, loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType),
- tensorSource, /*readOnly=*/true);
+ tensorSource, /*read_only=*/true);
linalg::CopyOp::create(b, loc, toBuffer, memrefDest);
} break;
};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 0a9c176..40085a2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -6,10 +6,12 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Dominance.h"
#include "llvm/ADT/SetOperations.h"
@@ -1236,6 +1238,272 @@ private:
ControlPropagationFn controlFn;
};
+// This struct contains infomation about extract_slice dims.
+struct SliceDimInfo {
+ OpFoldResult offset;
+ OpFoldResult sliceSize;
+ OpFoldResult outputSize;
+};
+
+/// Return the first input extract slice operand, if present, for the current
+/// generic op.
+static FailureOr<OpOperand *> getSliceOperand(GenericOp genericOp) {
+ OpOperand *sliceOperand = nullptr;
+ for (auto operand : genericOp.getDpsInputOperands()) {
+ auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
+ if (!extractOp)
+ continue;
+ sliceOperand = operand;
+ break;
+ }
+ if (!sliceOperand) {
+ return failure();
+ }
+ return sliceOperand;
+}
+
+// Return a map of dims that have partial slices on them so that other operands
+// can use this information. Also return a bool mentioning if a reduction dim
+// has a non full slice as that can be used to fold the original extract slice.
+static FailureOr<llvm::DenseMap<int64_t, SliceDimInfo>>
+getPartialSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand) {
+ tensor::ExtractSliceOp producerSliceOp =
+ sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
+ assert(producerSliceOp && "expect a valid ExtractSliceOp");
+ llvm::DenseMap<int64_t, SliceDimInfo> partialSliceDimMap;
+ SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+
+ SmallVector<OpFoldResult> shape = getAsIndexOpFoldResult(
+ genericOp.getContext(), producerSliceOp.getSourceType().getShape());
+
+ for (auto [idx, expr] : llvm::enumerate(
+ genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
+ // If we have a full slice in a dimension then we dont need to add it to
+ // the partial slice map.
+ if (isConstantIntValue(offsets[idx], 0) &&
+ isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
+ continue;
+ }
+ // We only support partial slices of AffineDimExprs so bail-out if thats not
+ // the case.
+ if (!isa<AffineDimExpr>(expr)) {
+ return failure();
+ }
+ SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
+ int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
+ partialSliceDimMap[dimPos] = sliceDimInfo;
+ }
+ // Next check if the dims with partial slice info are used in non
+ // AffineDimExpr in other operands and if they are then bail-out.
+ for (OpOperand &operand : genericOp->getOpOperands()) {
+ if (operand == *sliceOperand) {
+ continue;
+ }
+ AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
+ if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
+ if (isa<AffineDimExpr>(expr)) {
+ return false;
+ }
+ WalkResult status = expr.walk([&](AffineExpr expr) {
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+ if (partialSliceDimMap.contains(dimExpr.getPosition())) {
+ return WalkResult::interrupt();
+ }
+ }
+ return WalkResult::advance();
+ });
+ if (status.wasInterrupted()) {
+ return true;
+ }
+ return false;
+ })) {
+ return failure();
+ }
+ }
+ return partialSliceDimMap;
+}
+
+static FailureOr<std::tuple<GenericOp, Value>>
+pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
+ GenericOp genericOp,
+ ControlPropagationFn controlFn) {
+ if (genericOp.getNumResults() != 1)
+ return rewriter.notifyMatchFailure(
+ genericOp, "propagation through multi-result generic is unsupported.");
+ if (hasGatherSemantics(genericOp))
+ return rewriter.notifyMatchFailure(
+ genericOp,
+ "propagation through generic with gather semantics is unsupported.");
+ // Collect the sliced operand, if present.
+ auto maybeSliceOperand = getSliceOperand(genericOp);
+ if (failed(maybeSliceOperand))
+ return failure();
+ OpOperand *sliceOperand = *maybeSliceOperand;
+ unsigned OperandIndex = sliceOperand->getOperandNumber();
+
+ if (!controlFn(sliceOperand))
+ return failure();
+
+ tensor::ExtractSliceOp producerSliceOp =
+ sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
+ assert(producerSliceOp && "expect a valid ExtractSliceOp");
+
+ if (producerSliceOp.getSource().getType().getRank() !=
+ producerSliceOp.getResult().getType().getRank()) {
+ return rewriter.notifyMatchFailure(
+ genericOp,
+ "propagation of rank-reducing extract slice is unsupported.");
+ }
+
+ SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
+ if (!areAllConstantIntValue(strides, 1))
+ return rewriter.notifyMatchFailure(
+ genericOp, "propagation of strided extract slice is unsupported.");
+
+ // check if we can support the propagation of this extractSlice
+ // through the generic op and if so return the dimensions that
+
+ auto maybePartialSliceDimMap =
+ getPartialSliceDimInfo(genericOp, sliceOperand);
+
+ if (failed(maybePartialSliceDimMap)) {
+ return failure();
+ }
+
+ auto partialSliceDimMap = *maybePartialSliceDimMap;
+
+ SmallVector<utils::IteratorType> iterators =
+ genericOp.getIteratorTypesArray();
+ bool hasPartialReductionDimSlice =
+ llvm::any_of(partialSliceDimMap, [&](const auto &slice) {
+ int64_t sliceDim = slice.first;
+ return iterators[sliceDim] == utils::IteratorType::reduction;
+ });
+
+ // Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
+ Location loc = genericOp->getLoc();
+ AffineExpr dim0, dim1;
+ bindDims(rewriter.getContext(), dim0, dim1);
+ auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
+ auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
+ return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap,
+ {v1, v2});
+ };
+
+ MLIRContext *ctx = genericOp.getContext();
+ SmallVector<Value> paddedInputs;
+ for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+ if (idx == OperandIndex && !hasPartialReductionDimSlice) {
+ paddedInputs.push_back(producerSliceOp.getSource());
+ continue;
+ }
+ AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
+ SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 0));
+ SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 0));
+ for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) {
+ if (!isa<AffineDimExpr>(expr)) {
+ continue;
+ }
+ AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+ if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
+ continue;
+ }
+ SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
+ operandLowPads[idx] = sliceDimInfo.offset;
+ operandHighPads[idx] =
+ sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+ sliceDimInfo.sliceSize);
+ }
+ auto paddingValue = ub::PoisonOp::create(
+ rewriter, loc, getElementTypeOrSelf(operand->get().getType()));
+ auto paddedOperand = tensor::PadOp::create(
+ rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads,
+ paddingValue, /*nofold=*/false);
+ paddedInputs.push_back(paddedOperand);
+ }
+ AffineMap outputIndexingMap =
+ genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
+
+ auto outputShapeType =
+ llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
+ SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector(
+ outputShapeType.getShape(),
+ [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); });
+ SmallVector<OpFoldResult> newSizes = OutputShape;
+ SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 0));
+ SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 0));
+ SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 1));
+ for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) {
+ if (!isa<AffineDimExpr>(expr)) {
+ continue;
+ }
+ AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+ if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
+ continue;
+ }
+ SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
+ outputLowPads[idx] = sliceDimInfo.offset;
+ outputHighPads[idx] = sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+ sliceDimInfo.sliceSize);
+ OutputShape[idx] = sliceDimInfo.outputSize;
+ newSizes[idx] = sliceDimInfo.sliceSize;
+ }
+ Value newPadOutput;
+ auto outputElType =
+ getElementTypeOrSelf(genericOp.getDpsInits()[0].getType());
+ if (isGenericOutsNotUsed(genericOp)) {
+ newPadOutput =
+ tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
+ } else {
+ auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
+ newPadOutput = tensor::PadOp::create(
+ rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads,
+ outputHighPads, paddingValue, /*nofold=*/false);
+ }
+
+ auto newGenericOp = linalg::GenericOp::create(
+ rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput},
+ genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
+ /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
+ rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
+ newGenericOp.getRegion().begin());
+
+ auto extractOp = tensor::ExtractSliceOp::create(
+ rewriter, loc,
+ newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
+ outputLowPads, newSizes, newStrides);
+ Value extractRes = extractOp.getResult();
+
+ return std::make_tuple(newGenericOp, extractRes);
+}
+
+class PushDownExtractSliceOpThroughGenericOp final
+ : public OpRewritePattern<GenericOp> {
+public:
+ PushDownExtractSliceOpThroughGenericOp(MLIRContext *context,
+ ControlPropagationFn fun)
+ : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
+
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ auto genericAndRepl =
+ pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
+ if (failed(genericAndRepl))
+ return failure();
+ rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
+ return success();
+ }
+
+private:
+ ControlPropagationFn controlFn;
+};
+
} // namespace
void mlir::linalg::populateDataLayoutPropagationPatterns(
@@ -1247,3 +1515,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
patterns.getContext(), controlPackUnPackPropagation);
}
+
+void mlir::linalg::populateExtractSliceSinkingPatterns(
+ RewritePatternSet &patterns,
+ const ControlPropagationFn &controlPackUnPackPropagation) {
+ patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
+ patterns.getContext(), controlPackUnPackPropagation);
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index bf66ed0..22690da 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -691,9 +691,9 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
auto newResultType = RankedTensorType::get(
newResultShape, padOp.getResultType().getElementType());
- auto newPadOp = rewriter.create<tensor::PadOp>(
- padOp.getLoc(), /*result=*/newResultType, collapsedSource, newLowPad,
- newHighPad, paddingVal, padOp.getNofold());
+ auto newPadOp = tensor::PadOp::create(
+ rewriter, padOp.getLoc(), /*result=*/newResultType, collapsedSource,
+ newLowPad, newHighPad, paddingVal, padOp.getNofold());
Value dest = padOp.getResult();
if (options.rankReductionStrategy ==
@@ -1052,12 +1052,8 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
static bool constexpr reduceLeft =
(std::is_same_v<FromOpTy, BatchMatmulOp> &&
std::is_same_v<ToOpTy, BatchVecmatOp>) ||
- (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
- std::is_same_v<ToOpTy, BatchVecmatOp>) ||
(std::is_same_v<FromOpTy, MatmulOp> &&
std::is_same_v<ToOpTy, VecmatOp>) ||
- (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
- std::is_same_v<ToOpTy, VecmatOp>) ||
(std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
/// Look for non-batch spatial dims to collapse.
@@ -1113,27 +1109,15 @@ void mlir::linalg::populateContractionOpRankReducingPatterns(
MLIRContext *context = patterns.getContext();
// Unbatching patterns for unit batch size
patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
- patterns
- .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
- context);
- patterns
- .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
- context);
patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
// Non-batch rank 1 reducing patterns
patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
- patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
- patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
// Batch rank 1 reducing patterns
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
- patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
- context);
- patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
- context);
// Non-batch rank 0 reducing patterns
patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index c523153..baf4083 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -20,13 +20,26 @@ namespace mlir {
using namespace mlir;
+static inline bool isScalarLike(Type t) {
+ return isa<IntegerType, FloatType, IndexType, ComplexType>(t);
+}
+
static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
if (!OpTrait::hasElementwiseMappableTraits(op))
return false;
- // TODO: The conversion pattern can be made to work for `any_of` here, but
- // it's more complex as it requires tracking which operands are scalars.
- return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
+ auto types = op->getOperandTypes();
+
+ // We want at least one ranked tensor.
+ bool anyRankedTensor = llvm::any_of(types, llvm::IsaPred<RankedTensorType>);
+
+ // No invalid operands (i.e., every operand is a ranked tensor or
+ // scalar-like).
+ bool noneInvalid = llvm::none_of(types, [](Type t) {
+ return !(isa<RankedTensorType>(t) || isScalarLike(t));
+ });
+
+ return anyRankedTensor && noneInvalid;
}
/// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
@@ -81,13 +94,41 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
return rewriter.notifyMatchFailure(
op, "requires elementwise op on ranked tensors");
- auto rank = cast<RankedTensorType>(op->getResult(0).getType()).getRank();
- SmallVector<AffineMap, 3> indexingMaps(
- op->getNumResults() + op->getNumOperands(),
- rewriter.getMultiDimIdentityMap(rank));
- SmallVector<utils::IteratorType, 6> iteratorTypes(
+ auto resTy = cast<RankedTensorType>(op->getResult(0).getType());
+ auto rank = resTy.getRank();
+
+ // Maps: identity for tensors (rank > 0), scalar map for scalars.
+ AffineMap scalarMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0,
+ /*results=*/{}, rewriter.getContext());
+ AffineMap idMap = rewriter.getMultiDimIdentityMap(rank);
+
+ // Match phase.
+ SmallVector<bool> isScalarOperand;
+ isScalarOperand.reserve(op->getNumOperands());
+ for (Type ty : op->getOperandTypes()) {
+ if (isScalarLike(ty))
+ isScalarOperand.push_back(true);
+ else if (auto rt = dyn_cast<RankedTensorType>(ty))
+ isScalarOperand.push_back(false);
+ else
+ return rewriter.notifyMatchFailure(
+ op,
+ "unsupported operand type (expected scalar-like or ranked tensor)");
+ }
+
+ // Create indexing maps.
+ SmallVector<AffineMap> indexingMaps;
+ indexingMaps.reserve(op->getNumOperands() + op->getNumResults());
+
+ for (bool isScalar : isScalarOperand)
+ indexingMaps.push_back(isScalar ? scalarMap : idMap);
+
+ indexingMaps.append(op->getNumResults(), idMap);
+
+ SmallVector<utils::IteratorType> iteratorTypes(
rank, utils::IteratorType::parallel);
- auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op);
+ SmallVector<Value> outputs =
+ getOrCreateOperandsMatchingResultTypes(rewriter, op);
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, /*resultTensorTypes=*/op->getResultTypes(),
/*inputs=*/op->getOperands(),
@@ -96,14 +137,14 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
/*iteratorTypes=*/iteratorTypes,
/*bodyBuilder=*/
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
- auto resultTypes = llvm::to_vector<6>(
+ SmallVector<Type> resultEltTys = llvm::to_vector<6>(
llvm::map_range(op->getResultTypes(), [](Type type) {
return cast<TensorType>(type).getElementType();
}));
- auto *scalarOp =
+ Operation *scalarOp =
builder.create(loc, op->getName().getIdentifier(),
regionArgs.take_front(op->getNumOperands()),
- resultTypes, op->getAttrs());
+ resultEltTys, op->getAttrs());
linalg::YieldOp::create(builder, loc, scalarOp->getResults());
});
return success();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index fd530f2..9436f1c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -594,7 +594,8 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl(
auto clonedForOp = scf::ForOp::create(
rewriter, loc, bvm.lookupOrDefault(forOp.getLowerBound()),
bvm.lookupOrDefault(forOp.getUpperBound()),
- bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor);
+ bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor,
+ /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
// Map the induction var, region args and results to the `clonedForOp`.
bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 58986a6..36434cf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -55,7 +55,8 @@ static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter,
scf::ForOp newLoop = scf::ForOp::create(
rewriter, loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(),
- loop.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {});
+ loop.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {},
+ loop.getUnsignedCmp());
// Generate the new yield with the replaced operand.
auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
@@ -165,8 +166,12 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
Value source = transferRead.getBase();
// Skip view-like Ops and retrive the actual soruce Operation
- while (auto srcOp = source.getDefiningOp<ViewLikeOpInterface>())
- source = srcOp.getViewSource();
+ while (auto viewLike = source.getDefiningOp<ViewLikeOpInterface>()) {
+ if (viewLike.getViewDest() != source) {
+ break;
+ }
+ source = viewLike.getViewSource();
+ }
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
source.getUsers().end());
@@ -177,7 +182,8 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
if (!processed.insert(user).second)
continue;
if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
- users.append(viewLike->getUsers().begin(), viewLike->getUsers().end());
+ Value viewDest = viewLike.getViewDest();
+ users.append(viewDest.getUsers().begin(), viewDest.getUsers().end());
continue;
}
if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 3d12bc3..8942670 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -263,11 +263,11 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
paddingValue, /*nofold=*/false, dynDims);
}
-FailureOr<TilingInterface>
-linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
- const PadTilingInterfaceOptions &constOptions,
- SmallVector<tensor::PadOp> &padOps,
- PadSizeComputationFunction computePaddingSizeFun) {
+FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
+ RewriterBase &rewriter, TilingInterface opToPad,
+ const PadTilingInterfaceOptions &constOptions,
+ SmallVector<tensor::PadOp> &padOps,
+ const PadSizeComputationFunction &computePaddingSizeFun) {
LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n");
Location loc = opToPad.getLoc();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp
index a2bd9d9..27ccf3c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp
@@ -21,7 +21,7 @@
#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
-#define GEN_PASS_DEF_LINALGNAMEDOPCONVERSIONPASS
+#define GEN_PASS_DEF_SIMPLIFYDEPTHWISECONVPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir
@@ -143,23 +143,22 @@ struct SimplifyDepthwiseConvQOp
}
};
-struct LinalgNamedOpConversionPass
- : public impl::LinalgNamedOpConversionPassBase<
- LinalgNamedOpConversionPass> {
- using impl::LinalgNamedOpConversionPassBase<
- LinalgNamedOpConversionPass>::LinalgNamedOpConversionPassBase;
+struct SimplifyDepthwiseConvPass
+ : public impl::SimplifyDepthwiseConvPassBase<SimplifyDepthwiseConvPass> {
+ using impl::SimplifyDepthwiseConvPassBase<
+ SimplifyDepthwiseConvPass>::SimplifyDepthwiseConvPassBase;
void runOnOperation() override {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
- populateLinalgNamedOpConversionPatterns(patterns);
+ populateSimplifyDepthwiseConvPatterns(patterns);
if (failed(applyPatternsGreedily(op, std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
-void mlir::linalg::populateLinalgNamedOpConversionPatterns(
+void mlir::linalg::populateSimplifyDepthwiseConvPatterns(
RewritePatternSet &patterns) {
patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>(
patterns.getContext());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 455e1a6..35ba4f15 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -234,19 +234,8 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
/// Codegen the different matmul variants.
if (numOfBatchDims) {
- if (a == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
- genericOp);
- if (b == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
- genericOp);
return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
}
-
- if (a == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
- if (b == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index bb725f2..e9a8b25 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -29,6 +29,7 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/InterleavedRange.h"
#include "llvm/Support/raw_ostream.h"
#include <utility>
@@ -38,9 +39,6 @@
using namespace mlir;
using namespace mlir::linalg;
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
-
//===----------------------------------------------------------------------===//
// Transformations exposed as functional-style API calls.
//===----------------------------------------------------------------------===//
@@ -91,11 +89,11 @@ static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) {
}
return true;
}
+#endif // NDEBUG
static std::string stringifyReassocIndices(ReassociationIndicesRef ri) {
return llvm::interleaved(ri, ", ", /*Prefix=*/"|", /*Suffix=*/"");
}
-#endif // NDEBUG
/// Return the index of the first result of `map` that is a function of
/// AffineDimExpr(dim), std::nullopt otherwise.
@@ -276,23 +274,18 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
tensor::PadOp::create(rewriter, loc, collapsed, packOp.getSource(), lows,
highs, paddingValue, /*nofold=*/false);
- LLVM_DEBUG(
- DBGSNL(); DBGSNL();
- DBGS() << "insertPositions: "
- << llvm::interleaved(packingMetadata.insertPositions);
- DBGSNL(); DBGS() << "outerPositions: "
- << llvm::interleaved(packingMetadata.outerPositions);
- DBGSNL(); DBGS() << "packedShape: "
- << llvm::interleaved(packedTensorType.getShape());
- DBGSNL(); DBGS() << "packedToStripMinedShapePerm: "
- << llvm::interleaved(packedToStripMinedShapePerm);
- DBGSNL();
- DBGS() << "reassociations: "
- << llvm::interleaved(llvm::map_range(
- packingMetadata.reassociations, stringifyReassocIndices));
- DBGSNL();
- DBGS() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
- DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
+ LDBG() << "insertPositions: "
+ << llvm::interleaved(packingMetadata.insertPositions);
+ LDBG() << "outerPositions: "
+ << llvm::interleaved(packingMetadata.outerPositions);
+ LDBG() << "packedShape: " << llvm::interleaved(packedTensorType.getShape());
+ LDBG() << "packedToStripMinedShapePerm: "
+ << llvm::interleaved(packedToStripMinedShapePerm);
+ LDBG() << "reassociations: "
+ << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
+ stringifyReassocIndices));
+ LDBG() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
+ LDBG() << "collapsed type: " << collapsed;
if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
// Pack ops which operate as simple pads may not produce legal
@@ -317,7 +310,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
rewriter, loc, /*source=*/padOp, /*dest=*/packOp.getDest(),
/*offsets=*/zeros, sizes, /*strides=*/ones);
- LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
+ LDBG() << "insert_slice op: " << insertSliceOp;
rewriter.replaceOp(packOp, insertSliceOp->getResults());
@@ -339,10 +332,9 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
auto transposeOp = linalg::TransposeOp::create(
rewriter, loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
- LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
- DBGS() << "reshape op: " << reshapeOp; DBGSNL();
- DBGS() << "transpPerm: " << llvm::interleaved(transpPerm);
- DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
+ LDBG() << "reshape op: " << reshapeOp;
+ LDBG() << "transpPerm: " << llvm::interleaved(transpPerm);
+ LDBG() << "transpose op: " << transposeOp;
// 7. Replace packOp by transposeOp.
rewriter.replaceOp(packOp, transposeOp->getResults());
@@ -410,21 +402,16 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
linalg::TransposeOp::create(rewriter, loc, unPackOp.getSource(), emptyOp,
packedToStripMinedShapePerm);
- LLVM_DEBUG(
- DBGSNL(); DBGSNL();
- DBGS() << "insertPositions: "
- << llvm::interleaved(packingMetadata.insertPositions);
- DBGSNL(); DBGS() << "packedShape: "
- << llvm::interleaved(packedTensorType.getShape());
- DBGSNL(); DBGS() << "packedToStripMinedShapePerm: "
- << llvm::interleaved(packedToStripMinedShapePerm);
- DBGSNL();
- DBGS() << "reassociations: "
- << llvm::interleaved(llvm::map_range(
- packingMetadata.reassociations, stringifyReassocIndices));
- DBGSNL();
- DBGS() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
- DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););
+ LDBG() << "insertPositions: "
+ << llvm::interleaved(packingMetadata.insertPositions);
+ LDBG() << "packedShape: " << llvm::interleaved(packedTensorType.getShape());
+ LDBG() << "packedToStripMinedShapePerm: "
+ << llvm::interleaved(packedToStripMinedShapePerm);
+ LDBG() << "reassociations: "
+ << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
+ stringifyReassocIndices));
+ LDBG() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
+ LDBG() << "collapsed type: " << collapsedType;
// 4. Collapse from the stripMinedShape to the padded result.
auto reshapeOp = tensor::CollapseShapeOp::create(
@@ -486,10 +473,9 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
SmallVector<utils::IteratorType> iteratorTypes =
linalgOp.getIteratorTypesArray();
- LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n"
- << "maps: " << llvm::interleaved(indexingMaps) << "\n"
- << "iterators: " << llvm::interleaved(iteratorTypes)
- << "\n");
+ LDBG() << "Start packing: " << linalgOp;
+ LDBG() << "maps: " << llvm::interleaved(indexingMaps);
+ LDBG() << "iterators: " << llvm::interleaved(iteratorTypes);
SmallVector<linalg::PackOp> packOps;
SmallVector<linalg::UnPackOp> unPackOps;
@@ -511,14 +497,11 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
- LLVM_DEBUG(
- DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i]
- << "\n"
- << "maps: " << llvm::interleaved(indexingMaps) << "\n"
- << "iterators: " << llvm::interleaved(iteratorTypes) << "\n"
- << "packedDimForEachOperand: "
- << llvm::interleaved(packedOperandsDims.packedDimForEachOperand)
- << "\n");
+ LDBG() << "++++ After pack size #" << i << ": " << packedSizes[i];
+ LDBG() << "maps: " << llvm::interleaved(indexingMaps);
+ LDBG() << "iterators: " << llvm::interleaved(iteratorTypes);
+ LDBG() << "packedDimForEachOperand: "
+ << llvm::interleaved(packedOperandsDims.packedDimForEachOperand);
}
// Step 2. Propagate packing to all LinalgOp operands.
@@ -534,10 +517,9 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
SmallVector<OpFoldResult> innerPackSizes =
listOfPackedOperandsDim.extractPackSizesForOperand(pos);
- LLVM_DEBUG(DBGS() << "operand: " << operand << "\n"
- << "innerPos: " << llvm::interleaved(innerPos) << "\n"
- << "innerPackSizes: "
- << llvm::interleaved(innerPackSizes) << "\n");
+ LDBG() << "operand: " << operand;
+ LDBG() << "innerPos: " << llvm::interleaved(innerPos);
+ LDBG() << "innerPackSizes: " << llvm::interleaved(innerPackSizes);
if (innerPackSizes.empty()) {
inputsAndInits.push_back(operand);
continue;
@@ -776,8 +758,8 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
int64_t numLoops = linalgOp.getNumLoops();
if (numLoops <= 2) {
- LLVM_DEBUG(DBGS() << "need 3+ loops to find a matmul to pack, got "
- << numLoops << "\nin: " << linalgOp << "\n");
+ LDBG() << "need 3+ loops to find a matmul to pack, got " << numLoops
+ << " in: " << linalgOp;
return rewriter.notifyMatchFailure(
linalgOp, "need 3+ loops to find a matmul to pack");
}
@@ -801,8 +783,7 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
FailureOr<ContractionDimensions> maybeDimensions =
inferContractionDims(linalgOp);
if (failed(maybeDimensions)) {
- LLVM_DEBUG(DBGS() << "couldn't infer matmul iterators in: " << linalgOp
- << "\n");
+ LDBG() << "couldn't infer matmul iterators in: " << linalgOp;
return rewriter.notifyMatchFailure(linalgOp,
"couldn't infer matmul iterators");
}
@@ -814,10 +795,8 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
// to plug a heuristic.
int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
kPos = maybeDimensions->k.back();
- LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
- DBGS() << "Start packing generic op greedily with (m@" << mPos
- << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp
- << "\n";);
+ LDBG() << "Start packing generic op greedily with (m@" << mPos << ", n@"
+ << nPos << ", k@" << kPos << "): " << linalgOp;
// 2.a. Rewrite as a generic.
auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
@@ -833,14 +812,14 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
// not change the indexings of any operand.
SmallVector<int64_t> permutation =
computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos);
- LLVM_DEBUG(DBGS() << "perm: " << llvm::interleaved(permutation) << "\n");
+ LDBG() << "perm: " << llvm::interleaved(permutation);
// Sign .. unsigned pollution.
SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end());
FailureOr<GenericOp> interchangeResult =
interchangeGenericOp(rewriter, genericOp, unsignedPerm);
assert(succeeded(interchangeResult) && "unexpected failure interchanging op");
genericOp = *interchangeResult;
- LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n";);
+ LDBG() << "Generalized Op to pack: " << genericOp;
// At this point, the op iterators are normalized to {leading, k, m, n}.
// The layouts induced by packing will always be:
@@ -862,12 +841,11 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
// Add leading zeros to match numLoops, we only pack the last 3 dimensions
// post interchange.
- LLVM_DEBUG(DBGS() << "paddedSizesNextMultipleOf: "
- << llvm::interleaved(paddedSizesNextMultipleOf) << "\n"
- << "loopRanges: "
- << llvm::interleaved(llvm::map_range(
- loopRanges, [](Range r) { return r.size; }))
- << "\n");
+ LDBG() << "paddedSizesNextMultipleOf: "
+ << llvm::interleaved(paddedSizesNextMultipleOf);
+ LDBG() << "loopRanges: "
+ << llvm::interleaved(
+ llvm::map_range(loopRanges, [](Range r) { return r.size; }));
SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(),
rewriter.getIndexAttr(0));
for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
@@ -883,8 +861,7 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
{loopRanges[adjustedPackedSizes.size()].size,
rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
}
- LLVM_DEBUG(DBGS() << "adjustedPackedSizes: "
- << llvm::interleaved(adjustedPackedSizes) << "\n");
+ LDBG() << "adjustedPackedSizes: " << llvm::interleaved(adjustedPackedSizes);
// TODO: If we wanted to give the genericOp a name after packing, after
// calling `pack` would be a good time. One would still need to check that
@@ -1214,9 +1191,8 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
}
srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end());
- LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"
- << "perm: " << llvm::interleaved(srcPermForTranspose)
- << "\n");
+ LDBG() << "Pack permutation: " << packOp;
+ LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose);
// 2.1 Create tensor.empty (init value for TransposeOp)
SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index a2a4335..2650488 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -59,12 +59,12 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
ArrayRef<int64_t>{1, 0});
Operation *newMatmulOp;
if (transposeLHS) {
- newMatmulOp = linalg::MatmulTransposeAOp::create(
+ newMatmulOp = MatmulTransposeAOp::create(
rewriter, loc, matmulOp.getResultTypes(),
ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
matmulOp.getOutputs());
} else {
- newMatmulOp = linalg::MatmulTransposeBOp::create(
+ newMatmulOp = MatmulTransposeBOp::create(
rewriter, loc, matmulOp.getResultTypes(),
ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
matmulOp.getOutputs());
@@ -116,12 +116,12 @@ mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
ArrayRef<int64_t>{0, 2, 1});
Operation *newMatmulOp;
if (transposeLHS) {
- newMatmulOp = linalg::BatchMatmulTransposeAOp::create(
+ newMatmulOp = BatchMatmulTransposeAOp::create(
rewriter, loc, batchMatmulOp.getResultTypes(),
ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
batchMatmulOp.getOutputs());
} else {
- newMatmulOp = linalg::BatchMatmulTransposeBOp::create(
+ newMatmulOp = BatchMatmulTransposeBOp::create(
rewriter, loc, batchMatmulOp.getResultTypes(),
ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
batchMatmulOp.getOutputs());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index cf65e67..406f05c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2563,7 +2563,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
"vectorization";
return failure();
}
- if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
+ if (isa<linalg::MatmulOp>(op)) {
LDBG()
<< "Scalable vectorization of the reduction dim in Matmul-like ops "
"is not supported";
@@ -2604,17 +2604,12 @@ vectorizeScalableVectorPrecondition(Operation *op,
return failure();
}
- // Check to not let go the matmul with extended semantic, through this
- // transform.
- if (linalgOp.hasUserDefinedMaps())
- return failure();
-
// Cond 4: Only the following ops are supported in the
// presence of scalable vectors
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
- isa<linalg::MatmulTransposeAOp>(op) ||
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
+ isa<linalg::BatchMmt4DOp>(op) ||
hasReductionIterator(linalgOp));
}
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index e1c0c24..d37a056 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -1,6 +1,6 @@
add_mlir_dialect_library(MLIRMathTransforms
AlgebraicSimplification.cpp
- ExpandPatterns.cpp
+ ExpandOps.cpp
ExtendToSupportedTypes.cpp
PolynomialApproximation.cpp
UpliftToFMA.cpp
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index 4a40a30..cd68039 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -13,14 +13,18 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHEXPANDOPSPASS
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
/// Create a float constant.
static Value createFloatConst(Location loc, Type type, APFloat value,
OpBuilder &b) {
@@ -661,66 +665,77 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op,
return success();
}
-void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
- patterns.add(convertCtlzOp);
-}
-
-void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) {
- patterns.add(convertSinhOp);
-}
-
-void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) {
- patterns.add(convertCoshOp);
-}
-
-void mlir::populateExpandTanPattern(RewritePatternSet &patterns) {
- patterns.add(convertTanOp);
-}
-
-void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
- patterns.add(convertTanhOp);
-}
-
-void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) {
- patterns.add(convertAsinhOp);
-}
-
-void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) {
- patterns.add(convertAcoshOp);
-}
-
-void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) {
- patterns.add(convertAtanhOp);
-}
-
-void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
- patterns.add(convertFmaFOp);
-}
-
-void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) {
- patterns.add(convertCeilOp);
-}
-
-void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
- patterns.add(convertExp2fOp);
-}
-
-void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
- patterns.add(convertPowfOp);
-}
-
-void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
- patterns.add(convertFPowIOp);
-}
-
-void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
- patterns.add(convertRoundOp);
+// Convert `math.clampf` into `arith.minimumf` + `arith.maximumf`
+static LogicalResult convertClampfOp(math::ClampFOp op,
+ PatternRewriter &rewriter) {
+ auto minOp = arith::MinimumFOp::create(rewriter, op.getLoc(), op.getValue(),
+ op.getMin(), op.getFastmath());
+ rewriter.replaceOpWithNewOp<arith::MaximumFOp>(op, minOp, op.getMax(),
+ op.getFastmath());
+ return success();
}
-void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) {
- patterns.add(convertRoundEvenOp);
+void mlir::math::populateExpansionPatterns(RewritePatternSet &patterns,
+ ArrayRef<StringRef> opMnemonics) {
+ auto filter = [&](StringRef name) {
+ // This should be a static assert and `consume_front` take a twine, but none
+ // is currently possible. TODO: augment `StringRef::consume_front` and make
+ // `getDialectNamespace` use `std::string_view`.
+ assert("math" == MathDialect::getDialectNamespace());
+ name.consume_front("math.");
+ return opMnemonics.empty() || (llvm::count(opMnemonics, name) > 0);
+ };
+ if (filter(CountLeadingZerosOp::getOperationName()))
+ patterns.add(convertCtlzOp);
+ if (filter(SinhOp::getOperationName()))
+ patterns.add(convertSinhOp);
+ if (filter(CoshOp::getOperationName()))
+ patterns.add(convertCoshOp);
+ if (filter(TanOp::getOperationName()))
+ patterns.add(convertTanOp);
+ if (filter(TanhOp::getOperationName()))
+ patterns.add(convertTanhOp);
+ if (filter(AsinhOp::getOperationName()))
+ patterns.add(convertAsinhOp);
+ if (filter(AcoshOp::getOperationName()))
+ patterns.add(convertAcoshOp);
+ if (filter(AtanhOp::getOperationName()))
+ patterns.add(convertAtanhOp);
+ if (filter(FmaOp::getOperationName()))
+ patterns.add(convertFmaFOp);
+ if (filter(CeilOp::getOperationName()))
+ patterns.add(convertCeilOp);
+ if (filter(Exp2Op::getOperationName()))
+ patterns.add(convertExp2fOp);
+ if (filter(PowFOp::getOperationName()))
+ patterns.add(convertPowfOp);
+ if (filter(FPowIOp::getOperationName()))
+ patterns.add(convertFPowIOp);
+ if (filter(RoundOp::getOperationName()))
+ patterns.add(convertRoundOp);
+ if (filter(RoundEvenOp::getOperationName()))
+ patterns.add(convertRoundEvenOp);
+ if (filter(RsqrtOp::getOperationName()))
+ patterns.add(convertRsqrtOp);
+ if (filter(ClampFOp::getOperationName()))
+ patterns.add(convertClampfOp);
}
-void mlir::populateExpandRsqrtPattern(RewritePatternSet &patterns) {
- patterns.add(convertRsqrtOp);
-}
+//===----------------------------------------------------------------------===//
+// MathExpandOpsPass pass
+//===----------------------------------------------------------------------===//
+namespace {
+struct MathExpandOpsPass final
+ : math::impl::MathExpandOpsPassBase<MathExpandOpsPass> {
+ using MathExpandOpsPassBase::MathExpandOpsPassBase;
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ SmallVector<StringRef> mnemonics =
+ llvm::to_vector_of<StringRef>(opMnemonics);
+ math::populateExpansionPatterns(patterns, mnemonics);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 74b968c..b59d73d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3558,6 +3558,7 @@ LogicalResult AtomicRMWOp::verify() {
case arith::AtomicRMWKind::minu:
case arith::AtomicRMWKind::muli:
case arith::AtomicRMWKind::ori:
+ case arith::AtomicRMWKind::xori:
case arith::AtomicRMWKind::andi:
if (!llvm::isa<IntegerType>(getValue().getType()))
return emitOpError() << "with kind '"
diff --git a/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp
index bbb269b..1939195 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp
@@ -21,9 +21,9 @@ namespace {
struct ReallocOpInterface
: public BufferViewFlowOpInterface::ExternalModel<ReallocOpInterface,
ReallocOp> {
- void
- populateDependencies(Operation *op,
- RegisterDependenciesFn registerDependenciesFn) const {
+ void populateDependencies(
+ Operation *op,
+ const RegisterDependenciesFn &registerDependenciesFn) const {
auto reallocOp = cast<ReallocOp>(op);
// memref.realloc may return the source operand.
registerDependenciesFn(reallocOp.getSource(), reallocOp.getResult());
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 9771bd2..d35566a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -959,7 +959,7 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
PatternRewriter &rewriter) const override {
auto viewLikeOp =
extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
- if (!viewLikeOp)
+ if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest())
return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source");
rewriter.modifyOpInPlace(extractOp, [&]() {
extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index 5d3cec4..860384f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -43,50 +43,34 @@ static bool overrideBuffer(Operation *op, Value buffer) {
/// propagate the type change and erase old subview ops.
static void replaceUsesAndPropagateType(RewriterBase &rewriter,
Operation *oldOp, Value val) {
- SmallVector<Operation *> opsToDelete;
- SmallVector<OpOperand *> operandsToReplace;
-
- // Save the operand to replace / delete later (avoid iterator invalidation).
- // TODO: can we use an early_inc iterator?
- for (OpOperand &use : oldOp->getUses()) {
- // Non-subview ops will be replaced by `val`.
- auto subviewUse = dyn_cast<memref::SubViewOp>(use.getOwner());
- if (!subviewUse) {
- operandsToReplace.push_back(&use);
+ // Iterate with early_inc to erase current user inside the loop.
+ for (OpOperand &use : llvm::make_early_inc_range(oldOp->getUses())) {
+ Operation *user = use.getOwner();
+ if (auto subviewUse = dyn_cast<memref::SubViewOp>(user)) {
+ // `subview(old_op)` is replaced by a new `subview(val)`.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(subviewUse);
+ MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
+ subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
+ subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
+ subviewUse.getStaticStrides());
+ Value newSubview = memref::SubViewOp::create(
+ rewriter, subviewUse->getLoc(), newType, val,
+ subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
+ subviewUse.getMixedStrides());
+
+ // Ouch recursion ... is this really necessary?
+ replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
+
+ // Safe to erase.
+ rewriter.eraseOp(subviewUse);
continue;
}
-
- // `subview(old_op)` is replaced by a new `subview(val)`.
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(subviewUse);
- MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
- subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
- subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
- subviewUse.getStaticStrides());
- Value newSubview = memref::SubViewOp::create(
- rewriter, subviewUse->getLoc(), newType, val,
- subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
- subviewUse.getMixedStrides());
-
- // Ouch recursion ... is this really necessary?
- replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
-
- opsToDelete.push_back(use.getOwner());
+ // Non-subview: replace with new value.
+ rewriter.startOpModification(user);
+ use.set(val);
+ rewriter.finalizeOpModification(user);
}
-
- // Perform late replacement.
- // TODO: can we use an early_inc iterator?
- for (OpOperand *operand : operandsToReplace) {
- Operation *op = operand->getOwner();
- rewriter.startOpModification(op);
- operand->set(val);
- rewriter.finalizeOpModification(op);
- }
-
- // Perform late op erasure.
- // TODO: can we use an early_inc iterator?
- for (Operation *op : opsToDelete)
- rewriter.eraseOp(op);
}
// Transformation to do multi-buffering/array expansion to remove dependencies
@@ -216,8 +200,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
offsets, sizes, strides);
LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");
- // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to
- // handle dealloc uses separately..
+ // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need
+ // to handle dealloc uses separately..
for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) {
auto deallocOp = dyn_cast<memref::DeallocOp>(use.getOwner());
if (!deallocOp)
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index 5af46a4..3de9c38 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -210,8 +210,10 @@ MemrefValue skipFullyAliasingOperations(MemrefValue source) {
MemrefValue skipViewLikeOps(MemrefValue source) {
while (auto op = source.getDefiningOp()) {
if (auto viewLike = dyn_cast<ViewLikeOpInterface>(op)) {
- source = cast<MemrefValue>(viewLike.getViewSource());
- continue;
+ if (source == viewLike.getViewDest()) {
+ source = cast<MemrefValue>(viewLike.getViewSource());
+ continue;
+ }
}
return source;
}
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 34c95e3..8474244 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -422,6 +422,12 @@ std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
<< descMemref << " != " << dstMemref;
}
+ int lastDimBytes =
+ descMemref.getShape().back() * descMemref.getElementTypeBitWidth() / 8;
+ if (lastDimBytes % 16 != 0) {
+ return op->emitError() << "the bytes in the last dimension of the tensor "
+ "map must be a multiple of 16";
+ }
return std::nullopt;
}
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 485bb73..ded4c7a 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -173,9 +173,7 @@ void OpenACCDialect::initialize() {
//===----------------------------------------------------------------------===//
static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
- if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
- return true;
- return false;
+ return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
}
static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
@@ -1390,6 +1388,36 @@ void acc::ParallelOp::addPrivatization(MLIRContext *context,
setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
+void acc::ParallelOp::addFirstPrivatization(
+ MLIRContext *context, mlir::acc::FirstprivateOp op,
+ mlir::acc::FirstprivateRecipeOp recipe) {
+ getFirstprivateOperandsMutable().append(op.getResult());
+
+ llvm::SmallVector<mlir::Attribute> recipes;
+
+ if (getFirstprivatizationRecipesAttr())
+ llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
+
+ recipes.push_back(
+ mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
+ setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
+}
+
+void acc::ParallelOp::addReduction(MLIRContext *context,
+ mlir::acc::ReductionOp op,
+ mlir::acc::ReductionRecipeOp recipe) {
+ getReductionOperandsMutable().append(op.getResult());
+
+ llvm::SmallVector<mlir::Attribute> recipes;
+
+ if (getReductionRecipesAttr())
+ llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
+
+ recipes.push_back(
+ mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
+ setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
+}
+
static ParseResult parseNumGangs(
mlir::OpAsmParser &parser,
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
@@ -2041,6 +2069,36 @@ void acc::SerialOp::addPrivatization(MLIRContext *context,
setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
+void acc::SerialOp::addFirstPrivatization(
+ MLIRContext *context, mlir::acc::FirstprivateOp op,
+ mlir::acc::FirstprivateRecipeOp recipe) {
+ getFirstprivateOperandsMutable().append(op.getResult());
+
+ llvm::SmallVector<mlir::Attribute> recipes;
+
+ if (getFirstprivatizationRecipesAttr())
+ llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
+
+ recipes.push_back(
+ mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
+ setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
+}
+
+void acc::SerialOp::addReduction(MLIRContext *context,
+ mlir::acc::ReductionOp op,
+ mlir::acc::ReductionRecipeOp recipe) {
+ getReductionOperandsMutable().append(op.getResult());
+
+ llvm::SmallVector<mlir::Attribute> recipes;
+
+ if (getReductionRecipesAttr())
+ llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
+
+ recipes.push_back(
+ mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
+ setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
+}
+
//===----------------------------------------------------------------------===//
// KernelsOp
//===----------------------------------------------------------------------===//
@@ -3059,6 +3117,20 @@ void acc::LoopOp::addPrivatization(MLIRContext *context,
setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
+void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op,
+ mlir::acc::ReductionRecipeOp recipe) {
+ getReductionOperandsMutable().append(op.getResult());
+
+ llvm::SmallVector<mlir::Attribute> recipes;
+
+ if (getReductionRecipesAttr())
+ llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
+
+ recipes.push_back(
+ mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
+ setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
+}
+
//===----------------------------------------------------------------------===//
// DataOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index c1c1767..6e43f28 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -3874,6 +3874,159 @@ LogicalResult AllocateDirOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// TargetAllocMemOp
+//===----------------------------------------------------------------------===//
+
+mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
+ return getInTypeAttr().getValue();
+}
+
+/// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype,
+/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
+/// attr-dict-without-keyword
+static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser,
+ mlir::OperationState &result) {
+ auto &builder = parser.getBuilder();
+ bool hasOperands = false;
+ std::int32_t typeparamsSize = 0;
+
+ // Parse device number as a new operand
+ mlir::OpAsmParser::UnresolvedOperand deviceOperand;
+ mlir::Type deviceType;
+ if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType))
+ return mlir::failure();
+ if (parser.resolveOperand(deviceOperand, deviceType, result.operands))
+ return mlir::failure();
+ if (parser.parseComma())
+ return mlir::failure();
+
+ mlir::Type intype;
+ if (parser.parseType(intype))
+ return mlir::failure();
+ result.addAttribute("in_type", mlir::TypeAttr::get(intype));
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands;
+ llvm::SmallVector<mlir::Type> typeVec;
+ if (!parser.parseOptionalLParen()) {
+ // parse the LEN params of the derived type. (<params> : <types>)
+ if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) ||
+ parser.parseColonTypeList(typeVec) || parser.parseRParen())
+ return mlir::failure();
+ typeparamsSize = operands.size();
+ hasOperands = true;
+ }
+ std::int32_t shapeSize = 0;
+ if (!parser.parseOptionalComma()) {
+ // parse size to scale by, vector of n dimensions of type index
+ if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None))
+ return mlir::failure();
+ shapeSize = operands.size() - typeparamsSize;
+ auto idxTy = builder.getIndexType();
+ for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
+ typeVec.push_back(idxTy);
+ hasOperands = true;
+ }
+ if (hasOperands &&
+ parser.resolveOperands(operands, typeVec, parser.getNameLoc(),
+ result.operands))
+ return mlir::failure();
+
+ mlir::Type restype = builder.getIntegerType(64);
+ if (!restype) {
+ parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype;
+ return mlir::failure();
+ }
+ llvm::SmallVector<std::int32_t> segmentSizes{1, typeparamsSize, shapeSize};
+ result.addAttribute("operandSegmentSizes",
+ builder.getDenseI32ArrayAttr(segmentSizes));
+ if (parser.parseOptionalAttrDict(result.attributes) ||
+ parser.addTypeToList(restype, result.types))
+ return mlir::failure();
+ return mlir::success();
+}
+
+mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser,
+ mlir::OperationState &result) {
+ return parseTargetAllocMemOp(parser, result);
+}
+
+void omp::TargetAllocMemOp::print(mlir::OpAsmPrinter &p) {
+ p << " ";
+ p.printOperand(getDevice());
+ p << " : ";
+ p << getDevice().getType();
+ p << ", ";
+ p << getInType();
+ if (!getTypeparams().empty()) {
+ p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')';
+ }
+ for (auto sh : getShape()) {
+ p << ", ";
+ p.printOperand(sh);
+ }
+ p.printOptionalAttrDict((*this)->getAttrs(),
+ {"in_type", "operandSegmentSizes"});
+}
+
+llvm::LogicalResult omp::TargetAllocMemOp::verify() {
+ mlir::Type outType = getType();
+ if (!mlir::dyn_cast<IntegerType>(outType))
+ return emitOpError("must be a integer type");
+ return mlir::success();
+}
+
+//===----------------------------------------------------------------------===//
+// WorkdistributeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WorkdistributeOp::verify() {
+ // Check that region exists and is not empty
+ Region &region = getRegion();
+ if (region.empty())
+ return emitOpError("region cannot be empty");
+ // Verify single entry point.
+ Block &entryBlock = region.front();
+ if (entryBlock.empty())
+ return emitOpError("region must contain a structured block");
+ // Verify single exit point.
+ bool hasTerminator = false;
+ for (Block &block : region) {
+ if (isa<TerminatorOp>(block.back())) {
+ if (hasTerminator) {
+ return emitOpError("region must have exactly one terminator");
+ }
+ hasTerminator = true;
+ }
+ }
+ if (!hasTerminator) {
+ return emitOpError("region must be terminated with omp.terminator");
+ }
+ auto walkResult = region.walk([&](Operation *op) -> WalkResult {
+ // No implicit barrier at end
+ if (isa<BarrierOp>(op)) {
+ return emitOpError(
+ "explicit barriers are not allowed in workdistribute region");
+ }
+ // Check for invalid nested constructs
+ if (isa<ParallelOp>(op)) {
+ return emitOpError(
+ "nested parallel constructs not allowed in workdistribute");
+ }
+ if (isa<TeamsOp>(op)) {
+ return emitOpError(
+ "nested teams constructs not allowed in workdistribute");
+ }
+ return WalkResult::advance();
+ });
+ if (walkResult.wasInterrupted())
+ return failure();
+
+ Operation *parentOp = (*this)->getParentOp();
+ if (!llvm::dyn_cast<TeamsOp>(parentOp))
+ return emitOpError("workdistribute must be nested under teams");
+ return success();
+}
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
index 497468b..bd1e655 100644
--- a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
@@ -1,3 +1,22 @@
+set(LLVM_OPTIONAL_SOURCES
+ MemorySpaceInterfaces.cpp
+ PtrAttrs.cpp
+ PtrTypes.cpp
+ PtrDialect.cpp
+)
+
+add_mlir_dialect_library(
+ MLIRPtrMemorySpaceInterfaces
+ MemorySpaceInterfaces.cpp
+
+ DEPENDS
+ MLIRPtrOpsEnumsGen
+ MLIRPtrMemorySpaceInterfacesIncGen
+ LINK_LIBS
+ PUBLIC
+ MLIRIR
+)
+
add_mlir_dialect_library(
MLIRPtrDialect
PtrAttrs.cpp
@@ -15,4 +34,5 @@ add_mlir_dialect_library(
MLIRDataLayoutInterfaces
MLIRMemorySlotInterfaces
MLIRViewLikeInterface
+ MLIRPtrMemorySpaceInterfaces
)
diff --git a/mlir/lib/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp b/mlir/lib/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp
new file mode 100644
index 0000000..059e67f
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp
@@ -0,0 +1,15 @@
+//===-- MemorySpaceInterfaces.cpp - ptr memory space interfaces -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the ptr dialect memory space interfaces.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
+
+#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.cpp.inc"
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp
index 772d25d..ac3bcd6 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp
@@ -22,26 +22,30 @@ constexpr const static unsigned kBitsInByte = 8;
//===----------------------------------------------------------------------===//
bool GenericSpaceAttr::isValidLoad(
- Type type, ptr::AtomicOrdering ordering, IntegerAttr alignment,
+ Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
+ const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError) const {
return true;
}
bool GenericSpaceAttr::isValidStore(
- Type type, ptr::AtomicOrdering ordering, IntegerAttr alignment,
+ Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
+ const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError) const {
return true;
}
bool GenericSpaceAttr::isValidAtomicOp(
ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering,
- IntegerAttr alignment, function_ref<InFlightDiagnostic()> emitError) const {
+ std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
+ function_ref<InFlightDiagnostic()> emitError) const {
return true;
}
bool GenericSpaceAttr::isValidAtomicXchg(
Type type, ptr::AtomicOrdering successOrdering,
- ptr::AtomicOrdering failureOrdering, IntegerAttr alignment,
+ ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
+ const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError) const {
return true;
}
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index c5ec0ca..d5976b9 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -85,6 +85,124 @@ LogicalResult FromPtrOp::verify() {
}
//===----------------------------------------------------------------------===//
+// LoadOp
+//===----------------------------------------------------------------------===//
+
+/// Verifies the attributes and the type of atomic memory access operations.
+template <typename OpTy>
+static LogicalResult
+verifyAtomicMemOp(OpTy memOp, ArrayRef<AtomicOrdering> unsupportedOrderings) {
+ if (memOp.getOrdering() != AtomicOrdering::not_atomic) {
+ if (llvm::is_contained(unsupportedOrderings, memOp.getOrdering()))
+ return memOp.emitOpError("unsupported ordering '")
+ << stringifyAtomicOrdering(memOp.getOrdering()) << "'";
+ if (!memOp.getAlignment())
+ return memOp.emitOpError("expected alignment for atomic access");
+ return success();
+ }
+ if (memOp.getSyncscope()) {
+ return memOp.emitOpError(
+ "expected syncscope to be null for non-atomic access");
+ }
+ return success();
+}
+
+/// Verifies that the alignment attribute is a power of 2 if present.
+static LogicalResult
+verifyAlignment(std::optional<int64_t> alignment,
+ function_ref<InFlightDiagnostic()> emitError) {
+ if (!alignment)
+ return success();
+ if (alignment.value() <= 0)
+ return emitError() << "alignment must be positive";
+ if (!llvm::isPowerOf2_64(alignment.value()))
+ return emitError() << "alignment must be a power of 2";
+ return success();
+}
+
+void LoadOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable());
+ // Volatile operations can have target-specific read-write effects on
+ // memory besides the one referred to by the pointer operand.
+ // Similarly, atomic operations that are monotonic or stricter cause
+ // synchronization that from a language point-of-view, are arbitrary
+ // read-writes into memory.
+ if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
+ getOrdering() != AtomicOrdering::unordered)) {
+ effects.emplace_back(MemoryEffects::Write::get());
+ effects.emplace_back(MemoryEffects::Read::get());
+ }
+}
+
+LogicalResult LoadOp::verify() {
+ auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
+ MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
+ DataLayout dataLayout = DataLayout::closest(*this);
+ if (!ms.isValidLoad(getResult().getType(), getOrdering(), getAlignment(),
+ &dataLayout, emitDiag))
+ return failure();
+ if (failed(verifyAlignment(getAlignment(), emitDiag)))
+ return failure();
+ return verifyAtomicMemOp(*this,
+ {AtomicOrdering::release, AtomicOrdering::acq_rel});
+}
+
+void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
+ Value addr, unsigned alignment, bool isVolatile,
+ bool isNonTemporal, bool isInvariant, bool isInvariantGroup,
+ AtomicOrdering ordering, StringRef syncscope) {
+ build(builder, state, type, addr,
+ alignment ? std::optional<int64_t>(alignment) : std::nullopt,
+ isVolatile, isNonTemporal, isInvariant, isInvariantGroup, ordering,
+ syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
+}
+
+//===----------------------------------------------------------------------===//
+// StoreOp
+//===----------------------------------------------------------------------===//
+
+void StoreOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Write::get(), &getPtrMutable());
+ // Volatile operations can have target-specific read-write effects on
+ // memory besides the one referred to by the pointer operand.
+ // Similarly, atomic operations that are monotonic or stricter cause
+ // synchronization that from a language point-of-view, are arbitrary
+ // read-writes into memory.
+ if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
+ getOrdering() != AtomicOrdering::unordered)) {
+ effects.emplace_back(MemoryEffects::Write::get());
+ effects.emplace_back(MemoryEffects::Read::get());
+ }
+}
+
+LogicalResult StoreOp::verify() {
+ auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
+ MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
+ DataLayout dataLayout = DataLayout::closest(*this);
+ if (!ms.isValidStore(getValue().getType(), getOrdering(), getAlignment(),
+ &dataLayout, emitDiag))
+ return failure();
+ if (failed(verifyAlignment(getAlignment(), emitDiag)))
+ return failure();
+ return verifyAtomicMemOp(*this,
+ {AtomicOrdering::acquire, AtomicOrdering::acq_rel});
+}
+
+void StoreOp::build(OpBuilder &builder, OperationState &state, Value value,
+ Value addr, unsigned alignment, bool isVolatile,
+ bool isNonTemporal, bool isInvariantGroup,
+ AtomicOrdering ordering, StringRef syncscope) {
+ build(builder, state, value, addr,
+ alignment ? std::optional<int64_t>(alignment) : std::nullopt,
+ isVolatile, isNonTemporal, isInvariantGroup, ordering,
+ syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
+}
+
+//===----------------------------------------------------------------------===//
// PtrAddOp
//===----------------------------------------------------------------------===//
@@ -152,10 +270,6 @@ llvm::TypeSize TypeOffsetOp::getTypeSize(std::optional<DataLayout> layout) {
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc"
-#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp.inc"
-
-#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.cpp.inc"
-
#include "mlir/Dialect/Ptr/IR/PtrOpsEnums.cpp.inc"
#define GET_TYPEDEF_CLASSES
diff --git a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
index 825d119..deb7109 100644
--- a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
@@ -4,7 +4,7 @@ add_mlir_dialect_library(MLIRQuantTransforms
StripFuncQuantTypes.cpp
ADDITIONAL_HEADER_DIRS
- {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Quant/Transforms
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Quant/Transforms
DEPENDS
MLIRQuantTransformsIncGen
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 0262a1b..84f9777 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -157,8 +157,7 @@ void ExecuteRegionOp::print(OpAsmPrinter &p) {
p.printRegion(getRegion(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
-
- p.printOptionalAttrDict((*this)->getAttrs());
+ p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"no_inline"});
}
LogicalResult ExecuteRegionOp::verify() {
@@ -318,9 +317,12 @@ void ConditionOp::getSuccessorRegions(
void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
Value ub, Value step, ValueRange initArgs,
- BodyBuilderFn bodyBuilder) {
+ BodyBuilderFn bodyBuilder, bool unsignedCmp) {
OpBuilder::InsertionGuard guard(builder);
+ if (unsignedCmp)
+ result.addAttribute(getUnsignedCmpAttrName(result.name),
+ builder.getUnitAttr());
result.addOperands({lb, ub, step});
result.addOperands(initArgs);
for (Value v : initArgs)
@@ -450,6 +452,9 @@ static void printInitializationList(OpAsmPrinter &p,
}
void ForOp::print(OpAsmPrinter &p) {
+ if (getUnsignedCmp())
+ p << " unsigned";
+
p << " " << getInductionVar() << " = " << getLowerBound() << " to "
<< getUpperBound() << " step " << getStep();
@@ -462,7 +467,8 @@ void ForOp::print(OpAsmPrinter &p) {
p.printRegion(getRegion(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/!getInitArgs().empty());
- p.printOptionalAttrDict((*this)->getAttrs());
+ p.printOptionalAttrDict((*this)->getAttrs(),
+ /*elidedAttrs=*/getUnsignedCmpAttrName().strref());
}
ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -472,6 +478,10 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::Argument inductionVariable;
OpAsmParser::UnresolvedOperand lb, ub, step;
+ if (succeeded(parser.parseOptionalKeyword("unsigned")))
+ result.addAttribute(getUnsignedCmpAttrName(result.name),
+ builder.getUnitAttr());
+
// Parse the induction variable followed by '='.
if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
// Parse loop bounds.
@@ -562,7 +572,7 @@ ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
inits.append(newInitOperands.begin(), newInitOperands.end());
scf::ForOp newLoop = scf::ForOp::create(
rewriter, getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
- [](OpBuilder &, Location, Value, ValueRange) {});
+ [](OpBuilder &, Location, Value, ValueRange) {}, getUnsignedCmp());
newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
// Generate the new yield values and append them to the scf.yield operation.
@@ -806,7 +816,8 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
// 2. Create the new forOp shell.
scf::ForOp newForOp = scf::ForOp::create(
rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), newIterOperands);
+ forOp.getStep(), newIterOperands, /*bodyBuilder=*/nullptr,
+ forOp.getUnsignedCmp());
newForOp->setAttrs(forOp->getAttrs());
Block &newBlock = newForOp.getRegion().front();
SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
@@ -931,7 +942,8 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
scf::ForOp newForOp =
scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
- forOp.getUpperBound(), forOp.getStep(), newIterArgs);
+ forOp.getUpperBound(), forOp.getStep(), newIterArgs,
+ /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
newForOp->setAttrs(forOp->getAttrs());
Block &newBlock = newForOp.getRegion().front();
@@ -989,12 +1001,12 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
/// Util function that tries to compute a constant diff between u and l.
/// Returns std::nullopt when the difference between two AffineValueMap is
/// dynamic.
-static std::optional<int64_t> computeConstDiff(Value l, Value u) {
+static std::optional<APInt> computeConstDiff(Value l, Value u) {
IntegerAttr clb, cub;
if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
llvm::APInt lbValue = clb.getValue();
llvm::APInt ubValue = cub.getValue();
- return (ubValue - lbValue).getSExtValue();
+ return ubValue - lbValue;
}
// Else a simple pattern match for x + c or c + x
@@ -1003,7 +1015,7 @@ static std::optional<int64_t> computeConstDiff(Value l, Value u) {
u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
matchPattern(
u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
- return diff.getSExtValue();
+ return diff;
return std::nullopt;
}
@@ -1022,13 +1034,15 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
return success();
}
- std::optional<int64_t> diff =
+ std::optional<APInt> diff =
computeConstDiff(op.getLowerBound(), op.getUpperBound());
if (!diff)
return failure();
// If the loop is known to have 0 iterations, remove it.
- if (*diff <= 0) {
+ bool zeroOrLessIterations =
+ diff->isZero() || (!op.getUnsignedCmp() && diff->isNegative());
+ if (zeroOrLessIterations) {
rewriter.replaceOp(op, op.getInitArgs());
return success();
}
@@ -3384,9 +3398,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
if (functionType.getNumInputs() != operands.size()) {
return parser.emitError(typeLoc)
- << "expected as many input types as operands "
- << "(expected " << operands.size() << " got "
- << functionType.getNumInputs() << ")";
+ << "expected as many input types as operands " << "(expected "
+ << operands.size() << " got " << functionType.getNumInputs() << ")";
}
// Resolve input operands.
@@ -4222,14 +4235,15 @@ LogicalResult scf::IndexSwitchOp::verify() {
<< "see yield operation here";
}
for (auto [idx, result, operand] :
- llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
- yield.getOperandTypes())) {
- if (result == operand)
+ llvm::enumerate(getResultTypes(), yield.getOperands())) {
+ if (!operand)
+ return yield.emitOpError() << "operand " << idx << " is null\n";
+ if (result == operand.getType())
continue;
return (emitOpError("expected result #")
<< idx << " of each region to be " << result)
.attachNote(yield.getLoc())
- << name << " returns " << operand << " here";
+ << name << " returns " << operand.getType() << " here";
}
return success();
};
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index f8799c5..fb179e6 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -769,7 +769,8 @@ struct ForOpInterface
// Construct a new scf.for op with memref instead of tensor values.
auto newForOp = scf::ForOp::create(
rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), castedInitArgs);
+ forOp.getStep(), castedInitArgs, /*bodyBuilder=*/nullptr,
+ forOp.getUnsignedCmp());
newForOp->setAttrs(forOp->getAttrs());
Block *loopBody = newForOp.getBody();
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index bee7780..ae52af5 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -58,9 +58,12 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
auto *beforeBlock = rewriter.createBlock(
&whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs);
rewriter.setInsertionPointToStart(whileOp.getBeforeBody());
- auto cmpOp = arith::CmpIOp::create(
- rewriter, whileOp.getLoc(), arith::CmpIPredicate::slt,
- beforeBlock->getArgument(0), forOp.getUpperBound());
+ arith::CmpIPredicate predicate = forOp.getUnsignedCmp()
+ ? arith::CmpIPredicate::ult
+ : arith::CmpIPredicate::slt;
+ auto cmpOp = arith::CmpIOp::create(rewriter, whileOp.getLoc(), predicate,
+ beforeBlock->getArgument(0),
+ forOp.getUpperBound());
scf::ConditionOp::create(rewriter, whileOp.getLoc(), cmpOp.getResult(),
beforeBlock->getArguments());
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 1130538..7e7fba4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -791,6 +791,11 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
bool *modifiedIR) {
if (modifiedIR)
*modifiedIR = false;
+
+ // TODO: Add support for unsigned loops.
+ if (forOp.getUnsignedCmp())
+ return failure();
+
LoopPipelinerInternal pipeliner;
if (!pipeliner.initializeLoopInfo(forOp, options))
return failure();
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index 4752c08..f1203b2 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -256,6 +256,10 @@ struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
LogicalResult matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const override {
+ if (forOp.getUnsignedCmp())
+ return rewriter.notifyMatchFailure(forOp,
+ "unsigned loops are not supported");
+
// Do not peel already peeled loops.
if (forOp->hasAttr(kPeeledLoopLabel))
return failure();
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index 694cd85..4ea8321 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -269,10 +269,10 @@ namespace {
struct ParallelLoopFusion
: public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
void runOnOperation() override {
- auto &AA = getAnalysis<AliasAnalysis>();
+ auto &aa = getAnalysis<AliasAnalysis>();
auto mayAlias = [&](Value val1, Value val2) -> bool {
- return !AA.alias(val1, val2).isNo();
+ return !aa.alias(val1, val2).isNo();
};
getOperation()->walk([&](Operation *child) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 1b07b77..072bc50 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -52,8 +52,8 @@ public:
SmallVector<unsigned> offsets;
offsets.push_back(0);
// Do the type conversion and record the offsets.
- for (Type type : op.getResultTypes()) {
- if (failed(typeConverter->convertTypes(type, dstTypes)))
+ for (Value v : op.getResults()) {
+ if (failed(typeConverter->convertType(v, dstTypes)))
return rewriter.notifyMatchFailure(op, "could not convert result type");
offsets.push_back(dstTypes.size());
}
@@ -116,7 +116,8 @@ public:
llvm::getSingleElement(adaptor.getLowerBound()),
llvm::getSingleElement(adaptor.getUpperBound()),
llvm::getSingleElement(adaptor.getStep()),
- flattenValues(adaptor.getInitArgs()));
+ flattenValues(adaptor.getInitArgs()),
+ /*bodyBuilder=*/nullptr, op.getUnsignedCmp());
// Reserve whatever attributes in the original op.
newOp->setAttrs(op->getAttrs());
@@ -126,7 +127,6 @@ public:
// Inline the type converted region from the original operation.
rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
newOp.getRegion().end());
-
return newOp;
}
};
@@ -225,15 +225,14 @@ void mlir::scf::populateSCFStructuralTypeConversions(
void mlir::scf::populateSCFStructuralTypeConversionTarget(
const TypeConverter &typeConverter, ConversionTarget &target) {
- target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
- return typeConverter.isLegal(op->getResultTypes());
- });
+ target.addDynamicallyLegalOp<ForOp, IfOp>(
+ [&](Operation *op) { return typeConverter.isLegal(op->getResults()); });
target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
// We only have conversions for a subset of ops that use scf.yield
// terminators.
if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp()))
return true;
- return typeConverter.isLegal(op.getOperandTypes());
+ return typeConverter.isLegal(op.getOperands());
});
target.addDynamicallyLegalOp<WhileOp, ConditionOp>(
[&](Operation *op) { return typeConverter.isLegal(op); });
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index c0e47ee..834c021 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -797,7 +797,8 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
inits.append(newInitOperands.begin(), newInitOperands.end());
auto newLoop = scf::ForOp::create(
rewriter, loc, loopOp.getLowerBound(), loopOp.getUpperBound(),
- loopOp.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {});
+ loopOp.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {},
+ loopOp.getUnsignedCmp());
// Move the loop body to the new op.
Block *loopBody = loopOp.getBody();
@@ -935,7 +936,8 @@ static LogicalResult addInitOperandsToLoopNest(
auto newLoop = scf::ForOp::create(
rewriter, forLoop.getLoc(), forLoop.getLowerBound(),
forLoop.getUpperBound(), forLoop.getStep(), newInits,
- [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {},
+ forLoop.getUnsignedCmp());
// Merge the body of the new loop with the body of the old loops.
SmallVector<Value> sourceBlockArgs;
@@ -1914,63 +1916,6 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
return failure();
}
-/// Check that the loop is perfectly nested.
-/// The loops are expected to be ordered from outer most to inner most.
-/// For example:
-/// ```
-/// %0 = scf.for()
-/// %1 = scf.for()
-/// %2 = scf.for()
-/// %3 = ...
-/// yield %3
-/// yield %2
-/// yield %1
-/// ```
-/// Here loops should be [%0, %1].
-static bool
-isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) {
- assert(!loops.empty() && "unexpected empty loop nest");
- if (loops.size() == 1) {
- return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
- }
- for (auto [outerLoop, innerLoop] :
- llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
- auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
- auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
- if (!outerFor || !innerFor) {
- return false;
- }
- auto outerBBArgs = outerFor.getRegionIterArgs();
- auto innerIterArgs = innerFor.getInitArgs();
- if (outerBBArgs.size() != innerIterArgs.size()) {
- return false;
- }
-
- for (auto [outerBBArg, innerIterArg] :
- llvm::zip_equal(outerBBArgs, innerIterArgs)) {
- if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
- innerIterArg != outerBBArg) {
- return false;
- }
- }
-
- ValueRange outerYields =
- cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
- ValueRange innerResults = innerFor.getResults();
- if (outerYields.size() != innerResults.size()) {
- return false;
- }
- for (auto [outerYield, innerResult] :
- llvm::zip_equal(outerYields, innerResults)) {
- if (!llvm::hasSingleElement(innerResult.getUses()) ||
- outerYield != innerResult) {
- return false;
- }
- }
- }
- return true;
-}
-
/// Fetch the untiled consumer of the outermost scf.for's result which is
/// yielded by a tensor.insert_slice from the innermost scf.for. This function
/// makes the following assumptions :
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 5731795..684dff8 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1233,6 +1233,7 @@ static void getPerfectlyNestedLoopsImpl(
static Loops stripmineSink(scf::ForOp forOp, Value factor,
ArrayRef<scf::ForOp> targets) {
+ assert(!forOp.getUnsignedCmp() && "unsigned loops are not supported");
auto originalStep = forOp.getStep();
auto iv = forOp.getInductionVar();
@@ -1241,6 +1242,8 @@ static Loops stripmineSink(scf::ForOp forOp, Value factor,
Loops innerLoops;
for (auto t : targets) {
+ assert(!t.getUnsignedCmp() && "unsigned loops are not supported");
+
// Save information for splicing ops out of t when done
auto begin = t.getBody()->begin();
auto nOps = t.getBody()->getOperations().size();
@@ -1415,6 +1418,8 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
scf::ForOp source,
RewriterBase &rewriter) {
+ assert(source.getUnsignedCmp() == target.getUnsignedCmp() &&
+ "incompatible signedness");
unsigned numTargetOuts = target.getNumResults();
unsigned numSourceOuts = source.getNumResults();
@@ -1428,7 +1433,8 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
rewriter.setInsertionPointAfter(source);
scf::ForOp fusedLoop = scf::ForOp::create(
rewriter, source.getLoc(), source.getLowerBound(), source.getUpperBound(),
- source.getStep(), fusedInitArgs);
+ source.getStep(), fusedInitArgs, /*bodyBuilder=*/nullptr,
+ source.getUnsignedCmp());
// Map original induction variables and operands to those of the fused loop.
IRMapping mapping;
@@ -1506,3 +1512,41 @@ FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
rewriter.replaceOp(forallOp, normalizedForallOp);
return normalizedForallOp;
}
+
+bool mlir::isPerfectlyNestedForLoops(
+ MutableArrayRef<LoopLikeOpInterface> loops) {
+ assert(!loops.empty() && "unexpected empty loop nest");
+ if (loops.size() == 1)
+ return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
+ for (auto [outerLoop, innerLoop] :
+ llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
+ auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
+ auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
+ if (!outerFor || !innerFor)
+ return false;
+ auto outerBBArgs = outerFor.getRegionIterArgs();
+ auto innerIterArgs = innerFor.getInitArgs();
+ if (outerBBArgs.size() != innerIterArgs.size())
+ return false;
+
+ for (auto [outerBBArg, innerIterArg] :
+ llvm::zip_equal(outerBBArgs, innerIterArgs)) {
+ if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
+ innerIterArg != outerBBArg)
+ return false;
+ }
+
+ ValueRange outerYields =
+ cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
+ ValueRange innerResults = innerFor.getResults();
+ if (outerYields.size() != innerResults.size())
+ return false;
+ for (auto [outerYield, innerResult] :
+ llvm::zip_equal(outerYields, innerResults)) {
+ if (!llvm::hasSingleElement(innerResult.getUses()) ||
+ outerYield != innerResult)
+ return false;
+ }
+ }
+ return true;
+}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index ddb3426..369b953 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -1322,7 +1322,7 @@ struct spirv::detail::TensorArmTypeStorage final : TypeStorage {
}
TensorArmTypeStorage(ArrayRef<int64_t> shape, Type elementType)
- : shape(std::move(shape)), elementType(std::move(elementType)) {}
+ : shape(shape), elementType(elementType) {}
ArrayRef<int64_t> shape;
Type elementType;
diff --git a/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp
index d4e7618..7a05dfe 100644
--- a/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp
@@ -513,8 +513,9 @@ LogicalResult shard::detail::defaultAddShardingAnnotations(
}
#ifndef NDEBUG
-static bool isValueCompatibleWithFullReplicationSharding(Value value,
- Sharding sharding) {
+static bool
+isValueCompatibleWithFullReplicationSharding(Value value,
+ const Sharding &sharding) {
if (isa<RankedTensorType>(value.getType())) {
return isFullReplication(sharding);
}
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 3e3d476..5dc61a2 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -477,10 +477,10 @@ reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid,
return targetShard;
}
-TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, GridOp grid,
- Sharding sourceSharding, Sharding targetSharding,
- TypedValue<ShapedType> sourceUnshardedValue,
- TypedValue<ShapedType> sourceShard) {
+static TypedValue<ShapedType>
+reshard(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding,
+ Sharding targetSharding, TypedValue<ShapedType> sourceUnshardedValue,
+ TypedValue<ShapedType> sourceShard) {
// If source and destination sharding are the same, no need to do anything.
if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) &&
isFullReplication(targetSharding))) {
@@ -535,7 +535,7 @@ using UnshardedToShardedValueMap = DenseMap<Value, Value>;
// Get the types of block arguments for an partitioned block.
// Reads the sharding annotations of the arguments to deduce the sharded types.
// Types that are not ranked tensors are left unchanged.
-SmallVector<Type>
+static SmallVector<Type>
shardedBlockArgumentTypes(Block &block,
SymbolTableCollection &symbolTableCollection) {
SmallVector<Type> res;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
index 56b435c..9694a40 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
@@ -231,7 +231,9 @@ ParseResult DimLvlMapParser::parseLvlSpecList() {
const auto loc = parser.getCurrentLocation();
const auto res = parser.parseCommaSeparatedList(
mlir::OpAsmParser::Delimiter::Paren,
- [=]() -> ParseResult { return parseLvlSpec(requireLvlVarBinding); },
+ [this, requireLvlVarBinding]() -> ParseResult {
+ return parseLvlSpec(requireLvlVarBinding);
+ },
" in level-specifier list");
FAILURE_IF_FAILED(res)
const auto specLvlRank = lvlSpecs.size();
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index 9e2e6ab..a1711a6 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -156,13 +156,14 @@ minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) {
return pair1 <= pair2 ? sm1 : sm2;
}
-bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, StringRef name) {
+static bool isInternalConsistent(VarEnv const &env, VarInfo::ID id,
+ StringRef name) {
const auto &var = env.access(id);
return (var.getName() == name && var.getID() == id);
}
-bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, llvm::SMLoc loc,
- VarKind vk) {
+static bool isUsageConsistent(VarEnv const &env, VarInfo::ID id,
+ llvm::SMLoc loc, VarKind vk) {
const auto &var = env.access(id);
return var.getKind() == vk;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
index 3b97786..dabbea1 100644
--- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
@@ -71,7 +71,6 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm,
pm.addPass(createLowerAffinePass());
pm.addPass(
createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
- pm.addPass(createFinalizeMemRefToLLVMConversionPass());
pm.addNestedPass<func::FuncOp>(createConvertComplexToStandardPass());
pm.addNestedPass<func::FuncOp>(arith::createArithExpandOpsPass());
pm.addNestedPass<func::FuncOp>(createConvertMathToLLVMPass());
@@ -79,12 +78,6 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm,
pm.addPass(createConvertComplexToLibm());
pm.addPass(
createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
- pm.addPass(createConvertComplexToLLVMPass());
- pm.addPass(
- createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
- pm.addPass(createConvertFuncToLLVMPass());
- pm.addPass(createArithToLLVMConversionPass());
- pm.addPass(createConvertControlFlowToLLVMPass());
// Finalize GPU code generation.
if (gpuCodegen) {
@@ -99,8 +92,8 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm,
pm.addPass(createGpuModuleToBinaryPass(gpuModuleToBinaryPassOptions));
}
- // Convert poison values.
- pm.addPass(createUBToLLVMConversionPass());
+ // Convert to LLVM.
+ pm.addPass(createConvertToLLVMPass());
// Ensure all casts are realized.
pm.addPass(createReconcileUnrealizedCastsPass());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 3b4140e..ae7eef2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -1219,8 +1219,9 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
/// Implements the rewriting for operator sort and sort_coo.
template <typename OpTy>
-LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm,
- uint64_t ny, PatternRewriter &rewriter) {
+static LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys,
+ AffineMap xPerm, uint64_t ny,
+ PatternRewriter &rewriter) {
Location loc = op.getLoc();
SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 134aef3..0e88d31d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -730,9 +730,9 @@ public:
{tensor, lvlCoords, values, filled, added, count},
EmitCInterface::On);
Operation *parent = getTop(op);
+ rewriter.setInsertionPointAfter(parent);
rewriter.replaceOp(op, adaptor.getTensor());
// Deallocate the buffers on exit of the loop nest.
- rewriter.setInsertionPointAfter(parent);
memref::DeallocOp::create(rewriter, loc, values);
memref::DeallocOp::create(rewriter, loc, filled);
memref::DeallocOp::create(rewriter, loc, added);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index 4464450..febec6d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -533,8 +533,10 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
VectorType vtp = vectorType(vl, init.getType());
Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
forOp.getRegionIterArg(0), init, vtp);
- forOpNew = scf::ForOp::create(rewriter, loc, forOp.getLowerBound(),
- forOp.getUpperBound(), step, vinit);
+ forOpNew =
+ scf::ForOp::create(rewriter, loc, forOp.getLowerBound(),
+ forOp.getUpperBound(), step, vinit,
+ /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
forOpNew->setAttr(
LoopEmitter::getLoopEmitterLoopAttrName(),
forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
@@ -605,8 +607,8 @@ public:
ForOpRewriter(MLIRContext *context, unsigned vectorLength,
bool enableVLAVectorization, bool enableSIMDIndex32)
- : OpRewritePattern(context), vl{vectorLength, enableVLAVectorization,
- enableSIMDIndex32} {}
+ : OpRewritePattern(context),
+ vl{vectorLength, enableVLAVectorization, enableSIMDIndex32} {}
LogicalResult matchAndRewrite(scf::ForOp op,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 7d4b112..68584ec 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3200,20 +3200,6 @@ void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "padded");
}
-// TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
-// supports optional types.
-void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
- Type typeToInfer, Type typeToInferFrom) {}
-
-ParseResult
-parseInferType(OpAsmParser &parser,
- std::optional<OpAsmParser::UnresolvedOperand> optOperand,
- Type &typeToInfer, Type typeToInferFrom) {
- if (optOperand)
- typeToInfer = typeToInferFrom;
- return success();
-}
-
LogicalResult PadOp::verify() {
auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
@@ -4059,7 +4045,7 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
// Common Canonicalizers and Folders.
//===----------------------------------------------------------------------===//
-bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
+static bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
// 1. InsertSliceOp has its own logic about folding tensor.cast ops.
// 2. Exclude DPS ops that are also LoopLike from this interface as they
// might need special handling of attached regions.
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index 2ec23e1..dfce835 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -327,172 +327,31 @@ struct BubbleUpExpandShapeThroughExtractSlice
PatternRewriter &rewriter) const override {
auto expandShapeOp =
sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
+ if (!expandShapeOp) {
+ return rewriter.notifyMatchFailure(
+ sliceOp, "tensor.extract_slice source not produced by expand_shape");
+ }
+ SmallVector<ReassociationIndices> reassociation =
+ expandShapeOp.getReassociationIndices();
- if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
- rewriter)
- .failed())
+ SmallVector<OpFoldResult> offsets, sizes, strides;
+ if (failed(getCollapsedExtractSliceInfo(rewriter, sliceOp, reassociation,
+ offsets, sizes, strides)))
return failure();
- // The tensor.extract_slice before applying the pattern works on the result
- // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
- // referring to the state before applying the pattern are named with the
- // prefix "expanded", and ones referring to the state after applying the
- // pattern are named with the prefix "collapsed".
- SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
- SmallVector<OpFoldResult> expandedShape =
- getMixedValues(expandShapeOp.getStaticOutputShape(),
- expandShapeOp.getOutputShape(), rewriter);
-
- // Helper variables and function for accumulating the size values.
- Location loc = expandShapeOp->getLoc();
- AffineExpr d0, d1, d2;
- bindDims(rewriter.getContext(), d0, d1, d2);
- // Multiply two integers.
- auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
- auto mulMap = AffineMap::get(2, 0, {d0 * d1});
- return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
- {v1, v2});
- };
-
- // Compute new offsets, sizes, and strides for tensor.extract_slice.
- // The new tensor.extract_slice will work on a tensor that has has a rank of
- // ReassociationIndices.size(). In the loop a single offset, size, and
- // stride value is computed per reassociation group.
- SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes,
- collapsedStrides;
- for (const ReassociationIndices &indices :
- expandShapeOp.getReassociationIndices()) {
- // collapsedSize will hold the size of the single dim that represents the
- // reassociation group in the non expanded tensor.
- OpFoldResult collapsedSize = rewriter.getIndexAttr(1);
- // The reassocGroupSizes and reassocGroupOffsets are used to create an
- // affine.linearize_index op to linearize the single offset value required
- // for this reassociation group.
- SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
-
- for (long expandedDim : indices) {
- // reassocGroupSizes and reassocGroupOffsets can be obtained directly
- // from the expanded state, but the collapsed size requires calculation
- // as it did not previously exist.
- reassocGroupSizes.push_back(expandedShape[expandedDim]);
- reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
- collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
- }
-
- SmallVector<Value> offsetVals =
- llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
- return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
- });
- OpFoldResult collapsedOffset =
- affine::AffineLinearizeIndexOp::create(rewriter, loc, offsetVals,
- reassocGroupSizes,
- /*disjoint=*/true)
- .getResult();
- collapsedOffsets.push_back(collapsedOffset);
- collapsedSizes.push_back(collapsedSize);
-
- // Only unit stride is supported.
- collapsedStrides.push_back(rewriter.getIndexAttr(1));
- }
-
// The shape of the result can be obtained from the sizes passed in.
- SmallVector<Value> dynDims;
- SmallVector<int64_t> shape;
- dispatchIndexOpFoldResults(expandedSizes, dynDims, shape);
- RankedTensorType resultType = RankedTensorType::get(
- shape, expandShapeOp.getResultType().getElementType());
+ SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
+ RankedTensorType resultType = sliceOp.getResultType();
// Create a new ExtractSliceOp and ExpandShapeOp.
+ Location loc = sliceOp.getLoc();
Value newSliceOp = tensor::ExtractSliceOp::create(
- rewriter, loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes,
- collapsedStrides);
+ rewriter, loc, expandShapeOp.getSrc(), offsets, sizes, strides);
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
sliceOp, resultType, newSliceOp,
expandShapeOp.getReassociationIndices(), expandedSizes);
return success();
}
-
- // Helper function to check if all the required conditions for the
- // tensor.extract_slice to be bubbled up through the tensor.expand_shape are
- // met.
- LogicalResult
- checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp,
- tensor::ExpandShapeOp expandShapeOp,
- PatternRewriter &rewriter) const {
-
- if (!expandShapeOp) {
- return rewriter.notifyMatchFailure(
- sliceOp, "tensor.extract_slice source not produced by expand_shape");
- }
-
- if (!sliceOp.hasUnitStride()) {
- return rewriter.notifyMatchFailure(
- sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
- "be supported in this transformation.");
- }
-
- SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
-
- if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
- sizes.size()) {
- return rewriter.notifyMatchFailure(sliceOp,
- "unimplemented: rank reducing slice");
- }
-
- SmallVector<OpFoldResult> outputShape =
- getMixedValues(expandShapeOp.getStaticOutputShape(),
- expandShapeOp.getOutputShape(), rewriter);
-
- std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
- isZeroOffsetAndFullSize =
- [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
- if (!isZeroInteger(offset))
- return false;
- FailureOr<bool> maybeEqual =
- ValueBoundsConstraintSet::areEqual(sliceSize, size);
- return llvm::succeeded(maybeEqual) && maybeEqual.value();
- };
-
- // Check that the slice is contiguous within each reassociation group.
- // The slice is contiguous only if after the first dimension where a non
- // unit slice is taken, the slice size on all subsequent dimensions of the
- // group is equal to the entire size of the dimension.
- // Examples of contiguous slices:
- // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
- // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
- // Examples of non contiguous slices:
- // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
- // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
- for (const ReassociationIndices &indices :
- expandShapeOp.getReassociationIndices()) {
- int64_t i = 0;
- int64_t e = indices.size();
- // Find the first expanded dim after the first dim with non-unit extracted
- // size.
- for (; i < e; ++i) {
- if (!isOneInteger(sizes[indices[i]])) {
- // +1 to skip the first non-unit size dim.
- i++;
- break;
- }
- }
-
- // Verify that all subsequent dimensions extract the full size of the
- // source tensor.
- for (; i < e; ++i) {
- int64_t expandedDim = indices[i];
- if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
- outputShape[expandedDim])) {
- return rewriter.notifyMatchFailure(
- sliceOp, "Not a contiguous slice of the expanded tensor.");
- }
- }
- }
-
- return success();
- }
};
/// Converts `tensor.extract_slice(tensor.collapse_shape)` to
@@ -582,170 +441,281 @@ struct BubbleUpCollapseShapeThroughExtractSlice
"tensor.extract_slice source not produced by tensor.collapse_shape");
}
- if (!sliceOp.hasUnitStride()) {
- return rewriter.notifyMatchFailure(
- sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
- "be supported in this transformation.");
- }
+ SmallVector<OpFoldResult> offsets, sizes, strides;
+ if (failed(getExpandedExtractSliceInfo(
+ rewriter, sliceOp, collapseShapeOp.getReassociationIndices(),
+ collapseShapeOp.getSrcType().getShape(), offsets, sizes, strides)))
+ return failure();
- // The tensor.extract_slice before applying the pattern works on the result
- // of the tensor.collapse_shape, so variables (i.e. inputs for
- // ExtractSliceOp) referring to the state before applying the pattern are
- // named with the prefix "collapsed", and ones referring to the state after
- // applying the pattern are named with the prefix "expanded".
- SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
-
- if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
- collapsedSizes.size()) {
- return rewriter.notifyMatchFailure(sliceOp,
- "unimplemented: rank reducing slice");
- }
+ Value newSliceOp = tensor::ExtractSliceOp::create(
+ rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), offsets,
+ sizes, strides);
+ rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+ sliceOp, sliceOp.getResultType(), newSliceOp,
+ collapseShapeOp.getReassociationIndices());
- ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape();
- SmallVector<ReassociationIndices, 4> reassociationIndices =
- collapseShapeOp.getReassociationIndices();
-
- // Compute new offsets, sizes, and strides for tensor.extract_slice.
- // The new tensor.extract_slice will work on a tensor that has has a rank
- // equal to the rank of the src of the collapse_shape. In each iteration of
- // the loop, the offsets and sizes will be computed per reassociation group.
- SmallVector<OpFoldResult> expandedOffsets, expandedSizes;
- SmallVector<OpFoldResult> expandedStrides(srcShape.size(),
- rewriter.getIndexAttr(1));
-
- for (auto [collapsedSize, collapsedOffset, reassocIndices] :
- llvm::zip_equal(collapsedSizes, collapsedOffsets,
- collapseShapeOp.getReassociationIndices())) {
- // CASE #1 - size and/or offset are dynamic.
- // In this case, the slice can be represented as a contiguous slice only
- // if there is a single dimension in the reassociation group that has a
- // size not equal to 1.
- if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
- int nonUnitSizeCount = 0;
- for (int64_t expandedShapeIdx : reassocIndices) {
- if (srcShape[expandedShapeIdx] != 1) {
- nonUnitSizeCount++;
- expandedSizes.push_back(collapsedSize);
- expandedOffsets.push_back(collapsedOffset);
- continue;
- }
-
- expandedSizes.push_back(rewriter.getIndexAttr(1));
- expandedOffsets.push_back(rewriter.getIndexAttr(0));
- }
+ return success();
+ }
+};
- if (nonUnitSizeCount != 1) {
- return rewriter.notifyMatchFailure(
- sliceOp,
- "unsupported: slice cannot be verified to be contiguous");
- }
- continue;
- }
+} // namespace
- // CASE #2 = size and offset are static.
- // Verify that the slice can be represented as a contiguous slice of the
- // src of the collapse_shape.
- // Checking this is done on order of most internal dimensions first,
- // so traversal is done in reverse order of the reassociation group.
- // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
- // ...,An] then we first find the size and offset for n...k+1 then for k
- // and then for k-1...0.
-
- // currentCollapsedsize and currentCollapsedOffset are initialized with
- // the original collapsed size and offset and divided by the expanded
- // shape size in each dimension as we go along the reassociation group.
- // In essence we are spreading the original collapsed size and offset over
- // the various expanded slice dimensions.
- // The variables are used both to check the validity of the slice and to
- // compute the expanded sizes and offsets.
- int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
- int64_t currentCollapsedOffset =
- getConstantIntValue(collapsedOffset).value();
-
- SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
-
- ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
- reassocIndices.rend());
- int64_t idx = 0;
- int64_t reassocGroupSize = reassocIndices.size();
-
- // First handle the trailing dimensions where the slice size should be
- // equal to the tensor shape and the offset should be 0 (n...k+1).
- for (; idx < reassocGroupSize; ++idx) {
- int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
-
- if (currentCollapsedsize < expandedShapeSize)
- break;
-
- // We need to make sure that the slice size can be set to the shape size
- // and the offset to 0.
- if ((currentCollapsedsize % expandedShapeSize) != 0 ||
- (currentCollapsedOffset % expandedShapeSize) != 0) {
- return rewriter.notifyMatchFailure(
- sliceOp, "unsupported: cannot be extracted as a contiguous slice "
- "of the src of the collapse_shape");
- }
+LogicalResult mlir::tensor::getCollapsedExtractSliceInfo(
+ OpBuilder &b, tensor::ExtractSliceOp sliceOp,
+ ArrayRef<ReassociationIndices> reassociation,
+ SmallVectorImpl<OpFoldResult> &collapsedOffsets,
+ SmallVectorImpl<OpFoldResult> &collapsedSizes,
+ SmallVectorImpl<OpFoldResult> &collapsedStrides) {
+ if (!sliceOp.hasUnitStride()) {
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
- groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize));
- groupExpandedOffsets.push_back(rewriter.getIndexAttr(0));
+ if (static_cast<size_t>(sliceOp.getResultType().getRank()) != sizes.size()) {
+ return failure();
+ }
- currentCollapsedsize /= expandedShapeSize;
- currentCollapsedOffset /= expandedShapeSize;
+ auto isZeroOffsetAndFullSize = [&](OpFoldResult offset,
+ OpFoldResult sliceSize, int64_t inputDim) {
+ if (!isZeroInteger(offset))
+ return false;
+ ValueBoundsConstraintSet::Variable inputSize(sliceOp.getSource(), inputDim);
+ FailureOr<bool> maybeEqual =
+ ValueBoundsConstraintSet::areEqual(sliceSize, inputSize);
+ return llvm::succeeded(maybeEqual) && maybeEqual.value();
+ };
+
+ // Check that the slice is contiguous within each reassociation group.
+ // The slice is contiguous only if after the first dimension where a non
+ // unit slice is taken, the slice size on all subsequent dimensions of the
+ // group is equal to the entire size of the dimension.
+ // Examples of contiguous slices:
+ // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
+ // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
+ // Examples of non contiguous slices:
+ // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
+ // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
+ for (const ReassociationIndices &indices : reassociation) {
+ int64_t i = 0;
+ int64_t e = indices.size();
+ // Find the first expanded dim after the first dim with non-unit extracted
+ // size.
+ for (; i < e; ++i) {
+ if (!isOneInteger(sizes[indices[i]])) {
+ // +1 to skip the first non-unit size dim.
+ i++;
+ break;
}
+ }
+
+ // Verify that all subsequent dimensions extract the full size of the
+ // source tensor.
+ for (; i < e; ++i) {
+ int64_t expandedDim = indices[i];
+ if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
+ expandedDim)) {
+ return failure();
+ }
+ }
+ }
+
+ // The tensor.extract_slice before applying the pattern works on the result
+ // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
+ // referring to the state before applying the pattern are named with the
+ // prefix "expanded", and ones referring to the state after applying the
+ // pattern are named with the prefix "collapsed".
+ Location loc = sliceOp.getLoc();
+ SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> expandedShape =
+ getMixedSizes(b, loc, sliceOp.getSource());
+
+ // Helper variables and function for accumulating the size values.
+ AffineExpr d0, d1, d2;
+ bindDims(b.getContext(), d0, d1, d2);
+ // Multiply two integers.
+ auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
+ auto mulMap = AffineMap::get(2, 0, {d0 * d1});
+ return affine::makeComposedFoldedAffineApply(b, loc, mulMap, {v1, v2});
+ };
+
+ // Compute new offsets, sizes, and strides for tensor.extract_slice.
+ // The new tensor.extract_slice will work on a tensor that has has a rank of
+ // ReassociationIndices.size(). In the loop a single offset, size, and
+ // stride value is computed per reassociation group.
+ for (const ReassociationIndices &indices : reassociation) {
+ // collapsedSize will hold the size of the single dim that represents the
+ // reassociation group in the non expanded tensor.
+ OpFoldResult collapsedSize = b.getIndexAttr(1);
+ // The reassocGroupSizes and reassocGroupOffsets are used to create an
+ // affine.linearize_index op to linearize the single offset value required
+ // for this reassociation group.
+ SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
+
+ for (long expandedDim : indices) {
+ // reassocGroupSizes and reassocGroupOffsets can be obtained directly
+ // from the expanded state, but the collapsed size requires calculation
+ // as it did not previously exist.
+ reassocGroupSizes.push_back(expandedShape[expandedDim]);
+ reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
+ collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
+ }
+
+ SmallVector<Value> offsetVals =
+ llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
+ return getValueOrCreateConstantIndexOp(b, loc, ofr);
+ });
+ OpFoldResult collapsedOffset = affine::AffineLinearizeIndexOp::create(
+ b, loc, offsetVals, reassocGroupSizes,
+ /*disjoint=*/true)
+ .getResult();
+ collapsedOffsets.push_back(collapsedOffset);
+ collapsedSizes.push_back(collapsedSize);
+
+ // Only unit stride is supported.
+ collapsedStrides.push_back(b.getIndexAttr(1));
+ }
+ return success();
+}
+
+LogicalResult mlir::tensor::getExpandedExtractSliceInfo(
+ OpBuilder &b, tensor::ExtractSliceOp sliceOp,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<int64_t> expandedShape,
+ SmallVectorImpl<OpFoldResult> &expandedOffsets,
+ SmallVectorImpl<OpFoldResult> &expandedSizes,
+ SmallVectorImpl<OpFoldResult> &expandedStrides) {
+ if (!sliceOp.hasUnitStride()) {
+ return failure();
+ }
+
+ // The tensor.extract_slice before applying the pattern works on the result
+ // of the tensor.collapse_shape, so variables (i.e. inputs for
+ // ExtractSliceOp) referring to the state before applying the pattern are
+ // named with the prefix "collapsed", and ones referring to the state after
+ // applying the pattern are named with the prefix "expanded".
+ SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
+ if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
+ collapsedSizes.size()) {
+ return failure();
+ }
- // Now handle the first dim where slicing occurs on (k).
- if (idx < reassocGroupSize) {
- int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
- int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
- // We need to make sure that the slice size in this dim + offset will
- // not exceed the shape size.
- if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
- return rewriter.notifyMatchFailure(
- sliceOp, "unsupported: slice cannot be extracted as a contiguous "
- "slice of the src of the collapse_shape");
+ // Compute new offsets, sizes, and strides for tensor.extract_slice.
+ // The new tensor.extract_slice will work on a tensor that has has a rank
+ // equal to the rank of the src of the collapse_shape. In each iteration of
+ // the loop, the offsets and sizes will be computed per reassociation group.
+ expandedStrides.resize(expandedShape.size(), b.getIndexAttr(1));
+ for (auto [collapsedSize, collapsedOffset, reassocIndices] :
+ llvm::zip_equal(collapsedSizes, collapsedOffsets, reassociation)) {
+ // CASE #1 - size and/or offset are dynamic.
+ // In this case, the slice can be represented as a contiguous slice only
+ // if there is a single dimension in the reassociation group that has a
+ // size not equal to 1.
+ if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
+ int nonUnitSizeCount = 0;
+ for (int64_t expandedShapeIdx : reassocIndices) {
+ if (expandedShape[expandedShapeIdx] != 1) {
+ nonUnitSizeCount++;
+ expandedSizes.push_back(collapsedSize);
+ expandedOffsets.push_back(collapsedOffset);
+ continue;
}
- groupExpandedSizes.push_back(
- rewriter.getIndexAttr(currentCollapsedsize));
- groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
+ expandedSizes.push_back(b.getIndexAttr(1));
+ expandedOffsets.push_back(b.getIndexAttr(0));
+ }
- currentCollapsedOffset /= expandedShapeSize;
+ if (nonUnitSizeCount != 1) {
+ return failure();
}
+ continue;
+ }
- // Now handle the leading dimensions where the slice size is equal to 1
- // (k-1...0).
- // The size for these dimensions must be 1 because of how we constructed
- // the slice size of the expanded shape. We spread the original collapsed
- // size over the expanded shape sizes until we reached dimension k where
- // the remaining size was smaller than the expanded shape size, and spread
- // the remaining size on it. So, now we are left with only 1s.
- for (idx++; idx < reassocGroupSize; ++idx) {
- int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
- int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
- groupExpandedSizes.push_back(rewriter.getIndexAttr(1));
- groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
- currentCollapsedOffset /= expandedShapeSize;
+ // CASE #2 = size and offset are static.
+ // Verify that the slice can be represented as a contiguous slice of the
+ // src of the collapse_shape.
+ // Checking this is done on order of most internal dimensions first,
+ // so traversal is done in reverse order of the reassociation group.
+ // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
+ // ...,An] then we first find the size and offset for n...k+1 then for k
+ // and then for k-1...0.
+
+ // currentCollapsedsize and currentCollapsedOffset are initialized with
+ // the original collapsed size and offset and divided by the expanded
+ // shape size in each dimension as we go along the reassociation group.
+ // In essence we are spreading the original collapsed size and offset over
+ // the various expanded slice dimensions.
+ // The variables are used both to check the validity of the slice and to
+ // compute the expanded sizes and offsets.
+ int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
+ int64_t currentCollapsedOffset =
+ getConstantIntValue(collapsedOffset).value();
+ SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
+ ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
+ reassocIndices.rend());
+ int64_t idx = 0;
+ int64_t reassocGroupSize = reassocIndices.size();
+
+ // First handle the trailing dimensions where the slice size should be
+ // equal to the tensor shape and the offset should be 0 (n...k+1).
+ for (; idx < reassocGroupSize; ++idx) {
+ int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
+
+ if (currentCollapsedsize < expandedShapeSize)
+ break;
+
+ // We need to make sure that the slice size can be set to the shape size
+ // and the offset to 0.
+ if ((currentCollapsedsize % expandedShapeSize) != 0 ||
+ (currentCollapsedOffset % expandedShapeSize) != 0) {
+ return failure();
}
- expandedSizes.append(groupExpandedSizes.rbegin(),
- groupExpandedSizes.rend());
- expandedOffsets.append(groupExpandedOffsets.rbegin(),
- groupExpandedOffsets.rend());
+ groupExpandedSizes.push_back(b.getIndexAttr(expandedShapeSize));
+ groupExpandedOffsets.push_back(b.getIndexAttr(0));
+
+ currentCollapsedsize /= expandedShapeSize;
+ currentCollapsedOffset /= expandedShapeSize;
}
- Value newSliceOp = tensor::ExtractSliceOp::create(
- rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(),
- expandedOffsets, expandedSizes, expandedStrides);
- rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
- sliceOp, sliceOp.getResultType(), newSliceOp,
- collapseShapeOp.getReassociationIndices());
+ // Now handle the first dim where slicing occurs on (k).
+ if (idx < reassocGroupSize) {
+ int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
+ int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
+ // We need to make sure that the slice size in this dim + offset will
+ // not exceed the shape size.
+ if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
+ return failure();
+ }
+ groupExpandedSizes.push_back(b.getIndexAttr(currentCollapsedsize));
+ groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim));
+ currentCollapsedOffset /= expandedShapeSize;
+ }
- return success();
+ // Now handle the leading dimensions where the slice size is equal to 1
+ // (k-1...0).
+ // The size for these dimensions must be 1 because of how we constructed
+ // the slice size of the expanded shape. We spread the original collapsed
+ // size over the expanded shape sizes until we reached dimension k where
+ // the remaining size was smaller than the expanded shape size, and spread
+ // the remaining size on it. So, now we are left with only 1s.
+ for (idx++; idx < reassocGroupSize; ++idx) {
+ int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
+ int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
+ groupExpandedSizes.push_back(b.getIndexAttr(1));
+ groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim));
+ currentCollapsedOffset /= expandedShapeSize;
+ }
+ expandedSizes.append(groupExpandedSizes.rbegin(),
+ groupExpandedSizes.rend());
+ expandedOffsets.append(groupExpandedOffsets.rbegin(),
+ groupExpandedOffsets.rend());
}
-};
-
-} // namespace
+ return success();
+}
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
RewritePatternSet &patterns) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index e3cba388..8d63646 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -122,8 +122,9 @@ struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
const APFloat lowestVal =
APFloat::getLargest(padConstVal.getSemantics(), true);
return padConstVal == lowestVal;
- } else if (auto padConstIntAttr =
- mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
+ }
+ if (auto padConstIntAttr =
+ mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
const APInt padConstVal = *padConstIntAttr.begin();
const unsigned int bitWidth = padConstVal.getBitWidth();
const APInt lowestVal =
@@ -555,7 +556,8 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
// Check we have a valid NaN propagation combination.
const auto opNanMode = op.getNanMode();
const auto clampNanMode = clampOp.getNanMode();
- if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
+ if (opNanMode == NanPropagationMode::IGNORE &&
+ clampNanMode == NanPropagationMode::PROPAGATE)
return failure();
auto maxValAttr = op.getMaxValAttr();
@@ -636,10 +638,16 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
}
}
+ auto newMode = (opNanMode != clampNanMode)
+ ? tosa::NanPropagationMode::IGNORE
+ : opNanMode;
+
+ auto newModeAttr =
+ NanPropagationModeAttr::get(rewriter.getContext(), newMode);
+
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
- rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
- : opNanMode));
+ newModeAttr);
return success();
}
};
@@ -1120,13 +1128,14 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
}
if (rhsTy == resultTy) {
- if (isSplatZero(resultETy, lhsAttr))
+ if (isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape())
+ // constant values can only be resized if resulting type is static
return lhsAttr.resizeSplat(resultTy);
if (isSplatOne(resultETy, lhsAttr, shift))
return rhs;
}
if (lhsTy == resultTy) {
- if (isSplatZero(resultETy, rhsAttr))
+ if (isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape())
return rhsAttr.resizeSplat(resultTy);
if (isSplatOne(resultETy, rhsAttr, shift))
return lhs;
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 3cafb19..bd7aee5 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -270,6 +270,244 @@ void mlir::tosa::printVariableOpTypeOrInitialValue(
}
}
+namespace {
+
+// parse attributes with special handling for tosa enum attributes
+template <typename EnumType>
+ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser,
+ NamedAttrList &outAttrs) {
+ llvm::StringRef name;
+ if (parser.parseOptionalKeyword(&name) || parser.parseEqual())
+ return failure();
+
+ // special handling: rounding_mode accepts a *bare* RoundingMode enum
+ // keyword.
+ llvm::StringRef kw;
+ if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
+ if (name == "rounding_mode" &&
+ succeeded(parser.parseOptionalKeyword(&kw))) {
+ auto sym = symbolizeRoundingMode(kw);
+ if (!sym)
+ return parser.emitError(parser.getCurrentLocation())
+ << "invalid rounding_mode value: " << kw;
+ auto attr = RoundingModeAttr::get(parser.getContext(), sym.value());
+ outAttrs.push_back(NamedAttribute(name, attr));
+ return success();
+ }
+ }
+ // special handling: mode accepts a *bare* ResizeMode enum keyword.
+ if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
+ if (name == "mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
+ auto sym = symbolizeResizeMode(kw);
+ if (!sym)
+ return parser.emitError(parser.getCurrentLocation())
+ << "invalid resize mode value: " << kw;
+ auto attr = ResizeModeAttr::get(parser.getContext(), sym.value());
+ outAttrs.push_back(NamedAttribute(name, attr));
+ return success();
+ }
+ }
+ // special handling: nan_mode accepts a *bare* NanPropagationMode enum
+ // keyword.
+ if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
+ if (name == "nan_mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
+ auto sym = symbolizeNanPropagationMode(kw);
+ if (!sym)
+ return parser.emitError(parser.getCurrentLocation())
+ << "invalid nan_mode value: " << kw;
+ auto attr = NanPropagationModeAttr::get(parser.getContext(), sym.value());
+ outAttrs.push_back(NamedAttribute(name, attr));
+ return success();
+ }
+ }
+
+ // Default path: parse any normal attribute literal, including fully qualified
+ // enum keyword
+ Attribute attr;
+ return parser.parseAttribute(attr, name, outAttrs);
+}
+
+template <typename EnumType>
+ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
+ // parse operands
+ SmallVector<OpAsmParser::UnresolvedOperand, 5> operands;
+ if (parser.parseCommaSeparatedList(
+ [&]() { return parser.parseOperand(operands.emplace_back()); }))
+ return failure();
+
+ // Parse { attr-dict } with special handling for enum bare token
+ NamedAttrList attrs;
+ if (succeeded(parser.parseOptionalLBrace()) &&
+ failed(parser.parseOptionalRBrace())) {
+ do {
+ if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
+ return failure();
+ } while (succeeded(parser.parseOptionalComma()));
+ if (parser.parseRBrace())
+ return failure();
+ }
+
+ FunctionType fnTy;
+ if (parser.parseColonType(fnTy))
+ return failure();
+
+ // Resolve operands and types
+ if (failed(parser.resolveOperands(operands, fnTy.getInputs(),
+ parser.getCurrentLocation(),
+ result.operands)))
+ return failure();
+
+ result.addTypes(fnTy.getResult(0));
+ result.addAttributes(attrs);
+
+ return success();
+}
+
+void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) {
+ parser << namedAttr.getName().strref() << " = ";
+ auto attr = namedAttr.getValue();
+ if (auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
+ parser << roundingModeAttr.getValue();
+ } else if (auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
+ parser << resizeModeAttr.getValue();
+ } else if (auto nanPropagationModeAttr =
+ dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
+ parser << nanPropagationModeAttr.getValue();
+ } else {
+ parser.printAttribute(attr);
+ }
+}
+
+// print with special handling for default valued NanPropagationMode attribute
+void printWithNanPropagationHandling(OpAsmPrinter &parser, Operation *op) {
+ parser << " ";
+ parser.printOperands(op->getOperands());
+
+ NamedAttrList toPrint(op->getAttrs());
+ // remove default NanPropagate attribute
+ const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
+ for (auto attr : op->getAttrs()) {
+ if (auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
+ if (nanAttr.getValue() == kDefaultNanValue) {
+ // elide from toPrint
+ toPrint.erase(attr.getName());
+ break;
+ }
+ }
+ }
+
+ if (!toPrint.empty()) {
+ parser << " {";
+ llvm::interleaveComma(toPrint, parser, [&](const NamedAttribute namedAttr) {
+ printNamedAttr(parser, namedAttr);
+ });
+ parser << "}";
+ }
+
+ parser << " : ";
+ parser.printFunctionalType(op);
+}
+
+// print with special handling for enums: RoundingMode, ResizeMode
+void printWithEnumHandling(OpAsmPrinter &parser, Operation *op) {
+ parser << " ";
+ parser.printOperands(op->getOperands());
+
+ if (!op->getAttrs().empty()) {
+ parser << " {";
+ llvm::interleaveComma(op->getAttrs(), parser,
+ [&](const NamedAttribute namedAttr) {
+ printNamedAttr(parser, namedAttr);
+ });
+ parser << "}";
+ }
+
+ parser << " : ";
+ parser.printFunctionalType(op);
+}
+
+} // namespace
+
+ParseResult RescaleOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
+}
+
+void RescaleOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
+ParseResult ApplyScaleOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
+}
+
+void ApplyScaleOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
+ParseResult ResizeOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::ResizeMode>(parser, result);
+}
+
+void ResizeOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
+ParseResult ArgMaxOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void ArgMaxOp::print(OpAsmPrinter &parser) {
+ printWithNanPropagationHandling(parser, *this);
+}
+
+ParseResult MaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void MaxPool2dOp::print(OpAsmPrinter &parser) {
+ printWithNanPropagationHandling(parser, *this);
+}
+
+ParseResult ClampOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void ClampOp::print(OpAsmPrinter &parser) {
+ printWithNanPropagationHandling(parser, *this);
+}
+
+ParseResult MaximumOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void MaximumOp::print(OpAsmPrinter &parser) {
+ printWithNanPropagationHandling(parser, *this);
+}
+
+ParseResult MinimumOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void MinimumOp::print(OpAsmPrinter &parser) {
+ printWithNanPropagationHandling(parser, *this);
+}
+
+ParseResult ReduceMaxOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void ReduceMaxOp::print(OpAsmPrinter &parser) {
+ printWithNanPropagationHandling(parser, *this);
+}
+
+ParseResult ReduceMinOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void ReduceMinOp::print(OpAsmPrinter &parser) {
+ printWithNanPropagationHandling(parser, *this);
+}
+
//===----------------------------------------------------------------------===//
// Tosa utilities.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index c7b9534..790bbf7 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -508,14 +508,15 @@ private:
bool attributeCheckRescale(Operation *op) {
if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
- if (rescale.getRoundingMode() == "DOUBLE_ROUND" &&
+ if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
!targetEnv.allows(Extension::doubleround)) {
op->emitOpError()
<< "failed attribute check: rounding_mode = DOUBLE_ROUND "
<< "requires extension [doubleround]";
return false;
- } else if (rescale.getRoundingMode() == "INEXACT_ROUND" &&
- !targetEnv.allows(Extension::inexactround)) {
+ }
+ if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
+ !targetEnv.allows(Extension::inexactround)) {
op->emitOpError()
<< "failed attribute check: rounding_mode = INEXACT_ROUND "
<< "requires extension [inexactround]";
@@ -1122,7 +1123,7 @@ bool checkErrorIfRescale(Operation *op) {
}
// ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
- if (!scale32 && roundingMode == "DOUBLE_ROUND") {
+ if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND) {
op->emitOpError() << "DOUBLE_ROUND is only allowed with scale32=true.";
return false;
}
@@ -1307,7 +1308,8 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
if (isa<FloatType>(type)) {
return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
Float8E5M2Type>(type);
- } else if (auto intTy = dyn_cast<IntegerType>(type)) {
+ }
+ if (auto intTy = dyn_cast<IntegerType>(type)) {
if (intTy.isSignless()) {
switch (intTy.getWidth()) {
case 1:
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 9266a63..48df1a0 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -37,16 +37,13 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/InterleavedRange.h"
#include <optional>
#define DEBUG_TYPE "transform-dialect"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
-
#define DEBUG_TYPE_MATCHER "transform-matcher"
-#define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
-#define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
using namespace mlir;
@@ -182,8 +179,7 @@ transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
DiagnosedSilenceableFailure result =
state.applyTransform(cast<TransformOpInterface>(transform));
if (result.isSilenceableFailure()) {
- LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
- << "\n");
+ LDBG() << "alternative failed: " << result.getMessage();
failed = true;
break;
}
@@ -1155,12 +1151,10 @@ transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
std::optional<DiagnosedSilenceableFailure> maybeFailure;
for (Operation *root : state.getPayloadOps(getRoot())) {
WalkResult walkResult = root->walk([&](Operation *op) {
- DEBUG_MATCHER({
- DBGS_MATCHER() << "matching ";
- op->print(llvm::dbgs(),
- OpPrintingFlags().assumeVerified().skipRegions());
- llvm::dbgs() << " @" << op << "\n";
- });
+ LDBG(1, DEBUG_TYPE_MATCHER)
+ << "matching "
+ << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions())
+ << " @" << op;
// Try matching.
SmallVector<SmallVector<MappedValue>> mappings;
@@ -1172,8 +1166,8 @@ transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
if (diag.isDefiniteFailure())
return WalkResult::interrupt();
if (diag.isSilenceableFailure()) {
- DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
- << " failed: " << diag.getMessage());
+ LDBG(1, DEBUG_TYPE_MATCHER) << "matcher " << matcher.getName()
+ << " failed: " << diag.getMessage();
return WalkResult::advance();
}
@@ -1304,12 +1298,10 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
if (!getRestrictRoot() && op == root)
return WalkResult::advance();
- DEBUG_MATCHER({
- DBGS_MATCHER() << "matching ";
- op->print(llvm::dbgs(),
- OpPrintingFlags().assumeVerified().skipRegions());
- llvm::dbgs() << " @" << op << "\n";
- });
+ LDBG(1, DEBUG_TYPE_MATCHER)
+ << "matching "
+ << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions())
+ << " @" << op;
firstMatchArgument.clear();
firstMatchArgument.push_back(op);
@@ -1322,8 +1314,8 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
if (diag.isDefiniteFailure())
return WalkResult::interrupt();
if (diag.isSilenceableFailure()) {
- DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
- << " failed: " << diag.getMessage());
+ LDBG(1, DEBUG_TYPE_MATCHER) << "matcher " << matcher.getName()
+ << " failed: " << diag.getMessage();
continue;
}
@@ -2173,10 +2165,10 @@ DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
::std::optional<::mlir::Operation *> maybeCurrent,
transform::TransformResults &results, transform::TransformState &state) {
if (!maybeCurrent.has_value()) {
- DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; });
+ LDBG(1, DEBUG_TYPE_MATCHER) << "MatchOperationEmptyOp success";
return DiagnosedSilenceableFailure::success();
}
- DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; });
+ LDBG(1, DEBUG_TYPE_MATCHER) << "MatchOperationEmptyOp failure";
return emitSilenceableError() << "operation is not empty";
}
diff --git a/mlir/lib/Dialect/Transform/IR/Utils.cpp b/mlir/lib/Dialect/Transform/IR/Utils.cpp
index d666390..773eb13 100644
--- a/mlir/lib/Dialect/Transform/IR/Utils.cpp
+++ b/mlir/lib/Dialect/Transform/IR/Utils.cpp
@@ -11,6 +11,7 @@
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
using namespace mlir;
@@ -90,7 +91,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
//
// Rename private symbols in both ops in order to resolve conflicts that can
// be resolved that way.
- LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
+ LDBG() << "renaming private symbols to resolve conflicts:";
// TODO: Do we *actually* need to test in both directions?
for (auto &&[symbolTable, otherSymbolTable] : llvm::zip(
SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable},
@@ -102,7 +103,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
if (!symbolOp)
continue;
StringAttr name = symbolOp.getNameAttr();
- LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n");
+ LDBG() << " found @" << name.getValue();
// Check if there is a colliding op in the other module.
auto collidingOp =
@@ -110,7 +111,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
if (!collidingOp)
continue;
- LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue());
+ LDBG() << " collision found for @" << name.getValue();
// Collisions are fine if both opt are functions and can be merged.
if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
@@ -119,13 +120,12 @@ transform::detail::mergeSymbolsInto(Operation *target,
funcOp && collidingFuncOp) {
if (canMergeInto(funcOp, collidingFuncOp) ||
canMergeInto(collidingFuncOp, funcOp)) {
- LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
- "will be merged\n");
+ LDBG() << " but both ops are functions and will be merged";
continue;
}
// If they can't be merged, proceed like any other collision.
- LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
+ LDBG() << " and both ops are function definitions";
}
// Collision can be resolved by renaming if one of the ops is private.
@@ -133,7 +133,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
[&](SymbolOpInterface op, SymbolOpInterface otherOp,
SymbolTable &symbolTable,
SymbolTable &otherSymbolTable) -> InFlightDiagnostic {
- LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
+ LDBG() << ", renaming";
FailureOr<StringAttr> maybeNewName =
symbolTable.renameToUnique(op, {&otherSymbolTable});
if (failed(maybeNewName)) {
@@ -142,8 +142,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
<< "attempted renaming due to collision with this op";
return diag;
}
- LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue()
- << "\n");
+ LDBG() << " renamed to @" << maybeNewName->getValue();
return InFlightDiagnostic();
};
@@ -161,7 +160,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
return diag;
continue;
}
- LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
+ LDBG() << ", emitting error";
InFlightDiagnostic diag = symbolOp.emitError()
<< "doubly defined symbol @" << name.getValue();
diag.attachNote(collidingOp->getLoc()) << "previously defined here";
@@ -179,7 +178,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
// Step 2:
//
// Move all ops from `other` into target and merge public symbols.
- LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
+ LDBG() << "moving all symbols into target";
{
SmallVector<SymbolOpInterface> opsToMove;
for (Operation &op : other->getRegion(0).front()) {
@@ -193,13 +192,13 @@ transform::detail::mergeSymbolsInto(Operation *target,
targetSymbolTable.lookup(op.getNameAttr()));
// Move op even if we get a collision.
- LLVM_DEBUG(DBGS() << " moving @" << op.getName());
+ LDBG() << " moving @" << op.getName();
op->moveBefore(&target->getRegion(0).front(),
target->getRegion(0).front().end());
// If there is no collision, we are done.
if (!collidingOp) {
- LLVM_DEBUG(llvm::dbgs() << " without collision\n");
+ LDBG() << " without collision";
continue;
}
@@ -217,9 +216,9 @@ transform::detail::mergeSymbolsInto(Operation *target,
}
assert(canMergeInto(funcOp, collidingFuncOp));
- LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
- << collidingFuncOp.getLoc() << ":\n"
- << collidingFuncOp << "\n");
+ LDBG() << " with collision, trying to keep op at "
+ << collidingFuncOp.getLoc() << ":\n"
+ << collidingFuncOp;
// Update symbol table. This works with or without the previous `swap`.
targetSymbolTable.remove(funcOp);
@@ -239,6 +238,6 @@ transform::detail::mergeSymbolsInto(Operation *target,
return target->emitError()
<< "failed to verify target op after merging symbols";
- LLVM_DEBUG(DBGS() << "done merging ops\n");
+ LDBG() << "done merging ops";
return InFlightDiagnostic();
}
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index 14a4fdf..4f4620a 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -312,7 +312,7 @@ LogicalResult transform::TransformState::setParams(Value value,
}
template <typename Mapping, typename Key, typename Mapped>
-void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
+static void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
auto it = mapping.find(key);
if (it == mapping.end())
return;
@@ -771,7 +771,7 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
}
template <typename T>
-DiagnosedSilenceableFailure
+static DiagnosedSilenceableFailure
checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
transform::TransformOpInterface transform,
unsigned operandNumber) {
diff --git a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp
index 41955c8..3ced1a6 100644
--- a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp
@@ -100,12 +100,7 @@ LogicalResult PatternApplicatorExtension::findAllMatches(
PatternApplicator applicator(it->second);
// We want to discourage direct use of PatternRewriter in APIs but In this
// very specific case, an IRRewriter is not enough.
- struct TrivialPatternRewriter : public PatternRewriter {
- public:
- explicit TrivialPatternRewriter(MLIRContext *context)
- : PatternRewriter(context) {}
- };
- TrivialPatternRewriter rewriter(root->getContext());
+ PatternRewriter rewriter(root->getContext());
applicator.applyDefaultCostModel();
root->walk([&](Operation *op) {
if (succeeded(applicator.matchAndRewrite(op, rewriter)))
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index e6ef028..34385d7 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -276,7 +276,7 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
if (!ubConstant)
return std::nullopt;
std::optional<int64_t> stepConstant = getConstantIntValue(step);
- if (!stepConstant)
+ if (!stepConstant || *stepConstant == 0)
return std::nullopt;
return llvm::divideCeilSigned(*ubConstant - *lbConstant, *stepConstant);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index cb4783d..9b2a455 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2402,6 +2402,16 @@ LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
return foldToElementsFromElements(*this, results);
}
+LogicalResult
+ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
+ ToElementsOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ auto vecType = cast<VectorType>(adaptor.getSource().getType());
+ Type elType = vecType.getElementType();
+ inferredReturnTypes.append(vecType.getNumElements(), elType);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// FromElementsOp
//===----------------------------------------------------------------------===//
@@ -2456,8 +2466,12 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
if (llvm::any_of(elements, [](Attribute attr) { return !attr; }))
return {};
+ // DenseElementsAttr only supports int/index/float/complex types.
auto destVecType = fromElementsOp.getDest().getType();
auto destEltType = destVecType.getElementType();
+ if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
+ return {};
+
// Constant attributes might have a different type than the return type.
// Convert them before creating the dense elements attribute.
auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
@@ -2768,8 +2782,8 @@ BroadcastableToResult mlir::vector::isBroadcastableTo(
Type srcType, VectorType dstVectorType,
std::pair<VectorDim, VectorDim> *mismatchingDims) {
// Broadcast scalar to vector of the same element type.
- if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
- getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
+ if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
+ srcType == getElementTypeOrSelf(dstVectorType))
return BroadcastableToResult::Success;
// From now on, only vectors broadcast.
VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
@@ -3276,6 +3290,18 @@ LogicalResult InsertOp::verify() {
return success();
}
+// Calculate the linearized position of the continuous chunk of elements to
+// insert, based on the shape of the value to insert and the positions to insert
+// at.
+static int64_t calculateInsertPosition(VectorType destTy,
+ ArrayRef<int64_t> positions) {
+ llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
+ assert(positions.size() <= completePositions.size() &&
+ "positions size must be less than or equal to destTy rank");
+ copy(positions, completePositions.begin());
+ return linearize(completePositions, computeStrides(destTy.getShape()));
+}
+
namespace {
// If insertOp is only inserting unit dimensions it can be transformed to a
@@ -3313,6 +3339,132 @@ public:
return success();
}
};
+
+/// Pattern to optimize a chain of insertions.
+///
+/// This pattern identifies chains of vector.insert operations that:
+/// 1. Only insert values at static positions.
+/// 2. Completely initialize all elements in the resulting vector.
+/// 3. All intermediate insert operations have only one use.
+///
+/// When these conditions are met, the entire chain can be replaced with a
+/// single vector.from_elements operation.
+///
+/// To keep this pattern simple, and avoid spending too much time on matching
+/// fragmented insert chains, this pattern only considers the last insert op in
+/// the chain.
+///
+/// Example transformation:
+/// %poison = ub.poison : vector<2xi32>
+/// %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
+/// %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
+/// ->
+/// %result = vector.from_elements %c1, %c2 : vector<2xi32>
+class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(InsertOp op,
+ PatternRewriter &rewriter) const override {
+
+ VectorType destTy = op.getDestVectorType();
+ if (destTy.isScalable())
+ return failure();
+ // Ensure this is the trailing vector.insert op in a chain of inserts.
+ for (Operation *user : op.getResult().getUsers())
+ if (auto insertOp = dyn_cast<InsertOp>(user))
+ if (insertOp.getDest() == op.getResult())
+ return failure();
+
+ InsertOp currentOp = op;
+ SmallVector<InsertOp> chainInsertOps;
+ while (currentOp) {
+ // Check cond 1: Dynamic position is not supported.
+ if (currentOp.hasDynamicPosition())
+ return failure();
+
+ chainInsertOps.push_back(currentOp);
+ currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
+ // Check cond 3: Intermediate inserts have only one use to avoid an
+ // explosion of vectors.
+ if (currentOp && !currentOp->hasOneUse())
+ return failure();
+ }
+
+ int64_t vectorSize = destTy.getNumElements();
+ int64_t initializedCount = 0;
+ SmallVector<bool> initializedDestIdxs(vectorSize, false);
+ SmallVector<int64_t> pendingInsertPos;
+ SmallVector<int64_t> pendingInsertSize;
+ SmallVector<Value> pendingInsertValues;
+
+ for (auto insertOp : chainInsertOps) {
+ // This pattern can do nothing with poison index.
+ if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
+ return failure();
+
+ // Calculate the linearized position for inserting elements.
+ int64_t insertBeginPosition =
+ calculateInsertPosition(destTy, insertOp.getStaticPosition());
+
+ // The valueToStore operand may be a vector or a scalar. Need to handle
+ // both cases.
+ int64_t insertSize = 1;
+ if (auto srcVectorType =
+ llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
+ insertSize = srcVectorType.getNumElements();
+
+ assert(insertBeginPosition + insertSize <= vectorSize &&
+ "insert would overflow the vector");
+
+ for (auto index : llvm::seq<int64_t>(insertBeginPosition,
+ insertBeginPosition + insertSize)) {
+ if (initializedDestIdxs[index])
+ continue;
+ initializedDestIdxs[index] = true;
+ ++initializedCount;
+ }
+
+ // Defer the creation of ops before we can make sure the pattern can
+ // succeed.
+ pendingInsertPos.push_back(insertBeginPosition);
+ pendingInsertSize.push_back(insertSize);
+ pendingInsertValues.push_back(insertOp.getValueToStore());
+
+ if (initializedCount == vectorSize)
+ break;
+ }
+
+ // Check cond 2: all positions must be initialized.
+ if (initializedCount != vectorSize)
+ return failure();
+
+ SmallVector<Value> elements(vectorSize);
+ for (auto [insertBeginPosition, insertSize, valueToStore] :
+ llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
+ pendingInsertValues))) {
+ auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
+
+ if (!srcVectorType) {
+ elements[insertBeginPosition] = valueToStore;
+ continue;
+ }
+
+ SmallVector<Type> elementToInsertTypes(insertSize,
+ srcVectorType.getElementType());
+ // Get all elements from the vector in row-major order.
+ auto elementsToInsert = rewriter.create<vector::ToElementsOp>(
+ op.getLoc(), elementToInsertTypes, valueToStore);
+ for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
+ elements[insertBeginPosition + linearIdx] =
+ elementsToInsert.getResult(linearIdx);
+ }
+ }
+
+ rewriter.replaceOpWithNewOp<vector::FromElementsOp>(op, destTy, elements);
+ return success();
+ }
+};
+
} // namespace
static Attribute
@@ -3339,13 +3491,9 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
!insertOp->hasOneUse())
return {};
- // Calculate the linearized position of the continuous chunk of elements to
- // insert.
- llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
- copy(insertOp.getStaticPosition(), completePositions.begin());
+ // Calculate the linearized position for inserting elements.
int64_t insertBeginPosition =
- linearize(completePositions, computeStrides(destTy.getShape()));
-
+ calculateInsertPosition(destTy, insertOp.getStaticPosition());
SmallVector<Attribute> insertedValues;
Type destEltType = destTy.getElementType();
@@ -3381,7 +3529,8 @@ static Value foldInsertUseChain(InsertOp insertOp) {
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
+ results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
+ InsertChainFullyInitialized>(context);
}
OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
@@ -5637,7 +5786,7 @@ LogicalResult GatherOp::verify() {
if (resVType.getElementType() != baseType.getElementType())
return emitOpError("base and result element type should match");
- if (llvm::size(getIndices()) != baseType.getRank())
+ if (llvm::size(getOffsets()) != baseType.getRank())
return emitOpError("requires ") << baseType.getRank() << " indices";
if (resVType.getShape() != indVType.getShape())
return emitOpError("expected result dim to match indices dim");
@@ -5709,11 +5858,11 @@ public:
if (!isa<MemRefType>(op.getBase().getType()))
return rewriter.notifyMatchFailure(op, "base must be of memref type");
- if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
+ if (failed(isZeroBasedContiguousSeq(op.getIndices())))
return failure();
rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
- op.getIndices(), op.getMask(),
+ op.getOffsets(), op.getMask(),
op.getPassThru());
return success();
}
@@ -5737,7 +5886,7 @@ LogicalResult ScatterOp::verify() {
if (valueVType.getElementType() != memType.getElementType())
return emitOpError("base and valueToStore element type should match");
- if (llvm::size(getIndices()) != memType.getRank())
+ if (llvm::size(getOffsets()) != memType.getRank())
return emitOpError("requires ") << memType.getRank() << " indices";
if (valueVType.getShape() != indVType.getShape())
return emitOpError("expected valueToStore dim to match indices dim");
@@ -5772,11 +5921,11 @@ public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ScatterOp op,
PatternRewriter &rewriter) const override {
- if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
+ if (failed(isZeroBasedContiguousSeq(op.getIndices())))
return failure();
rewriter.replaceOpWithNewOp<MaskedStoreOp>(
- op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
+ op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
return success();
}
};
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 2d5cc07..fe066dc 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -139,6 +139,11 @@ void transform::ApplyLowerGatherPatternsOp::populatePatterns(
vector::populateVectorGatherLoweringPatterns(patterns);
}
+void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateVectorFromElementsLoweringPatterns(patterns);
+}
+
void transform::ApplyLowerScanPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorScanLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 6619619..546099c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -162,7 +162,7 @@ struct GatherOpInterface
return failure();
replaceOpWithNewBufferizedOp<vector::GatherOp>(
rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
- gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(),
+ gatherOp.getOffsets(), gatherOp.getIndices(), gatherOp.getMask(),
gatherOp.getPassThru());
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 9e287fc..acbf2b7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorBitCast.cpp
LowerVectorBroadcast.cpp
LowerVectorContract.cpp
+ LowerVectorFromElements.cpp
LowerVectorGather.cpp
LowerVectorInterleave.cpp
LowerVectorMask.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp
new file mode 100644
index 0000000..c22fd54
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp
@@ -0,0 +1,65 @@
+//===- LowerVectorFromElements.cpp - Lower 'vector.from_elements' op -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.from_elements' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+
+#define DEBUG_TYPE "lower-vector-from-elements"
+
+using namespace mlir;
+
+namespace {
+
+/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
+/// outermost dimension. For example:
+/// ```
+/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
+///
+/// ==>
+///
+/// %0 = ub.poison : vector<2x3xf32>
+/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
+/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
+/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
+/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
+/// ```
+///
+/// When applied exhaustively, this will produce a sequence of 1-d from_elements
+/// ops.
+struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::FromElementsOp op,
+ PatternRewriter &rewriter) const override {
+ ValueRange allElements = op.getElements();
+
+ auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
+ VectorType subTy, int64_t index) {
+ size_t subTyNumElements = subTy.getNumElements();
+ assert((index + 1) * subTyNumElements <= allElements.size() &&
+ "out of bounds");
+ ValueRange subElements =
+ allElements.slice(index * subTyNumElements, subTyNumElements);
+ return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
+ };
+
+ return unrollVectorOp(op, rewriter, unrollFromElementsFn);
+ }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorFromElementsLoweringPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<UnrollFromElements>(patterns.getContext(), benefit);
+}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index e062f55..9830189 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -54,27 +54,13 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
LogicalResult matchAndRewrite(vector::GatherOp op,
PatternRewriter &rewriter) const override {
- VectorType resultTy = op.getType();
- if (resultTy.getRank() < 2)
- return rewriter.notifyMatchFailure(op, "already 1-D");
-
- // Unrolling doesn't take vscale into account. Pattern is disabled for
- // vectors with leading scalable dim(s).
- if (resultTy.getScalableDims().front())
- return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
-
- Location loc = op.getLoc();
- Value indexVec = op.getIndexVec();
+ Value indexVec = op.getIndices();
Value maskVec = op.getMask();
Value passThruVec = op.getPassThru();
- Value result = arith::ConstantOp::create(rewriter, loc, resultTy,
- rewriter.getZeroAttr(resultTy));
-
- VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
-
- for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
- int64_t thisIdx[1] = {i};
+ auto unrollGatherFn = [&](PatternRewriter &rewriter, Location loc,
+ VectorType subTy, int64_t index) {
+ int64_t thisIdx[1] = {index};
Value indexSubVec =
vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
@@ -82,15 +68,12 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx);
Value passThruSubVec =
vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx);
- Value subGather = vector::GatherOp::create(
- rewriter, loc, subTy, op.getBase(), op.getIndices(), indexSubVec,
- maskSubVec, passThruSubVec);
- result =
- vector::InsertOp::create(rewriter, loc, subGather, result, thisIdx);
- }
+ return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(),
+ op.getOffsets(), indexSubVec, maskSubVec,
+ passThruSubVec);
+ };
- rewriter.replaceOp(op, result);
- return success();
+ return unrollVectorOp(op, rewriter, unrollGatherFn);
}
};
@@ -158,18 +141,18 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
// 2. Generate new gather indices that will model the
// strided access.
IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
- VectorType vType = op.getIndexVec().getType();
+ VectorType vType = op.getIndices().getType();
Value mulCst = arith::ConstantOp::create(
rewriter, op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
Value newIdxs =
- arith::MulIOp::create(rewriter, op.getLoc(), op.getIndexVec(), mulCst);
+ arith::MulIOp::create(rewriter, op.getLoc(), op.getIndices(), mulCst);
// 3. Create an updated gather op with the collapsed input memref and the
// updated indices.
Value newGather = vector::GatherOp::create(
rewriter, op.getLoc(), op.getResult().getType(), collapsed,
- op.getIndices(), newIdxs, op.getMask(), op.getPassThru());
+ op.getOffsets(), newIdxs, op.getMask(), op.getPassThru());
rewriter.replaceOp(op, newGather);
return success();
@@ -212,8 +195,8 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
- op.getIndexVec());
- auto baseOffsets = llvm::to_vector(op.getIndices());
+ op.getIndices());
+ auto baseOffsets = llvm::to_vector(op.getOffsets());
Value lastBaseOffset = baseOffsets.back();
Value result = op.getPassThru();
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 45ef7f0..5617b06 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -269,7 +269,7 @@ public:
// Replace the `vector.mask` operation.
rewriter.replaceOpWithNewOp<GatherOp>(
maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
- gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(),
+ gatherOp.getOffsets(), gatherOp.getIndices(), maskingOp.getMask(),
passthru);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index bb0f339..c84eb2c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -528,8 +528,7 @@ struct WarpOpTransferWrite : public WarpDistributionPattern {
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
Operation *lastNode = yield->getPrevNode();
auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
if (!writeOp)
@@ -706,6 +705,52 @@ struct WarpOpConstant : public WarpDistributionPattern {
}
};
+/// Sink out step op feeding into a warp op yield.
+/// Vector step op is treated similar to arith.constant, apart from
+/// the result that represents a sequence [0, vec_size).
+/// Due to the to vec_size == warp_size limitation,
+/// we can simply wrap the lane id into a vector (i.e., broadcast).
+/// Supporting vec_size != warp_size may involve preserving the step
+/// result and using additional arith ops (the exact details are TBD).
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xindex>) {
+/// ...
+/// %cst = vector.step : vector<32xindex>
+/// gpu.yield %cst : vector<1xindex>
+/// }
+/// ```
+/// To
+/// ```
+/// gpu.warp_execute_on_lane_0(%arg0) {
+/// ...
+/// }
+/// %lane_id_vec = vector.broadcast %arg0 : index to vector<1xindex>
+struct WarpOpStep final : public WarpDistributionPattern {
+ using Base::Base;
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *yieldOperand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
+ if (!yieldOperand)
+ return failure();
+ const unsigned operandIdx = yieldOperand->getOperandNumber();
+ auto stepOp = yieldOperand->get().getDefiningOp<vector::StepOp>();
+ VectorType resTy = stepOp.getResult().getType();
+ if (resTy.getNumElements() != static_cast<int64_t>(warpOp.getWarpSize()))
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ llvm::formatv("Expected result size ({0}) to be of warp size ({1})",
+ resTy.getNumElements(), warpOp.getWarpSize()));
+ VectorType newVecTy =
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
+ rewriter.setInsertionPointAfter(warpOp);
+ Value laneIdVec = vector::BroadcastOp::create(rewriter, warpOp.getLoc(),
+ newVecTy, warpOp.getLaneid());
+ rewriter.replaceAllUsesWith(warpOp.getResult(operandIdx), laneIdVec);
+ return success();
+ }
+};
+
/// Sink out transfer_read op feeding into a warp op yield.
/// ```
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
@@ -846,8 +891,7 @@ struct WarpOpDeadResult : public WarpDistributionPattern {
newYieldValues.reserve(warpOp->getNumResults());
DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
DenseMap<OpResult, int64_t> dedupResultPositionMap;
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
// Some values may be yielded multiple times and correspond to multiple
// results. Deduplicating occurs by taking each result with its matching
@@ -901,8 +945,7 @@ struct WarpOpForwardOperand : public WarpDistributionPattern {
using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
Value valForwarded;
unsigned resultIndex;
for (OpOperand &operand : yield->getOpOperands()) {
@@ -1708,8 +1751,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto warpOpYield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp warpOpYield = warpOp.getTerminator();
// Only pick up `ForOp` if it is the last op in the region.
Operation *lastNode = warpOpYield->getPrevNode();
auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
@@ -1826,7 +1868,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
rewriter.setInsertionPointAfter(newWarpOp);
auto newForOp = scf::ForOp::create(
rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), newForOpOperands);
+ forOp.getStep(), newForOpOperands, /*bodyBuilder=*/nullptr,
+ forOp.getUnsignedCmp());
// Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
// newly created `ForOp`. This `WarpOp` will contain all ops that were
// contained within the original `ForOp` body.
@@ -2019,7 +2062,7 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
- WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>(
+ WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
patterns.getContext(), benefit);
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
benefit);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 491b448..7dde631 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -762,6 +762,42 @@ struct LinearizeVectorStore final
}
};
+/// This pattern linearizes `vector.from_elements` operations by converting
+/// the result type to a 1-D vector while preserving all element values.
+/// The transformation creates a linearized `vector.from_elements` followed by
+/// a `vector.shape_cast` to restore the original multidimensional shape.
+///
+/// Example:
+///
+/// %0 = vector.from_elements %a, %b, %c, %d : vector<2x2xf32>
+///
+/// is converted to:
+///
+/// %0 = vector.from_elements %a, %b, %c, %d : vector<4xf32>
+/// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32>
+///
+struct LinearizeVectorFromElements final
+ : public OpConversionPattern<vector::FromElementsOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorFromElements(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+ LogicalResult
+ matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType dstTy =
+ getTypeConverter()->convertType<VectorType>(fromElementsOp.getType());
+ assert(dstTy && "vector type destination expected.");
+
+ OperandRange elements = fromElementsOp.getElements();
+ assert(elements.size() == static_cast<size_t>(dstTy.getNumElements()) &&
+ "expected same number of elements");
+ rewriter.replaceOpWithNewOp<vector::FromElementsOp>(fromElementsOp, dstTy,
+ elements);
+ return success();
+ }
+};
+
} // namespace
/// This method defines the set of operations that are linearizable, and hence
@@ -854,7 +890,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
- LinearizeVectorStore>(typeConverter, patterns.getContext());
+ LinearizeVectorStore, LinearizeVectorFromElements>(
+ typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index c707f38..369857f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -98,8 +98,9 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
// If the user has already been processed skip.
if (!processed.insert(user).second)
continue;
- if (isa<ViewLikeOpInterface>(user)) {
- users.append(user->getUsers().begin(), user->getUsers().end());
+ if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
+ Value viewDest = viewLike.getViewDest();
+ users.append(viewDest.getUsers().begin(), viewDest.getUsers().end());
continue;
}
if (isMemoryEffectFree(user))
@@ -182,8 +183,9 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
// If the user has already been processed skip.
if (!processed.insert(user).second)
continue;
- if (isa<ViewLikeOpInterface>(user)) {
- users.append(user->getUsers().begin(), user->getUsers().end());
+ if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
+ Value viewDest = viewLike.getViewDest();
+ users.append(viewDest.getUsers().begin(), viewDest.getUsers().end());
continue;
}
if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 2269a40..dbb5eb3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -600,7 +600,7 @@ struct BubbleDownVectorBitCastForExtract
// Get the first element of the mixed position as integer.
auto mixedPos = extractOp.getMixedPosition();
- if (mixedPos.size() > 0 && !isa<Attribute>(mixedPos[0]))
+ if (!mixedPos.empty() && !isa<Attribute>(mixedPos[0]))
return failure();
uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt();
@@ -2274,7 +2274,7 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
LogicalResult matchAndRewrite(MulOpType mulOp,
PatternRewriter &rewriter) const override {
- auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
+ auto resType = llvm::dyn_cast<VectorType>(mulOp.getResult().getType());
if (!resType)
return failure();
if (resType.getRank() != 2)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 501abec..e8ecb0c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -640,7 +640,7 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
// decomposed shape from each of the index, mask, and pass-through
// vectors.
Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
- loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
+ loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
Value passThruSubVec =
@@ -648,7 +648,7 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
strides);
auto slicedGather = vector::GatherOp::create(
- rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
+ rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
indexSubVec, maskSubVec, passThruSubVec);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 6e2fa35..841e138 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -392,3 +392,29 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
}
return success();
}
+
+LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
+ vector::UnrollVectorOpFn unrollFn) {
+ assert(op->getNumResults() == 1 && "expected single result");
+ assert(isa<VectorType>(op->getResult(0).getType()) && "expected vector type");
+ VectorType resultTy = cast<VectorType>(op->getResult(0).getType());
+ if (resultTy.getRank() < 2)
+ return rewriter.notifyMatchFailure(op, "already 1-D");
+
+ // Unrolling doesn't take vscale into account. Pattern is disabled for
+ // vectors with leading scalable dim(s).
+ if (resultTy.getScalableDims().front())
+ return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
+
+ Location loc = op->getLoc();
+ Value result = ub::PoisonOp::create(rewriter, loc, resultTy);
+ VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
+
+ for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
+ Value subVector = unrollFn(rewriter, loc, subTy, i);
+ result = vector::InsertOp::create(rewriter, loc, subVector, result, i);
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+}
diff --git a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
index 7c6a4f3..7869a28 100644
--- a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
@@ -17,6 +17,8 @@ add_mlir_dialect_library(MLIRXeGPUDialect
MLIRAffineUtils
MLIRArithUtils
MLIRDialectUtils
+ MLIRGPUDialect
+ MLIRXeVMDialect
MLIRIR
MLIRViewLikeInterface
MLIRVectorDialect
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index d997296..7f3be7f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -67,7 +67,7 @@ genOffsetsComputingInsts(OpBuilder &builder, Location loc,
StaticTileOffsetRange(sizePerWg, distUnit)) {
SmallVector<Value> base =
llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
- return builder.create<arith::ConstantIndexOp>(loc, d);
+ return arith::ConstantIndexOp::create(builder, loc, d);
});
SmallVector<Value> adds = llvm::map_to_vector(
@@ -80,7 +80,7 @@ genOffsetsComputingInsts(OpBuilder &builder, Location loc,
llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value {
return builder.createOrFold<index::RemUOp>(
loc, std::get<0>(t),
- builder.create<arith::ConstantIndexOp>(loc, std::get<1>(t)));
+ arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
});
offsets.push_back(mods);
@@ -91,7 +91,7 @@ genOffsetsComputingInsts(OpBuilder &builder, Location loc,
// Checks if the given shape can be evenly distributed based on the layout
// and data factors provided by the LayoutAttr.
bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
- xegpu::LayoutAttr attr) {
+ xegpu::DistributeLayoutAttr attr) {
assert(attr && "Layout attribute is missing.");
// Checks whether the given shape can be evenly distributed using the
@@ -104,52 +104,51 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
// smaller than `layout[i] * data[i]`, allowing multiple compute units to
// share the data.
auto tryDistribute = [&](llvm::ArrayRef<int64_t> shape,
- DenseI32ArrayAttr layout, DenseI32ArrayAttr data,
+ SmallVector<int64_t> layout,
+ SmallVector<int64_t> data,
bool rr = true) -> optional<SmallVector<int64_t>> {
llvm::SmallVector<int64_t> newShape(shape);
- if (layout) {
- auto vec = llvm::to_vector_of<int64_t>(layout.asArrayRef());
- if (vec.size() != shape.size())
+ if (layout.size()) {
+ if (layout.size() != shape.size())
return std::nullopt;
- auto ratio = computeShapeRatio(shape, vec);
+ auto ratio = computeShapeRatio(shape, layout);
if (!ratio.has_value())
return std::nullopt;
newShape = ratio.value();
}
- if (data) {
- auto vec = llvm::to_vector_of<int64_t>(data.asArrayRef());
- if (vec.size() != shape.size())
+ if (data.size()) {
+ if (data.size() != shape.size())
return std::nullopt;
- auto ratio = computeShapeRatio(newShape, vec);
+ auto ratio = computeShapeRatio(newShape, data);
if (!ratio.has_value() && rr)
- ratio = computeShapeRatio(vec, newShape);
+ ratio = computeShapeRatio(data, newShape);
if (!ratio.has_value())
return std::nullopt;
// if data is not null, we always return it for next phase.
- newShape = vec;
+ newShape = data;
}
return newShape;
};
// check the sgLayout and sgData
auto maybeSgShape =
- tryDistribute(shape, attr.getSgLayout(), attr.getSgData());
+ tryDistribute(shape, attr.getSgLayoutAsInt(), attr.getSgDataAsInt());
if (!maybeSgShape)
return false;
auto sgShape = maybeSgShape.value();
// check InstData, it neither have layout nor need round-robin
auto maybeInstShape =
- tryDistribute(sgShape, nullptr, attr.getInstData(), false);
+ tryDistribute(sgShape, {}, attr.getInstDataAsInt(), false);
if (!maybeInstShape)
return false;
auto instShape = maybeInstShape.value();
// check LaneLayout and LaneData
- auto maybeLaneShape =
- tryDistribute(instShape, attr.getLaneLayout(), attr.getLaneData(), false);
+ auto maybeLaneShape = tryDistribute(instShape, attr.getLaneLayoutAsInt(),
+ attr.getLaneDataAsInt(), false);
return maybeLaneShape.has_value();
}
@@ -271,7 +270,7 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
Value linearId) {
// delinearizeSubgroupId is only available for
// workgroup-level layout attribute
- if (!isWgLayout())
+ if (!isForWorkgroup())
return failure();
// TODO: handle order attribute
@@ -283,29 +282,30 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
if (!hasDefaultOrder())
return mlir::emitError(loc, "order attribute is currently not supported.");
- auto dims = llvm::map_to_vector(*getSgLayoutAsInt(), [&](int64_t d) -> Value {
+ auto dims = llvm::map_to_vector(getSgLayoutAsInt(), [&](int64_t d) -> Value {
return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
});
return affine::delinearizeIndex(builder, loc, linearId, dims);
}
-/// Implements LayoutTrait::getOffsets to generate instructions for
-/// computing multi-dimensional offsets when distributed by LayoutAttr.
+/// Implements DistributeLayoutAttr::getOffsets to generate
+/// instructions for computing multi-dimensional offsets when distributed by
+/// LayoutAttr.
FailureOr<SmallVector<SmallVector<Value>>>
LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
ArrayRef<int64_t> shape) {
- if (!isWgLayout())
+ if (!isForWorkgroup())
return failure();
- SmallVector<int64_t> sgLayout = getSgLayoutAsInt().value();
- SmallVector<int64_t> sgShape;
- if (auto maybeSgShape = getSgDataAsInt())
- sgShape = maybeSgShape.value();
- else if (auto derivedShape = computeShapeRatio(shape, sgLayout))
- sgShape = derivedShape.value();
- else
- return failure();
+ SmallVector<int64_t> sgLayout = getSgLayoutAsInt();
+ SmallVector<int64_t> sgShape = getSgDataAsInt();
+ if (sgShape.empty()) {
+ if (auto derivedShape = computeShapeRatio(shape, sgLayout))
+ sgShape = derivedShape.value();
+ else
+ return failure();
+ }
// delinearize Ids
auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
@@ -322,7 +322,7 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
//===----------------------------------------------------------------------===//
LogicalResult
SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
- xegpu::LayoutTrait parent, DenseI64ArrayAttr dims) {
+ xegpu::DistributeLayoutAttr parent, DenseI64ArrayAttr dims) {
if (!parent || !dims)
return emitError() << "expected parent layout and dims attribute";
@@ -340,7 +340,7 @@ SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
}
SliceAttr SliceAttr::flatten() const {
- xegpu::LayoutTrait parent = getParent();
+ xegpu::DistributeLayoutAttr parent = getParent();
SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
@@ -375,23 +375,24 @@ SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
return parent.delinearizeSubgroupId(builder, loc, linearId);
}
-/// Implements LayoutTrait::getOffsets to generate instructions for
-/// computing multi-dimensional offsets when distributed by SliceAttr.
+/// Implements DistributeLayoutAttr::getOffsets to generate
+/// instructions for computing multi-dimensional offsets when distributed by
+/// SliceAttr.
FailureOr<SmallVector<SmallVector<Value>>>
SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
ArrayRef<int64_t> shape) {
assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
- if (!isWgLayout())
+ if (!isForWorkgroup())
return failure();
- SmallVector<int64_t> sgLayout = getSgLayoutAsInt().value();
- SmallVector<int64_t> sgShape;
- if (auto maybeSgShape = getSgDataAsInt())
- sgShape = maybeSgShape.value();
- else if (auto derivedShape = computeShapeRatio(shape, sgLayout))
- sgShape = derivedShape.value();
- else
- return failure();
+ SmallVector<int64_t> sgLayout = getSgLayoutAsInt();
+ SmallVector<int64_t> sgShape = getSgDataAsInt();
+ if (sgShape.empty()) {
+ if (auto derivedShape = computeShapeRatio(shape, sgLayout))
+ sgShape = derivedShape.value();
+ else
+ return failure();
+ }
// delinearize Ids
auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
@@ -427,7 +428,7 @@ RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
// XeGPU_TensorDescType
//===----------------------------------------------------------------------===//
-mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
+mlir::Type TensorDescType::parse(AsmParser &parser) {
llvm::SmallVector<int64_t> shape;
mlir::Type elementType;
mlir::FailureOr<mlir::Attribute> encoding;
@@ -477,7 +478,7 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
layout.value_or(mlir::Attribute()));
}
-void TensorDescType::print(::mlir::AsmPrinter &printer) const {
+void TensorDescType::print(AsmPrinter &printer) const {
printer << "<";
auto shape = getShape();
@@ -522,10 +523,10 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
return Base::get(context, shape, elementType, attr, layout);
}
-LogicalResult TensorDescType::verify(
- llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
- llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
- mlir::Attribute encoding, mlir::Attribute layout) {
+LogicalResult
+TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
+ llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
+ mlir::Attribute encoding, mlir::Attribute layout) {
size_t rank = shape.size();
if (rank == 0)
@@ -591,6 +592,119 @@ LogicalResult TensorDescType::verify(
return success();
}
+//===----------------------------------------------------------------------===//
+// XeGPU_MemDescType
+//===----------------------------------------------------------------------===//
+mlir::Type MemDescType::parse(AsmParser &parser) {
+ llvm::SmallVector<int64_t> shape;
+ mlir::Type elementType;
+ mlir::FailureOr<MemLayoutAttr> layout;
+
+ // Parse literal '<'
+ if (parser.parseLess())
+ return {};
+
+ auto shapeLoc = parser.getCurrentLocation();
+ if (mlir::failed(parser.parseDimensionList(shape, false, true))) {
+ parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
+ return {};
+ }
+
+ auto elemTypeLoc = parser.getCurrentLocation();
+ if (mlir::failed(parser.parseType(elementType))) {
+ parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
+ return {};
+ }
+
+ // parse optional attributes
+ if (mlir::succeeded(parser.parseOptionalComma())) {
+ MemLayoutAttr attr;
+ ParseResult res = parser.parseAttribute(attr);
+ if (mlir::failed(res))
+ return {};
+ layout = attr;
+ }
+
+ // Parse literal '>'
+ if (parser.parseGreater())
+ return {};
+
+ MLIRContext *ctxt = parser.getContext();
+ return MemDescType::getChecked(
+ [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
+ elementType, layout.value_or(MemLayoutAttr()));
+}
+
+void MemDescType::print(AsmPrinter &printer) const {
+ printer << "<";
+
+ printer.printDimensionList(getShape());
+ printer << 'x';
+ printer << getElementType();
+
+ if (auto layout = getMemLayout())
+ printer << ", " << layout;
+
+ printer << ">";
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_MemDescType
+//===----------------------------------------------------------------------===//
+
+Attribute MemLayoutAttr::parse(AsmParser &parser, Type type) {
+
+ auto context = parser.getContext();
+ llvm::SMLoc loc = parser.getCurrentLocation();
+
+ llvm::SmallDenseSet<StringRef> seenKeys;
+ SmallVector<NamedAttribute> attributes;
+
+ auto parseElt = [&]() -> ParseResult {
+ StringRef nameId;
+ if (failed(parser.parseKeyword(&nameId)))
+ return parser.emitError(loc, "expected valid attribute name");
+
+ if (!seenKeys.insert(nameId).second)
+ return parser.emitError(loc, "duplicate key '")
+ << nameId << " in mem layout attribute";
+
+ if (failed(parser.parseEqual()))
+ return failure();
+
+ Attribute attr;
+ if (failed(parser.parseAttribute(attr)))
+ return failure();
+ attributes.emplace_back(nameId, attr);
+ return success();
+ };
+
+ // Parse literal '<'
+ if (parser.parseLess())
+ return {};
+
+ if (failed(parser.parseCommaSeparatedList(parseElt)))
+ return {};
+
+ // Parse literal '>'
+ if (parser.parseGreater())
+ return {};
+
+ return parser.getChecked<MemLayoutAttr>(
+ loc, context, DictionaryAttr::get(context, attributes));
+}
+
+void MemLayoutAttr::print(AsmPrinter &printer) const {
+ printer << "<";
+ ArrayRef<NamedAttribute> attrs = getAttrs().getValue();
+ for (size_t i = 0; i < attrs.size(); i++) {
+ printer << attrs[i].getName().str() << " = " << attrs[i].getValue();
+ if (i < attrs.size() - 1)
+ printer << ", ";
+ }
+ printer << ">";
+}
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index fc11fa8..aca6654 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
@@ -21,6 +23,17 @@
namespace mlir {
namespace xegpu {
+bool isSharedMemory(const MemRefType &memrefTy) {
+ Attribute attr = memrefTy.getMemorySpace();
+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
+ return intAttr.getInt() == 3;
+ if (auto memrefSpace = llvm::dyn_cast<MemorySpaceAttr>(attr))
+ return memrefSpace.getValue() == MemorySpace::SLM;
+ if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
+ return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
+ return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
+}
+
template <typename T>
static std::string makeString(T array, bool breakline = false) {
std::string buf;
@@ -45,13 +58,6 @@ static SmallVector<int64_t> getShapeOf(Type type) {
return shape;
}
-static int64_t getRankOf(Value val) {
- auto type = val.getType();
- if (auto ty = llvm::dyn_cast<ShapedType>(type))
- return ty.getRank();
- return 0;
-}
-
static bool isReadHintOrNone(const CachePolicyAttr &attr) {
if (!attr)
return true;
@@ -76,13 +82,18 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
if (!tdescTy.isScattered())
return emitError() << "Expects a scattered TensorDesc.";
- if (!valueTy)
- return emitError() << "Expecting a vector type result.";
+ auto chunkSize = tdescTy.getChunkSizeAsInt();
+ if (!valueTy) {
+ if (chunkSize > 1)
+ return emitError() << "Expecting chunk size == 1 for scalar result";
+ if (dyn_cast<VectorType>(maskTy))
+ return emitError() << "Expecting a vector type result.";
+ return success();
+ }
auto maskShape = getShapeOf(maskTy);
auto valueShape = getShapeOf(valueTy);
auto tdescShape = getShapeOf(tdescTy);
- auto chunkSize = tdescTy.getChunkSizeAsInt();
if (valueTy.getElementType() != tdescTy.getElementType())
return emitError()
@@ -111,25 +122,49 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
}
static LogicalResult
-isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy,
- int64_t chunkSize,
+isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
+ VectorType valueTy, int64_t chunkSize,
function_ref<InFlightDiagnostic()> emitError) {
- if (!valueTy)
- return emitError() << "Expecting a vector type result.";
+ auto maskVecTy = dyn_cast<VectorType>(maskTy);
+ auto offsetsVecTy = dyn_cast<VectorType>(offsetsTy);
+ if (!valueTy) {
+ if (chunkSize > 1)
+ return emitError() << "Expecting chunk size == 1 for scalar result";
+ if (maskVecTy || offsetsVecTy)
+ return emitError() << "Expecting scalar mask and offsets.";
+ else if (maskVecTy && offsetsVecTy)
+ return emitError() << "Expecting a vector type result.";
+ return success();
+ }
+ auto valueSize = valueTy.getNumElements();
+ // SIMT mode with scalar mask and offsets.
+ if (!maskVecTy && !offsetsVecTy) {
+ if (valueSize != chunkSize)
+ return emitError() << "value elements must match chunk size "
+ << chunkSize;
+ return success();
+ }
auto maskShape = getShapeOf(maskTy);
auto valueShape = getShapeOf(valueTy);
- // a valid shape for SIMT case
- if (valueTy.getRank() == 1) {
- if (valueTy.getNumElements() != chunkSize)
- return emitError() << "value elements must match chunk size " << chunkSize
- << " for SIMT code.";
- return success();
+ if (!maskVecTy)
+ return emitError() << "Expecting a vector type mask.";
+ int64_t maskSize = maskVecTy.getNumElements();
+
+ if (chunkSize > 1) {
+ if ((valueTy.getRank() == 1) && (valueSize != chunkSize))
+ return emitError() << "value elements must match chunk size "
+ << chunkSize;
+ } else {
+ if (valueSize != maskSize)
+ return emitError()
+ << "Mask should match value except the chunk size dim.";
}
-
llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
+ if (maskSize == 1)
+ return success();
if (chunkSize > 1)
expectedMaskShape.pop_back();
if (expectedMaskShape != maskShape)
@@ -156,41 +191,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
}
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
- Type tdesc, TypedValue<MemRefType> source,
+ Type tdesc, Value source,
llvm::ArrayRef<OpFoldResult> shape,
llvm::ArrayRef<OpFoldResult> strides) {
- assert(shape.size() && strides.size() && shape.size() == strides.size() &&
- "Shape and strides must be present and of equal size for ui64 "
- "initialization.");
+ Type srcTy = source.getType();
+ assert((isa<IntegerType, MemRefType>(srcTy)) &&
+ "Source has to be either int or memref.");
- llvm::SmallVector<int64_t> staticShape;
- llvm::SmallVector<int64_t> staticStrides;
llvm::SmallVector<Value> dynamicShape;
llvm::SmallVector<Value> dynamicStrides;
- dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
- dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-
- auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
- auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
-
- build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
- dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
- staticStridesAttr);
-}
-
-void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
- Type tdesc, TypedValue<IntegerType> source,
- llvm::ArrayRef<OpFoldResult> shape,
- llvm::ArrayRef<OpFoldResult> strides) {
- assert(shape.size() && strides.size() && shape.size() == strides.size() &&
- "Shape and strides must be present and of equal size for ui64 "
- "initialization.");
-
llvm::SmallVector<int64_t> staticShape;
llvm::SmallVector<int64_t> staticStrides;
- llvm::SmallVector<Value> dynamicShape;
- llvm::SmallVector<Value> dynamicStrides;
dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
@@ -198,6 +210,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
+ if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
+ auto memrefShape = memrefTy.getShape();
+ auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
+
+ // if shape and strides are from Memref, we don't need attributes for them
+ // to keep the IR print clean.
+ if (staticShape == memrefShape && staticStrides == memrefStrides) {
+ staticShapeAttr = DenseI64ArrayAttr();
+ staticStridesAttr = DenseI64ArrayAttr();
+ }
+ }
+
build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
staticStridesAttr);
@@ -265,8 +289,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
}
LogicalResult CreateNdDescOp::verify() {
- auto rank = (int64_t)getMixedOffsets().size();
- bool invalidRank = false;
+ size_t rank = getMixedSizes().size();
+ bool invalidRank = rank != getMixedStrides().size();
bool invalidElemTy = false;
// Memory space of created TensorDesc should match with the source.
@@ -280,31 +304,28 @@ LogicalResult CreateNdDescOp::verify() {
<< " Source: " << srcMemorySpace
<< ", TensorDesc: " << tdescMemorySpace;
+ if (size_t offsetRank = getMixedOffsets().size())
+ invalidRank |= (offsetRank != rank);
+
// check source type matches the rank if it is a memref.
// It also should have the same ElementType as TensorDesc.
- auto memrefTy = dyn_cast<MemRefType>(getSourceType());
- if (memrefTy) {
- invalidRank |= (memrefTy.getRank() != rank);
+ if (auto memrefTy = dyn_cast<MemRefType>(getSourceType()))
invalidElemTy |= memrefTy.getElementType() != getElementType();
- }
if (llvm::isa<IntegerType>(getSourceType())) {
// strides and shape must present for integer source.
if (getMixedStrides().empty() || getMixedSizes().empty())
- return emitOpError("Expecting strides and shape to be present for "
+ return emitOpError("expecting strides and shape to be present for "
"integer source.");
}
- // mismatches among shape, strides, and offsets are
- // already handeled by OffsetSizeAndStrideOpInterface.
- // So they are not check here.
if (invalidRank)
return emitOpError(
"Expecting the rank of shape, strides, offsets, and source (if source "
"is a memref) should match with each other.");
// check result TensorDesc rank
- if (getType().getRank() > rank)
+ if (getType().getRank() > (int64_t)rank)
return emitOpError(
"Expecting the TensorDesc rank is not greater than the "
"ranks of shape, strides, offsets or the memref source.");
@@ -360,13 +381,10 @@ ParseResult parseOptionalDynamicIndexList(
void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op,
OperandRange values,
DenseI64ArrayAttr integers) {
-
- if (!integers)
+ if (!integers || integers.empty())
return;
-
- return printDynamicIndexList(printer, op, values, integers,
- /*scalableFlags=*/{}, {},
- AsmParser::Delimiter::Square);
+ printDynamicIndexList(printer, op, values, integers,
+ /*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square);
}
//===----------------------------------------------------------------------===//
// XeGPU_PrefetchNdOp
@@ -381,6 +399,21 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
l1_hint, l2_hint, l3_hint);
}
+void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
+ Value tensorDesc, ArrayRef<OpFoldResult> offsets,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ SmallVector<Value> dynamicOffsets;
+ SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+
+ build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
+ l2_hint, l3_hint);
+}
+
LogicalResult PrefetchNdOp::verify() {
auto tdescTy = getTensorDescType();
if (tdescTy.isScattered())
@@ -423,6 +456,22 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
l3_hint);
}
+void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
+ Value tensorDesc, ArrayRef<OpFoldResult> offsets,
+ UnitAttr packed, DenseI64ArrayAttr transpose,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ SmallVector<Value> dynamicOffsets;
+ SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+
+ build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
+ packed, transpose, l1_hint, l2_hint, l3_hint);
+}
+
LogicalResult LoadNdOp::verify() {
auto tdescTy = getTensorDescType();
auto valueTy = getType();
@@ -529,6 +578,21 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
}
+void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
+ Value tensorDesc, ArrayRef<OpFoldResult> offsets,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ SmallVector<Value> dynamicOffsets;
+ SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+
+ build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
+ l1_hint, l2_hint, l3_hint);
+}
+
LogicalResult StoreNdOp::verify() {
auto dstTy = getTensorDescType(); // Tile
auto valTy = getValueType(); // Vector
@@ -635,10 +699,6 @@ void CreateDescOp::build(OpBuilder &builder, OperationState &state,
LogicalResult CreateDescOp::verify() {
auto tdescTy = getTensorDescType();
- if (getRankOf(getSource()) > 1)
- return emitOpError(
- "Expecting the source is a 1D memref or pointer (uint64_t).");
-
if (!tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
@@ -673,12 +733,14 @@ LogicalResult CreateDescOp::verify() {
LogicalResult PrefetchOp::verify() {
auto tdescTy = getTensorDescType();
- if (tdescTy && !tdescTy.isScattered())
- return emitOpError("Expects a scattered TensorDesc.\n");
+ if (!tdescTy && !getOffsets())
+ return emitOpError("Expects offsets.");
- if (!tdescTy && getRankOf(getSource()) > 1)
- return emitOpError(
- "Expecting the source is a 1D memref or pointer (uint64_t).");
+ if (tdescTy && getOffsets())
+ return emitOpError("offsets not allowed.");
+
+ if (tdescTy && !tdescTy.isScattered())
+ return emitOpError("Expects a scattered TensorDesc.");
if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -689,6 +751,13 @@ LogicalResult PrefetchOp::verify() {
if (!isReadHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
+ auto srcTy = getSourceType();
+ if (srcTy.isInteger() && !getOffsetAlignByteAttr())
+ return emitOpError("offset_align_byte is required with integer source.");
+
+ if (getOffsetAlignByteAttr() && !srcTy.isInteger())
+ return emitOpError("offset_align_byte only allowed with integer source.");
+
return success();
}
@@ -696,7 +765,8 @@ void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
- build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint);
+ build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint,
+ IntegerAttr{});
}
//===----------------------------------------------------------------------===//
@@ -707,13 +777,15 @@ LogicalResult LoadGatherOp::verify() {
auto maskTy = getMaskType();
auto valueTy = getValueType();
+ if (!tdescTy && !getOffsets())
+ return emitOpError("Expects offsets.");
+
+ if (tdescTy && getOffsets())
+ return emitOpError("offsets not allowed.");
+
if (tdescTy && !tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.");
- if (!tdescTy && getRankOf(getSource()) > 1)
- return emitOpError(
- "Expecting the source is a 1D memref or pointer (uint64_t).");
-
if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -730,10 +802,11 @@ LogicalResult LoadGatherOp::verify() {
uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
auto memTy = dyn_cast<MemRefType>(srcTy);
- if (memTy && (valueTy.getElementType() != memTy.getElementType()))
+ if (memTy && (getElementType() != memTy.getElementType()))
return emitError() << "Value should have the same element type as MemRef.";
- return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
+ auto offsetsTy = getOffsets().getType();
+ return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
[&]() { return emitOpError(); });
}
@@ -746,6 +819,22 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
l1_hint, l2_hint, l3_hint);
}
+void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
+ Type valueType, Value source,
+ ArrayRef<OpFoldResult> offsets, Value mask,
+ IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ auto loc = source.getLoc();
+ int64_t size = static_cast<int64_t>(offsets.size());
+ auto type = VectorType::get(size, builder.getIndexType());
+ auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+ auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+ build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
+ l2_hint, l3_hint);
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_StoreScatterOp
//===----------------------------------------------------------------------===//
@@ -754,12 +843,14 @@ LogicalResult StoreScatterOp::verify() {
auto maskTy = getMaskType();
auto valueTy = getValueType();
- if (tdescTy && !tdescTy.isScattered())
- return emitOpError("Expects a scattered TensorDesc.\n");
+ if (!tdescTy && !getOffsets())
+ return emitOpError("Expects offsets.");
- if (!tdescTy && getRankOf(getDest()) > 1)
- return emitOpError(
- "Expecting the dest is a 1D memref or pointer (uint64_t).");
+ if (tdescTy && getOffsets())
+ return emitOpError("offsets not allowed.");
+
+ if (tdescTy && !tdescTy.isScattered())
+ return emitOpError("Expects a scattered TensorDesc.");
if (!isWriteHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -778,10 +869,11 @@ LogicalResult StoreScatterOp::verify() {
uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
auto memTy = dyn_cast<MemRefType>(destTy);
- if (memTy && (valueTy.getElementType() != memTy.getElementType()))
+ if (memTy && (getElementType() != memTy.getElementType()))
return emitError() << "Value should have the same element type as MemRef.";
- return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
+ auto offsetsTy = getOffsets().getType();
+ return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
[&]() { return emitOpError(); });
}
@@ -794,6 +886,24 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
l2_hint, l3_hint);
}
+void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
+ Value value, Value dest,
+ ArrayRef<OpFoldResult> offsets, Value mask,
+ IntegerAttr chunk_size,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ auto loc = dest.getLoc();
+ int64_t size = static_cast<int64_t>(offsets.size());
+ auto type = VectorType::get(size, builder.getIndexType());
+ auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+ auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+ // Call the correct builder overload that does not expect result types.
+ build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
+ l3_hint);
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_UpdateOffsetOp
//===----------------------------------------------------------------------===//
@@ -888,8 +998,8 @@ LogicalResult ConvertLayoutOp::verify() {
// both input and target layouts should be WgLayout or SgLayout at the same
// time.
- if ((!srcLayout.isWgLayout() || !resLayout.isWgLayout()) &&
- (!srcLayout.isSgLayout() || !resLayout.isSgLayout()))
+ if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) &&
+ (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup()))
return emitOpError("expected input layout and target layout be WgLayout or "
"SgLayout at the same time.");
@@ -928,6 +1038,101 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<FoldConvertLayoutOp>(context);
}
+//===----------------------------------------------------------------------===//
+// XeGPU_LoadMatrixOp
+//===----------------------------------------------------------------------===//
+void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
+ TypedValue<MemDescType> memDesc,
+ llvm::ArrayRef<OpFoldResult> offsets,
+ DistributeLayoutAttr layout) {
+ llvm::SmallVector<Value> dynamicOffsets;
+ llvm::SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+ build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
+ layout);
+}
+
+LogicalResult LoadMatrixOp::verify() {
+ VectorType resTy = getRes().getType();
+ MemDescType mdescTy = getMemDesc().getType();
+
+ if (mdescTy.getRank() != 2)
+ return emitOpError("mem_desc must be 2D.");
+
+ ArrayRef<int64_t> valueShape = resTy.getShape();
+ ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+ if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape),
+ [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+ return emitOpError("result shape must not exceed mem_desc shape.");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_StoreMatrixOp
+//===----------------------------------------------------------------------===//
+void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
+ TypedValue<MemDescType> memDesc,
+ llvm::ArrayRef<OpFoldResult> offsets,
+ DistributeLayoutAttr layout) {
+ llvm::SmallVector<Value> dynamicOffsets;
+ llvm::SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+ build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
+ layout);
+}
+
+LogicalResult StoreMatrixOp::verify() {
+ VectorType dataTy = getData().getType();
+ MemDescType mdescTy = getMemDesc().getType();
+
+ if (mdescTy.getRank() != 2)
+ return emitOpError("mem_desc must be 2D.");
+
+ ArrayRef<int64_t> dataShape = dataTy.getShape();
+ ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+ if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
+ [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+ return emitOpError("data shape must not exceed mem_desc shape.");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_MemDescSubviewOp
+//===----------------------------------------------------------------------===//
+
+void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state,
+ Type resTy, Value src,
+ llvm::ArrayRef<OpFoldResult> offsets) {
+ llvm::SmallVector<Value> dynamicOffsets;
+ llvm::SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+ build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr);
+}
+
+LogicalResult MemDescSubviewOp::verify() {
+ MemDescType srcTy = getSrc().getType();
+ MemDescType resTy = getRes().getType();
+ ArrayRef<int64_t> srcShape = srcTy.getShape();
+ ArrayRef<int64_t> resShape = resTy.getShape();
+
+ if (srcTy.getRank() < resTy.getRank())
+ return emitOpError("result rank must not exceed source rank.");
+
+ if (llvm::any_of(
+ llvm::zip_equal(resShape, srcShape.take_back(resShape.size())),
+ [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+ return emitOpError("result shape must not exceed source shape.");
+
+ if (srcTy.getStrides() != resTy.getStrides())
+ return emitOpError("result must inherit the source strides.");
+
+ return success();
+}
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index d82c541..9ee002e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -84,9 +84,10 @@ struct ConvertLayoutOpPattern
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
PatternRewriter &rewriter) const override {
- xegpu::LayoutAttr input_layout = op.getInputLayoutAttr();
- xegpu::LayoutAttr target_layout = op.getTargetLayoutAttr();
- if (!input_layout.getInstData() || !target_layout.getInstData())
+ xegpu::DistributeLayoutAttr input_layout = op.getInputLayoutAttr();
+ xegpu::DistributeLayoutAttr target_layout = op.getTargetLayoutAttr();
+ if (input_layout.getInstDataAsInt().empty() ||
+ target_layout.getInstDataAsInt().empty())
return rewriter.notifyMatchFailure(op, "Not a target ConvertLayoutOp.");
input_layout = input_layout.dropInstData();
@@ -140,10 +141,11 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
else
value = (Value)operandOrResult;
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operandOrResult);
- if (layout && layout.isSgLayout()) {
- if (auto inst_data = layout.getInstData())
- return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(operandOrResult);
+ if (layout && layout.isForSubgroup()) {
+ if (!layout.getInstDataAsInt().empty())
+ return layout.getInstDataAsInt();
if (auto type = dyn_cast<ShapedType>(value.getType()))
return llvm::to_vector(type.getShape());
@@ -204,13 +206,15 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
// skip the op if any of its operands or results has workgroup level layouts
bool hasWgLayoutOperands =
llvm::any_of(op->getOpOperands(), [](OpOperand &opr) {
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(opr);
- return layout && layout.isWgLayout();
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(opr);
+ return layout && layout.isForWorkgroup();
});
bool hasWgLayoutResults =
llvm::any_of(op->getOpResults(), [](OpResult result) {
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result);
- return layout && layout.isWgLayout();
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(result);
+ return layout && layout.isForWorkgroup();
});
if (hasWgLayoutOperands || hasWgLayoutResults) {
LDBG() << "skip unrolling for op with workgroup level layout: " << *op;
@@ -220,8 +224,8 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
auto isUnrollable = [](Value value, ArrayRef<int64_t> tileShape) {
Type valTy = value.getType();
if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) {
- xegpu::LayoutAttr layout = tdescTy.getLayoutAttr();
- return layout && layout.getInstData();
+ xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
+ return layout && !layout.getInstDataAsInt().empty();
}
auto shapedType = dyn_cast<ShapedType>(valTy);
return shapedType && !llvm::equal(tileShape, shapedType.getShape());
@@ -247,7 +251,8 @@ void XeGPUBlockingPass::runOnOperation() {
// Preserve the LayoutAttr for each operand to the owner's DictionaryAttr.
// This ensures that the LayoutAttr remains accessible even if the defining
// operation is replaced.
- xegpu::setLayoutAttrs(op, [](Value v) { return xegpu::getLayoutAttr(v); });
+ xegpu::setDistributeLayoutAttrs(
+ op, [](Value v) { return xegpu::getDistributeLayoutAttr(v); });
auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
xegpu::LayoutAttr layout) {
@@ -272,7 +277,7 @@ void XeGPUBlockingPass::runOnOperation() {
auto layout =
llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding());
- if (layout && layout.isWgLayout())
+ if (layout && layout.isForWorkgroup())
return failure();
int count;
@@ -289,7 +294,7 @@ void XeGPUBlockingPass::runOnOperation() {
ArrayRef<int64_t> shape = type.getShape();
xegpu::LayoutAttr layout = type.getLayoutAttr();
- if (layout && layout.isWgLayout())
+ if (layout && layout.isForWorkgroup())
return failure();
int count;
@@ -377,7 +382,7 @@ void XeGPUBlockingPass::runOnOperation() {
if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
op->removeAttr(name);
if (!isa<LoopLikeOpInterface>(op))
- xegpu::setLayoutAttr(result, layout.dropInstData());
+ xegpu::setDistributeLayoutAttr(result, layout.dropInstData());
}
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index bef8804..5cb47b2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -718,7 +718,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
}
// If the result is a vector type, add a temporary layout attribute to the
// op.
- xegpu::setLayoutAttr(result, layout);
+ xegpu::setDistributeLayoutAttr(result, layout);
}
return success();
}
@@ -800,7 +800,7 @@ updateControlFlowOps(mlir::OpBuilder &builder,
// If the type is a vector type and this region argument is an OpResult,
// set the layout attribute on the OpResult.
if (auto result = dyn_cast<OpResult>(successorInput))
- xegpu::setLayoutAttr(result, successorOperandLayout);
+ xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
}
}
return success();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 2088c3c..dddb5ea 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -336,8 +336,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
Operation *lastNode = yield->getPrevNode();
auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
if (!storeOp)
@@ -449,8 +448,7 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
// Make sure the same load op is the last operation in the warp op body.
// This ensure that load op is not sinked earlier violating any barrier
// synchronizations.
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
return yield->getPrevNode() == op;
});
@@ -752,8 +750,7 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
Operation *lastNode = yield->getPrevNode();
auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
if (!prefetchOp)
@@ -794,8 +791,7 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
Operation *lastNode = yield->getPrevNode();
// The last node must be a gpu::BarrierOp.
auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
@@ -841,14 +837,15 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
if (!isa<VectorType>(operand.get().getType()))
continue;
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand);
+ auto layout =
+ xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(operand);
if (!layout) {
op->emitError("Could not find layout attribute for operand ")
<< operand.getOperandNumber() << " of operation " << op->getName();
signalPassFailure();
return;
}
- xegpu::setLayoutAttr(operand, layout);
+ xegpu::setDistributeLayoutAttr(operand, layout);
}
});
// Step 2: Move all operations of a GPU function inside
@@ -882,7 +879,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
if (vecRank == 0)
return AffineMap::get(val.getContext());
// Get the layout of the vector type.
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(val);
+ // TODO: support more layout types
+ auto layout = xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(val);
// If no layout is specified, assume the inner most dimension is distributed
// for now.
if (!layout)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 4a5525c..9f627c7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -34,38 +34,29 @@ using namespace mlir;
namespace {
-// Check if there is sg id range attached to the scf.if op.
-static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange,
- int64_t &endOfRange) {
- Operation *parent = op->getParentOp();
- // Find the outermost scf::IfOp with xegpu.sg_id_range.
+// Retrieve the RangeAttr if it is specified.
+static xegpu::RangeAttr getRangeSpecAttr(Operation *op) {
+ Operation *parent = op->getParentOfType<scf::IfOp>();
while (parent) {
- if (auto ifOp = dyn_cast<scf::IfOp>(parent)) {
- if (auto attr = llvm::dyn_cast_or_null<xegpu::RangeAttr>(
- ifOp->getAttr("sg_id_range"))) {
- startOfRange = attr.getStart().getInt();
- endOfRange = attr.getEnd().getInt();
- break;
- }
- }
- parent = parent->getParentOp();
+ if (auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
+ parent->getAttr("sg_id_range")))
+ return attr;
+ parent = parent->getParentOfType<scf::IfOp>();
}
- // Return false if startOfRange is 0
- return (startOfRange > 0 && endOfRange > startOfRange);
+ return {};
}
static std::pair<SmallVector<int64_t>, int>
-getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
+getSgShapeAndCount(ArrayRef<int64_t> shape,
+ xegpu::DistributeLayoutAttr layout) {
int count = 1;
SmallVector<int64_t> sgShape(shape);
-
- if (layout && layout.isWgLayout()) {
- DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout();
- auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
- if (DenseI32ArrayAttr sgDataAttr = layout.getSgData())
- sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
- else
- sgShape = computeShapeRatio(shape, sgLayout).value_or(sgShape);
+ if (layout && layout.isForWorkgroup()) {
+ SmallVector<int64_t> sgLayout = layout.getSgLayoutAsInt();
+ if (!layout.getSgDataAsInt().empty())
+ sgShape = layout.getSgDataAsInt();
+ else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout))
+ sgShape = *maybeDerivedSgData;
SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
// Clamp distUnit to the original shape to handle cases where data is
// shared among subgroups, which may cause distUnit to exceed the original
@@ -77,6 +68,67 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
return std::make_pair(sgShape, count);
}
+/// Utility helper for deriving a list of offsets for each sub-TensorDescs
+/// or sub-MemDescs to be accessed by current subgroup (sgId) based on the
+/// associated distribute layout attribute, the shape, subgroup id and the
+/// original offsets of the op
+template <
+ typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
+ xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
+static LogicalResult
+genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
+ SmallVector<SmallVector<OpFoldResult>> &offsetsList) {
+ Location loc = op.getLoc();
+ SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
+ // not applicable to ops without offsets operands.
+ if (origOffsets.empty())
+ return failure();
+
+ // not applicable to ops without workgroup layout attributes
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ Value sgId = rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr);
+
+ // verify and adjust the sgId if the range specifier is present
+ xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
+ if (sgIdRange) {
+ int64_t startOfRange = sgIdRange.getStart().getInt();
+ int64_t endOfRange = sgIdRange.getEnd().getInt();
+ // verify the RangeAttr against the layout attribute
+ if (layout.getNumSubgroups() != endOfRange - startOfRange)
+ return rewriter.notifyMatchFailure(
+ op, "sg_layout size must match the sg_id_range");
+ // adjust the sgId if necessary
+ if (startOfRange > 0) {
+ Value startOfRangeVal =
+ rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
+ sgId = rewriter.create<index::SubOp>(loc, sgId, startOfRangeVal);
+ }
+ }
+
+ // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
+ // descriptors to be accessed, based on the layout information.
+ ArrayRef<int64_t> wgShape = op.getDataShape();
+ auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ if (failed(maybeDescOffsets))
+ return failure();
+
+ // Compute the final global offsets for each accessed sub-tensor
+ // or sub-memory descriptor.
+ for (const auto &sgOffsets : *maybeDescOffsets) {
+ SmallVector<OpFoldResult> newOffsets = xegpu::addWithRightAligned(
+ rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets);
+ offsetsList.push_back(std::move(newOffsets));
+ }
+
+ // callback(offsetsList);
+ return success();
+}
+
/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
/// from a workgroup descriptor. It replaces the offsets and sizes with
/// appropriate values for the subgroup.
@@ -128,72 +180,72 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
LogicalResult
matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ SmallVector<SmallVector<OpFoldResult>> offsetsList;
+ if (failed(genOffsetsList(rewriter, op, offsetsList)))
+ return failure();
+
+ MLIRContext *ctx = op.getContext();
+ xegpu::TensorDescType tdescTy = op.getType();
+ ArrayRef<int64_t> wgShape = tdescTy.getShape();
+ Type elemTy = tdescTy.getElementType();
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ auto newTdescTy =
+ xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
+ layout.dropSgLayoutAndData());
+
+ SmallVector<Value> newOps;
+ for (auto offsets : offsetsList) {
+ auto newOp = xegpu::CreateNdDescOp::create(
+ rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
+ op.getMixedSizes(), op.getMixedStrides());
+
+ newOps.push_back(newOp);
+ }
+ rewriter.replaceOpWithMultiple(op, {newOps});
+
+ return success();
+ }
+};
+
+// This pattern transforms the CreateNdDescOp without offsets to create a
+// subgroup descriptor from a workgroup descriptor
+struct WgToSgCreateNdOpNoOffset
+ : public OpConversionPattern<xegpu::CreateNdDescOp> {
+ using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ // Check no offsets are specified.
+ if (!op.getMixedOffsets().empty())
+ return failure();
+
Location loc = op.getLoc();
MLIRContext *ctx = op.getContext();
xegpu::TensorDescType tdescTy = op.getType();
auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
- if (!layout)
+ if (!layout || !layout.isForWorkgroup())
return failure();
+
Type elemTy = tdescTy.getElementType();
ArrayRef<int64_t> wgShape = tdescTy.getShape();
- // sgLayout must be present for workgroup-level distribution.
- SmallVector<int64_t> sgLayout;
- if (auto sgLayoutAttr = layout.getSgLayout())
- sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
- else
- return rewriter.notifyMatchFailure(
- op, "sgLayout attribute is required in layout");
-
- // Get the subgroup ID
- Value linearSgId =
- gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-
- int64_t startOfRange = -1, endOfRange = -1;
- bool sgIdRangeSpecified =
- isSgIdRangeSpecified(op, startOfRange, endOfRange);
-
- if (sgIdRangeSpecified) {
- int64_t sgCount = endOfRange - startOfRange;
- if (computeProduct(sgLayout) != sgCount)
- return rewriter.notifyMatchFailure(
- op, "sg_layout size must match the sg_id_range");
- // Subtract startOfRange from the original subgroup id to get
- // the adjusted sg id
- Value startOfRangeVal =
- rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
- linearSgId =
- rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
- }
- auto maybeTdescOffsets =
- layout.getOffsets(rewriter, loc, linearSgId, wgShape);
- if (failed(maybeTdescOffsets))
- return failure();
-
- SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ SmallVector<int64_t> sgShape;
+ int count;
+ std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
xegpu::TensorDescType newTdescTy =
xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
layout.dropSgLayoutAndData());
- SmallVector<Value> newCreateNdOps;
- SmallVector<OpFoldResult> wgOffsets = op.getMixedOffsets();
-
- for (auto tdescOffsets : *maybeTdescOffsets) {
- SmallVector<OpFoldResult> sgOffsets;
- size_t rank = tdescOffsets.size();
- for (size_t i = 0; i < rank; i++) {
- size_t idx = wgOffsets.size() - rank + i;
- Value add = rewriter.createOrFold<index::AddOp>(
- loc, tdescOffsets[i],
- getValueOrCreateConstantIndexOp(rewriter, loc, wgOffsets[idx]));
- sgOffsets.push_back(add);
- }
+ SmallVector<Value> newCreateNdOps(count);
+ std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
+ return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
+ op.getSource(), op.getMixedSizes(),
+ op.getMixedStrides());
+ });
- auto newOp = xegpu::CreateNdDescOp::create(
- rewriter, loc, newTdescTy, op.getSource(), sgOffsets,
- op.getMixedSizes(), op.getMixedStrides());
- newCreateNdOps.push_back(newOp);
- }
rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
return success();
}
@@ -205,12 +257,10 @@ struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
LogicalResult
matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- SmallVector<Value> newLoadOps;
-
- int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
- if ((offsetSize != 0) || op.getConstOffsetsAttr())
+ if (!op.getMixedOffsets().empty())
return failure();
+ SmallVector<Value> newLoadOps;
for (auto src : adaptor.getTensorDesc()) {
xegpu::TensorDescType tdescTy =
dyn_cast<xegpu::TensorDescType>(src.getType());
@@ -233,9 +283,7 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
LogicalResult
matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
-
- int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
- if ((offsetSize != 0) || op.getConstOffsetsAttr())
+ if (!op.getMixedOffsets().empty())
return failure();
for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
@@ -247,6 +295,84 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
}
};
+// This pattern transforms the LoadNdOp with explicit offsets to load
+// subgroup data.
+struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
+ using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ SmallVector<SmallVector<OpFoldResult>> offsetsList;
+ if (failed(genOffsetsList(rewriter, op, offsetsList)))
+ return failure();
+
+ SmallVector<Value> newOps;
+ for (auto [tdesc, offsets] :
+ llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
+ auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
+ VectorType newResTy =
+ VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
+ auto newOp = xegpu::LoadNdOp::create(
+ rewriter, op.getLoc(), newResTy, tdesc, offsets,
+ /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
+ newOps.push_back(newOp);
+ }
+ rewriter.replaceOpWithMultiple(op, {newOps});
+
+ return success();
+ }
+};
+
+// This pattern transforms the StoreNdOp with explicit offsets to store
+// subgroup data.
+struct WgToSgStoreNdOpWithOffset
+ : public OpConversionPattern<xegpu::StoreNdOp> {
+ using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<SmallVector<OpFoldResult>> offsetsList;
+ if (failed(genOffsetsList(rewriter, op, offsetsList)))
+ return failure();
+
+ for (auto [v, tdesc, offsets] :
+ llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
+ rewriter.create<xegpu::StoreNdOp>(op.getLoc(), v, tdesc, offsets,
+ op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr());
+ }
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+};
+
+// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
+// subgroup data.
+struct WgToSgPrefetchNdOpWithOffset
+ : public OpConversionPattern<xegpu::PrefetchNdOp> {
+ using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<SmallVector<OpFoldResult>> offsetsList;
+ if (failed(genOffsetsList(rewriter, op, offsetsList)))
+ return failure();
+
+ for (auto [tdesc, offsets] :
+ llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
+ rewriter.create<xegpu::PrefetchNdOp>(
+ op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr());
+ }
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+};
+
/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
/// offsets of the new subgroup src tensor descriptors.
@@ -280,7 +406,7 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
if (resultTy.getRank() != 2)
return failure();
- auto originalLayout = xegpu::getLayoutAttr(op.getResult());
+ auto originalLayout = xegpu::getDistributeLayoutAttr(op.getResult());
if (!originalLayout)
return failure();
@@ -303,8 +429,8 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
resultTy.getElementType());
tmpC = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
- xegpu::setLayoutAttr(cast<OpResult>(tmpC),
- originalLayout.dropSgLayoutAndData());
+ xegpu::setDistributeLayoutAttr(cast<OpResult>(tmpC),
+ originalLayout.dropSgLayoutAndData());
newDpasOps.push_back(tmpC);
}
@@ -344,8 +470,9 @@ struct WgToSgVectorBroadcastOp
VectorType resultType = op.getResult().getType();
ArrayRef<int64_t> wgShape = resultType.getShape();
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
- if (!layout || !layout.getSgLayout())
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
return failure();
// TODO: Currently only supports cases where the source and result ranks
@@ -360,10 +487,8 @@ struct WgToSgVectorBroadcastOp
VectorType::get(sgShape, resultType.getElementType());
// Check if the output layout is distributable
- SmallVector<int64_t> sgLayout;
- if (auto sgLayoutAttr = layout.getSgLayout())
- sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
- else
+ SmallVector<int64_t> sgLayout = layout.getSgLayoutAsInt();
+ if (sgLayout.empty())
return failure();
if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
@@ -382,8 +507,8 @@ struct WgToSgVectorBroadcastOp
for (auto operand : adaptor.getOperands().front()) {
auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
newResultType, operand);
- xegpu::setLayoutAttr(newBroadcast->getResult(0),
- layout.dropSgLayoutAndData());
+ xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
+ layout.dropSgLayoutAndData());
newBroadcastOps.push_back(newBroadcast.getResult());
}
@@ -409,8 +534,9 @@ struct WgToSgElementwiseOp : public ConversionPattern {
ArrayRef<int64_t> wgShape = resultType.getShape();
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
- if (!layout || !layout.getSgLayout())
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op->getResult(0));
+ if (!layout || !layout.isForWorkgroup())
return failure();
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
@@ -475,8 +601,8 @@ struct WgToSgElementwiseOp : public ConversionPattern {
// is lowered to:
// #a = #xegpu.layout<inst_data = [16, 16]>
// #b = #xegpu.layout<inst_data = [8, 16]>
-// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, matrix_desc<32x64xf32>
-// %d = load_matrix %slm <{layout_result_0 = #a}> : matrix_desc<32x64xf32> -> vector<16x32xf32>
+// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32>
+// %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32>
// xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32>
// clang-format on
struct WgToSgConvertLayoutOp
@@ -485,10 +611,12 @@ struct WgToSgConvertLayoutOp
LogicalResult
matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- xegpu::LayoutAttr input = op.getInputLayout();
- xegpu::LayoutAttr target = op.getTargetLayout();
+ // TODO: currently, we only support LayoutAttr
+ auto input = dyn_cast<xegpu::LayoutAttr>(op.getInputLayout());
+ auto target = dyn_cast<xegpu::LayoutAttr>(op.getTargetLayout());
- if (!input || !target || !input.isWgLayout() || !target.isWgLayout())
+ if (!input || !target || !input.isForWorkgroup() ||
+ !target.isForWorkgroup())
return rewriter.notifyMatchFailure(
op, "Input and target layouts must have subgroup layout");
@@ -598,16 +726,213 @@ struct UnrealizedConversionCastOpPattern
}
};
+// This pattern distributes arith.constant op into subgroup-level constants
+struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
+ using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
+ auto vecType = dyn_cast<VectorType>(op.getType());
+ if (!vecAttr || !vecAttr.isSplat() || !vecType)
+ return failure();
+
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ ArrayRef<int64_t> wgShape = vecType.getShape();
+ SmallVector<int64_t> sgShape;
+ int count;
+ std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+
+ // Current limitation: constant of vector with single value.
+ // TODO: support more complex cases, e.g., vector with multiple values.
+ Attribute singleVal = vecAttr.getSplatValue<Attribute>();
+
+ auto newType = VectorType::get(sgShape, vecType.getElementType());
+ auto sgAttr = DenseElementsAttr::get(newType, singleVal);
+ auto cstOp =
+ arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
+ if (auto newLayout = layout.dropSgLayoutAndData())
+ xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
+ SmallVector<Value> newConsts(count, cstOp);
+
+ rewriter.replaceOpWithMultiple(op, {newConsts});
+ return success();
+ }
+};
+
+// This pattern transforms the LoadGatherOp with explicit offsets to load
+// subgroup data
+struct WgToSgLoadGatherOpWithOffset
+ : public OpConversionPattern<xegpu::LoadGatherOp> {
+ using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ if (!op.getOffsets())
+ return failure();
+
+ Location loc = op.getLoc();
+ VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
+ if (!resultType)
+ return failure();
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+
+ // The offsets need to be distributed
+ auto offsetsVecType =
+ dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
+ auto maskVecType =
+ dyn_cast<VectorType>(adaptor.getMask().front().getType());
+ if (!offsetsVecType || !maskVecType ||
+ offsetsVecType.getShape() != maskVecType.getShape()) {
+ return rewriter.notifyMatchFailure(op,
+ "offsets have not been distributed");
+ }
+
+ SmallVector<Value> newLoadOps;
+ auto chunkSizeAttr =
+ rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
+ VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
+ for (auto [offsets, mask] :
+ llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
+ auto newLoadOp = rewriter.create<xegpu::LoadGatherOp>(
+ loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
+ op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+ xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0),
+ layout.dropSgLayoutAndData());
+ newLoadOps.push_back(newLoadOp);
+ }
+ rewriter.replaceOpWithMultiple(op, {newLoadOps});
+ return success();
+ }
+};
+
+// This pattern transforms the StoreScatterOp with explicit offsets to store
+// subgroup data
+struct WgToSgStoreScatterOpWithOffset
+ : public OpConversionPattern<xegpu::StoreScatterOp> {
+ using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ if (!op.getOffsets())
+ return failure();
+
+ Location loc = op.getLoc();
+ VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
+ if (!valueType)
+ return failure();
+
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getValue());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ // The offsets need to be distributed
+ auto offsetsVecType =
+ dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
+ auto maskVecType =
+ dyn_cast<VectorType>(adaptor.getMask().front().getType());
+ if (!offsetsVecType || !maskVecType ||
+ offsetsVecType.getShape() != maskVecType.getShape()) {
+ return rewriter.notifyMatchFailure(op,
+ "offsets have not been distributed");
+ }
+
+ auto chunkSizeOpt = op.getChunkSize();
+ int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
+ auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
+ for (auto [val, offs, mask] : llvm::zip(
+ adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
+ rewriter.create<xegpu::StoreScatterOp>(
+ loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
+ // Update the layout attribute to drop sg_layout and sg_data.
+ if (auto newLayout = layout.dropSgLayoutAndData())
+ op->setAttr("layout", newLayout);
+ }
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
+ using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ SmallVector<SmallVector<OpFoldResult>> offsetsList;
+ if (failed(genOffsetsList(rewriter, op, offsetsList)))
+ return failure();
+
+ ArrayRef<int64_t> wgShape = op.getDataShape();
+ VectorType valueTy = op.getRes().getType();
+ Type elemTy = valueTy.getElementType();
+
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ VectorType newResTy = VectorType::get(sgShape, elemTy);
+ SmallVector<Value> newOps;
+ for (auto offsets : offsetsList) {
+ auto newOp = rewriter.create<xegpu::LoadMatrixOp>(
+ op.getLoc(), newResTy, op.getMemDesc(), offsets,
+ layout.dropSgLayoutAndData());
+ newOps.push_back(newOp);
+ }
+ rewriter.replaceOpWithMultiple(op, {newOps});
+
+ return success();
+ }
+};
+
+struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
+ using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ SmallVector<SmallVector<OpFoldResult>> offsetsList;
+ if (failed(genOffsetsList(rewriter, op, offsetsList)))
+ return failure();
+
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
+ rewriter.create<xegpu::StoreMatrixOp>(op.getLoc(), v, op.getMemDesc(),
+ offsets,
+ layout.dropSgLayoutAndData());
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
} // namespace
namespace mlir {
namespace xegpu {
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
- patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
- WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
- UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
- WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
- patterns.getContext());
+ patterns
+ .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
+ WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
+ WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
+ WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
+ WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
+ WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
+ WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
+ WgToSgStoreMatrixOp>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -697,8 +1022,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return xegpu::TensorDescType();
};
- auto isLegal = [&](xegpu::LayoutAttr layout) -> bool {
- return !layout || !layout.isWgLayout();
+ auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool {
+ return !layout || !layout.isForWorkgroup();
};
target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
@@ -710,13 +1035,46 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
});
target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
- auto layout = xegpu::getLayoutAttr(op.getResult());
+ auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
return isLegal(layout);
});
+ target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
+ [=](xegpu::LoadMatrixOp op) -> bool {
+ return isLegal(op.getLayoutAttr());
+ });
+
+ target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
+ [=](xegpu::StoreMatrixOp op) -> bool {
+ return isLegal(op.getLayoutAttr());
+ });
+
+ target.addDynamicallyLegalOp<arith::ConstantOp>(
+ [=](arith::ConstantOp op) -> bool {
+ auto vecType = dyn_cast<VectorType>(op.getType());
+ if (!vecType)
+ return true;
+ return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
+ });
+
+ target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
+ [=](xegpu::LoadGatherOp op) -> bool {
+ auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
+ return isLegal(layout);
+ });
+
+ target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
+ [=](xegpu::StoreScatterOp op) -> bool {
+ // Check if the layout attribute is present on the result.
+ auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout");
+ if (!layout)
+ return true;
+ return isLegal(layout);
+ });
+
target.addDynamicallyLegalOp<vector::BroadcastOp>(
[=](vector::BroadcastOp op) -> bool {
- return isLegal(xegpu::getLayoutAttr(op.getResult()));
+ return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
});
target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
@@ -744,7 +1102,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
}
}
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op->getResult(0));
return isLegal(layout);
});
diff --git a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
index 98e84a4..d9bf4a1 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
@@ -7,5 +7,7 @@ add_mlir_dialect_library(MLIRXeGPUUtils
LINK_LIBS PUBLIC
MLIRIR
MLIRSCFTransforms
+ MLIRGPUDialect
+ MLIRXeVMDialect
MLIRXeGPUDialect
)
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 2cf21fb..cac1ffe 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -11,6 +11,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
@@ -38,7 +41,7 @@ mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) {
auto layout = llvm::dyn_cast_if_present<LayoutAttr>(tdescTy.getLayout());
// It only works for subgroup level layout, which only has lane_layout
// and lane_data, and is to distribute a SIMD code into SIMT code.
- if (!layout || !layout.isSgLayout())
+ if (!layout || !layout.isForSubgroup())
return failure();
SmallVector<int64_t> laneData(layout.getLaneData().asArrayRef());
@@ -111,7 +114,7 @@ std::string xegpu::getLayoutName(const OpResult result) {
return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
}
-xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {
+xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
if (!value)
return nullptr;
@@ -129,11 +132,11 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {
// for LoadNdOp, the layout is stored in the tensor descriptor
if (auto loadNd = dyn_cast<xegpu::LoadNdOp>(defOp))
- return getLayoutAttr(loadNd.getTensorDesc());
+ return getDistributeLayoutAttr(loadNd.getTensorDesc());
std::string layoutName = getLayoutName(result);
if (defOp->hasAttr(layoutName))
- return defOp->getAttrOfType<xegpu::LayoutAttr>(layoutName);
+ return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
}
if (auto arg = dyn_cast<BlockArgument>(value)) {
@@ -141,49 +144,51 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {
if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
OpOperand *tiedInit = loop.getTiedLoopInit(arg);
if (tiedInit)
- return getLayoutAttr(tiedInit->get());
+ return getDistributeLayoutAttr(tiedInit->get());
}
}
return nullptr;
}
-xegpu::LayoutAttr xegpu::getLayoutAttr(const OpOperand &opr) {
+xegpu::DistributeLayoutAttr
+xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
Operation *op = opr.getOwner();
std::string layoutName = xegpu::getLayoutName(opr);
if (op->hasAttr(layoutName))
- return op->getAttrOfType<xegpu::LayoutAttr>(layoutName);
- return getLayoutAttr(opr.get());
+ return op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+ return getDistributeLayoutAttr(opr.get());
}
template <typename T, typename>
-void xegpu::setLayoutAttr(const T &operandOrResult, const LayoutAttr layout) {
+void xegpu::setDistributeLayoutAttr(const T &operandOrResult,
+ const DistributeLayoutAttr layout) {
Operation *owner = operandOrResult.getOwner();
std::string name = xegpu::getLayoutName(operandOrResult);
- if (layout && !owner->hasAttrOfType<LayoutAttr>(name))
+ if (layout && !owner->hasAttrOfType<DistributeLayoutAttr>(name))
owner->setAttr(name, layout);
}
// Explicit instantiation for OpResult
-template void
-xegpu::setLayoutAttr<mlir::OpResult>(const mlir::OpResult &result,
- const mlir::xegpu::LayoutAttr layout);
+template void xegpu::setDistributeLayoutAttr<mlir::OpResult>(
+ const mlir::OpResult &result,
+ const mlir::xegpu::DistributeLayoutAttr layout);
// Explicit instantiation for OpOperand
-template void
-xegpu::setLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand,
- const mlir::xegpu::LayoutAttr layout);
+template void xegpu::setDistributeLayoutAttr<mlir::OpOperand>(
+ const mlir::OpOperand &operand,
+ const mlir::xegpu::DistributeLayoutAttr layout);
-void xegpu::setLayoutAttrs(Operation *op,
- function_ref<LayoutAttr(Value)> getLayoutImpl) {
+void xegpu::setDistributeLayoutAttrs(
+ Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) {
op->walk([&](Operation *nestOp) {
for (OpOperand &opr : nestOp->getOpOperands()) {
auto layout = getLayoutImpl(opr.get());
- setLayoutAttr(opr, layout);
+ setDistributeLayoutAttr(opr, layout);
}
for (OpResult result : nestOp->getOpResults()) {
auto layout = getLayoutImpl(result);
- setLayoutAttr(result, layout);
+ setDistributeLayoutAttr(result, layout);
}
});
}
@@ -192,7 +197,7 @@ template <typename T, typename>
void xegpu::removeLayoutAttr(const T &operandOrResult) {
Operation *owner = operandOrResult.getOwner();
std::string name = xegpu::getLayoutName(operandOrResult);
- if (owner->hasAttrOfType<LayoutAttr>(name))
+ if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
owner->removeAttr(name);
}
@@ -303,7 +308,8 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
if (!inputTy || !resultTy)
return WalkResult::skip();
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(input);
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(input);
if (!layout)
return WalkResult::skip();
@@ -341,7 +347,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
}
{ // perform the conversion from RankedTensorType to VectorType based on the
- // LayoutAttr
+ // DistributeLayoutAttr
// Handle the UnrealizedConversionCastOp introduced by the first step.
// For vector->RankedTensorType, it will simply forward the inputs.
@@ -404,3 +410,49 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
(void)mlir::applyPartialConversion(op, target, std::move(patterns));
}
}
+
+std::optional<std::string> xegpu::getChipStr(Operation *op) {
+ auto gpuModuleOp = op->getParentOfType<gpu::GPUModuleOp>();
+
+ if (!gpuModuleOp)
+ return std::nullopt;
+
+ auto targetAttrs = gpuModuleOp.getTargets();
+ if (targetAttrs) {
+ for (auto &attr : *targetAttrs) {
+ auto xevmAttr = llvm::dyn_cast<xevm::XeVMTargetAttr>(attr);
+ if (xevmAttr)
+ return xevmAttr.getChip().str();
+ }
+ }
+
+ return std::nullopt;
+}
+
+/// Generates element-wise addition ops of two arrays with automatic alignment.
+/// When the input arrays have different sizes, the shorter array is
+/// right-aligned with the longer array, and the unmatched leading elements from
+/// the longer array are preserved unchanged. This is commonly used for offset
+/// computation where higher-dimensional offsets need to be added to
+/// lower-dimensional adjustments.
+///
+/// Example:
+/// lhs = [l1, l2, l3], rhs = [r1, r2]
+/// Result: [11, l2+r1, l3+r2]
+SmallVector<OpFoldResult>
+xegpu::addWithRightAligned(OpBuilder &builder, Location loc,
+ ArrayRef<OpFoldResult> lhs,
+ ArrayRef<OpFoldResult> rhs) {
+ // ensure a is longer than b
+ ArrayRef<OpFoldResult> a = lhs.size() >= rhs.size() ? lhs : rhs;
+ ArrayRef<OpFoldResult> b = lhs.size() >= rhs.size() ? rhs : lhs;
+ SmallVector<OpFoldResult> results(a.take_front(a.size() - b.size()));
+ a = a.slice(a.size() - b.size());
+ for (auto [l, r] : llvm::zip(a, b)) {
+ auto lval = getValueOrCreateConstantIndexOp(builder, loc, l);
+ auto rval = getValueOrCreateConstantIndexOp(builder, loc, r);
+ results.push_back(builder.createOrFold<index::AddOp>(loc, lval, rval));
+ }
+ return results;
+ return {};
+}
diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
index f704fbf..52162a4 100644
--- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
+++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
@@ -106,7 +106,7 @@ void ExecutionEngine::dumpToObjectFile(StringRef filename) {
}
// Compilation is lazy and it doesn't populate object cache unless requested.
// In case object dump is requested before cache is populated, we need to
- // force compilation manually.
+ // force compilation manually.
if (cache->isEmpty()) {
for (std::string &functionName : functionNames) {
auto result = lookupPacked(functionName);
@@ -400,13 +400,6 @@ ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options,
return symbolMap;
};
engine->registerSymbols(runtimeSymbolMap);
-
- // Execute the global constructors from the module being processed.
- // TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a
- // crash for AArch64 see related issue #71963.
- if (!engine->jit->getTargetTriple().isAArch64())
- cantFail(engine->jit->initialize(engine->jit->getMainJITDylib()));
-
return std::move(engine);
}
@@ -442,6 +435,7 @@ Expected<void *> ExecutionEngine::lookup(StringRef name) const {
Error ExecutionEngine::invokePacked(StringRef name,
MutableArrayRef<void *> args) {
+ initialize();
auto expectedFPtr = lookupPacked(name);
if (!expectedFPtr)
return expectedFPtr.takeError();
@@ -451,3 +445,13 @@ Error ExecutionEngine::invokePacked(StringRef name,
return Error::success();
}
+
+void ExecutionEngine::initialize() {
+ if (isInitialized)
+ return;
+ // TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a
+ // crash for AArch64 see related issue #71963.
+ if (!jit->getTargetTriple().isAArch64())
+ cantFail(jit->initialize(jit->getMainJITDylib()));
+ isInitialized = true;
+}
diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp
index 2107df3..0ada4cc 100644
--- a/mlir/lib/ExecutionEngine/JitRunner.cpp
+++ b/mlir/lib/ExecutionEngine/JitRunner.cpp
@@ -202,6 +202,8 @@ compileAndExecute(Options &options, Operation *module, StringRef entryPoint,
auto engine = std::move(*expectedEngine);
+ engine->initialize();
+
auto expectedFPtr = engine->lookupPacked(entryPoint);
if (!expectedFPtr)
return expectedFPtr.takeError();
diff --git a/mlir/lib/ExecutionEngine/VulkanRuntime.cpp b/mlir/lib/ExecutionEngine/VulkanRuntime.cpp
index 9f653b2..9452a56 100644
--- a/mlir/lib/ExecutionEngine/VulkanRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/VulkanRuntime.cpp
@@ -20,7 +20,7 @@
#include <iomanip>
#include <iostream>
-inline void emitVulkanError(const char *api, VkResult error) {
+static inline void emitVulkanError(const char *api, VkResult error) {
std::cerr << " failed with error code " << error << " when executing " << api;
}
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index 57825d9..27b47e2 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -251,6 +251,16 @@ bool Block::mightHaveTerminator() {
return !empty() && back().mightHaveTrait<OpTrait::IsTerminator>();
}
+iterator_range<Block::iterator> Block::without_terminator_impl() {
+ // Note: When the op is unregistered, we do not know for sure if the last
+ // op is a terminator. In that case, we include it in `without_terminator`,
+ // but that decision is somewhat arbitrary.
+ if (!back().hasTrait<OpTrait::IsTerminator>())
+ return {begin(), end()};
+ auto endIt = --end();
+ return {begin(), endIt};
+}
+
// Indexed successor access.
unsigned Block::getNumSuccessors() {
return empty() ? 0 : back().getNumSuccessors();
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index fd898b7..6f880f8 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -19,6 +19,7 @@
#include "mlir/IR/Types.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/Endian.h"
#include <optional>
@@ -1119,9 +1120,8 @@ static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
auto denseEltBitWidth = getDenseElementBitWidth(type);
auto dataSize = static_cast<size_t>(dataEltSize * CHAR_BIT);
if (denseEltBitWidth != dataSize) {
- LLVM_DEBUG(llvm::dbgs() << "expected dense element bit width "
- << denseEltBitWidth << " to match data size "
- << dataSize << " for type " << type << "\n");
+ LDBG() << "expected dense element bit width " << denseEltBitWidth
+ << " to match data size " << dataSize << " for type " << type;
return false;
}
@@ -1129,9 +1129,7 @@ static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
if (!isInt) {
bool valid = llvm::isa<FloatType>(type);
if (!valid)
- LLVM_DEBUG(llvm::dbgs()
- << "expected float type when isInt is false, but found "
- << type << "\n");
+ LDBG() << "expected float type when isInt is false, but found " << type;
return valid;
}
if (type.isIndex())
@@ -1139,9 +1137,7 @@ static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
auto intType = llvm::dyn_cast<IntegerType>(type);
if (!intType) {
- LLVM_DEBUG(llvm::dbgs()
- << "expected integer type when isInt is true, but found " << type
- << "\n");
+ LDBG() << "expected integer type when isInt is true, but found " << type;
return false;
}
@@ -1151,8 +1147,7 @@ static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
bool valid = intType.isSigned() == isSigned;
if (!valid)
- LLVM_DEBUG(llvm::dbgs() << "expected signedness " << isSigned
- << " to match type " << type << "\n");
+ LDBG() << "expected signedness " << isSigned << " to match type " << type;
return valid;
}
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 3ef69ce..d95bdc9 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -33,6 +33,7 @@ add_mlir_library(MLIRIR
PatternMatch.cpp
Region.cpp
RegionKindInterface.cpp
+ Remarks.cpp
SymbolTable.cpp
TensorEncoding.cpp
Types.cpp
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index f84fe89..952619b 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -19,7 +19,7 @@
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/Twine.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/Regex.h"
#include <memory>
@@ -104,14 +104,8 @@ void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
auto it = registeredInterfaces.try_emplace(interface->getID(),
std::move(interface));
- (void)it;
- LLVM_DEBUG({
- if (!it.second) {
- llvm::dbgs() << "[" DEBUG_TYPE
- "] repeated interface registration for dialect "
- << getNamespace();
- }
- });
+ if (!it.second)
+ LDBG() << "repeated interface registration for dialect " << getNamespace();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp
index 23e70c6..662681e 100644
--- a/mlir/lib/IR/Location.cpp
+++ b/mlir/lib/IR/Location.cpp
@@ -52,8 +52,8 @@ struct FileLineColRangeAttrStorage final
FileLineColRangeAttrStorage::totalSizeToAlloc<unsigned>(locEnc - 1);
auto *rawMem =
allocator.allocate(byteSize, alignof(FileLineColRangeAttrStorage));
- auto *result = ::new (rawMem) FileLineColRangeAttrStorage(
- std::move(std::get<0>(tblgenKey)), locEnc - 1);
+ auto *result = ::new (rawMem)
+ FileLineColRangeAttrStorage(std::get<0>(tblgenKey), locEnc - 1);
if (numInArray > 0) {
ArrayRef<unsigned> elements = std::get<1>(tblgenKey);
result->startLine = elements[0];
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 2d5381d..1fa04ed 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -25,12 +25,13 @@
#include "mlir/IR/Location.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/Remarks.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Compiler.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/Mutex.h"
#include "llvm/Support/RWMutex.h"
@@ -134,6 +135,11 @@ public:
DiagnosticEngine diagEngine;
//===--------------------------------------------------------------------===//
+ // Remark
+ //===--------------------------------------------------------------------===//
+ std::unique_ptr<remark::detail::RemarkEngine> remarkEngine;
+
+ //===--------------------------------------------------------------------===//
// Options
//===--------------------------------------------------------------------===//
@@ -388,6 +394,19 @@ bool MLIRContext::hasActionHandler() { return (bool)getImpl().actionHandler; }
DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
//===----------------------------------------------------------------------===//
+// Remark Handlers
+//===----------------------------------------------------------------------===//
+
+void MLIRContext::setRemarkEngine(
+ std::unique_ptr<remark::detail::RemarkEngine> engine) {
+ getImpl().remarkEngine = std::move(engine);
+}
+
+remark::detail::RemarkEngine *MLIRContext::getRemarkEngine() {
+ return getImpl().remarkEngine.get();
+}
+
+//===----------------------------------------------------------------------===//
// Dialect and Operation Registration
//===----------------------------------------------------------------------===//
@@ -455,8 +474,7 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
auto dialectIt = impl.loadedDialects.try_emplace(dialectNamespace, nullptr);
if (dialectIt.second) {
- LLVM_DEBUG(llvm::dbgs()
- << "Load new dialect in Context " << dialectNamespace << "\n");
+ LDBG() << "Load new dialect in Context " << dialectNamespace;
#ifndef NDEBUG
if (impl.multiThreadedExecutionContext != 0)
llvm::report_fatal_error(
@@ -525,8 +543,7 @@ DynamicDialect *MLIRContext::getOrLoadDynamicDialect(
"' has already been registered");
}
- LLVM_DEBUG(llvm::dbgs() << "Load new dynamic dialect in Context "
- << dialectNamespace << "\n");
+ LDBG() << "Load new dynamic dialect in Context " << dialectNamespace;
#ifndef NDEBUG
if (impl.multiThreadedExecutionContext != 0)
llvm::report_fatal_error(
@@ -1192,11 +1209,10 @@ willBeValidAffineMap(unsigned dimCount, unsigned symbolCount,
getMaxDimAndSymbol(ArrayRef<ArrayRef<AffineExpr>>(results), maxDimPosition,
maxSymbolPosition);
if ((maxDimPosition >= dimCount) || (maxSymbolPosition >= symbolCount)) {
- LLVM_DEBUG(
- llvm::dbgs()
+ LDBG()
<< "maximum dimensional identifier position in result expression must "
"be less than `dimCount` and maximum symbolic identifier position "
- "in result expression must be less than `symbolCount`\n");
+ "in result expression must be less than `symbolCount`";
return false;
}
return true;
diff --git a/mlir/lib/IR/ODSSupport.cpp b/mlir/lib/IR/ODSSupport.cpp
index 69b4a56..f2665d2 100644
--- a/mlir/lib/IR/ODSSupport.cpp
+++ b/mlir/lib/IR/ODSSupport.cpp
@@ -112,7 +112,7 @@ Attribute mlir::convertToAttribute(MLIRContext *ctx, bool storage) {
}
template <typename DenseArrayTy, typename T>
-LogicalResult
+static LogicalResult
convertDenseArrayFromAttr(MutableArrayRef<T> storage, Attribute attr,
function_ref<InFlightDiagnostic()> emitError,
StringRef denseArrayTyStr) {
@@ -143,7 +143,7 @@ mlir::convertFromAttribute(MutableArrayRef<int32_t> storage, Attribute attr,
}
template <typename DenseArrayTy, typename T>
-LogicalResult
+static LogicalResult
convertDenseArrayFromAttr(SmallVectorImpl<T> &storage, Attribute attr,
function_ref<InFlightDiagnostic()> emitError,
StringRef denseArrayTyStr) {
diff --git a/mlir/lib/IR/Remarks.cpp b/mlir/lib/IR/Remarks.cpp
new file mode 100644
index 0000000..78c9644
--- /dev/null
+++ b/mlir/lib/IR/Remarks.cpp
@@ -0,0 +1,279 @@
+//===- Remarks.cpp - MLIR Remarks -----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Remarks.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Value.h"
+
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringRef.h"
+
+using namespace mlir::remark::detail;
+
+//------------------------------------------------------------------------------
+// Remark
+//------------------------------------------------------------------------------
+
+Remark::Arg::Arg(llvm::StringRef k, Value v) : key(k) {
+ llvm::raw_string_ostream os(val);
+ os << v;
+}
+
+Remark::Arg::Arg(llvm::StringRef k, Type t) : key(k) {
+ llvm::raw_string_ostream os(val);
+ os << t;
+}
+
+void Remark::insert(llvm::StringRef s) { args.emplace_back(s); }
+void Remark::insert(Arg a) { args.push_back(std::move(a)); }
+
+// Simple helper to print key=val list (sorted).
+static void printArgs(llvm::raw_ostream &os, llvm::ArrayRef<Remark::Arg> args) {
+ if (args.empty())
+ return;
+
+ llvm::SmallVector<Remark::Arg, 8> sorted(args.begin(), args.end());
+ llvm::sort(sorted, [](const Remark::Arg &a, const Remark::Arg &b) {
+ return a.key < b.key;
+ });
+
+ for (size_t i = 0; i < sorted.size(); ++i) {
+ const auto &a = sorted[i];
+ os << a.key << "=";
+
+ llvm::StringRef val(a.val);
+ bool needsQuote = val.contains(' ') || val.contains(',') ||
+ val.contains('{') || val.contains('}');
+ if (needsQuote)
+ os << '"' << val << '"';
+ else
+ os << val;
+
+ if (i + 1 < sorted.size())
+ os << ", ";
+ }
+}
+
+/// Print the remark to the given output stream.
+/// Example output:
+// clang-format off
+/// [Missed] Category: Loop | Pass:Unroller | Function=main | Reason="tripCount=4 < threshold=256"
+/// [Failure] LoopOptimizer | Reason="failed due to unsupported pattern"
+// clang-format on
+void Remark::print(llvm::raw_ostream &os, bool printLocation) const {
+ // Header: [Type] pass:remarkName
+ StringRef type = getRemarkTypeString();
+ StringRef categoryName = getFullCategoryName();
+ StringRef name = remarkName;
+
+ os << '[' << type << "] ";
+ os << name << " | ";
+ if (!categoryName.empty())
+ os << "Category:" << categoryName << " | ";
+ if (!functionName.empty())
+ os << "Function=" << getFunction() << " | ";
+
+ if (printLocation) {
+ if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(getLocation()))
+ os << " @" << flc.getFilename() << ":" << flc.getLine() << ":"
+ << flc.getColumn();
+ }
+
+ printArgs(os, getArgs());
+}
+
+std::string Remark::getMsg() const {
+ std::string s;
+ llvm::raw_string_ostream os(s);
+ print(os);
+ os.flush();
+ return s;
+}
+
+llvm::StringRef Remark::getRemarkTypeString() const {
+ switch (remarkKind) {
+ case RemarkKind::RemarkUnknown:
+ return "Unknown";
+ case RemarkKind::RemarkPassed:
+ return "Passed";
+ case RemarkKind::RemarkMissed:
+ return "Missed";
+ case RemarkKind::RemarkFailure:
+ return "Failure";
+ case RemarkKind::RemarkAnalysis:
+ return "Analysis";
+ }
+ llvm_unreachable("Unknown remark kind");
+}
+
+llvm::remarks::Type Remark::getRemarkType() const {
+ switch (remarkKind) {
+ case RemarkKind::RemarkUnknown:
+ return llvm::remarks::Type::Unknown;
+ case RemarkKind::RemarkPassed:
+ return llvm::remarks::Type::Passed;
+ case RemarkKind::RemarkMissed:
+ return llvm::remarks::Type::Missed;
+ case RemarkKind::RemarkFailure:
+ return llvm::remarks::Type::Failure;
+ case RemarkKind::RemarkAnalysis:
+ return llvm::remarks::Type::Analysis;
+ }
+ llvm_unreachable("Unknown remark kind");
+}
+
+llvm::remarks::Remark Remark::generateRemark() const {
+ auto locLambda = [&]() -> llvm::remarks::RemarkLocation {
+ if (auto flc = dyn_cast<FileLineColLoc>(getLocation()))
+ return {flc.getFilename(), flc.getLine(), flc.getColumn()};
+ return {"<unknown file>", 0, 0};
+ };
+
+ llvm::remarks::Remark r; // The result.
+ r.RemarkType = getRemarkType();
+ r.RemarkName = getRemarkName();
+ // MLIR does not use passes; instead, it has categories and sub-categories.
+ r.PassName = getFullCategoryName();
+ r.FunctionName = getFunction();
+ r.Loc = locLambda();
+ for (const Remark::Arg &arg : getArgs()) {
+ r.Args.emplace_back();
+ r.Args.back().Key = arg.key;
+ r.Args.back().Val = arg.val;
+ }
+ return r;
+}
+
+//===----------------------------------------------------------------------===//
+// InFlightRemark
+//===----------------------------------------------------------------------===//
+
+InFlightRemark::~InFlightRemark() {
+ if (remark && owner)
+ owner->report(std::move(*remark));
+ owner = nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// Remark Engine
+//===----------------------------------------------------------------------===//
+
+template <typename RemarkT, typename... Args>
+InFlightRemark RemarkEngine::makeRemark(Args &&...args) {
+ static_assert(std::is_base_of_v<Remark, RemarkT>,
+ "RemarkT must derive from Remark");
+ return InFlightRemark(*this,
+ std::make_unique<RemarkT>(std::forward<Args>(args)...));
+}
+
+template <typename RemarkT>
+InFlightRemark
+RemarkEngine::emitIfEnabled(Location loc, RemarkOpts opts,
+ bool (RemarkEngine::*isEnabled)(StringRef) const) {
+ return (this->*isEnabled)(opts.categoryName) ? makeRemark<RemarkT>(loc, opts)
+ : InFlightRemark{};
+}
+
+bool RemarkEngine::isMissedOptRemarkEnabled(StringRef categoryName) const {
+ return missFilter && missFilter->match(categoryName);
+}
+
+bool RemarkEngine::isPassedOptRemarkEnabled(StringRef categoryName) const {
+ return passedFilter && passedFilter->match(categoryName);
+}
+
+bool RemarkEngine::isAnalysisOptRemarkEnabled(StringRef categoryName) const {
+ return analysisFilter && analysisFilter->match(categoryName);
+}
+
+bool RemarkEngine::isFailedOptRemarkEnabled(StringRef categoryName) const {
+ return failedFilter && failedFilter->match(categoryName);
+}
+
+InFlightRemark RemarkEngine::emitOptimizationRemark(Location loc,
+ RemarkOpts opts) {
+ return emitIfEnabled<OptRemarkPass>(loc, opts,
+ &RemarkEngine::isPassedOptRemarkEnabled);
+}
+
+InFlightRemark RemarkEngine::emitOptimizationRemarkMiss(Location loc,
+ RemarkOpts opts) {
+ return emitIfEnabled<OptRemarkMissed>(
+ loc, opts, &RemarkEngine::isMissedOptRemarkEnabled);
+}
+
+InFlightRemark RemarkEngine::emitOptimizationRemarkFailure(Location loc,
+ RemarkOpts opts) {
+ return emitIfEnabled<OptRemarkFailure>(
+ loc, opts, &RemarkEngine::isFailedOptRemarkEnabled);
+}
+
+InFlightRemark RemarkEngine::emitOptimizationRemarkAnalysis(Location loc,
+ RemarkOpts opts) {
+ return emitIfEnabled<OptRemarkAnalysis>(
+ loc, opts, &RemarkEngine::isAnalysisOptRemarkEnabled);
+}
+
+//===----------------------------------------------------------------------===//
+// RemarkEngine
+//===----------------------------------------------------------------------===//
+
+void RemarkEngine::report(const Remark &&remark) {
+ // Stream the remark
+ if (remarkStreamer)
+ remarkStreamer->streamOptimizationRemark(remark);
+
+ // Print using MLIR's diagnostic
+ if (printAsEmitRemarks)
+ emitRemark(remark.getLocation(), remark.getMsg());
+}
+
+RemarkEngine::~RemarkEngine() {
+ if (remarkStreamer)
+ remarkStreamer->finalize();
+}
+
+llvm::LogicalResult
+RemarkEngine::initialize(std::unique_ptr<MLIRRemarkStreamerBase> streamer,
+ std::string *errMsg) {
+ // If you need to validate categories/filters, do so here and set errMsg.
+ remarkStreamer = std::move(streamer);
+ return success();
+}
+
+RemarkEngine::RemarkEngine(bool printAsEmitRemarks,
+ const RemarkCategories &cats)
+ : printAsEmitRemarks(printAsEmitRemarks) {
+ if (cats.passed)
+ passedFilter = llvm::Regex(cats.passed.value());
+ if (cats.missed)
+ missFilter = llvm::Regex(cats.missed.value());
+ if (cats.analysis)
+ analysisFilter = llvm::Regex(cats.analysis.value());
+ if (cats.failed)
+ failedFilter = llvm::Regex(cats.failed.value());
+}
+
+llvm::LogicalResult mlir::remark::enableOptimizationRemarks(
+ MLIRContext &ctx,
+ std::unique_ptr<remark::detail::MLIRRemarkStreamerBase> streamer,
+ const remark::RemarkCategories &cats, bool printAsEmitRemarks) {
+ auto engine =
+ std::make_unique<remark::detail::RemarkEngine>(printAsEmitRemarks, cats);
+
+ std::string errMsg;
+ if (failed(engine->initialize(std::move(streamer), &errMsg))) {
+ llvm::report_fatal_error(
+ llvm::Twine("Failed to initialize remark engine. Error: ") + errMsg);
+ }
+ ctx.setRemarkEngine(std::move(engine));
+
+ return success();
+}
diff --git a/mlir/lib/Interfaces/SideEffectInterfaces.cpp b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
index 266f6db..b5a6888 100644
--- a/mlir/lib/Interfaces/SideEffectInterfaces.cpp
+++ b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
@@ -304,6 +304,11 @@ template bool
mlir::hasEffect<BlockArgument, MemoryEffects::Write, MemoryEffects::Free>(
Operation *, BlockArgument);
+bool mlir::hasUnknownEffects(Operation *op) {
+ return !isa<MemoryEffectOpInterface>(op) &&
+ !op->hasTrait<OpTrait::HasRecursiveMemoryEffects>();
+}
+
bool mlir::wouldOpBeTriviallyDead(Operation *op) {
if (op->mightHaveTrait<OpTrait::IsTerminator>())
return false;
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 2f47939..af4ea5a 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -290,8 +290,7 @@ static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,
DivisionFixupFn fixup) {
const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
&rhsMax = rhs.umax();
-
- if (!rhsMin.isZero()) {
+ if (!rhsMin.isZero() && !rhsMax.isZero()) {
auto udiv = [&fixup](const APInt &a,
const APInt &b) -> std::optional<APInt> {
return fixup(a, b, a.udiv(b));
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index c9481fb..caa9091 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -813,7 +813,7 @@ FailureOr<bool> ValueBoundsConstraintSet::strongCompare(const Variable &lhs,
return false;
// Keep processing as long as the strong relation cannot be proven.
FailureOr<bool> ordered = cstr.strongComparePos(lhsPos, cmp, rhsPos);
- return failed(ordered) ? true : false;
+ return failed(ordered);
};
ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands);
diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp
index 98ed4cc..ec18c48 100644
--- a/mlir/lib/Query/Matcher/VariantValue.cpp
+++ b/mlir/lib/Query/Matcher/VariantValue.cpp
@@ -35,7 +35,7 @@ public:
std::optional<DynMatcher> getDynMatcher() const override {
std::vector<DynMatcher> dynMatchers;
- for (auto variantMatcher : args) {
+ for (const auto &variantMatcher : args) {
std::optional<DynMatcher> dynMatcher = variantMatcher.getDynMatcher();
if (dynMatcher)
dynMatchers.push_back(dynMatcher.value());
@@ -66,8 +66,7 @@ VariantMatcher VariantMatcher::SingleMatcher(DynMatcher matcher) {
VariantMatcher
VariantMatcher::VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp,
ArrayRef<VariantMatcher> args) {
- return VariantMatcher(
- std::make_shared<VariadicOpPayload>(varOp, std::move(args)));
+ return VariantMatcher(std::make_shared<VariadicOpPayload>(varOp, args));
}
std::optional<DynMatcher> VariantMatcher::MatcherOps::constructVariadicOperator(
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index 03e4177..375e820 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -141,7 +141,7 @@ LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
os << "\n";
for (auto &results : matches) {
os << "Match #" << ++matchCount << ":\n\n";
- for (auto op : results.matchedOps) {
+ for (Operation *op : results.matchedOps) {
if (op == results.rootOp) {
finder.printMatch(os, qs, op, "root");
} else {
diff --git a/mlir/lib/RegisterAllDialects.cpp b/mlir/lib/RegisterAllDialects.cpp
index 950b85e2..258fed1 100644
--- a/mlir/lib/RegisterAllDialects.cpp
+++ b/mlir/lib/RegisterAllDialects.cpp
@@ -102,6 +102,7 @@
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Target/LLVM/NVVM/Target.h"
#include "mlir/Target/LLVM/ROCDL/Target.h"
+#include "mlir/Target/LLVM/XeVM/Target.h"
#include "mlir/Target/SPIRV/Target.h"
/// Add all the MLIR dialects to the provided registry.
@@ -199,6 +200,7 @@ void mlir::registerAllDialects(DialectRegistry &registry) {
NVVM::registerNVVMTargetInterfaceExternalModels(registry);
ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
spirv::registerSPIRVTargetInterfaceExternalModels(registry);
+ xevm::registerXeVMTargetInterfaceExternalModels(registry);
}
/// Append all the MLIR dialects to the registry contained in the given context.
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index 8f7c67c..69a85db 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -28,6 +28,7 @@
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
+#include "mlir/Conversion/PtrToLLVM/PtrToLLVM.h"
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
@@ -58,6 +59,7 @@
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h"
/// This function may be called to register all MLIR dialect extensions with the
/// provided registry.
@@ -80,6 +82,7 @@ void mlir::registerAllExtensions(DialectRegistry &registry) {
registerConvertMemRefToEmitCInterface(registry);
registerConvertMemRefToLLVMInterface(registry);
registerConvertNVVMToLLVMInterface(registry);
+ ptr::registerConvertPtrToLLVMInterface(registry);
registerConvertOpenMPToLLVMInterface(registry);
registerConvertSCFToEmitCInterface(registry);
ub::registerConvertUBToLLVMInterface(registry);
diff --git a/mlir/lib/RegisterAllPasses.cpp b/mlir/lib/RegisterAllPasses.cpp
index 1ed3a37..c67b242 100644
--- a/mlir/lib/RegisterAllPasses.cpp
+++ b/mlir/lib/RegisterAllPasses.cpp
@@ -45,6 +45,7 @@
#include "mlir/Dialect/Transform/Transforms/Passes.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Target/LLVMIR/Transforms/Passes.h"
#include "mlir/Transforms/Passes.h"
// This function may be called to register the MLIR passes with the
@@ -74,6 +75,7 @@ void mlir::registerAllPasses() {
registerNVGPUPasses();
registerSparseTensorPasses();
LLVM::registerLLVMPasses();
+ LLVM::registerTargetLLVMIRTransformsPasses();
math::registerMathPasses();
memref::registerMemRefPasses();
shard::registerShardPasses();
diff --git a/mlir/lib/Remark/CMakeLists.txt b/mlir/lib/Remark/CMakeLists.txt
new file mode 100644
index 0000000..920a95d
--- /dev/null
+++ b/mlir/lib/Remark/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_library(MLIRRemarkStreamer
+ RemarkStreamer.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Remark
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+
+ LINK_COMPONENTS
+ Remarks
+ Core
+ BitstreamReader
+ )
diff --git a/mlir/lib/Remark/RemarkStreamer.cpp b/mlir/lib/Remark/RemarkStreamer.cpp
new file mode 100644
index 0000000..8e3544f
--- /dev/null
+++ b/mlir/lib/Remark/RemarkStreamer.cpp
@@ -0,0 +1,69 @@
+#include "mlir/Remark/RemarkStreamer.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Remarks.h"
+
+#include "llvm/Remarks/RemarkSerializer.h"
+#include "llvm/Remarks/RemarkStreamer.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+namespace mlir::remark::detail {
+
+FailureOr<std::unique_ptr<MLIRRemarkStreamerBase>>
+LLVMRemarkStreamer::createToFile(llvm::StringRef path,
+ llvm::remarks::Format fmt) {
+ std::error_code ec;
+ // Use error_code ctor; YAML is text. (Bitstream also works fine here.)
+ auto f =
+ std::make_unique<llvm::ToolOutputFile>(path, ec, llvm::sys::fs::OF_Text);
+ if (ec)
+ return failure();
+
+ auto serOr = llvm::remarks::createRemarkSerializer(
+ fmt, llvm::remarks::SerializerMode::Separate, f->os());
+ if (!serOr) {
+ llvm::consumeError(serOr.takeError());
+ return failure();
+ }
+
+ auto rs =
+ std::make_unique<llvm::remarks::RemarkStreamer>(std::move(*serOr), path);
+
+ auto impl = std::unique_ptr<LLVMRemarkStreamer>(new LLVMRemarkStreamer());
+ impl->remarkStreamer = std::move(rs);
+ impl->file = std::move(f);
+ return std::unique_ptr<MLIRRemarkStreamerBase>(std::move(impl));
+}
+
+void LLVMRemarkStreamer::streamOptimizationRemark(const Remark &remark) {
+ if (!remarkStreamer->matchesFilter(remark.getCategoryName()))
+ return;
+
+ // First, convert the diagnostic to a remark.
+ llvm::remarks::Remark r = remark.generateRemark();
+ // Then, emit the remark through the serializer.
+ remarkStreamer->getSerializer().emit(r);
+}
+
+LLVMRemarkStreamer::~LLVMRemarkStreamer() {
+ if (file && remarkStreamer)
+ file->keep();
+}
+} // namespace mlir::remark::detail
+
+namespace mlir::remark {
+LogicalResult enableOptimizationRemarksWithLLVMStreamer(
+ MLIRContext &ctx, StringRef path, llvm::remarks::Format fmt,
+ const RemarkCategories &cat, bool printAsEmitRemarks) {
+
+ FailureOr<std::unique_ptr<detail::MLIRRemarkStreamerBase>> sOr =
+ detail::LLVMRemarkStreamer::createToFile(path, fmt);
+ if (failed(sOr))
+ return failure();
+
+ return remark::enableOptimizationRemarks(ctx, std::move(*sOr), cat,
+ printAsEmitRemarks);
+}
+
+} // namespace mlir::remark
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index ae3f22d..5cbea5d 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -20,8 +20,10 @@
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/InterleavedRange.h"
#include <numeric>
#include <optional>
@@ -707,10 +709,8 @@ void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
}
// Print the index usage and ensure that we did not run out of index space.
- LLVM_DEBUG({
- llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
- << "(down from initial " << valueDefRanges.size() << ").\n";
- });
+ LDBG() << "Allocated " << allocatedIndices.size() << " indices "
+ << "(down from initial " << valueDefRanges.size() << ").";
assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
"Ran out of memory for allocated indices");
@@ -736,6 +736,7 @@ void Generator::generate(Region *region, ByteCodeWriter &writer) {
}
void Generator::generate(Operation *op, ByteCodeWriter &writer) {
+ LDBG() << "Generating bytecode for operation: " << op->getName();
LLVM_DEBUG({
// The following list must contain all the operations that do not
// produce any bytecode.
@@ -1275,12 +1276,8 @@ private:
/// Handle a switch operation with the provided value and cases.
template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
- LLVM_DEBUG({
- llvm::dbgs() << " * Value: " << value << "\n"
- << " * Cases: ";
- llvm::interleaveComma(cases, llvm::dbgs());
- llvm::dbgs() << "\n";
- });
+ LDBG() << "Switch operation:\n * Value: " << value
+ << "\n * Cases: " << llvm::interleaved(cases);
// Check to see if the attribute value is within the case list. Jump to
// the correct successor index based on the result.
@@ -1424,38 +1421,27 @@ private:
} // namespace
void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
- LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
+ LDBG() << "Executing ApplyConstraint:";
ByteCodeField fun_idx = read();
SmallVector<PDLValue, 16> args;
readList<PDLValue>(args);
- LLVM_DEBUG({
- llvm::dbgs() << " * Arguments: ";
- llvm::interleaveComma(args, llvm::dbgs());
- llvm::dbgs() << "\n";
- });
+ LDBG() << " * Arguments: " << llvm::interleaved(args);
ByteCodeField isNegated = read();
- LLVM_DEBUG({
- llvm::dbgs() << " * isNegated: " << isNegated << "\n";
- llvm::interleaveComma(args, llvm::dbgs());
- });
+ LDBG() << " * isNegated: " << isNegated;
ByteCodeField numResults = read();
const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
ByteCodeRewriteResultList results(numResults);
LogicalResult rewriteResult = constraintFn(rewriter, results, args);
[[maybe_unused]] ArrayRef<PDLValue> constraintResults = results.getResults();
- LLVM_DEBUG({
- if (succeeded(rewriteResult)) {
- llvm::dbgs() << " * Constraint succeeded\n";
- llvm::dbgs() << " * Results: ";
- llvm::interleaveComma(constraintResults, llvm::dbgs());
- llvm::dbgs() << "\n";
- } else {
- llvm::dbgs() << " * Constraint failed\n";
- }
- });
+ if (succeeded(rewriteResult)) {
+ LDBG() << " * Constraint succeeded, results: "
+ << llvm::interleaved(constraintResults);
+ } else {
+ LDBG() << " * Constraint failed";
+ }
assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
"native PDL rewrite function succeeded but returned "
"unexpected number of results");
@@ -1466,15 +1452,12 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
}
LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
- LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
+ LDBG() << "Executing ApplyRewrite:";
const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
SmallVector<PDLValue, 16> args;
readList<PDLValue>(args);
- LLVM_DEBUG({
- llvm::dbgs() << " * Arguments: ";
- llvm::interleaveComma(args, llvm::dbgs());
- });
+ LDBG() << " * Arguments: " << llvm::interleaved(args);
// Execute the rewrite function.
ByteCodeField numResults = read();
@@ -1487,7 +1470,7 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
processNativeFunResults(results, numResults, rewriteResult);
if (failed(rewriteResult)) {
- LLVM_DEBUG(llvm::dbgs() << " - Failed");
+ LDBG() << " - Failed";
return failure();
}
return success();
@@ -1516,7 +1499,7 @@ void ByteCodeExecutor::processNativeFunResults(
PDLValue::Kind resultKind = read<PDLValue::Kind>();
(void)resultKind;
PDLValue result = results.getResults()[resultIdx];
- LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
+ LDBG() << " * Result: " << result;
assert(result.getKind() == resultKind &&
"native PDL rewrite function returned an unexpected type of "
"result");
@@ -1544,16 +1527,16 @@ void ByteCodeExecutor::processNativeFunResults(
}
void ByteCodeExecutor::executeAreEqual() {
- LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
+ LDBG() << "Executing AreEqual:";
const void *lhs = read<const void *>();
const void *rhs = read<const void *>();
- LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n");
+ LDBG() << " * " << lhs << " == " << rhs;
selectJump(lhs == rhs);
}
void ByteCodeExecutor::executeAreRangesEqual() {
- LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
+ LDBG() << "Executing AreRangesEqual:";
PDLValue::Kind valueKind = read<PDLValue::Kind>();
const void *lhs = read<const void *>();
const void *rhs = read<const void *>();
@@ -1562,14 +1545,14 @@ void ByteCodeExecutor::executeAreRangesEqual() {
case PDLValue::Kind::TypeRange: {
const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
- LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
+ LDBG() << " * " << lhs << " == " << rhs;
selectJump(*lhsRange == *rhsRange);
break;
}
case PDLValue::Kind::ValueRange: {
const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
- LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
+ LDBG() << " * " << lhs << " == " << rhs;
selectJump(*lhsRange == *rhsRange);
break;
}
@@ -1579,20 +1562,19 @@ void ByteCodeExecutor::executeAreRangesEqual() {
}
void ByteCodeExecutor::executeBranch() {
- LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
+ LDBG() << "Executing Branch";
curCodeIt = &code[read<ByteCodeAddr>()];
}
void ByteCodeExecutor::executeCheckOperandCount() {
- LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
+ LDBG() << "Executing CheckOperandCount:";
Operation *op = read<Operation *>();
uint32_t expectedCount = read<uint32_t>();
bool compareAtLeast = read();
- LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n"
- << " * Expected: " << expectedCount << "\n"
- << " * Comparator: "
- << (compareAtLeast ? ">=" : "==") << "\n");
+ LDBG() << " * Found: " << op->getNumOperands()
+ << "\n * Expected: " << expectedCount
+ << "\n * Comparator: " << (compareAtLeast ? ">=" : "==");
if (compareAtLeast)
selectJump(op->getNumOperands() >= expectedCount);
else
@@ -1600,25 +1582,24 @@ void ByteCodeExecutor::executeCheckOperandCount() {
}
void ByteCodeExecutor::executeCheckOperationName() {
- LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
+ LDBG() << "Executing CheckOperationName:";
Operation *op = read<Operation *>();
OperationName expectedName = read<OperationName>();
- LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n"
- << " * Expected: \"" << expectedName << "\"\n");
+ LDBG() << " * Found: \"" << op->getName() << "\"\n * Expected: \""
+ << expectedName << "\"";
selectJump(op->getName() == expectedName);
}
void ByteCodeExecutor::executeCheckResultCount() {
- LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
+ LDBG() << "Executing CheckResultCount:";
Operation *op = read<Operation *>();
uint32_t expectedCount = read<uint32_t>();
bool compareAtLeast = read();
- LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n"
- << " * Expected: " << expectedCount << "\n"
- << " * Comparator: "
- << (compareAtLeast ? ">=" : "==") << "\n");
+ LDBG() << " * Found: " << op->getNumResults()
+ << "\n * Expected: " << expectedCount
+ << "\n * Comparator: " << (compareAtLeast ? ">=" : "==");
if (compareAtLeast)
selectJump(op->getNumResults() >= expectedCount);
else
@@ -1626,36 +1607,35 @@ void ByteCodeExecutor::executeCheckResultCount() {
}
void ByteCodeExecutor::executeCheckTypes() {
- LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
+ LDBG() << "Executing AreEqual:";
TypeRange *lhs = read<TypeRange *>();
Attribute rhs = read<Attribute>();
- LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
+ LDBG() << " * " << lhs << " == " << rhs;
selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
}
void ByteCodeExecutor::executeContinue() {
ByteCodeField level = read();
- LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
- << " * Level: " << level << "\n");
+ LDBG() << "Executing Continue\n * Level: " << level;
++loopIndex[level];
popCodeIt();
}
void ByteCodeExecutor::executeCreateConstantTypeRange() {
- LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n");
+ LDBG() << "Executing CreateConstantTypeRange:";
unsigned memIndex = read();
unsigned rangeIndex = read();
ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
- LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n");
+ LDBG() << " * Types: " << typesAttr;
assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
rangeIndex);
}
void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
Location mainRewriteLoc) {
- LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
+ LDBG() << "Executing CreateOperation:";
unsigned memIndex = read();
OperationState state(mainRewriteLoc, read<OperationName>());
@@ -1696,45 +1676,37 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
Operation *resultOp = rewriter.create(state);
memory[memIndex] = resultOp;
- LLVM_DEBUG({
- llvm::dbgs() << " * Attributes: "
- << state.attributes.getDictionary(state.getContext())
- << "\n * Operands: ";
- llvm::interleaveComma(state.operands, llvm::dbgs());
- llvm::dbgs() << "\n * Result Types: ";
- llvm::interleaveComma(state.types, llvm::dbgs());
- llvm::dbgs() << "\n * Result: " << *resultOp << "\n";
- });
+ LDBG() << " * Attributes: "
+ << state.attributes.getDictionary(state.getContext())
+ << "\n * Operands: " << llvm::interleaved(state.operands)
+ << "\n * Result Types: " << llvm::interleaved(state.types)
+ << "\n * Result: " << *resultOp;
}
template <typename T>
void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
- LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n");
+ LDBG() << "Executing CreateDynamic" << type << "Range:";
unsigned memIndex = read();
unsigned rangeIndex = read();
SmallVector<T> values;
readList(values);
- LLVM_DEBUG({
- llvm::dbgs() << "\n * " << type << "s: ";
- llvm::interleaveComma(values, llvm::dbgs());
- llvm::dbgs() << "\n";
- });
+ LDBG() << " * " << type << "s: " << llvm::interleaved(values);
assignRangeToMemory(values, memIndex, rangeIndex);
}
void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
- LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
+ LDBG() << "Executing EraseOp:";
Operation *op = read<Operation *>();
- LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
+ LDBG() << " * Operation: " << *op;
rewriter.eraseOp(op);
}
template <typename T, typename Range, PDLValue::Kind kind>
void ByteCodeExecutor::executeExtract() {
- LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
+ LDBG() << "Executing Extract" << kind << ":";
Range *range = read<Range *>();
unsigned index = read<uint32_t>();
unsigned memIndex = read();
@@ -1745,18 +1717,16 @@ void ByteCodeExecutor::executeExtract() {
}
T result = index < range->size() ? (*range)[index] : T();
- LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n"
- << " * Index: " << index << "\n"
- << " * Result: " << result << "\n");
+ LDBG() << " * " << kind << "s(" << range->size() << ")";
+ LDBG() << " * Index: " << index;
+ LDBG() << " * Result: " << result;
storeToMemory(memIndex, result);
}
-void ByteCodeExecutor::executeFinalize() {
- LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
-}
+void ByteCodeExecutor::executeFinalize() { LDBG() << "Executing Finalize"; }
void ByteCodeExecutor::executeForEach() {
- LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
+ LDBG() << "Executing ForEach:";
const ByteCodeField *prevCodeIt = getPrevCodeIt();
unsigned rangeIndex = read();
unsigned memIndex = read();
@@ -1768,12 +1738,12 @@ void ByteCodeExecutor::executeForEach() {
ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
assert(index <= array.size() && "iterated past the end");
if (index < array.size()) {
- LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n");
+ LDBG() << " * Result: " << array[index];
value = array[index];
break;
}
- LLVM_DEBUG(llvm::dbgs() << " * Done\n");
+ LDBG() << " * Done";
index = 0;
selectJump(size_t(0));
return;
@@ -1791,49 +1761,47 @@ void ByteCodeExecutor::executeForEach() {
}
void ByteCodeExecutor::executeGetAttribute() {
- LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
+ LDBG() << "Executing GetAttribute:";
unsigned memIndex = read();
Operation *op = read<Operation *>();
StringAttr attrName = read<StringAttr>();
Attribute attr = op->getAttr(attrName);
- LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
- << " * Attribute: " << attrName << "\n"
- << " * Result: " << attr << "\n");
+ LDBG() << " * Operation: " << *op << "\n * Attribute: " << attrName
+ << "\n * Result: " << attr;
memory[memIndex] = attr.getAsOpaquePointer();
}
void ByteCodeExecutor::executeGetAttributeType() {
- LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
+ LDBG() << "Executing GetAttributeType:";
unsigned memIndex = read();
Attribute attr = read<Attribute>();
Type type;
if (auto typedAttr = dyn_cast<TypedAttr>(attr))
type = typedAttr.getType();
- LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
- << " * Result: " << type << "\n");
+ LDBG() << " * Attribute: " << attr << "\n * Result: " << type;
memory[memIndex] = type.getAsOpaquePointer();
}
void ByteCodeExecutor::executeGetDefiningOp() {
- LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
+ LDBG() << "Executing GetDefiningOp:";
unsigned memIndex = read();
Operation *op = nullptr;
if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
Value value = read<Value>();
if (value)
op = value.getDefiningOp();
- LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
+ LDBG() << " * Value: " << value;
} else {
ValueRange *values = read<ValueRange *>();
if (values && !values->empty()) {
op = values->front().getDefiningOp();
}
- LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n");
+ LDBG() << " * Values: " << values;
}
- LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n");
+ LDBG() << " * Result: " << op;
memory[memIndex] = op;
}
@@ -1843,9 +1811,8 @@ void ByteCodeExecutor::executeGetOperand(unsigned index) {
Value operand =
index < op->getNumOperands() ? op->getOperand(index) : Value();
- LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
- << " * Index: " << index << "\n"
- << " * Result: " << operand << "\n");
+ LDBG() << " * Operation: " << *op << "\n * Index: " << index
+ << "\n * Result: " << operand;
memory[memIndex] = operand.getAsOpaquePointer();
}
@@ -1860,13 +1827,12 @@ executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
// Check for the sentinel index that signals that all values should be
// returned.
if (index == std::numeric_limits<uint32_t>::max()) {
- LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n");
+ LDBG() << " * Getting all values";
// `values` is already the full value range.
// Otherwise, check to see if this operation uses AttrSizedSegments.
} else if (op->hasTrait<AttrSizedSegmentsT>()) {
- LLVM_DEBUG(llvm::dbgs()
- << " * Extracting values from `" << attrSizedSegments << "`\n");
+ LDBG() << " * Extracting values from `" << attrSizedSegments << "`";
auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments);
if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
@@ -1877,16 +1843,15 @@ executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
std::accumulate(segments.begin(), segments.begin() + index, 0);
values = values.slice(startIndex, *std::next(segments.begin(), index));
- LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", "
- << *std::next(segments.begin(), index) << "]\n");
+ LDBG() << " * Extracting range[" << startIndex << ", "
+ << *std::next(segments.begin(), index) << "]";
// Otherwise, assume this is the last operand group of the operation.
// FIXME: We currently don't support operations with
// SameVariadicOperandSize/SameVariadicResultSize here given that we don't
// have a way to detect it's presence.
} else if (values.size() >= index) {
- LLVM_DEBUG(llvm::dbgs()
- << " * Treating values as trailing variadic range\n");
+ LDBG() << " * Treating values as trailing variadic range";
values = values.drop_front(index);
// If we couldn't detect a way to compute the values, bail out.
@@ -1905,7 +1870,7 @@ executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
}
void ByteCodeExecutor::executeGetOperands() {
- LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
+ LDBG() << "Executing GetOperands:";
unsigned index = read<uint32_t>();
Operation *op = read<Operation *>();
ByteCodeField rangeIndex = read();
@@ -1914,7 +1879,7 @@ void ByteCodeExecutor::executeGetOperands() {
op->getOperands(), op, index, rangeIndex, "operandSegmentSizes",
valueRangeMemory);
if (!result)
- LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n");
+ LDBG() << " * Invalid operand range";
memory[read()] = result;
}
@@ -1924,14 +1889,13 @@ void ByteCodeExecutor::executeGetResult(unsigned index) {
OpResult result =
index < op->getNumResults() ? op->getResult(index) : OpResult();
- LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
- << " * Index: " << index << "\n"
- << " * Result: " << result << "\n");
+ LDBG() << " * Operation: " << *op << "\n * Index: " << index
+ << "\n * Result: " << result;
memory[memIndex] = result.getAsOpaquePointer();
}
void ByteCodeExecutor::executeGetResults() {
- LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
+ LDBG() << "Executing GetResults:";
unsigned index = read<uint32_t>();
Operation *op = read<Operation *>();
ByteCodeField rangeIndex = read();
@@ -1940,12 +1904,12 @@ void ByteCodeExecutor::executeGetResults() {
op->getResults(), op, index, rangeIndex, "resultSegmentSizes",
valueRangeMemory);
if (!result)
- LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n");
+ LDBG() << " * Invalid result range";
memory[read()] = result;
}
void ByteCodeExecutor::executeGetUsers() {
- LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
+ LDBG() << "Executing GetUsers:";
unsigned memIndex = read();
unsigned rangeIndex = read();
OwningOpRange &range = opRangeMemory[rangeIndex];
@@ -1957,7 +1921,7 @@ void ByteCodeExecutor::executeGetUsers() {
Value value = read<Value>();
if (!value)
return;
- LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
+ LDBG() << " * Value: " << value;
// Extract the users of a single value.
range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
@@ -1967,11 +1931,8 @@ void ByteCodeExecutor::executeGetUsers() {
ValueRange *values = read<ValueRange *>();
if (!values)
return;
- LLVM_DEBUG({
- llvm::dbgs() << " * Values (" << values->size() << "): ";
- llvm::interleaveComma(*values, llvm::dbgs());
- llvm::dbgs() << "\n";
- });
+ LDBG() << " * Values (" << values->size()
+ << "): " << llvm::interleaved(*values);
// Extract all the users of a range of values.
SmallVector<Operation *> users;
@@ -1981,54 +1942,49 @@ void ByteCodeExecutor::executeGetUsers() {
llvm::copy(users, range.begin());
}
- LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n");
+ LDBG() << " * Result: " << range.size() << " operations";
}
void ByteCodeExecutor::executeGetValueType() {
- LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
+ LDBG() << "Executing GetValueType:";
unsigned memIndex = read();
Value value = read<Value>();
Type type = value ? value.getType() : Type();
- LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
- << " * Result: " << type << "\n");
+ LDBG() << " * Value: " << value << "\n * Result: " << type;
memory[memIndex] = type.getAsOpaquePointer();
}
void ByteCodeExecutor::executeGetValueRangeTypes() {
- LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
+ LDBG() << "Executing GetValueRangeTypes:";
unsigned memIndex = read();
unsigned rangeIndex = read();
ValueRange *values = read<ValueRange *>();
if (!values) {
- LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n");
+ LDBG() << " * Values: <NULL>";
memory[memIndex] = nullptr;
return;
}
- LLVM_DEBUG({
- llvm::dbgs() << " * Values (" << values->size() << "): ";
- llvm::interleaveComma(*values, llvm::dbgs());
- llvm::dbgs() << "\n * Result: ";
- llvm::interleaveComma(values->getType(), llvm::dbgs());
- llvm::dbgs() << "\n";
- });
+ LDBG() << " * Values (" << values->size()
+ << "): " << llvm::interleaved(*values)
+ << "\n * Result: " << llvm::interleaved(values->getType());
typeRangeMemory[rangeIndex] = values->getType();
memory[memIndex] = &typeRangeMemory[rangeIndex];
}
void ByteCodeExecutor::executeIsNotNull() {
- LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
+ LDBG() << "Executing IsNotNull:";
const void *value = read<const void *>();
- LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
+ LDBG() << " * Value: " << value;
selectJump(value != nullptr);
}
void ByteCodeExecutor::executeRecordMatch(
PatternRewriter &rewriter,
SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
- LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
+ LDBG() << "Executing RecordMatch:";
unsigned patternIndex = read();
PatternBenefit benefit = currentPatternBenefits[patternIndex];
const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
@@ -2036,7 +1992,7 @@ void ByteCodeExecutor::executeRecordMatch(
// If the benefit of the pattern is impossible, skip the processing of the
// rest of the pattern.
if (benefit.isImpossibleToMatch()) {
- LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n");
+ LDBG() << " * Benefit: Impossible To Match";
curCodeIt = dest;
return;
}
@@ -2052,8 +2008,8 @@ void ByteCodeExecutor::executeRecordMatch(
matchLocs.push_back(read<Operation *>()->getLoc());
Location matchLoc = rewriter.getFusedLoc(matchLocs);
- LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n"
- << " * Location: " << matchLoc << "\n");
+ LDBG() << " * Benefit: " << benefit.getBenefit();
+ LDBG() << " * Location: " << matchLoc;
matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
PDLByteCode::MatchResult &match = matches.back();
@@ -2083,38 +2039,34 @@ void ByteCodeExecutor::executeRecordMatch(
}
void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
- LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
+ LDBG() << "Executing ReplaceOp:";
Operation *op = read<Operation *>();
SmallVector<Value, 16> args;
readList(args);
- LLVM_DEBUG({
- llvm::dbgs() << " * Operation: " << *op << "\n"
- << " * Values: ";
- llvm::interleaveComma(args, llvm::dbgs());
- llvm::dbgs() << "\n";
- });
+ LDBG() << " * Operation: " << *op
+ << "\n * Values: " << llvm::interleaved(args);
rewriter.replaceOp(op, args);
}
void ByteCodeExecutor::executeSwitchAttribute() {
- LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
+ LDBG() << "Executing SwitchAttribute:";
Attribute value = read<Attribute>();
ArrayAttr cases = read<ArrayAttr>();
handleSwitch(value, cases);
}
void ByteCodeExecutor::executeSwitchOperandCount() {
- LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
+ LDBG() << "Executing SwitchOperandCount:";
Operation *op = read<Operation *>();
auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
- LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
+ LDBG() << " * Operation: " << *op;
handleSwitch(op->getNumOperands(), cases);
}
void ByteCodeExecutor::executeSwitchOperationName() {
- LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
+ LDBG() << "Executing SwitchOperationName:";
OperationName value = read<Operation *>()->getName();
size_t caseCount = read();
@@ -2123,13 +2075,11 @@ void ByteCodeExecutor::executeSwitchOperationName() {
// switch so that we can display all of the possible values.
LLVM_DEBUG({
const ByteCodeField *prevCodeIt = curCodeIt;
- llvm::dbgs() << " * Value: " << value << "\n"
- << " * Cases: ";
- llvm::interleaveComma(
- llvm::map_range(llvm::seq<size_t>(0, caseCount),
- [&](size_t) { return read<OperationName>(); }),
- llvm::dbgs());
- llvm::dbgs() << "\n";
+ LDBG() << " * Value: " << value << "\n * Cases: "
+ << llvm::interleaved(
+ llvm::map_range(llvm::seq<size_t>(0, caseCount), [&](size_t) {
+ return read<OperationName>();
+ }));
curCodeIt = prevCodeIt;
});
@@ -2144,27 +2094,27 @@ void ByteCodeExecutor::executeSwitchOperationName() {
}
void ByteCodeExecutor::executeSwitchResultCount() {
- LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
+ LDBG() << "Executing SwitchResultCount:";
Operation *op = read<Operation *>();
auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
- LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
+ LDBG() << " * Operation: " << *op;
handleSwitch(op->getNumResults(), cases);
}
void ByteCodeExecutor::executeSwitchType() {
- LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
+ LDBG() << "Executing SwitchType:";
Type value = read<Type>();
auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
handleSwitch(value, cases);
}
void ByteCodeExecutor::executeSwitchTypes() {
- LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
+ LDBG() << "Executing SwitchTypes:";
TypeRange *value = read<TypeRange *>();
auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
if (!value) {
- LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
+ LDBG() << "Types: <NULL>";
return selectJump(size_t(0));
}
handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
@@ -2178,7 +2128,7 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter,
std::optional<Location> mainRewriteLoc) {
while (true) {
// Print the location of the operation being executed.
- LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
+ LDBG() << readInline<Location>();
OpCode opCode = static_cast<OpCode>(read());
switch (opCode) {
@@ -2239,7 +2189,7 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter,
break;
case Finalize:
executeFinalize();
- LLVM_DEBUG(llvm::dbgs() << "\n");
+ LDBG() << "";
return success();
case ForEach:
executeForEach();
@@ -2258,12 +2208,12 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter,
case GetOperand2:
case GetOperand3: {
unsigned index = opCode - GetOperand0;
- LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
+ LDBG() << "Executing GetOperand" << index << ":";
executeGetOperand(index);
break;
}
case GetOperandN:
- LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
+ LDBG() << "Executing GetOperandN:";
executeGetOperand(read<uint32_t>());
break;
case GetOperands:
@@ -2274,12 +2224,12 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter,
case GetResult2:
case GetResult3: {
unsigned index = opCode - GetResult0;
- LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
+ LDBG() << "Executing GetResult" << index << ":";
executeGetResult(index);
break;
}
case GetResultN:
- LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
+ LDBG() << "Executing GetResultN:";
executeGetResult(read<uint32_t>());
break;
case GetResults:
@@ -2324,7 +2274,7 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter,
executeSwitchTypes();
break;
}
- LLVM_DEBUG(llvm::dbgs() << "\n");
+ LDBG() << "";
}
}
@@ -2383,7 +2333,7 @@ LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter,
// bug in the user code (i.e. failable rewrites should not be used with
// pattern rewriters that don't support it).
if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
- LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting");
+ LDBG() << " and rollback is not supported - aborting";
llvm::report_fatal_error(
"Native PDL Rewrite failed, but the pattern "
"rewriter doesn't support recovery. Failable pattern rewrites should "
diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp
index e13bcff..23ae95a 100644
--- a/mlir/lib/Rewrite/PatternApplicator.cpp
+++ b/mlir/lib/Rewrite/PatternApplicator.cpp
@@ -37,9 +37,9 @@ PatternApplicator::~PatternApplicator() = default;
#ifndef NDEBUG
/// Log a message for a pattern that is impossible to match.
static void logImpossibleToMatch(const Pattern &pattern) {
- llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind()
- << "' because it is impossible to match or cannot lead "
- "to legal IR (by cost model)\n";
+ LDBG() << "Ignoring pattern '" << pattern.getRootKind()
+ << "' because it is impossible to match or cannot lead "
+ "to legal IR (by cost model)";
}
/// Log IR after pattern application.
diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt
index 6eb0abc..f0c3ac4 100644
--- a/mlir/lib/Target/CMakeLists.txt
+++ b/mlir/lib/Target/CMakeLists.txt
@@ -4,3 +4,4 @@ add_subdirectory(SPIRV)
add_subdirectory(LLVMIR)
add_subdirectory(LLVM)
add_subdirectory(SMTLIB)
+add_subdirectory(Wasm)
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 8e83e45..570f38c 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -14,6 +14,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
#include "mlir/Support/IndentedOstream.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/Cpp/CppEmitter.h"
@@ -35,7 +36,7 @@ using llvm::formatv;
/// on each element doesn't return a string.
template <typename ForwardIterator, typename UnaryFunctor,
typename NullaryFunctor>
-inline LogicalResult
+static inline LogicalResult
interleaveWithError(ForwardIterator begin, ForwardIterator end,
UnaryFunctor eachFn, NullaryFunctor betweenFn) {
if (begin == end)
@@ -52,16 +53,16 @@ interleaveWithError(ForwardIterator begin, ForwardIterator end,
}
template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
-inline LogicalResult interleaveWithError(const Container &c,
- UnaryFunctor eachFn,
- NullaryFunctor betweenFn) {
+static inline LogicalResult interleaveWithError(const Container &c,
+ UnaryFunctor eachFn,
+ NullaryFunctor betweenFn) {
return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
}
template <typename Container, typename UnaryFunctor>
-inline LogicalResult interleaveCommaWithError(const Container &c,
- raw_ostream &os,
- UnaryFunctor eachFn) {
+static inline LogicalResult interleaveCommaWithError(const Container &c,
+ raw_ostream &os,
+ UnaryFunctor eachFn) {
return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
}
@@ -364,9 +365,10 @@ static bool shouldBeInlined(ExpressionOp expressionOp) {
if (hasDeferredEmission(user))
return false;
- // Do not inline expressions used by ops with the CExpressionInterface. If
- // this was intended, the user could have been merged into the expression op.
- return !isa<emitc::CExpressionInterface>(*user);
+ // Do not inline expressions used by other expressions or by ops with the
+ // CExpressionInterface. If this was intended, the user could have been merged
+ // into the expression op.
+ return !isa<emitc::ExpressionOp, emitc::CExpressionInterface>(*user);
}
static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
@@ -749,11 +751,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
if (t.getType().isIndex()) {
int64_t idx = t.getInt();
Value operand = op.getOperand(idx);
- if (!emitter.hasValueInScope(operand))
- return op.emitOpError("operand ")
- << idx << "'s value not defined in scope";
- os << emitter.getOrCreateName(operand);
- return success();
+ return emitter.emitOperand(operand);
}
}
if (failed(emitter.emitAttribute(op.getLoc(), attr)))
@@ -782,9 +780,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
if (failed(emitter.emitAssignPrefix(op)))
return failure();
os << applyOp.getApplicableOperator();
- os << emitter.getOrCreateName(applyOp.getOperand());
-
- return success();
+ return emitter.emitOperand(applyOp.getOperand());
}
static LogicalResult printOperation(CppEmitter &emitter,
@@ -1447,7 +1443,7 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
}
if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
if (auto iType = dyn_cast<IntegerType>(
- cast<TensorType>(dense.getType()).getElementType())) {
+ cast<ShapedType>(dense.getType()).getElementType())) {
os << '{';
interleaveComma(dense, os, [&](const APInt &val) {
printInt(val, shouldMapToUnsigned(iType.getSignedness()));
@@ -1456,7 +1452,7 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
return success();
}
if (auto iType = dyn_cast<IndexType>(
- cast<TensorType>(dense.getType()).getElementType())) {
+ cast<ShapedType>(dense.getType()).getElementType())) {
os << '{';
interleaveComma(dense, os,
[&](const APInt &val) { printInt(val, false); });
@@ -1538,6 +1534,20 @@ LogicalResult CppEmitter::emitOperand(Value value) {
if (expressionOp && shouldBeInlined(expressionOp))
return emitExpression(expressionOp);
+ if (BlockArgument arg = dyn_cast<BlockArgument>(value)) {
+ // If this operand is a block argument of an expression, emit instead the
+ // matching expression parameter.
+ Operation *argOp = arg.getParentBlock()->getParentOp();
+ if (auto expressionOp = dyn_cast<ExpressionOp>(argOp)) {
+ // This scenario is only expected when one of the operations within the
+ // expression being emitted references one of the expression's block
+ // arguments.
+ assert(expressionOp == emittedExpression &&
+ "Expected expression being emitted");
+ value = expressionOp->getOperand(arg.getArgNumber());
+ }
+ }
+
os << getOrCreateName(value);
return success();
}
@@ -1793,7 +1803,7 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
case 16: {
if (llvm::isa<Float16Type>(type))
return (os << "_Float16"), success();
- else if (llvm::isa<BFloat16Type>(type))
+ if (llvm::isa<BFloat16Type>(type))
return (os << "__bf16"), success();
else
return emitError(loc, "cannot emit float type ") << type;
diff --git a/mlir/lib/Target/LLVM/CMakeLists.txt b/mlir/lib/Target/LLVM/CMakeLists.txt
index f6e44c6..9a0e4d4 100644
--- a/mlir/lib/Target/LLVM/CMakeLists.txt
+++ b/mlir/lib/Target/LLVM/CMakeLists.txt
@@ -210,3 +210,27 @@ if(MLIR_ENABLE_ROCM_CONVERSIONS)
)
endif()
+if ("SPIRV" IN_LIST LLVM_TARGETS_TO_BUILD)
+ set(SPIRV_LIBS
+ SPIRVCodeGen
+ SPIRVDesc
+ SPIRVInfo
+ )
+endif()
+
+add_mlir_dialect_library(MLIRXeVMTarget
+ XeVM/Target.cpp
+
+ OBJECT
+
+ LINK_COMPONENTS
+ ${SPIRV_LIBS}
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRExecutionEngineUtils
+ MLIRSupport
+ MLIRGPUDialect
+ MLIRTargetLLVM
+ MLIRXeVMToLLVMIRTranslation
+)
diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp
index 55c8a64..8760ea8 100644
--- a/mlir/lib/Target/LLVM/NVVM/Target.cpp
+++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp
@@ -24,9 +24,11 @@
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
+#include "llvm/Support/InterleavedRange.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Config/Targets.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/FormatVariadic.h"
@@ -265,6 +267,8 @@ NVPTXSerializer::NVPTXSerializer(Operation &module, NVVMTargetAttr target,
std::optional<NVPTXSerializer::TmpFile>
NVPTXSerializer::createTemp(StringRef name, StringRef suffix) {
llvm::SmallString<128> filename;
+ if (name.size() > 80)
+ name = name.substr(0, 80);
std::error_code ec =
llvm::sys::fs::createTemporaryFile(name, suffix, filename);
if (ec) {
@@ -452,17 +456,11 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
// Dump tool invocation commands.
#define DEBUG_TYPE "serialize-to-binary"
- LLVM_DEBUG({
- llvm::dbgs() << "Tool invocation for module: "
- << getOperation().getNameAttr() << "\n";
- llvm::dbgs() << "ptxas executable:" << ptxasCompiler.value() << "\n";
- llvm::interleave(ptxasArgs, llvm::dbgs(), " ");
- llvm::dbgs() << "\n";
- if (createFatbin) {
- llvm::interleave(fatbinArgs, llvm::dbgs(), " ");
- llvm::dbgs() << "\n";
- }
- });
+ LDBG() << "Tool invocation for module: " << getOperation().getNameAttr()
+ << "\nptxas executable:" << ptxasCompiler.value()
+ << "\nptxas args: " << llvm::interleaved(ptxasArgs, " ");
+ if (createFatbin)
+ LDBG() << "fatbin args: " << llvm::interleaved(fatbinArgs, " ");
#undef DEBUG_TYPE
// Helper function for printing tool error logs.
@@ -507,7 +505,7 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> logBuffer =
llvm::MemoryBuffer::getFile(logFile->first);
if (logBuffer && !(*logBuffer)->getBuffer().empty()) {
- llvm::dbgs() << "Output:\n" << (*logBuffer)->getBuffer() << "\n";
+ LDBG() << "Output:\n" << (*logBuffer)->getBuffer();
llvm::dbgs().flush();
}
});
@@ -529,7 +527,7 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> logBuffer =
llvm::MemoryBuffer::getFile(logFile->first);
if (logBuffer && !(*logBuffer)->getBuffer().empty()) {
- llvm::dbgs() << "Output:\n" << (*logBuffer)->getBuffer() << "\n";
+ LDBG() << "Output:\n" << (*logBuffer)->getBuffer();
llvm::dbgs().flush();
}
});
@@ -629,12 +627,11 @@ NVPTXSerializer::compileToBinaryNVPTX(const std::string &ptxCode) {
SmallVector<char> log(logSize + 1, 0);
RETURN_ON_NVPTXCOMPILER_ERROR(
nvPTXCompilerGetInfoLog(compiler, log.data()));
- llvm::dbgs() << "NVPTX compiler invocation for module: "
- << getOperation().getNameAttr() << "\n";
- llvm::dbgs() << "Arguments: ";
- llvm::interleave(cmdOpts.second, llvm::dbgs(), " ");
- llvm::dbgs() << "\nOutput\n" << log.data() << "\n";
- llvm::dbgs().flush();
+ LDBG() << "NVPTX compiler invocation for module: "
+ << getOperation().getNameAttr()
+ << "\nArguments: " << llvm::interleaved(cmdOpts.second, " ")
+ << "\nOutput\n"
+ << log.data();
}
});
#undef DEBUG_TYPE
@@ -678,10 +675,8 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
// Return LLVM IR if the compilation target is `offload`.
#define DEBUG_TYPE "serialize-to-llvm"
LLVM_DEBUG({
- llvm::dbgs() << "LLVM IR for module: " << getOperation().getNameAttr()
- << "\n";
- llvm::dbgs() << llvmModule << "\n";
- llvm::dbgs().flush();
+ LDBG() << "LLVM IR for module: " << getOperation().getNameAttr();
+ LDBG() << llvmModule;
});
#undef DEBUG_TYPE
if (targetOptions.getCompilationTarget() == gpu::CompilationTarget::Offload)
@@ -716,11 +711,8 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
isaCallback(serializedISA.value());
#define DEBUG_TYPE "serialize-to-isa"
- LLVM_DEBUG({
- llvm::dbgs() << "PTX for module: " << getOperation().getNameAttr() << "\n";
- llvm::dbgs() << *serializedISA << "\n";
- llvm::dbgs().flush();
- });
+ LDBG() << "PTX for module: " << getOperation().getNameAttr() << "\n"
+ << *serializedISA;
#undef DEBUG_TYPE
// Return PTX if the compilation target is `assembly`.
diff --git a/mlir/lib/Target/LLVM/XeVM/Target.cpp b/mlir/lib/Target/LLVM/XeVM/Target.cpp
new file mode 100644
index 0000000..1e6784a2
--- /dev/null
+++ b/mlir/lib/Target/LLVM/XeVM/Target.cpp
@@ -0,0 +1,418 @@
+//===- Target.cpp - MLIR LLVM XeVM target compilation -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This files defines XeVM target related functions including registration
+// calls for the `#xevm.target` compilation attribute.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVM/XeVM/Target.h"
+
+#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
+#include "mlir/Target/LLVM/XeVM/Utils.h"
+#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Export.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/Target/TargetMachine.h"
+
+#include "llvm/Bitcode/BitcodeWriter.h"
+#include "llvm/Config/Targets.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/FileUtilities.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/Path.h"
+#include "llvm/Support/Process.h"
+#include "llvm/Support/Program.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <cstdint>
+#include <cstdlib>
+
+using namespace mlir;
+using namespace mlir::xevm;
+
+namespace {
+// XeVM implementation of the gpu:TargetAttrInterface.
+class XeVMTargetAttrImpl
+ : public gpu::TargetAttrInterface::FallbackModel<XeVMTargetAttrImpl> {
+public:
+ std::optional<SmallVector<char, 0>>
+ serializeToObject(Attribute attribute, Operation *module,
+ const gpu::TargetOptions &options) const;
+
+ Attribute createObject(Attribute attribute, Operation *module,
+ const SmallVector<char, 0> &object,
+ const gpu::TargetOptions &options) const;
+};
+} // namespace
+
+void mlir::xevm::registerXeVMTargetInterfaceExternalModels(
+ DialectRegistry &registry) {
+ registry.addExtension(+[](MLIRContext *ctx, XeVMDialect *dialect) {
+ XeVMTargetAttr::attachInterface<XeVMTargetAttrImpl>(*ctx);
+ });
+}
+
+void mlir::xevm::registerXeVMTargetInterfaceExternalModels(
+ MLIRContext &context) {
+ DialectRegistry registry;
+ registerXeVMTargetInterfaceExternalModels(registry);
+ context.appendDialectRegistry(registry);
+}
+
+SerializeGPUModuleBase::SerializeGPUModuleBase(
+ Operation &module, XeVMTargetAttr xeTarget,
+ const gpu::TargetOptions &targetOptions)
+ : ModuleToObject(module, xeTarget.getTriple(), "", {}, xeTarget.getO()),
+ xeTarget(xeTarget), librariesToLink(targetOptions.getLibrariesToLink()),
+ targetOptions(targetOptions) {
+ if (xeTarget.getLinkFiles())
+ librariesToLink.append(xeTarget.getLinkFiles().begin(),
+ xeTarget.getLinkFiles().end());
+}
+
+XeVMTargetAttr SerializeGPUModuleBase::getTarget() const { return xeTarget; }
+
+std::optional<SmallVector<std::unique_ptr<llvm::Module>>>
+SerializeGPUModuleBase::loadBitcodeFiles(llvm::Module &module) {
+ if (librariesToLink.empty())
+ return SmallVector<std::unique_ptr<llvm::Module>>();
+ SmallVector<std::unique_ptr<llvm::Module>> bcFiles;
+ if (failed(loadBitcodeFilesFromList(module.getContext(), librariesToLink,
+ bcFiles)))
+ return std::nullopt;
+ return std::move(bcFiles);
+}
+
+gpu::GPUModuleOp SerializeGPUModuleBase::getGPUModuleOp() {
+ return dyn_cast<gpu::GPUModuleOp>(&SerializeGPUModuleBase::getOperation());
+}
+
+// There is 1 way to finalize IL to native code: IGC
+// There are 2 ways to access IGC: AOT (ocloc) and JIT (L0 runtime).
+// - L0 runtime consumes IL and is external to MLIR codebase (rt wrappers).
+// - `ocloc` tool can be "queried" from within MLIR.
+std::optional<SmallVector<char, 0>>
+SerializeGPUModuleBase::compileToBinary(const std::string &asmStr,
+ StringRef inputFormat) {
+ using TmpFile = std::pair<llvm::SmallString<128>, llvm::FileRemover>;
+ // Find the `ocloc` tool.
+ std::optional<std::string> oclocCompiler = findTool("ocloc");
+ if (!oclocCompiler)
+ return std::nullopt;
+ Location loc = getGPUModuleOp().getLoc();
+ std::string basename = llvm::formatv(
+ "mlir-{0}-{1}-{2}", getGPUModuleOp().getNameAttr().getValue(),
+ getTarget().getTriple(), getTarget().getChip());
+
+ auto createTemp = [&](StringRef name,
+ StringRef suffix) -> std::optional<TmpFile> {
+ llvm::SmallString<128> filePath;
+ if (auto ec = llvm::sys::fs::createTemporaryFile(name, suffix, filePath)) {
+ getGPUModuleOp().emitError()
+ << "Couldn't create the temp file: `" << filePath
+ << "`, error message: " << ec.message();
+ return std::nullopt;
+ }
+ return TmpFile(filePath, llvm::FileRemover(filePath.c_str()));
+ };
+ // Create temp file
+ std::optional<TmpFile> asmFile = createTemp(basename, "asm");
+ std::optional<TmpFile> binFile = createTemp(basename, "");
+ std::optional<TmpFile> logFile = createTemp(basename, "log");
+ if (!logFile || !asmFile || !binFile)
+ return std::nullopt;
+ // Dump the assembly to a temp file
+ std::error_code ec;
+ {
+ llvm::raw_fd_ostream asmStream(asmFile->first, ec);
+ if (ec) {
+ emitError(loc) << "Couldn't open the file: `" << asmFile->first
+ << "`, error message: " << ec.message();
+ return std::nullopt;
+ }
+ asmStream << asmStr;
+ if (asmStream.has_error()) {
+ emitError(loc) << "An error occurred while writing the assembly to: `"
+ << asmFile->first << "`.";
+ return std::nullopt;
+ }
+ asmStream.flush();
+ }
+ // Set cmd options
+ std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> cmdOpts =
+ targetOptions.tokenizeCmdOptions();
+ // Example: --gpu-module-to-binary="opts='opt1 opt2'"
+ const std::string cmdOptsStr = "\"" + llvm::join(cmdOpts.second, " ") + "\"";
+ SmallVector<StringRef, 12> oclocArgs(
+ {"ocloc", "compile", "-file", asmFile->first, inputFormat, "-device",
+ getTarget().getChip(), "-output", binFile->first, "-output_no_suffix",
+ "-options", cmdOptsStr});
+
+// Dump tool invocation commands.
+#define DEBUG_TYPE "serialize-to-binary"
+ LLVM_DEBUG({
+ llvm::dbgs() << "Tool invocation for module: "
+ << getGPUModuleOp().getNameAttr() << "\n";
+ llvm::interleave(oclocArgs, llvm::dbgs(), " ");
+ llvm::dbgs() << "\n";
+ });
+#undef DEBUG_TYPE
+ // Helper function for printing tool error logs.
+ std::string message;
+ auto emitLogError =
+ [&](StringRef toolName) -> std::optional<SmallVector<char, 0>> {
+ if (message.empty()) {
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> toolStderr =
+ llvm::MemoryBuffer::getFile(logFile->first);
+ if (toolStderr)
+ emitError(loc) << toolName << " invocation failed. Log:\n"
+ << toolStderr->get()->getBuffer();
+ else
+ emitError(loc) << toolName << " invocation failed.";
+ return std::nullopt;
+ }
+ emitError(loc) << toolName
+ << " invocation failed, error message: " << message;
+ return std::nullopt;
+ };
+ std::optional<StringRef> redirects[] = {
+ std::nullopt,
+ logFile->first,
+ logFile->first,
+ };
+ // Invoke ocloc.
+ if (llvm::sys::ExecuteAndWait(oclocCompiler.value(), oclocArgs, std::nullopt,
+ redirects, 0, 0, &message))
+ return emitLogError("`ocloc`");
+ binFile->first.append(".bin");
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> binaryBuffer =
+ llvm::MemoryBuffer::getFile(binFile->first);
+ if (!binaryBuffer) {
+ emitError(loc) << "Couldn't open the file: `" << binFile->first
+ << "`, error message: " << binaryBuffer.getError().message();
+ return std::nullopt;
+ }
+ StringRef bin = (*binaryBuffer)->getBuffer();
+ return SmallVector<char, 0>(bin.begin(), bin.end());
+}
+
+std::optional<std::string> SerializeGPUModuleBase::findTool(StringRef tool) {
+ // 1. Check the toolkit path given in the command line.
+ StringRef pathRef = targetOptions.getToolkitPath();
+ SmallVector<char, 256> path;
+ if (!pathRef.empty()) {
+ path.insert(path.begin(), pathRef.begin(), pathRef.end());
+ llvm::sys::path::append(path, "bin", tool);
+ if (llvm::sys::fs::can_execute(path))
+ return StringRef(path.data(), path.size()).str();
+ }
+ // 2. Check PATH.
+ if (std::optional<std::string> toolPath =
+ llvm::sys::Process::FindInEnvPath("PATH", tool))
+ return *toolPath;
+
+ getGPUModuleOp().emitError()
+ << "Couldn't find the `" << tool
+ << "` binary. Please specify the toolkit "
+ "path via GpuModuleToBinaryPass or add the compiler to $PATH`.";
+ return std::nullopt;
+}
+
+namespace {
+class SPIRVSerializer : public SerializeGPUModuleBase {
+public:
+ SPIRVSerializer(Operation &module, XeVMTargetAttr xeTarget,
+ const gpu::TargetOptions &targetOptions)
+ : SerializeGPUModuleBase(module, xeTarget, targetOptions) {}
+
+ static void init();
+
+ /// Serializes the LLVM module to an object format, depending on the
+ /// compilation target selected in target options.
+ std::optional<SmallVector<char, 0>>
+ moduleToObject(llvm::Module &llvmModule) override;
+
+private:
+ /// Translates the LLVM module to SPIR-V binary using LLVM's
+ /// SPIR-V target.
+ std::optional<std::string>
+ translateToSPIRVBinary(llvm::Module &llvmModule,
+ llvm::TargetMachine &targetMachine);
+};
+} // namespace
+
+void SPIRVSerializer::init() {
+ static llvm::once_flag initializeBackendOnce;
+ llvm::call_once(initializeBackendOnce, []() {
+#if LLVM_HAS_SPIRV_TARGET
+ LLVMInitializeSPIRVTarget();
+ LLVMInitializeSPIRVTargetInfo();
+ LLVMInitializeSPIRVTargetMC();
+ LLVMInitializeSPIRVAsmPrinter();
+#endif
+ });
+}
+
+std::optional<SmallVector<char, 0>>
+SPIRVSerializer::moduleToObject(llvm::Module &llvmModule) {
+#define DEBUG_TYPE "serialize-to-llvm"
+ LLVM_DEBUG({
+ llvm::dbgs() << "LLVM IR for module: " << getGPUModuleOp().getNameAttr()
+ << "\n";
+ llvm::dbgs() << llvmModule << "\n";
+ llvm::dbgs().flush();
+ });
+#undef DEBUG_TYPE
+
+ // Return LLVM IR if the compilation target is `offload`.
+ if (targetOptions.getCompilationTarget() == gpu::CompilationTarget::Offload)
+ return SerializeGPUModuleBase::moduleToObject(llvmModule);
+
+#if !LLVM_HAS_SPIRV_TARGET
+ getGPUModuleOp()->emitError("The `SPIRV` target was not built. Please enable "
+ "it when building LLVM.");
+ return std::nullopt;
+#endif // LLVM_HAS_SPIRV_TARGET
+
+ std::optional<llvm::TargetMachine *> targetMachine =
+ getOrCreateTargetMachine();
+ if (!targetMachine) {
+ getGPUModuleOp().emitError() << "Target Machine unavailable for triple "
+ << triple << ", can't optimize with LLVM\n";
+ return std::nullopt;
+ }
+
+ // Return SPIRV if the compilation target is `assembly`.
+ if (targetOptions.getCompilationTarget() ==
+ gpu::CompilationTarget::Assembly) {
+ std::optional<std::string> serializedISA =
+ translateToISA(llvmModule, **targetMachine);
+ if (!serializedISA) {
+ getGPUModuleOp().emitError() << "Failed translating the module to ISA."
+ << triple << ", can't compile with LLVM\n";
+ return std::nullopt;
+ }
+
+#define DEBUG_TYPE "serialize-to-isa"
+ LLVM_DEBUG({
+ llvm::dbgs() << "SPIR-V for module: " << getGPUModuleOp().getNameAttr()
+ << "\n";
+ llvm::dbgs() << *serializedISA << "\n";
+ llvm::dbgs().flush();
+ });
+#undef DEBUG_TYPE
+
+ // Make sure to include the null terminator.
+ StringRef bin(serializedISA->c_str(), serializedISA->size() + 1);
+ return SmallVector<char, 0>(bin.begin(), bin.end());
+ }
+
+ // Level zero runtime is set up to accept SPIR-V binary
+ // translateToSPIRVBinary translates the LLVM module to SPIR-V binary
+ // using LLVM's SPIRV target.
+ // compileToBinary can be used in the future if level zero runtime
+ // implementation switches to native XeVM binary format.
+ std::optional<std::string> serializedSPIRVBinary =
+ translateToSPIRVBinary(llvmModule, **targetMachine);
+ if (!serializedSPIRVBinary) {
+ getGPUModuleOp().emitError() << "Failed translating the module to Binary.";
+ return std::nullopt;
+ }
+ if (serializedSPIRVBinary->size() % 4) {
+ getGPUModuleOp().emitError() << "SPIRV code size must be a multiple of 4.";
+ return std::nullopt;
+ }
+ StringRef bin(serializedSPIRVBinary->c_str(), serializedSPIRVBinary->size());
+ return SmallVector<char, 0>(bin.begin(), bin.end());
+}
+
+std::optional<std::string>
+SPIRVSerializer::translateToSPIRVBinary(llvm::Module &llvmModule,
+ llvm::TargetMachine &targetMachine) {
+ std::string targetISA;
+ llvm::raw_string_ostream stream(targetISA);
+
+ { // Drop pstream after this to prevent the ISA from being stuck buffering
+ llvm::buffer_ostream pstream(stream);
+ llvm::legacy::PassManager codegenPasses;
+ if (targetMachine.addPassesToEmitFile(codegenPasses, pstream, nullptr,
+ llvm::CodeGenFileType::ObjectFile))
+ return std::nullopt;
+
+ codegenPasses.run(llvmModule);
+ }
+ return targetISA;
+}
+
+std::optional<SmallVector<char, 0>>
+XeVMTargetAttrImpl::serializeToObject(Attribute attribute, Operation *module,
+ const gpu::TargetOptions &options) const {
+ if (!module)
+ return std::nullopt;
+ auto gpuMod = dyn_cast<gpu::GPUModuleOp>(module);
+ if (!gpuMod) {
+ module->emitError("expected to be a gpu.module op");
+ return std::nullopt;
+ }
+ auto xeTarget = cast<XeVMTargetAttr>(attribute);
+ if (xeTarget.getTriple().starts_with("spirv")) {
+ gpuMod.walk([&](LLVM::LLVMFuncOp funcOp) {
+ if (funcOp->hasAttr(gpu::GPUDialect::getKernelFuncAttrName())) {
+ funcOp.setIntelReqdSubGroupSize(16);
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+
+ SPIRVSerializer serializer(*module, cast<XeVMTargetAttr>(attribute),
+ options);
+ serializer.init();
+
+#if !LLVM_HAS_SPIRV_TARGET
+ module->emitError("Cannot run `TargetRegistry::lookupTarget()` for SPIRV "
+ "without having the target built.");
+#endif
+
+ return serializer.run();
+ }
+ module->emitError("Unsupported XeVM target triple: ") << xeTarget.getTriple();
+ return std::nullopt;
+}
+
+Attribute
+XeVMTargetAttrImpl::createObject(Attribute attribute, Operation *module,
+ const SmallVector<char, 0> &object,
+ const gpu::TargetOptions &options) const {
+ Builder builder(attribute.getContext());
+ gpu::CompilationTarget format = options.getCompilationTarget();
+ auto xeTarget = cast<XeVMTargetAttr>(attribute);
+ SmallVector<NamedAttribute, 2> properties;
+ if (format == gpu::CompilationTarget::Assembly)
+ properties.push_back(
+ builder.getNamedAttr("O", builder.getI32IntegerAttr(xeTarget.getO())));
+
+ DictionaryAttr objectProps;
+ if (!properties.empty())
+ objectProps = builder.getDictionaryAttr(properties);
+
+ return builder.getAttr<gpu::ObjectAttr>(
+ attribute, format,
+ builder.getStringAttr(StringRef(object.data(), object.size())),
+ objectProps, /*kernels=*/nullptr);
+}
diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt
index 9ea5c683..a73a78d 100644
--- a/mlir/lib/Target/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt
@@ -1,4 +1,5 @@
add_subdirectory(Dialect)
+add_subdirectory(Transforms)
set(LLVM_OPTIONAL_SOURCES
ConvertFromLLVMIR.cpp
@@ -58,6 +59,7 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration
MLIROpenACCToLLVMIRTranslation
MLIROpenMPToLLVMIRTranslation
MLIRROCDLToLLVMIRTranslation
+ MLIRPtrToLLVMIRTranslation
MLIRSPIRVToLLVMIRTranslation
MLIRVCIXToLLVMIRTranslation
MLIRXeVMToLLVMIRTranslation
diff --git a/mlir/lib/Target/LLVMIR/DataLayoutImporter.cpp b/mlir/lib/Target/LLVMIR/DataLayoutImporter.cpp
index fbad5c2..8bd07cd 100644
--- a/mlir/lib/Target/LLVMIR/DataLayoutImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DataLayoutImporter.cpp
@@ -6,13 +6,14 @@
//
//===----------------------------------------------------------------------===//
-#include "DataLayoutImporter.h"
+#include "mlir/Target/LLVMIR/DataLayoutImporter.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Target/LLVMIR/Import.h"
+
#include "llvm/IR/DataLayout.h"
using namespace mlir;
@@ -274,101 +275,88 @@ DataLayoutImporter::tryToEmplaceLegalIntWidthsEntry(StringRef token) {
return success();
}
-void DataLayoutImporter::translateDataLayout(
- const llvm::DataLayout &llvmDataLayout) {
- dataLayout = {};
-
- // Transform the data layout to its string representation and append the
- // default data layout string specified in the language reference
- // (https://llvm.org/docs/LangRef.html#data-layout). The translation then
- // parses the string and ignores the default value if a specific kind occurs
- // in both strings. Additionally, the following default values exist:
- // - non-default address space pointer specifications default to the default
- // address space pointer specification
- // - the alloca address space defaults to the default address space.
- layoutStr = llvmDataLayout.getStringRepresentation();
- if (!layoutStr.empty())
- layoutStr += "-";
- layoutStr += kDefaultDataLayout;
- StringRef layout(layoutStr);
+DataLayoutSpecInterface DataLayoutImporter::dataLayoutSpecFromDataLayoutStr() {
+ if (!dataLayoutStr.empty())
+ dataLayoutStr += "-";
+ dataLayoutStr += kDefaultDataLayout;
// Split the data layout string into tokens separated by a dash.
SmallVector<StringRef> tokens;
- layout.split(tokens, '-');
+ StringRef(dataLayoutStr).split(tokens, '-');
for (StringRef token : tokens) {
lastToken = token;
FailureOr<StringRef> prefix = tryToParseAlphaPrefix(token);
if (failed(prefix))
- return;
+ return {};
// Parse the endianness.
if (*prefix == "e") {
if (failed(tryToEmplaceEndiannessEntry(
DLTIDialect::kDataLayoutEndiannessLittle, token)))
- return;
+ return {};
continue;
}
if (*prefix == "E") {
if (failed(tryToEmplaceEndiannessEntry(
DLTIDialect::kDataLayoutEndiannessBig, token)))
- return;
+ return {};
continue;
}
// Parse the program address space.
if (*prefix == "P") {
if (failed(tryToEmplaceAddrSpaceEntry(
token, DLTIDialect::kDataLayoutProgramMemorySpaceKey)))
- return;
+ return {};
continue;
}
// Parse the mangling mode.
if (*prefix == "m") {
if (failed(tryToEmplaceManglingModeEntry(
token, DLTIDialect::kDataLayoutManglingModeKey)))
- return;
+ return {};
continue;
}
// Parse the global address space.
if (*prefix == "G") {
if (failed(tryToEmplaceAddrSpaceEntry(
token, DLTIDialect::kDataLayoutGlobalMemorySpaceKey)))
- return;
+ return {};
continue;
}
// Parse the alloca address space.
if (*prefix == "A") {
if (failed(tryToEmplaceAddrSpaceEntry(
token, DLTIDialect::kDataLayoutAllocaMemorySpaceKey)))
- return;
+ return {};
continue;
}
// Parse the stack alignment.
if (*prefix == "S") {
if (failed(tryToEmplaceStackAlignmentEntry(token)))
- return;
+ return {};
continue;
}
// Parse integer alignment specifications.
if (*prefix == "i") {
FailureOr<uint64_t> width = tryToParseInt(token);
if (failed(width))
- return;
+ return {};
Type type = IntegerType::get(context, *width);
if (failed(tryToEmplaceAlignmentEntry(type, token)))
- return;
+ return {};
continue;
}
// Parse float alignment specifications.
if (*prefix == "f") {
FailureOr<uint64_t> width = tryToParseInt(token);
if (failed(width))
- return;
+ return {};
Type type = getFloatType(context, *width);
if (failed(tryToEmplaceAlignmentEntry(type, token)))
- return;
+ return {};
continue;
}
// Parse pointer alignment specifications.
@@ -376,17 +364,17 @@ void DataLayoutImporter::translateDataLayout(
FailureOr<uint64_t> space =
token.starts_with(":") ? 0 : tryToParseInt(token);
if (failed(space))
- return;
+ return {};
auto type = LLVMPointerType::get(context, *space);
if (failed(tryToEmplacePointerAlignmentEntry(type, token)))
- return;
+ return {};
continue;
}
// Parse native integer widths specifications.
if (*prefix == "n") {
if (failed(tryToEmplaceLegalIntWidthsEntry(token)))
- return;
+ return {};
continue;
}
// Parse function pointer alignment specifications.
@@ -394,7 +382,7 @@ void DataLayoutImporter::translateDataLayout(
if (prefix->starts_with("F")) {
StringRef nextPrefix = prefix->drop_front(1);
if (failed(tryToEmplaceFunctionPointerAlignmentEntry(nextPrefix, token)))
- return;
+ return {};
continue;
}
@@ -409,11 +397,12 @@ void DataLayoutImporter::translateDataLayout(
entries.push_back(it.second);
for (const auto &it : keyEntries)
entries.push_back(it.second);
- dataLayout = DataLayoutSpecAttr::get(context, entries);
+ return DataLayoutSpecAttr::get(context, entries);
}
DataLayoutSpecInterface
mlir::translateDataLayout(const llvm::DataLayout &dataLayout,
MLIRContext *context) {
- return DataLayoutImporter(context, dataLayout).getDataLayout();
+ return DataLayoutImporter(context, dataLayout.getStringRepresentation())
+ .getDataLayoutSpec();
}
diff --git a/mlir/lib/Target/LLVMIR/DataLayoutImporter.h b/mlir/lib/Target/LLVMIR/DataLayoutImporter.h
deleted file mode 100644
index 88ceaf1..0000000
--- a/mlir/lib/Target/LLVMIR/DataLayoutImporter.h
+++ /dev/null
@@ -1,132 +0,0 @@
-//===- DataLayoutImporter.h - LLVM to MLIR data layout conversion -*- C++ -*-=//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements the translation between the LLVMIR data layout and the
-// corresponding MLIR representation.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_LIB_TARGET_LLVMIR_DATALAYOUTIMPORTER_H_
-#define MLIR_LIB_TARGET_LLVMIR_DATALAYOUTIMPORTER_H_
-
-#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/Interfaces/DataLayoutInterfaces.h"
-#include "llvm/ADT/MapVector.h"
-
-namespace llvm {
-class StringRef;
-class DataLayout;
-} // namespace llvm
-
-namespace mlir {
-class FloatType;
-class MLIRContext;
-class Operation;
-
-namespace LLVM {
-class LLVMFuncOp;
-
-namespace detail {
-
-/// Returns a supported MLIR floating point type of the given bit width or
-/// null if the bit width is not supported.
-FloatType getFloatType(MLIRContext *context, unsigned width);
-
-/// Helper class that translates an LLVM data layout to an MLIR data layout
-/// specification. Only integer, float, pointer, alloca memory space, stack
-/// alignment, and endianness entries are translated. The class also returns all
-/// entries from the default data layout specification found in the language
-/// reference (https://llvm.org/docs/LangRef.html#data-layout) if they are not
-/// overwritten by the provided data layout.
-class DataLayoutImporter {
-public:
- DataLayoutImporter(MLIRContext *context,
- const llvm::DataLayout &llvmDataLayout)
- : context(context) {
- translateDataLayout(llvmDataLayout);
- }
-
- /// Returns the MLIR data layout specification translated from the LLVM
- /// data layout.
- DataLayoutSpecInterface getDataLayout() const { return dataLayout; }
-
- /// Returns the last data layout token that has been processed before
- /// the data layout translation failed.
- StringRef getLastToken() const { return lastToken; }
-
- /// Returns the data layout tokens that have not been handled during the
- /// data layout translation.
- ArrayRef<StringRef> getUnhandledTokens() const { return unhandledTokens; }
-
-private:
- /// Translates the LLVM `dataLayout` to an MLIR data layout specification.
- void translateDataLayout(const llvm::DataLayout &llvmDataLayout);
-
- /// Tries to parse the letter only prefix that identifies the specification
- /// and removes the consumed characters from the beginning of the string.
- FailureOr<StringRef> tryToParseAlphaPrefix(StringRef &token) const;
-
- /// Tries to parse an integer parameter and removes the integer from the
- /// beginning of the string.
- FailureOr<uint64_t> tryToParseInt(StringRef &token) const;
-
- /// Tries to parse an integer parameter array.
- FailureOr<SmallVector<uint64_t>> tryToParseIntList(StringRef token) const;
-
- /// Tries to parse the parameters of a type alignment entry.
- FailureOr<DenseIntElementsAttr> tryToParseAlignment(StringRef token) const;
-
- /// Tries to parse the parameters of a pointer alignment entry.
- FailureOr<DenseIntElementsAttr>
- tryToParsePointerAlignment(StringRef token) const;
-
- /// Adds a type alignment entry if there is none yet.
- LogicalResult tryToEmplaceAlignmentEntry(Type type, StringRef token);
-
- /// Adds a pointer alignment entry if there is none yet.
- LogicalResult tryToEmplacePointerAlignmentEntry(LLVMPointerType type,
- StringRef token);
-
- /// Adds an endianness entry if there is none yet.
- LogicalResult tryToEmplaceEndiannessEntry(StringRef endianness,
- StringRef token);
-
- /// Adds an alloca address space entry if there is none yet.
- LogicalResult tryToEmplaceAddrSpaceEntry(StringRef token,
- llvm::StringLiteral spaceKey);
-
- /// Adds an mangling mode entry if there is none yet.
- LogicalResult tryToEmplaceManglingModeEntry(StringRef token,
- llvm::StringLiteral manglingKey);
-
- /// Adds a stack alignment entry if there is none yet.
- LogicalResult tryToEmplaceStackAlignmentEntry(StringRef token);
-
- /// Adds a function pointer alignment entry if there is none yet.
- LogicalResult
- tryToEmplaceFunctionPointerAlignmentEntry(StringRef fnPtrAlignEntry,
- StringRef token);
-
- /// Adds legal int widths entry if there is none yet.
- LogicalResult tryToEmplaceLegalIntWidthsEntry(StringRef token);
-
- std::string layoutStr = {};
- StringRef lastToken = {};
- SmallVector<StringRef> unhandledTokens;
- llvm::MapVector<StringAttr, DataLayoutEntryInterface> keyEntries;
- llvm::MapVector<TypeAttr, DataLayoutEntryInterface> typeEntries;
- MLIRContext *context;
- DataLayoutSpecInterface dataLayout;
-};
-
-} // namespace detail
-} // namespace LLVM
-} // namespace mlir
-
-#endif // MLIR_LIB_TARGET_LLVMIR_DATALAYOUTIMPORTER_H_
diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
index 86c731a..a102c43 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
@@ -8,6 +8,7 @@ add_subdirectory(NVVM)
add_subdirectory(OpenACC)
add_subdirectory(OpenMP)
add_subdirectory(ROCDL)
+add_subdirectory(Ptr)
add_subdirectory(SPIRV)
add_subdirectory(VCIX)
add_subdirectory(XeVM)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 0f675a0..fd8463a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -18,6 +18,7 @@
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/DIBuilder.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/Instructions.h"
@@ -358,6 +359,17 @@ static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder,
}
}
+static llvm::DILocalScope *
+getLocalScopeFromLoc(llvm::IRBuilderBase &builder, Location loc,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ if (auto scopeLoc =
+ loc->findInstanceOf<FusedLocWith<LLVM::DILocalScopeAttr>>())
+ if (auto *localScope = llvm::dyn_cast<llvm::DILocalScope>(
+ moduleTranslation.translateDebugInfo(scopeLoc.getMetadata())))
+ return localScope;
+ return builder.GetInsertBlock()->getParent()->getSubprogram();
+}
+
static LogicalResult
convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 90462d1..7f69af14 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -135,33 +135,83 @@ static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
llvm_unreachable("unsupported vote kind");
}
-/// Return the intrinsic ID associated with ldmatrix for the given paramters.
-static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
- int32_t num) {
- if (layout == NVVM::MMALayout::row) {
+static llvm::Intrinsic::ID
+getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
+ NVVM::LdStMatrixShapeAttr shape,
+ NVVM::LdStMatrixEltType eltType) {
+ if (shape.getM() == 8 && shape.getN() == 8) {
switch (num) {
case 1:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
+ return (layout == NVVM::MMALayout::row)
+ ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16
+ : llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
case 2:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
+ return (layout == NVVM::MMALayout::row)
+ ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16
+ : llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
case 4:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
- default:
- llvm_unreachable("unsupported number of matrix");
+ return (layout == NVVM::MMALayout::row)
+ ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16
+ : llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
}
-
- } else {
- switch (num) {
- case 1:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
- case 2:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
- case 4:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
- default:
- llvm_unreachable("unsupported number of matrix");
+ } else if (shape.getM() == 8 && shape.getN() == 16) {
+ if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32;
+ case 2:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32;
+ case 4:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32;
+ }
+ } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64;
+ case 2:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64;
+ case 4:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64;
+ }
+ }
+ } else if (shape.getM() == 16 && shape.getN() == 16) {
+ if (eltType == NVVM::LdStMatrixEltType::B8) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8;
+ case 2:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8;
+ }
+ } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32;
+ case 2:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32;
+ }
+ } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64;
+ case 2:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64;
+ }
}
}
+ llvm_unreachable("unknown ldmatrix kind");
}
/// Return the intrinsic ID associated with stmatrix for the given paramters.
@@ -418,7 +468,11 @@ public:
} else if (attribute.getName() ==
NVVM::NVVMDialect::getKernelFuncAttrName()) {
llvmFunc->setCallingConv(llvm::CallingConv::PTX_Kernel);
+ } else if (attribute.getName() ==
+ NVVM::NVVMDialect::getBlocksAreClustersAttrName()) {
+ llvmFunc->addFnAttr("nvvm.blocksareclusters");
}
+
return success();
}
@@ -429,51 +483,10 @@ public:
llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
llvm::Function *llvmFunc =
moduleTranslation.lookupFunction(funcOp.getName());
- llvm::NamedMDNode *nvvmAnnotations =
- moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations");
if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
- llvm::MDNode *gridConstantMetaData = nullptr;
-
- // Check if a 'grid_constant' metadata node exists for the given function
- for (llvm::MDNode *opnd : llvm::reverse(nvvmAnnotations->operands())) {
- if (opnd->getNumOperands() == 3 &&
- opnd->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc) &&
- opnd->getOperand(1) ==
- llvm::MDString::get(llvmContext, "grid_constant")) {
- gridConstantMetaData = opnd;
- break;
- }
- }
-
- // 'grid_constant' is a function-level meta data node with a list of
- // integers, where each integer n denotes that the nth parameter has the
- // grid_constant annotation (numbering from 1). This requires aggregating
- // the indices of the individual parameters that have this attribute.
- llvm::Type *i32 = llvm::IntegerType::get(llvmContext, 32);
- if (gridConstantMetaData == nullptr) {
- // Create a new 'grid_constant' metadata node
- SmallVector<llvm::Metadata *> gridConstMetadata = {
- llvm::ValueAsMetadata::getConstant(
- llvm::ConstantInt::get(i32, argIdx + 1))};
- llvm::Metadata *llvmMetadata[] = {
- llvm::ValueAsMetadata::get(llvmFunc),
- llvm::MDString::get(llvmContext, "grid_constant"),
- llvm::MDNode::get(llvmContext, gridConstMetadata)};
- llvm::MDNode *llvmMetadataNode =
- llvm::MDNode::get(llvmContext, llvmMetadata);
- nvvmAnnotations->addOperand(llvmMetadataNode);
- } else {
- // Append argIdx + 1 to the 'grid_constant' argument list
- if (auto argList =
- dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) {
- llvm::TempMDTuple clonedArgList = argList->clone();
- clonedArgList->push_back((llvm::ValueAsMetadata::getConstant(
- llvm::ConstantInt::get(i32, argIdx + 1))));
- gridConstantMetaData->replaceOperandWith(
- 2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList)));
- }
- }
+ llvmFunc->addParamAttr(
+ argIdx, llvm::Attribute::get(llvmContext, "nvvm.grid_constant"));
}
return success();
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 2cdd502..8a1b554 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2893,6 +2893,12 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
alignment = builder.getInt64(intAttr.getInt());
assert(ty->isPointerTy() && "Invalid type for aligned variable");
assert(alignment && "Invalid alignment value");
+
+ // Check if the alignment value is not a power of 2. If so, skip emitting
+ // alignment.
+ if (!intAttr.getValue().isPowerOf2())
+ continue;
+
auto curInsert = builder.saveIP();
builder.SetInsertPoint(sourceBlock);
llvmVal = builder.CreateLoad(ty, llvmVal);
@@ -4356,9 +4362,11 @@ createAlteredByCaptureMap(MapInfoData &mapData,
if (!isPtrTy) {
auto curInsert = builder.saveIP();
+ llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
builder.restoreIP(findAllocaInsertPoint(builder, moduleTranslation));
auto *memTempAlloc =
builder.CreateAlloca(builder.getPtrTy(), nullptr, ".casted");
+ builder.SetCurrentDebugLocation(DbgLoc);
builder.restoreIP(curInsert);
builder.CreateStore(newV, memTempAlloc);
@@ -5865,6 +5873,10 @@ static bool isTargetDeviceOp(Operation *op) {
if (mlir::isa<omp::ThreadprivateOp>(op))
return true;
+ if (mlir::isa<omp::TargetAllocMemOp>(op) ||
+ mlir::isa<omp::TargetFreeMemOp>(op))
+ return true;
+
if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>())
if (auto declareTargetIface =
llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
@@ -5877,6 +5889,85 @@ static bool isTargetDeviceOp(Operation *op) {
return false;
}
+static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
+ llvm::Module *llvmModule) {
+ llvm::Type *i64Ty = builder.getInt64Ty();
+ llvm::Type *i32Ty = builder.getInt32Ty();
+ llvm::Type *returnType = builder.getPtrTy(0);
+ llvm::FunctionType *fnType =
+ llvm::FunctionType::get(returnType, {i64Ty, i32Ty}, false);
+ llvm::Function *func = cast<llvm::Function>(
+ llvmModule->getOrInsertFunction("omp_target_alloc", fnType).getCallee());
+ return func;
+}
+
+static LogicalResult
+convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
+ if (!allocMemOp)
+ return failure();
+
+ // Get "omp_target_alloc" function
+ llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
+ llvm::Function *ompTargetAllocFunc = getOmpTargetAlloc(builder, llvmModule);
+ // Get the corresponding device value in llvm
+ mlir::Value deviceNum = allocMemOp.getDevice();
+ llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
+ // Get the allocation size.
+ llvm::DataLayout dataLayout = llvmModule->getDataLayout();
+ mlir::Type heapTy = allocMemOp.getAllocatedType();
+ llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy);
+ llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
+ llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
+ for (auto typeParam : allocMemOp.getTypeparams())
+ allocSize =
+ builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
+ // Create call to "omp_target_alloc" with the args as translated llvm values.
+ llvm::CallInst *call =
+ builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
+ llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
+
+ // Map the result
+ moduleTranslation.mapValue(allocMemOp.getResult(), resultI64);
+ return success();
+}
+
+static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder,
+ llvm::Module *llvmModule) {
+ llvm::Type *ptrTy = builder.getPtrTy(0);
+ llvm::Type *i32Ty = builder.getInt32Ty();
+ llvm::Type *voidTy = builder.getVoidTy();
+ llvm::FunctionType *fnType =
+ llvm::FunctionType::get(voidTy, {ptrTy, i32Ty}, false);
+ llvm::Function *func = dyn_cast<llvm::Function>(
+ llvmModule->getOrInsertFunction("omp_target_free", fnType).getCallee());
+ return func;
+}
+
+static LogicalResult
+convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
+ if (!freeMemOp)
+ return failure();
+
+ // Get "omp_target_free" function
+ llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
+ llvm::Function *ompTragetFreeFunc = getOmpTargetFree(builder, llvmModule);
+ // Get the corresponding device value in llvm
+ mlir::Value deviceNum = freeMemOp.getDevice();
+ llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
+ // Get the corresponding heapref value in llvm
+ mlir::Value heapref = freeMemOp.getHeapref();
+ llvm::Value *llvmHeapref = moduleTranslation.lookupValue(heapref);
+ // Convert heapref int to ptr and call "omp_target_free"
+ llvm::Value *intToPtr =
+ builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
+ builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
+ return success();
+}
+
/// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
/// OpenMP runtime calls).
static LogicalResult
@@ -6051,6 +6142,12 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
// the omp.canonical_loop.
return applyUnrollHeuristic(op, builder, moduleTranslation);
})
+ .Case([&](omp::TargetAllocMemOp) {
+ return convertTargetAllocMemOp(*op, builder, moduleTranslation);
+ })
+ .Case([&](omp::TargetFreeMemOp) {
+ return convertTargetFreeMemOp(*op, builder, moduleTranslation);
+ })
.Default([&](Operation *inst) {
return inst->emitError()
<< "not yet implemented: " << inst->getName();
@@ -6287,9 +6384,8 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
if (ompBuilder->Config.isTargetDevice()) {
if (isTargetDeviceOp(op)) {
return convertTargetDeviceOp(op, builder, moduleTranslation);
- } else {
- return convertTargetOpsInNest(op, builder, moduleTranslation);
}
+ return convertTargetOpsInNest(op, builder, moduleTranslation);
}
return convertHostOrTargetOperation(op, builder, moduleTranslation);
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/Ptr/CMakeLists.txt
new file mode 100644
index 0000000..f94410d
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_translation_library(MLIRPtrToLLVMIRTranslation
+ PtrToLLVMIRTranslation.cpp
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRPtrDialect
+ MLIRSupport
+ MLIRTargetLLVMIRExport
+ )
diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
new file mode 100644
index 0000000..7b89ec8
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
@@ -0,0 +1,66 @@
+//===- PtrToLLVMIRTranslation.cpp - Translate `ptr` to LLVM IR ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the MLIR `ptr` dialect and
+// LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.h"
+#include "mlir/Dialect/Ptr/IR/PtrOps.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+
+using namespace mlir;
+using namespace mlir::ptr;
+
+namespace {
+/// Implementation of the dialect interface that converts operations belonging
+/// to the `ptr` dialect to LLVM IR.
+class PtrDialectLLVMIRTranslationInterface
+ : public LLVMTranslationDialectInterface {
+public:
+ using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
+
+ /// Translates the given operation to LLVM IR using the provided IR builder
+ /// and saving the state in `moduleTranslation`.
+ LogicalResult
+ convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) const final {
+ // Translation for ptr dialect operations to LLVM IR is currently
+ // unimplemented.
+ return op->emitError("Translation for ptr dialect operations to LLVM IR is "
+ "not implemented.");
+ }
+
+ /// Attaches module-level metadata for functions marked as kernels.
+ LogicalResult
+ amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
+ LLVM::ModuleTranslation &moduleTranslation) const final {
+ // Translation for ptr dialect operations to LLVM IR is currently
+ // unimplemented.
+ return op->emitError("Translation for ptr dialect operations to LLVM IR is "
+ "not implemented.");
+ }
+};
+} // namespace
+
+void mlir::registerPtrDialectTranslation(DialectRegistry &registry) {
+ registry.insert<ptr::PtrDialect>();
+ registry.addExtension(+[](MLIRContext *ctx, ptr::PtrDialect *dialect) {
+ dialect->addInterfaces<PtrDialectLLVMIRTranslationInterface>();
+ });
+}
+
+void mlir::registerPtrDialectTranslation(MLIRContext &context) {
+ DialectRegistry registry;
+ registerPtrDialectTranslation(registry);
+ context.appendDialectRegistry(registry);
+}
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 6325480..7a888bb3 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -16,7 +16,6 @@
#include "mlir/Target/LLVMIR/Import.h"
#include "AttrKindDetail.h"
-#include "DataLayoutImporter.h"
#include "DebugImporter.h"
#include "LoopAnnotationImporter.h"
@@ -25,6 +24,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
+#include "mlir/Target/LLVMIR/DataLayoutImporter.h"
#include "mlir/Tools/mlir-translate/Translation.h"
#include "llvm/ADT/DepthFirstIterator.h"
@@ -1045,8 +1045,9 @@ LogicalResult ModuleImport::convertIFuncs() {
LogicalResult ModuleImport::convertDataLayout() {
Location loc = mlirModule.getLoc();
- DataLayoutImporter dataLayoutImporter(context, llvmModule->getDataLayout());
- if (!dataLayoutImporter.getDataLayout())
+ DataLayoutImporter dataLayoutImporter(
+ context, llvmModule->getDataLayout().getStringRepresentation());
+ if (!dataLayoutImporter.getDataLayoutSpec())
return emitError(loc, "cannot translate data layout: ")
<< dataLayoutImporter.getLastToken();
@@ -1054,7 +1055,7 @@ LogicalResult ModuleImport::convertDataLayout() {
emitWarning(loc, "unhandled data layout token: ") << token;
mlirModule->setAttr(DLTIDialect::kDataLayoutAttrName,
- dataLayoutImporter.getDataLayout());
+ dataLayoutImporter.getDataLayoutSpec());
return success();
}
@@ -1408,6 +1409,67 @@ LogicalResult ModuleImport::convertIFunc(llvm::GlobalIFunc *ifunc) {
return success();
}
+/// Converts LLVM string, integer, and enum attributes into MLIR attributes,
+/// skipping those in `attributesToSkip` and emitting a warning at `loc` for
+/// any other unsupported attributes.
+static ArrayAttr
+convertLLVMAttributesToMLIR(Location loc, MLIRContext *context,
+ llvm::AttributeSet attributes,
+ ArrayRef<StringLiteral> attributesToSkip = {}) {
+ SmallVector<Attribute> mlirAttributes;
+ for (llvm::Attribute attr : attributes) {
+ StringRef attrName;
+ if (attr.isStringAttribute())
+ attrName = attr.getKindAsString();
+ else
+ attrName = llvm::Attribute::getNameFromAttrKind(attr.getKindAsEnum());
+ if (llvm::is_contained(attributesToSkip, attrName))
+ continue;
+
+ auto keyAttr = StringAttr::get(context, attrName);
+ if (attr.isStringAttribute()) {
+ StringRef val = attr.getValueAsString();
+ if (val.empty()) {
+ // For string attributes without values, add only the attribute name.
+ mlirAttributes.push_back(keyAttr);
+ continue;
+ }
+ // For string attributes with a value, create a [name, value] pair.
+ mlirAttributes.push_back(
+ ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)}));
+ continue;
+ }
+ if (attr.isIntAttribute()) {
+ // For integer attributes, convert the value to a string and create a
+ // [name, value] pair.
+ auto val = std::to_string(attr.getValueAsInt());
+ mlirAttributes.push_back(
+ ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)}));
+ continue;
+ }
+ if (attr.isEnumAttribute()) {
+ // For enum attributes, add only the attribute name.
+ mlirAttributes.push_back(keyAttr);
+ continue;
+ }
+
+ emitWarning(loc)
+ << "'" << attrName
+ << "' attribute is invalid on current operation, skipping it";
+ }
+ return ArrayAttr::get(context, mlirAttributes);
+}
+
+/// Converts LLVM attributes from `globalVar` into MLIR attributes and adds them
+/// to `globalOp` as target-specific attributes.
+static void processTargetSpecificAttrs(llvm::GlobalVariable *globalVar,
+ GlobalOp globalOp) {
+ ArrayAttr targetSpecificAttrs = convertLLVMAttributesToMLIR(
+ globalOp.getLoc(), globalOp.getContext(), globalVar->getAttributes());
+ if (!targetSpecificAttrs.empty())
+ globalOp.setTargetSpecificAttrsAttr(targetSpecificAttrs);
+}
+
LogicalResult ModuleImport::convertGlobal(llvm::GlobalVariable *globalVar) {
// Insert the global after the last one or at the start of the module.
OpBuilder::InsertionGuard guard = setGlobalInsertionPoint();
@@ -1473,6 +1535,8 @@ LogicalResult ModuleImport::convertGlobal(llvm::GlobalVariable *globalVar) {
if (globalVar->hasComdat())
globalOp.setComdatAttr(comdatMapping.lookup(globalVar->getComdat()));
+ processTargetSpecificAttrs(globalVar, globalOp);
+
return success();
}
@@ -2525,7 +2589,7 @@ static void processMemoryEffects(llvm::Function *func, LLVMFuncOp funcOp) {
// List of LLVM IR attributes that map to an explicit attribute on the MLIR
// LLVMFuncOp.
-static constexpr std::array kExplicitAttributes{
+static constexpr std::array kExplicitLLVMFuncOpAttributes{
StringLiteral("aarch64_in_za"),
StringLiteral("aarch64_inout_za"),
StringLiteral("aarch64_new_za"),
@@ -2535,7 +2599,6 @@ static constexpr std::array kExplicitAttributes{
StringLiteral("aarch64_pstate_sm_compatible"),
StringLiteral("aarch64_pstate_sm_enabled"),
StringLiteral("alwaysinline"),
- StringLiteral("approx-func-fp-math"),
StringLiteral("convergent"),
StringLiteral("denormal-fp-math"),
StringLiteral("denormal-fp-math-f32"),
@@ -2543,6 +2606,7 @@ static constexpr std::array kExplicitAttributes{
StringLiteral("frame-pointer"),
StringLiteral("instrument-function-entry"),
StringLiteral("instrument-function-exit"),
+ StringLiteral("memory"),
StringLiteral("no-infs-fp-math"),
StringLiteral("no-nans-fp-math"),
StringLiteral("no-signed-zeros-fp-math"),
@@ -2557,61 +2621,17 @@ static constexpr std::array kExplicitAttributes{
StringLiteral("willreturn"),
};
+/// Converts LLVM attributes from `func` into MLIR attributes and adds them
+/// to `funcOp` as passthrough attributes, skipping those listed in
+/// `kExplicitLLVMFuncAttributes`.
static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) {
- MLIRContext *context = funcOp.getContext();
- SmallVector<Attribute> passthroughs;
llvm::AttributeSet funcAttrs = func->getAttributes().getAttributes(
llvm::AttributeList::AttrIndex::FunctionIndex);
- for (llvm::Attribute attr : funcAttrs) {
- // Skip the memory attribute since the LLVMFuncOp has an explicit memory
- // attribute.
- if (attr.hasAttribute(llvm::Attribute::Memory))
- continue;
-
- // Skip invalid type attributes.
- if (attr.isTypeAttribute()) {
- emitWarning(funcOp.getLoc(),
- "type attributes on a function are invalid, skipping it");
- continue;
- }
-
- StringRef attrName;
- if (attr.isStringAttribute())
- attrName = attr.getKindAsString();
- else
- attrName = llvm::Attribute::getNameFromAttrKind(attr.getKindAsEnum());
- auto keyAttr = StringAttr::get(context, attrName);
-
- // Skip attributes that map to an explicit attribute on the LLVMFuncOp.
- if (llvm::is_contained(kExplicitAttributes, attrName))
- continue;
-
- if (attr.isStringAttribute()) {
- StringRef val = attr.getValueAsString();
- if (val.empty()) {
- passthroughs.push_back(keyAttr);
- continue;
- }
- passthroughs.push_back(
- ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)}));
- continue;
- }
- if (attr.isIntAttribute()) {
- auto val = std::to_string(attr.getValueAsInt());
- passthroughs.push_back(
- ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)}));
- continue;
- }
- if (attr.isEnumAttribute()) {
- passthroughs.push_back(keyAttr);
- continue;
- }
-
- llvm_unreachable("unexpected attribute kind");
- }
-
- if (!passthroughs.empty())
- funcOp.setPassthroughAttr(ArrayAttr::get(context, passthroughs));
+ ArrayAttr passthroughAttr =
+ convertLLVMAttributesToMLIR(funcOp.getLoc(), funcOp.getContext(),
+ funcAttrs, kExplicitLLVMFuncOpAttributes);
+ if (!passthroughAttr.empty())
+ funcOp.setPassthroughAttr(passthroughAttr);
}
void ModuleImport::processFunctionAttributes(llvm::Function *func,
@@ -2703,10 +2723,6 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
attr.isStringAttribute())
funcOp.setNoNansFpMath(attr.getValueAsBool());
- if (llvm::Attribute attr = func->getFnAttribute("approx-func-fp-math");
- attr.isStringAttribute())
- funcOp.setApproxFuncFpMath(attr.getValueAsBool());
-
if (llvm::Attribute attr = func->getFnAttribute("instrument-function-entry");
attr.isStringAttribute())
funcOp.setInstrumentFunctionEntry(
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index b3a06e2..97253591 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1081,6 +1081,83 @@ static void addRuntimePreemptionSpecifier(bool dsoLocalRequested,
gv->setDSOLocal(true);
}
+/// Attempts to translate an MLIR attribute identified by `key`, optionally with
+/// the given `value`, into an LLVM IR attribute. Reports errors at `loc` if
+/// any. If the attribute name corresponds to a known LLVM IR attribute kind,
+/// creates the LLVM attribute of that kind; otherwise, keeps it as a string
+/// attribute. Performs additional checks for attributes known to have or not
+/// have a value in order to avoid assertions inside LLVM upon construction.
+static FailureOr<llvm::Attribute>
+convertMLIRAttributeToLLVM(Location loc, llvm::LLVMContext &ctx, StringRef key,
+ StringRef value = StringRef()) {
+ auto kind = llvm::Attribute::getAttrKindFromName(key);
+ if (kind == llvm::Attribute::None)
+ return llvm::Attribute::get(ctx, key, value);
+
+ if (llvm::Attribute::isIntAttrKind(kind)) {
+ if (value.empty())
+ return emitError(loc) << "LLVM attribute '" << key << "' expects a value";
+
+ int64_t result;
+ if (!value.getAsInteger(/*Radix=*/0, result))
+ return llvm::Attribute::get(ctx, kind, result);
+ return llvm::Attribute::get(ctx, key, value);
+ }
+
+ if (!value.empty())
+ return emitError(loc) << "LLVM attribute '" << key
+ << "' does not expect a value, found '" << value
+ << "'";
+
+ return llvm::Attribute::get(ctx, kind);
+}
+
+/// Converts the MLIR attributes listed in the given array attribute into LLVM
+/// attributes. Returns an `AttrBuilder` containing the converted attributes.
+/// Reports error to `loc` if any and returns immediately. Expects `arrayAttr`
+/// to contain either string attributes, treated as value-less LLVM attributes,
+/// or array attributes containing two string attributes, with the first string
+/// being the name of the corresponding LLVM attribute and the second string
+/// beings its value. Note that even integer attributes are expected to have
+/// their values expressed as strings.
+static FailureOr<llvm::AttrBuilder>
+convertMLIRAttributesToLLVM(Location loc, llvm::LLVMContext &ctx,
+ ArrayAttr arrayAttr, StringRef arrayAttrName) {
+ llvm::AttrBuilder attrBuilder(ctx);
+ if (!arrayAttr)
+ return attrBuilder;
+
+ for (Attribute attr : arrayAttr) {
+ if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
+ FailureOr<llvm::Attribute> llvmAttr =
+ convertMLIRAttributeToLLVM(loc, ctx, stringAttr.getValue());
+ if (failed(llvmAttr))
+ return failure();
+ attrBuilder.addAttribute(*llvmAttr);
+ continue;
+ }
+
+ auto arrayAttr = dyn_cast<ArrayAttr>(attr);
+ if (!arrayAttr || arrayAttr.size() != 2)
+ return emitError(loc) << "expected '" << arrayAttrName
+ << "' to contain string or array attributes";
+
+ auto keyAttr = dyn_cast<StringAttr>(arrayAttr[0]);
+ auto valueAttr = dyn_cast<StringAttr>(arrayAttr[1]);
+ if (!keyAttr || !valueAttr)
+ return emitError(loc) << "expected arrays within '" << arrayAttrName
+ << "' to contain two strings";
+
+ FailureOr<llvm::Attribute> llvmAttr = convertMLIRAttributeToLLVM(
+ loc, ctx, keyAttr.getValue(), valueAttr.getValue());
+ if (failed(llvmAttr))
+ return failure();
+ attrBuilder.addAttribute(*llvmAttr);
+ }
+
+ return attrBuilder;
+}
+
LogicalResult ModuleTranslation::convertGlobalsAndAliases() {
// Mapping from compile unit to its respective set of global variables.
DenseMap<llvm::DICompileUnit *, SmallVector<llvm::Metadata *>> allGVars;
@@ -1191,6 +1268,15 @@ LogicalResult ModuleTranslation::convertGlobalsAndAliases() {
}
}
}
+
+ // Forward the target-specific attributes to LLVM.
+ FailureOr<llvm::AttrBuilder> convertedTargetSpecificAttrs =
+ convertMLIRAttributesToLLVM(op.getLoc(), var->getContext(),
+ op.getTargetSpecificAttrsAttr(),
+ op.getTargetSpecificAttrsAttrName());
+ if (failed(convertedTargetSpecificAttrs))
+ return failure();
+ var->addAttributes(*convertedTargetSpecificAttrs);
}
// Create all llvm::GlobalAlias
@@ -1381,44 +1467,6 @@ LogicalResult ModuleTranslation::convertGlobalsAndAliases() {
return success();
}
-/// Attempts to add an attribute identified by `key`, optionally with the given
-/// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the
-/// attribute has a kind known to LLVM IR, create the attribute of this kind,
-/// otherwise keep it as a string attribute. Performs additional checks for
-/// attributes known to have or not have a value in order to avoid assertions
-/// inside LLVM upon construction.
-static LogicalResult checkedAddLLVMFnAttribute(Location loc,
- llvm::Function *llvmFunc,
- StringRef key,
- StringRef value = StringRef()) {
- auto kind = llvm::Attribute::getAttrKindFromName(key);
- if (kind == llvm::Attribute::None) {
- llvmFunc->addFnAttr(key, value);
- return success();
- }
-
- if (llvm::Attribute::isIntAttrKind(kind)) {
- if (value.empty())
- return emitError(loc) << "LLVM attribute '" << key << "' expects a value";
-
- int64_t result;
- if (!value.getAsInteger(/*Radix=*/0, result))
- llvmFunc->addFnAttr(
- llvm::Attribute::get(llvmFunc->getContext(), kind, result));
- else
- llvmFunc->addFnAttr(key, value);
- return success();
- }
-
- if (!value.empty())
- return emitError(loc) << "LLVM attribute '" << key
- << "' does not expect a value, found '" << value
- << "'";
-
- llvmFunc->addFnAttr(kind);
- return success();
-}
-
/// Return a representation of `value` as metadata.
static llvm::Metadata *convertIntegerToMetadata(llvm::LLVMContext &context,
const llvm::APInt &value) {
@@ -1454,45 +1502,6 @@ static llvm::MDNode *convertIntegerArrayToMDNode(llvm::LLVMContext &context,
return llvm::MDNode::get(context, mdValues);
}
-/// Attaches the attributes listed in the given array attribute to `llvmFunc`.
-/// Reports error to `loc` if any and returns immediately. Expects `attributes`
-/// to be an array attribute containing either string attributes, treated as
-/// value-less LLVM attributes, or array attributes containing two string
-/// attributes, with the first string being the name of the corresponding LLVM
-/// attribute and the second string beings its value. Note that even integer
-/// attributes are expected to have their values expressed as strings.
-static LogicalResult
-forwardPassthroughAttributes(Location loc, std::optional<ArrayAttr> attributes,
- llvm::Function *llvmFunc) {
- if (!attributes)
- return success();
-
- for (Attribute attr : *attributes) {
- if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
- if (failed(
- checkedAddLLVMFnAttribute(loc, llvmFunc, stringAttr.getValue())))
- return failure();
- continue;
- }
-
- auto arrayAttr = dyn_cast<ArrayAttr>(attr);
- if (!arrayAttr || arrayAttr.size() != 2)
- return emitError(loc)
- << "expected 'passthrough' to contain string or array attributes";
-
- auto keyAttr = dyn_cast<StringAttr>(arrayAttr[0]);
- auto valueAttr = dyn_cast<StringAttr>(arrayAttr[1]);
- if (!keyAttr || !valueAttr)
- return emitError(loc)
- << "expected arrays within 'passthrough' to contain two strings";
-
- if (failed(checkedAddLLVMFnAttribute(loc, llvmFunc, keyAttr.getValue(),
- valueAttr.getValue())))
- return failure();
- }
- return success();
-}
-
LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
// Clear the block, branch value mappings, they are only relevant within one
// function.
@@ -1561,10 +1570,6 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
if (auto noNansFpMath = func.getNoNansFpMath())
llvmFunc->addFnAttr("no-nans-fp-math", llvm::toStringRef(*noNansFpMath));
- if (auto approxFuncFpMath = func.getApproxFuncFpMath())
- llvmFunc->addFnAttr("approx-func-fp-math",
- llvm::toStringRef(*approxFuncFpMath));
-
if (auto noSignedZerosFpMath = func.getNoSignedZerosFpMath())
llvmFunc->addFnAttr("no-signed-zeros-fp-math",
llvm::toStringRef(*noSignedZerosFpMath));
@@ -1864,9 +1869,13 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
}
// Forward the pass-through attributes to LLVM.
- if (failed(forwardPassthroughAttributes(
- function.getLoc(), function.getPassthrough(), llvmFunc)))
+ FailureOr<llvm::AttrBuilder> convertedPassthroughAttrs =
+ convertMLIRAttributesToLLVM(function.getLoc(), llvmFunc->getContext(),
+ function.getPassthroughAttr(),
+ function.getPassthroughAttrName());
+ if (failed(convertedPassthroughAttrs))
return failure();
+ llvmFunc->addFnAttrs(*convertedPassthroughAttrs);
// Convert visibility attribute.
llvmFunc->setVisibility(convertVisibilityToLLVM(function.getVisibility_()));
@@ -2407,11 +2416,6 @@ mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
if (failed(translator.convertUnresolvedBlockAddress()))
return nullptr;
- // Once we've finished constructing elements in the module, we should convert
- // it to use the debug info format desired by LLVM.
- // See https://llvm.org/docs/RemoveDIsDebugInfo.html
- translator.llvmModule->convertToNewDbgValues();
-
// Add the necessary debug info module flags, if they were not encoded in MLIR
// beforehand.
translator.debugTranslation->addModuleFlagsIfNotPresent();
diff --git a/mlir/lib/Target/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..044da1c
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Transforms/CMakeLists.txt
@@ -0,0 +1,23 @@
+add_mlir_dialect_library(MLIRTargetLLVMIRTransforms
+ TargetToDataLayout.cpp
+ TargetToTargetFeatures.cpp
+ TargetUtils.cpp
+
+ DEPENDS
+ MLIRTargetLLVMIRTransformsIncGen
+
+ LINK_COMPONENTS
+ MC
+ Target
+ TargetParser
+ AllTargetsAsmParsers
+ AllTargetsCodeGens
+ AllTargetsDescs
+ AllTargetsInfos
+
+ LINK_LIBS PUBLIC
+ MLIRDLTIDialect
+ MLIRLLVMDialect
+ MLIRPass
+ MLIRTargetLLVMIRImport
+ )
diff --git a/mlir/lib/Target/LLVMIR/Transforms/TargetToDataLayout.cpp b/mlir/lib/Target/LLVMIR/Transforms/TargetToDataLayout.cpp
new file mode 100644
index 0000000..c0f9ceb
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Transforms/TargetToDataLayout.cpp
@@ -0,0 +1,62 @@
+//===- TargetToDataLayout.cpp - extract data layout from TargetMachine ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/Transforms/Passes.h"
+#include "mlir/Target/LLVMIR/Transforms/TargetUtils.h"
+
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Target/LLVMIR/Import.h"
+
+namespace mlir {
+namespace LLVM {
+#define GEN_PASS_DEF_LLVMTARGETTODATALAYOUT
+#include "mlir/Target/LLVMIR/Transforms/Passes.h.inc"
+} // namespace LLVM
+} // namespace mlir
+
+using namespace mlir;
+
+struct TargetToDataLayoutPass
+ : public LLVM::impl::LLVMTargetToDataLayoutBase<TargetToDataLayoutPass> {
+ using LLVM::impl::LLVMTargetToDataLayoutBase<
+ TargetToDataLayoutPass>::LLVMTargetToDataLayoutBase;
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+
+ if (initializeLLVMTargets)
+ LLVM::detail::initializeBackendsOnce();
+
+ auto targetAttr = op->getAttrOfType<LLVM::TargetAttrInterface>(
+ LLVM::LLVMDialect::getTargetAttrName());
+ if (!targetAttr) {
+ op->emitError()
+ << "no TargetAttrInterface-implementing attribute at key \""
+ << LLVM::LLVMDialect::getTargetAttrName() << "\"";
+ return signalPassFailure();
+ }
+
+ FailureOr<llvm::DataLayout> dataLayout =
+ LLVM::detail::getDataLayout(targetAttr);
+ if (failed(dataLayout)) {
+ op->emitError() << "failed to obtain llvm::DataLayout for " << targetAttr;
+ return signalPassFailure();
+ }
+
+ DataLayoutSpecInterface dataLayoutSpec =
+ mlir::translateDataLayout(dataLayout.value(), &getContext());
+
+ if (auto existingDlSpec = op->getAttrOfType<DataLayoutSpecInterface>(
+ DLTIDialect::kDataLayoutAttrName)) {
+ dataLayoutSpec = existingDlSpec.combineWith({dataLayoutSpec});
+ }
+
+ op->setAttr(DLTIDialect::kDataLayoutAttrName, dataLayoutSpec);
+ }
+};
diff --git a/mlir/lib/Target/LLVMIR/Transforms/TargetToTargetFeatures.cpp b/mlir/lib/Target/LLVMIR/Transforms/TargetToTargetFeatures.cpp
new file mode 100644
index 0000000..4a1ca46
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Transforms/TargetToTargetFeatures.cpp
@@ -0,0 +1,78 @@
+//===- TargetToTargetFeatures.cpp - extract features from TargetMachine ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/Transforms/Passes.h"
+#include "mlir/Target/LLVMIR/Transforms/TargetUtils.h"
+
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Target/LLVMIR/Import.h"
+
+#include "llvm/MC/MCSubtargetInfo.h"
+
+namespace mlir {
+namespace LLVM {
+#define GEN_PASS_DEF_LLVMTARGETTOTARGETFEATURES
+#include "mlir/Target/LLVMIR/Transforms/Passes.h.inc"
+} // namespace LLVM
+} // namespace mlir
+
+using namespace mlir;
+
+struct TargetToTargetFeaturesPass
+ : public LLVM::impl::LLVMTargetToTargetFeaturesBase<
+ TargetToTargetFeaturesPass> {
+ using LLVM::impl::LLVMTargetToTargetFeaturesBase<
+ TargetToTargetFeaturesPass>::LLVMTargetToTargetFeaturesBase;
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+
+ if (initializeLLVMTargets)
+ LLVM::detail::initializeBackendsOnce();
+
+ auto targetAttr = op->getAttrOfType<LLVM::TargetAttr>(
+ LLVM::LLVMDialect::getTargetAttrName());
+ if (!targetAttr) {
+ op->emitError() << "no LLVM::TargetAttr attribute at key \""
+ << LLVM::LLVMDialect::getTargetAttrName() << "\"";
+ return signalPassFailure();
+ }
+
+ FailureOr<std::unique_ptr<llvm::TargetMachine>> targetMachine =
+ LLVM::detail::getTargetMachine(targetAttr);
+ if (failed(targetMachine)) {
+ op->emitError() << "failed to obtain llvm::TargetMachine for "
+ << targetAttr;
+ return signalPassFailure();
+ }
+
+ llvm::MCSubtargetInfo const *subTargetInfo =
+ (*targetMachine)->getMCSubtargetInfo();
+
+ const std::vector<llvm::SubtargetFeatureKV> enabledFeatures =
+ subTargetInfo->getEnabledProcessorFeatures();
+
+ auto plussedFeatures = llvm::to_vector(
+ llvm::map_range(enabledFeatures, [](llvm::SubtargetFeatureKV feature) {
+ return std::string("+") + feature.Key;
+ }));
+
+ auto plussedFeaturesRefs = llvm::to_vector(llvm::map_range(
+ plussedFeatures, [](auto &it) { return StringRef(it.c_str()); }));
+
+ auto fullTargetFeaturesAttr =
+ LLVM::TargetFeaturesAttr::get(&getContext(), plussedFeaturesRefs);
+
+ auto updatedTargetAttr =
+ LLVM::TargetAttr::get(&getContext(), targetAttr.getTriple(),
+ targetAttr.getChip(), fullTargetFeaturesAttr);
+
+ op->setAttr(LLVM::LLVMDialect::getTargetAttrName(), updatedTargetAttr);
+ }
+};
diff --git a/mlir/lib/Target/LLVMIR/Transforms/TargetUtils.cpp b/mlir/lib/Target/LLVMIR/Transforms/TargetUtils.cpp
new file mode 100644
index 0000000..f1d3622
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Transforms/TargetUtils.cpp
@@ -0,0 +1,71 @@
+//===- TargetUtils.cpp - utils for obtaining generic target backend info --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/Transforms/Passes.h"
+
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Target/LLVMIR/Import.h"
+
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Target/TargetMachine.h"
+
+#define DEBUG_TYPE "mlir-llvm-target-utils"
+
+namespace mlir {
+namespace LLVM {
+namespace detail {
+void initializeBackendsOnce() {
+ static const auto initOnce = [] {
+ // Ensure that the targets, that LLVM has been configured to support,
+ // are loaded into the TargetRegistry.
+ llvm::InitializeAllTargets();
+ llvm::InitializeAllTargetMCs();
+ return true;
+ }();
+ (void)initOnce; // Dummy usage.
+}
+
+FailureOr<std::unique_ptr<llvm::TargetMachine>>
+getTargetMachine(mlir::LLVM::TargetAttrInterface attr) {
+ StringRef triple = attr.getTriple();
+ StringRef chipAKAcpu = attr.getChip();
+ // NB: `TargetAttrInterface::getFeatures()` is coarsely typed to work around
+ // cyclic dependency issue in tablegen files.
+ auto featuresAttr =
+ llvm::cast_if_present<LLVM::TargetFeaturesAttr>(attr.getFeatures());
+ std::string features = featuresAttr ? featuresAttr.getFeaturesString() : "";
+
+ std::string error;
+ const llvm::Target *target =
+ llvm::TargetRegistry::lookupTarget(triple, error);
+ if (!target || !error.empty()) {
+ LDBG() << "Looking up target '" << triple << "' failed: " << error << "\n";
+ return failure();
+ }
+
+ return std::unique_ptr<llvm::TargetMachine>(target->createTargetMachine(
+ llvm::Triple(triple), chipAKAcpu, features, {}, {}));
+}
+
+FailureOr<llvm::DataLayout>
+getDataLayout(mlir::LLVM::TargetAttrInterface attr) {
+ FailureOr<std::unique_ptr<llvm::TargetMachine>> targetMachine =
+ getTargetMachine(attr);
+ if (failed(targetMachine)) {
+ LDBG() << "Failed to retrieve the target machine for data layout.\n";
+ return failure();
+ }
+ return (targetMachine.value())->createDataLayout();
+}
+
+} // namespace detail
+} // namespace LLVM
+} // namespace mlir
diff --git a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
index e4ba478..ddd5946 100644
--- a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -71,7 +72,7 @@ public:
})
.Case<LLVM::LLVMArrayType, IntegerType, LLVM::LLVMFunctionType,
LLVM::LLVMPointerType, LLVM::LLVMStructType, VectorType,
- LLVM::LLVMTargetExtType>(
+ LLVM::LLVMTargetExtType, PtrLikeTypeInterface>(
[this](auto type) { return this->translate(type); })
.Default([](Type t) -> llvm::Type * {
llvm_unreachable("unknown LLVM dialect type");
@@ -149,6 +150,14 @@ private:
type.getIntParams());
}
+ /// Translates the given ptr type.
+ llvm::Type *translate(PtrLikeTypeInterface type) {
+ auto memSpace = dyn_cast<LLVM::AddressSpaceAttr>(type.getMemorySpace());
+ assert(memSpace && "expected pointer with the LLVM address space");
+ assert(!type.hasPtrMetadata() && "expected pointer without metadata");
+ return llvm::PointerType::get(context, memSpace.getAddressSpace());
+ }
+
/// Translates a list of types.
void translateTypes(ArrayRef<Type> types,
SmallVectorImpl<llvm::Type *> &result) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index d8c54ec..3625dd2 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -229,7 +229,7 @@ spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
}
template <typename AttrTy, typename EnumAttrTy, typename EnumTy>
-LogicalResult deserializeCacheControlDecoration(
+static LogicalResult deserializeCacheControlDecoration(
Location loc, OpBuilder &opBuilder,
DenseMap<uint32_t, NamedAttrList> &decorations, ArrayRef<uint32_t> words,
StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 7c007de..7fc7795 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
@@ -112,7 +113,9 @@ LogicalResult Serializer::serialize() {
// TODO: handle the other sections
processCapability();
- processExtension();
+ if (failed(processExtension())) {
+ return failure();
+ }
processMemoryModel();
processDebugInfo();
@@ -204,13 +207,24 @@ void Serializer::processDebugInfo() {
// TODO: Encode more debug instructions.
}
-void Serializer::processExtension() {
+LogicalResult Serializer::processExtension() {
llvm::SmallVector<uint32_t, 16> extName;
- for (spirv::Extension ext : module.getVceTriple()->getExtensions()) {
+ llvm::SmallSet<Extension, 4> deducedExts(
+ llvm::from_range, module.getVceTriple()->getExtensions());
+ auto nonSemanticInfoExt = spirv::Extension::SPV_KHR_non_semantic_info;
+ if (options.emitDebugInfo && !deducedExts.contains(nonSemanticInfoExt)) {
+ TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(module);
+ if (!is_contained(targetEnvAttr.getExtensions(), nonSemanticInfoExt))
+ return module.emitError(
+ "SPV_KHR_non_semantic_info extension not available");
+ deducedExts.insert(nonSemanticInfoExt);
+ }
+ for (spirv::Extension ext : deducedExts) {
extName.clear();
spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
}
+ return success();
}
void Serializer::processMemoryModel() {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index 7047869..fb2cecd 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -102,7 +102,7 @@ private:
void processDebugInfo();
- void processExtension();
+ LogicalResult processExtension();
void processMemoryModel();
diff --git a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
index ac338d55..796354e 100644
--- a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
+++ b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
@@ -21,8 +21,11 @@
#include "mlir/Target/SPIRV/Serialization.h"
#include "mlir/Tools/mlir-translate/Translation.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/FileSystem.h"
#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/Path.h"
#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/ToolOutputFile.h"
using namespace mlir;
@@ -76,24 +79,66 @@ void registerFromSPIRVTranslation() {
// Serialization registration
//===----------------------------------------------------------------------===//
-static LogicalResult serializeModule(spirv::ModuleOp module,
- raw_ostream &output) {
+static LogicalResult
+serializeModule(spirv::ModuleOp moduleOp, raw_ostream &output,
+ const spirv::SerializationOptions &options) {
SmallVector<uint32_t, 0> binary;
- if (failed(spirv::serialize(module, binary)))
+ if (failed(spirv::serialize(moduleOp, binary)))
return failure();
- output.write(reinterpret_cast<char *>(binary.data()),
- binary.size() * sizeof(uint32_t));
+ size_t sizeInBytes = binary.size() * sizeof(uint32_t);
+
+ output.write(reinterpret_cast<char *>(binary.data()), sizeInBytes);
+
+ if (options.saveModuleForValidation) {
+ size_t dirSeparator =
+ options.validationFilePrefix.find(llvm::sys::path::get_separator());
+ // If file prefix includes directory check if that directory exists.
+ if (dirSeparator != std::string::npos) {
+ llvm::StringRef parentDir =
+ llvm::sys::path::parent_path(options.validationFilePrefix);
+ if (!llvm::sys::fs::is_directory(parentDir))
+ return moduleOp.emitError(
+ "validation prefix directory does not exist\n");
+ }
+
+ SmallString<128> filename;
+ int fd = 0;
+
+ std::error_code errorCode = llvm::sys::fs::createUniqueFile(
+ options.validationFilePrefix + "%%%%%%.spv", fd, filename);
+ if (errorCode)
+ return moduleOp.emitError("error creating validation output file: ")
+ << errorCode.message() << "\n";
+
+ llvm::raw_fd_ostream validationOutput(fd, /*shouldClose=*/true);
+ validationOutput.write(reinterpret_cast<char *>(binary.data()),
+ sizeInBytes);
+ validationOutput.flush();
+ }
return mlir::success();
}
namespace mlir {
void registerToSPIRVTranslation() {
+ static llvm::cl::opt<std::string> validationFilesPrefix(
+ "spirv-save-validation-files-with-prefix",
+ llvm::cl::desc(
+ "When non-empty string is passed each serialized SPIR-V module is "
+ "saved to an additional file that starts with the given prefix. This "
+ "is used to generate separate binaries for validation, where "
+ "`--split-input-file` normally combines all outputs into one. The "
+ "one combined output (`-o`) is still written. Created files need to "
+ "be removed manually once processed."),
+ llvm::cl::init(""));
+
TranslateFromMLIRRegistration toBinary(
"serialize-spirv", "serialize SPIR-V dialect",
- [](spirv::ModuleOp module, raw_ostream &output) {
- return serializeModule(module, output);
+ [](spirv::ModuleOp moduleOp, raw_ostream &output) {
+ return serializeModule(moduleOp, output,
+ {true, false, !validationFilesPrefix.empty(),
+ validationFilesPrefix});
},
[](DialectRegistry &registry) {
registry.insert<spirv::SPIRVDialect>();
diff --git a/mlir/lib/Target/Wasm/CMakeLists.txt b/mlir/lib/Target/Wasm/CMakeLists.txt
new file mode 100644
index 0000000..890fc0ec
--- /dev/null
+++ b/mlir/lib/Target/Wasm/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_translation_library(MLIRTargetWasmImport
+ TranslateRegistration.cpp
+ TranslateFromWasm.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/Target/Wasm
+
+ LINK_LIBS PUBLIC
+ MLIRWasmSSADialect
+ MLIRIR
+ MLIRSupport
+ MLIRTranslateLib
+)
diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
new file mode 100644
index 0000000..6afbe05
--- /dev/null
+++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
@@ -0,0 +1,1522 @@
+//===- TranslateFromWasm.cpp - Translating to WasmSSA dialect -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the WebAssembly importer.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Target/Wasm/WasmBinaryEncoding.h"
+#include "mlir/Target/Wasm/WasmImporter.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/Endian.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/LEB128.h"
+#include "llvm/Support/LogicalResult.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <variant>
+
+#define DEBUG_TYPE "wasm-translate"
+
+static_assert(CHAR_BIT == 8,
+ "This code expects std::byte to be exactly 8 bits");
+
+using namespace mlir;
+using namespace mlir::wasm;
+using namespace mlir::wasmssa;
+
+namespace {
+using section_id_t = uint8_t;
+enum struct WasmSectionType : section_id_t {
+ CUSTOM = 0,
+ TYPE = 1,
+ IMPORT = 2,
+ FUNCTION = 3,
+ TABLE = 4,
+ MEMORY = 5,
+ GLOBAL = 6,
+ EXPORT = 7,
+ START = 8,
+ ELEMENT = 9,
+ CODE = 10,
+ DATA = 11,
+ DATACOUNT = 12
+};
+
+constexpr section_id_t highestWasmSectionID{
+ static_cast<section_id_t>(WasmSectionType::DATACOUNT)};
+
+#define APPLY_WASM_SEC_TRANSFORM \
+ WASM_SEC_TRANSFORM(CUSTOM) \
+ WASM_SEC_TRANSFORM(TYPE) \
+ WASM_SEC_TRANSFORM(IMPORT) \
+ WASM_SEC_TRANSFORM(FUNCTION) \
+ WASM_SEC_TRANSFORM(TABLE) \
+ WASM_SEC_TRANSFORM(MEMORY) \
+ WASM_SEC_TRANSFORM(GLOBAL) \
+ WASM_SEC_TRANSFORM(EXPORT) \
+ WASM_SEC_TRANSFORM(START) \
+ WASM_SEC_TRANSFORM(ELEMENT) \
+ WASM_SEC_TRANSFORM(CODE) \
+ WASM_SEC_TRANSFORM(DATA) \
+ WASM_SEC_TRANSFORM(DATACOUNT)
+
+template <WasmSectionType>
+constexpr const char *wasmSectionName = "";
+
+#define WASM_SEC_TRANSFORM(section) \
+ template <> \
+ [[maybe_unused]] constexpr const char \
+ *wasmSectionName<WasmSectionType::section> = #section;
+APPLY_WASM_SEC_TRANSFORM
+#undef WASM_SEC_TRANSFORM
+
+constexpr bool sectionShouldBeUnique(WasmSectionType secType) {
+ return secType != WasmSectionType::CUSTOM;
+}
+
+template <std::byte... Bytes>
+struct ByteSequence {};
+
+/// Template class for representing a byte sequence of only one byte
+template <std::byte Byte>
+struct UniqueByte : ByteSequence<Byte> {};
+
+[[maybe_unused]] constexpr ByteSequence<
+ WasmBinaryEncoding::Type::i32, WasmBinaryEncoding::Type::i64,
+ WasmBinaryEncoding::Type::f32, WasmBinaryEncoding::Type::f64,
+ WasmBinaryEncoding::Type::v128> valueTypesEncodings{};
+
+template <std::byte... allowedFlags>
+constexpr bool isValueOneOf(std::byte value,
+ ByteSequence<allowedFlags...> = {}) {
+ return ((value == allowedFlags) | ... | false);
+}
+
+template <std::byte... flags>
+constexpr bool isNotIn(std::byte value, ByteSequence<flags...> = {}) {
+ return !isValueOneOf<flags...>(value);
+}
+
+struct GlobalTypeRecord {
+ Type type;
+ bool isMutable;
+};
+
+struct TypeIdxRecord {
+ size_t id;
+};
+
+struct SymbolRefContainer {
+ FlatSymbolRefAttr symbol;
+};
+
+struct GlobalSymbolRefContainer : SymbolRefContainer {
+ Type globalType;
+};
+
+struct FunctionSymbolRefContainer : SymbolRefContainer {
+ FunctionType functionType;
+};
+
+using ImportDesc =
+ std::variant<TypeIdxRecord, TableType, LimitType, GlobalTypeRecord>;
+
+using parsed_inst_t = FailureOr<SmallVector<Value>>;
+
+struct WasmModuleSymbolTables {
+ SmallVector<FunctionSymbolRefContainer> funcSymbols;
+ SmallVector<GlobalSymbolRefContainer> globalSymbols;
+ SmallVector<SymbolRefContainer> memSymbols;
+ SmallVector<SymbolRefContainer> tableSymbols;
+ SmallVector<FunctionType> moduleFuncTypes;
+
+ std::string getNewSymbolName(StringRef prefix, size_t id) const {
+ return (prefix + Twine{id}).str();
+ }
+
+ std::string getNewFuncSymbolName() const {
+ size_t id = funcSymbols.size();
+ return getNewSymbolName("func_", id);
+ }
+
+ std::string getNewGlobalSymbolName() const {
+ size_t id = globalSymbols.size();
+ return getNewSymbolName("global_", id);
+ }
+
+ std::string getNewMemorySymbolName() const {
+ size_t id = memSymbols.size();
+ return getNewSymbolName("mem_", id);
+ }
+
+ std::string getNewTableSymbolName() const {
+ size_t id = tableSymbols.size();
+ return getNewSymbolName("table_", id);
+ }
+};
+
+class ParserHead;
+
+/// Wrapper around SmallVector to only allow access as push and pop on the
+/// stack. Makes sure that there are no "free accesses" on the stack to preserve
+/// its state.
+class ValueStack {
+private:
+ struct LabelLevel {
+ size_t stackIdx;
+ LabelLevelOpInterface levelOp;
+ };
+
+public:
+ bool empty() const { return values.empty(); }
+
+ size_t size() const { return values.size(); }
+
+ /// Pops values from the stack because they are being used in an operation.
+ /// @param operandTypes The list of expected types of the operation, used
+ /// to know how many values to pop and check if the types match the
+ /// expectation.
+ /// @param opLoc Location of the caller, used to report accurately the
+ /// location
+ /// if an error occurs.
+ /// @return Failure or the vector of popped values.
+ FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes,
+ Location *opLoc);
+
+ /// Push the results of an operation to the stack so they can be used in a
+ /// following operation.
+ /// @param results The list of results of the operation
+ /// @param opLoc Location of the caller, used to report accurately the
+ /// location
+ /// if an error occurs.
+ LogicalResult pushResults(ValueRange results, Location *opLoc);
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+ /// A simple dump function for debugging.
+ /// Writes output to llvm::dbgs().
+ LLVM_DUMP_METHOD void dump() const;
+#endif
+
+private:
+ SmallVector<Value> values;
+};
+
+using local_val_t = TypedValue<wasmssa::LocalRefType>;
+
+class ExpressionParser {
+public:
+ using locals_t = SmallVector<local_val_t>;
+ ExpressionParser(ParserHead &parser, WasmModuleSymbolTables const &symbols,
+ ArrayRef<local_val_t> initLocal)
+ : parser{parser}, symbols{symbols}, locals{initLocal} {}
+
+private:
+ template <std::byte opCode>
+ inline parsed_inst_t parseSpecificInstruction(OpBuilder &builder);
+
+ template <typename valueT>
+ parsed_inst_t
+ parseConstInst(OpBuilder &builder,
+ std::enable_if_t<std::is_arithmetic_v<valueT>> * = nullptr);
+
+ /// Construct an operation with \p numOperands operands and a single result.
+ /// Each operand must have the same type. Suitable for e.g. binops, unary
+ /// ops, etc.
+ ///
+ /// \p opcode - The WASM opcode to build.
+ /// \p valueType - The operand and result type for the built instruction.
+ /// \p numOperands - The number of operands for the built operation.
+ ///
+ /// \returns The parsed instruction result, or failure.
+ template <typename opcode, typename valueType, unsigned int numOperands>
+ inline parsed_inst_t
+ buildNumericOp(OpBuilder &builder,
+ std::enable_if_t<std::is_arithmetic_v<valueType>> * = nullptr);
+
+ /// This function generates a dispatch tree to associate an opcode with a
+ /// parser. Parsers are registered by specialising the
+ /// `parseSpecificInstruction` function for the op code to handle.
+ ///
+ /// The dispatcher is generated by recursively creating all possible patterns
+ /// for an opcode and calling the relevant parser on the leaf.
+ ///
+ /// @tparam patternBitSize is the first bit for which the pattern is not fixed
+ ///
+ /// @tparam highBitPattern is the fixed pattern that this instance handles for
+ /// the 8-patternBitSize bits
+ template <size_t patternBitSize = 0, std::byte highBitPattern = std::byte{0}>
+ inline parsed_inst_t dispatchToInstParser(std::byte opCode,
+ OpBuilder &builder) {
+ static_assert(patternBitSize <= 8,
+ "PatternBitSize is outside of range of opcode space! "
+ "(expected at most 8 bits)");
+ if constexpr (patternBitSize < 8) {
+ constexpr std::byte bitSelect{1 << (7 - patternBitSize)};
+ constexpr std::byte nextHighBitPatternStem = highBitPattern << 1;
+ constexpr size_t nextPatternBitSize = patternBitSize + 1;
+ if ((opCode & bitSelect) != std::byte{0})
+ return dispatchToInstParser<nextPatternBitSize,
+ nextHighBitPatternStem | std::byte{1}>(
+ opCode, builder);
+ return dispatchToInstParser<nextPatternBitSize, nextHighBitPatternStem>(
+ opCode, builder);
+ } else {
+ return parseSpecificInstruction<highBitPattern>(builder);
+ }
+ }
+
+ struct ParseResultWithInfo {
+ SmallVector<Value> opResults;
+ std::byte endingByte;
+ };
+
+public:
+ template <std::byte ParseEndByte = WasmBinaryEncoding::endByte>
+ parsed_inst_t parse(OpBuilder &builder, UniqueByte<ParseEndByte> = {});
+
+ template <std::byte... ExpressionParseEnd>
+ FailureOr<ParseResultWithInfo>
+ parse(OpBuilder &builder,
+ ByteSequence<ExpressionParseEnd...> parsingEndFilters);
+
+ FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes) {
+ return valueStack.popOperands(operandTypes, &currentOpLoc.value());
+ }
+
+ LogicalResult pushResults(ValueRange results) {
+ return valueStack.pushResults(results, &currentOpLoc.value());
+ }
+
+ /// The local.set and local.tee operations behave similarly and only differ
+ /// on their return value. This function factorizes the behavior of the two
+ /// operations in one place.
+ template <typename OpToCreate>
+ parsed_inst_t parseSetOrTee(OpBuilder &);
+
+private:
+ std::optional<Location> currentOpLoc;
+ ParserHead &parser;
+ WasmModuleSymbolTables const &symbols;
+ locals_t locals;
+ ValueStack valueStack;
+};
+
+class ParserHead {
+public:
+ ParserHead(StringRef src, StringAttr name) : head{src}, locName{name} {}
+ ParserHead(ParserHead &&) = default;
+
+private:
+ ParserHead(ParserHead const &other) = default;
+
+public:
+ auto getLocation() const {
+ return FileLineColLoc::get(locName, 0, anchorOffset + offset);
+ }
+
+ FailureOr<StringRef> consumeNBytes(size_t nBytes) {
+ LDBG() << "Consume " << nBytes << " bytes";
+ LDBG() << " Bytes remaining: " << size();
+ LDBG() << " Current offset: " << offset;
+ if (nBytes > size())
+ return emitError(getLocation(), "trying to extract ")
+ << nBytes << "bytes when only " << size() << "are available";
+
+ StringRef res = head.slice(offset, offset + nBytes);
+ offset += nBytes;
+ LDBG() << " Updated offset (+" << nBytes << "): " << offset;
+ return res;
+ }
+
+ FailureOr<std::byte> consumeByte() {
+ FailureOr<StringRef> res = consumeNBytes(1);
+ if (failed(res))
+ return failure();
+ return std::byte{*res->bytes_begin()};
+ }
+
+ template <typename T>
+ FailureOr<T> parseLiteral();
+
+ FailureOr<uint32_t> parseVectorSize();
+
+private:
+ // TODO: This is equivalent to parseLiteral<uint32_t> and could be removed
+ // if parseLiteral specialization were moved here, but default GCC on Ubuntu
+ // 22.04 has bug with template specialization in class declaration
+ inline FailureOr<uint32_t> parseUI32();
+ inline FailureOr<int64_t> parseI64();
+
+public:
+ FailureOr<StringRef> parseName() {
+ FailureOr<uint32_t> size = parseVectorSize();
+ if (failed(size))
+ return failure();
+
+ return consumeNBytes(*size);
+ }
+
+ FailureOr<WasmSectionType> parseWasmSectionType() {
+ FailureOr<std::byte> id = consumeByte();
+ if (failed(id))
+ return failure();
+ if (std::to_integer<unsigned>(*id) > highestWasmSectionID)
+ return emitError(getLocation(), "invalid section ID: ")
+ << static_cast<int>(*id);
+ return static_cast<WasmSectionType>(*id);
+ }
+
+ FailureOr<LimitType> parseLimit(MLIRContext *ctx) {
+ using WasmLimits = WasmBinaryEncoding::LimitHeader;
+ FileLineColLoc limitLocation = getLocation();
+ FailureOr<std::byte> limitHeader = consumeByte();
+ if (failed(limitHeader))
+ return failure();
+
+ if (isNotIn<WasmLimits::bothLimits, WasmLimits::lowLimitOnly>(*limitHeader))
+ return emitError(limitLocation, "invalid limit header: ")
+ << static_cast<int>(*limitHeader);
+ FailureOr<uint32_t> minParse = parseUI32();
+ if (failed(minParse))
+ return failure();
+ std::optional<uint32_t> max{std::nullopt};
+ if (*limitHeader == WasmLimits::bothLimits) {
+ FailureOr<uint32_t> maxParse = parseUI32();
+ if (failed(maxParse))
+ return failure();
+ max = *maxParse;
+ }
+ return LimitType::get(ctx, *minParse, max);
+ }
+
+ FailureOr<Type> parseValueType(MLIRContext *ctx) {
+ FileLineColLoc typeLoc = getLocation();
+ FailureOr<std::byte> typeEncoding = consumeByte();
+ if (failed(typeEncoding))
+ return failure();
+ switch (*typeEncoding) {
+ case WasmBinaryEncoding::Type::i32:
+ return IntegerType::get(ctx, 32);
+ case WasmBinaryEncoding::Type::i64:
+ return IntegerType::get(ctx, 64);
+ case WasmBinaryEncoding::Type::f32:
+ return Float32Type::get(ctx);
+ case WasmBinaryEncoding::Type::f64:
+ return Float64Type::get(ctx);
+ case WasmBinaryEncoding::Type::v128:
+ return IntegerType::get(ctx, 128);
+ case WasmBinaryEncoding::Type::funcRef:
+ return wasmssa::FuncRefType::get(ctx);
+ case WasmBinaryEncoding::Type::externRef:
+ return wasmssa::ExternRefType::get(ctx);
+ default:
+ return emitError(typeLoc, "invalid value type encoding: ")
+ << static_cast<int>(*typeEncoding);
+ }
+ }
+
+ FailureOr<GlobalTypeRecord> parseGlobalType(MLIRContext *ctx) {
+ using WasmGlobalMut = WasmBinaryEncoding::GlobalMutability;
+ FailureOr<Type> typeParsed = parseValueType(ctx);
+ if (failed(typeParsed))
+ return failure();
+ FileLineColLoc mutLoc = getLocation();
+ FailureOr<std::byte> mutSpec = consumeByte();
+ if (failed(mutSpec))
+ return failure();
+ if (isNotIn<WasmGlobalMut::isConst, WasmGlobalMut::isMutable>(*mutSpec))
+ return emitError(mutLoc, "invalid global mutability specifier: ")
+ << static_cast<int>(*mutSpec);
+ return GlobalTypeRecord{*typeParsed, *mutSpec == WasmGlobalMut::isMutable};
+ }
+
+ FailureOr<TupleType> parseResultType(MLIRContext *ctx) {
+ FailureOr<uint32_t> nParamsParsed = parseVectorSize();
+ if (failed(nParamsParsed))
+ return failure();
+ uint32_t nParams = *nParamsParsed;
+ SmallVector<Type> res{};
+ res.reserve(nParams);
+ for (size_t i = 0; i < nParams; ++i) {
+ FailureOr<Type> parsedType = parseValueType(ctx);
+ if (failed(parsedType))
+ return failure();
+ res.push_back(*parsedType);
+ }
+ return TupleType::get(ctx, res);
+ }
+
+ FailureOr<FunctionType> parseFunctionType(MLIRContext *ctx) {
+ FileLineColLoc typeLoc = getLocation();
+ FailureOr<std::byte> funcTypeHeader = consumeByte();
+ if (failed(funcTypeHeader))
+ return failure();
+ if (*funcTypeHeader != WasmBinaryEncoding::Type::funcType)
+ return emitError(typeLoc, "invalid function type header byte. Expecting ")
+ << std::to_integer<unsigned>(WasmBinaryEncoding::Type::funcType)
+ << " got " << std::to_integer<unsigned>(*funcTypeHeader);
+ FailureOr<TupleType> inputTypes = parseResultType(ctx);
+ if (failed(inputTypes))
+ return failure();
+
+ FailureOr<TupleType> resTypes = parseResultType(ctx);
+ if (failed(resTypes))
+ return failure();
+
+ return FunctionType::get(ctx, inputTypes->getTypes(), resTypes->getTypes());
+ }
+
+ FailureOr<TypeIdxRecord> parseTypeIndex() {
+ FailureOr<uint32_t> res = parseUI32();
+ if (failed(res))
+ return failure();
+ return TypeIdxRecord{*res};
+ }
+
+ FailureOr<TableType> parseTableType(MLIRContext *ctx) {
+ FailureOr<Type> elmTypeParse = parseValueType(ctx);
+ if (failed(elmTypeParse))
+ return failure();
+ if (!isWasmRefType(*elmTypeParse))
+ return emitError(getLocation(), "invalid element type for table");
+ FailureOr<LimitType> limitParse = parseLimit(ctx);
+ if (failed(limitParse))
+ return failure();
+ return TableType::get(ctx, *elmTypeParse, *limitParse);
+ }
+
+ FailureOr<ImportDesc> parseImportDesc(MLIRContext *ctx) {
+ FileLineColLoc importLoc = getLocation();
+ FailureOr<std::byte> importType = consumeByte();
+ auto packager = [](auto parseResult) -> FailureOr<ImportDesc> {
+ if (failed(parseResult))
+ return failure();
+ return {*parseResult};
+ };
+ if (failed(importType))
+ return failure();
+ switch (*importType) {
+ case WasmBinaryEncoding::Import::typeID:
+ return packager(parseTypeIndex());
+ case WasmBinaryEncoding::Import::tableType:
+ return packager(parseTableType(ctx));
+ case WasmBinaryEncoding::Import::memType:
+ return packager(parseLimit(ctx));
+ case WasmBinaryEncoding::Import::globalType:
+ return packager(parseGlobalType(ctx));
+ default:
+ return emitError(importLoc, "invalid import type descriptor: ")
+ << static_cast<int>(*importType);
+ }
+ }
+
+ parsed_inst_t parseExpression(OpBuilder &builder,
+ WasmModuleSymbolTables const &symbols,
+ ArrayRef<local_val_t> locals = {}) {
+ auto eParser = ExpressionParser{*this, symbols, locals};
+ return eParser.parse(builder);
+ }
+
+ LogicalResult parseCodeFor(FuncOp func,
+ WasmModuleSymbolTables const &symbols) {
+ SmallVector<local_val_t> locals{};
+ // Populating locals with function argument
+ Block &block = func.getBody().front();
+ // Delete temporary return argument which was only created for IR validity
+ assert(func.getBody().getBlocks().size() == 1 &&
+ "Function should only have its default created block at this point");
+ assert(block.getOperations().size() == 1 &&
+ "Only the placeholder return op should be present at this point");
+ auto returnOp = cast<ReturnOp>(&block.back());
+ assert(returnOp);
+
+ FailureOr<uint32_t> codeSizeInBytes = parseUI32();
+ if (failed(codeSizeInBytes))
+ return failure();
+ FailureOr<StringRef> codeContent = consumeNBytes(*codeSizeInBytes);
+ if (failed(codeContent))
+ return failure();
+ auto name = StringAttr::get(func->getContext(),
+ locName.str() + "::" + func.getSymName());
+ auto cParser = ParserHead{*codeContent, name};
+ FailureOr<uint32_t> localVecSize = cParser.parseVectorSize();
+ if (failed(localVecSize))
+ return failure();
+ OpBuilder builder{&func.getBody().front().back()};
+ for (auto arg : block.getArguments())
+ locals.push_back(cast<TypedValue<LocalRefType>>(arg));
+ // Declare the local ops
+ uint32_t nVarVec = *localVecSize;
+ for (size_t i = 0; i < nVarVec; ++i) {
+ FileLineColLoc varLoc = cParser.getLocation();
+ FailureOr<uint32_t> nSubVar = cParser.parseUI32();
+ if (failed(nSubVar))
+ return failure();
+ FailureOr<Type> varT = cParser.parseValueType(func->getContext());
+ if (failed(varT))
+ return failure();
+ for (size_t j = 0; j < *nSubVar; ++j) {
+ auto local = builder.create<LocalOp>(varLoc, *varT);
+ locals.push_back(local.getResult());
+ }
+ }
+ parsed_inst_t res = cParser.parseExpression(builder, symbols, locals);
+ if (failed(res))
+ return failure();
+ if (!cParser.end())
+ return emitError(cParser.getLocation(),
+ "unparsed garbage remaining at end of code block");
+ builder.create<ReturnOp>(func->getLoc(), *res);
+ returnOp->erase();
+ return success();
+ }
+
+ bool end() const { return curHead().empty(); }
+
+ ParserHead copy() const { return *this; }
+
+private:
+ StringRef curHead() const { return head.drop_front(offset); }
+
+ FailureOr<std::byte> peek() const {
+ if (end())
+ return emitError(
+ getLocation(),
+ "trying to peek at next byte, but input stream is empty");
+ return static_cast<std::byte>(curHead().front());
+ }
+
+ size_t size() const { return head.size() - offset; }
+
+ StringRef head;
+ StringAttr locName;
+ unsigned anchorOffset{0};
+ unsigned offset{0};
+};
+
+template <>
+FailureOr<float> ParserHead::parseLiteral<float>() {
+ FailureOr<StringRef> bytes = consumeNBytes(4);
+ if (failed(bytes))
+ return failure();
+ return llvm::support::endian::read<float>(bytes->bytes_begin(),
+ llvm::endianness::little);
+}
+
+template <>
+FailureOr<double> ParserHead::parseLiteral<double>() {
+ FailureOr<StringRef> bytes = consumeNBytes(8);
+ if (failed(bytes))
+ return failure();
+ return llvm::support::endian::read<double>(bytes->bytes_begin(),
+ llvm::endianness::little);
+}
+
+template <>
+FailureOr<uint32_t> ParserHead::parseLiteral<uint32_t>() {
+ char const *error = nullptr;
+ uint32_t res{0};
+ unsigned encodingSize{0};
+ StringRef src = curHead();
+ uint64_t decoded = llvm::decodeULEB128(src.bytes_begin(), &encodingSize,
+ src.bytes_end(), &error);
+ if (error)
+ return emitError(getLocation(), error);
+
+ if (std::isgreater(decoded, std::numeric_limits<uint32_t>::max()))
+ return emitError(getLocation()) << "literal does not fit on 32 bits";
+
+ res = static_cast<uint32_t>(decoded);
+ offset += encodingSize;
+ return res;
+}
+
+template <>
+FailureOr<int32_t> ParserHead::parseLiteral<int32_t>() {
+ char const *error = nullptr;
+ int32_t res{0};
+ unsigned encodingSize{0};
+ StringRef src = curHead();
+ int64_t decoded = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize,
+ src.bytes_end(), &error);
+ if (error)
+ return emitError(getLocation(), error);
+ if (std::isgreater(decoded, std::numeric_limits<int32_t>::max()) ||
+ std::isgreater(std::numeric_limits<int32_t>::min(), decoded))
+ return emitError(getLocation()) << "literal does not fit on 32 bits";
+
+ res = static_cast<int32_t>(decoded);
+ offset += encodingSize;
+ return res;
+}
+
+template <>
+FailureOr<int64_t> ParserHead::parseLiteral<int64_t>() {
+ char const *error = nullptr;
+ unsigned encodingSize{0};
+ StringRef src = curHead();
+ int64_t res = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize,
+ src.bytes_end(), &error);
+ if (error)
+ return emitError(getLocation(), error);
+
+ offset += encodingSize;
+ return res;
+}
+
+FailureOr<uint32_t> ParserHead::parseVectorSize() {
+ return parseLiteral<uint32_t>();
+}
+
+inline FailureOr<uint32_t> ParserHead::parseUI32() {
+ return parseLiteral<uint32_t>();
+}
+
+inline FailureOr<int64_t> ParserHead::parseI64() {
+ return parseLiteral<int64_t>();
+}
+
+template <std::byte opCode>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction(OpBuilder &) {
+ return emitError(*currentOpLoc, "unknown instruction opcode: ")
+ << static_cast<int>(opCode);
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+void ValueStack::dump() const {
+ llvm::dbgs() << "================= Wasm ValueStack =======================\n";
+ llvm::dbgs() << "size: " << size() << "\n";
+ llvm::dbgs() << "<Top>"
+ << "\n";
+ // Stack is pushed to via push_back. Therefore the top of the stack is the
+ // end of the vector. Iterate in reverse so that the first thing we print
+ // is the top of the stack.
+ size_t stackSize = size();
+ for (size_t idx = 0; idx < stackSize; idx++) {
+ size_t actualIdx = stackSize - 1 - idx;
+ llvm::dbgs() << " ";
+ values[actualIdx].dump();
+ }
+ llvm::dbgs() << "<Bottom>"
+ << "\n";
+ llvm::dbgs() << "=========================================================\n";
+}
+#endif
+
+parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) {
+ LDBG() << "Popping from ValueStack\n"
+ << " Elements(s) to pop: " << operandTypes.size() << "\n"
+ << " Current stack size: " << values.size();
+ if (operandTypes.size() > values.size())
+ return emitError(*opLoc,
+ "stack doesn't contain enough values. trying to get ")
+ << operandTypes.size() << " operands on a stack containing only "
+ << values.size() << " values.";
+ size_t stackIdxOffset = values.size() - operandTypes.size();
+ SmallVector<Value> res{};
+ res.reserve(operandTypes.size());
+ for (size_t i{0}; i < operandTypes.size(); ++i) {
+ Value operand = values[i + stackIdxOffset];
+ Type stackType = operand.getType();
+ if (stackType != operandTypes[i])
+ return emitError(*opLoc, "invalid operand type on stack. expecting ")
+ << operandTypes[i] << ", value on stack is of type " << stackType
+ << ".";
+ LDBG() << " POP: " << operand;
+ res.push_back(operand);
+ }
+ values.resize(values.size() - operandTypes.size());
+ LDBG() << " Updated stack size: " << values.size();
+ return res;
+}
+
+LogicalResult ValueStack::pushResults(ValueRange results, Location *opLoc) {
+ LDBG() << "Pushing to ValueStack\n"
+ << " Elements(s) to push: " << results.size() << "\n"
+ << " Current stack size: " << values.size();
+ for (Value val : results) {
+ if (!isWasmValueType(val.getType()))
+ return emitError(*opLoc, "invalid value type on stack: ")
+ << val.getType();
+ LDBG() << " PUSH: " << val;
+ values.push_back(val);
+ }
+
+ LDBG() << " Updated stack size: " << values.size();
+ return success();
+}
+
+template <std::byte EndParseByte>
+parsed_inst_t ExpressionParser::parse(OpBuilder &builder,
+ UniqueByte<EndParseByte> endByte) {
+ auto res = parse(builder, ByteSequence<EndParseByte>{});
+ if (failed(res))
+ return failure();
+ return res->opResults;
+}
+
+template <std::byte... ExpressionParseEnd>
+FailureOr<ExpressionParser::ParseResultWithInfo>
+ExpressionParser::parse(OpBuilder &builder,
+ ByteSequence<ExpressionParseEnd...> parsingEndFilters) {
+ SmallVector<Value> res;
+ for (;;) {
+ currentOpLoc = parser.getLocation();
+ FailureOr<std::byte> opCode = parser.consumeByte();
+ if (failed(opCode))
+ return failure();
+ if (isValueOneOf(*opCode, parsingEndFilters))
+ return {{res, *opCode}};
+ parsed_inst_t resParsed;
+ resParsed = dispatchToInstParser(*opCode, builder);
+ if (failed(resParsed))
+ return failure();
+ std::swap(res, *resParsed);
+ if (failed(pushResults(res)))
+ return failure();
+ }
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::localGet>(OpBuilder &builder) {
+ FailureOr<uint32_t> id = parser.parseLiteral<uint32_t>();
+ Location instLoc = *currentOpLoc;
+ if (failed(id))
+ return failure();
+ if (*id >= locals.size())
+ return emitError(instLoc, "invalid local index. function has ")
+ << locals.size() << " accessible locals, received index " << *id;
+ return {{builder.create<LocalGetOp>(instLoc, locals[*id]).getResult()}};
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::globalGet>(OpBuilder &builder) {
+ FailureOr<uint32_t> id = parser.parseLiteral<uint32_t>();
+ Location instLoc = *currentOpLoc;
+ if (failed(id))
+ return failure();
+ if (*id >= symbols.globalSymbols.size())
+ return emitError(instLoc, "invalid global index. function has ")
+ << symbols.globalSymbols.size()
+ << " accessible globals, received index " << *id;
+ GlobalSymbolRefContainer globalVar = symbols.globalSymbols[*id];
+ auto globalOp = builder.create<GlobalGetOp>(instLoc, globalVar.globalType,
+ globalVar.symbol);
+
+ return {{globalOp.getResult()}};
+}
+
+template <typename OpToCreate>
+parsed_inst_t ExpressionParser::parseSetOrTee(OpBuilder &builder) {
+ FailureOr<uint32_t> id = parser.parseLiteral<uint32_t>();
+ if (failed(id))
+ return failure();
+ if (*id >= locals.size())
+ return emitError(*currentOpLoc, "invalid local index. function has ")
+ << locals.size() << " accessible locals, received index " << *id;
+ if (valueStack.empty())
+ return emitError(
+ *currentOpLoc,
+ "invalid stack access, trying to access a value on an empty stack.");
+
+ parsed_inst_t poppedOp = popOperands(locals[*id].getType().getElementType());
+ if (failed(poppedOp))
+ return failure();
+ return {
+ builder.create<OpToCreate>(*currentOpLoc, locals[*id], poppedOp->front())
+ ->getResults()};
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::localSet>(OpBuilder &builder) {
+ return parseSetOrTee<LocalSetOp>(builder);
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::localTee>(OpBuilder &builder) {
+ return parseSetOrTee<LocalTeeOp>(builder);
+}
+
+template <typename T>
+inline Type buildLiteralType(OpBuilder &);
+
+template <>
+inline Type buildLiteralType<int32_t>(OpBuilder &builder) {
+ return builder.getI32Type();
+}
+
+template <>
+inline Type buildLiteralType<int64_t>(OpBuilder &builder) {
+ return builder.getI64Type();
+}
+
+template <>
+[[maybe_unused]] inline Type buildLiteralType<uint32_t>(OpBuilder &builder) {
+ return builder.getI32Type();
+}
+
+template <>
+[[maybe_unused]] inline Type buildLiteralType<uint64_t>(OpBuilder &builder) {
+ return builder.getI64Type();
+}
+
+template <>
+inline Type buildLiteralType<float>(OpBuilder &builder) {
+ return builder.getF32Type();
+}
+
+template <>
+inline Type buildLiteralType<double>(OpBuilder &builder) {
+ return builder.getF64Type();
+}
+
+template <typename ValT,
+ typename E = std::enable_if_t<std::is_arithmetic_v<ValT>>>
+struct AttrHolder;
+
+template <typename ValT>
+struct AttrHolder<ValT, std::enable_if_t<std::is_integral_v<ValT>>> {
+ using type = IntegerAttr;
+};
+
+template <typename ValT>
+struct AttrHolder<ValT, std::enable_if_t<std::is_floating_point_v<ValT>>> {
+ using type = FloatAttr;
+};
+
+template <typename ValT>
+using attr_holder_t = typename AttrHolder<ValT>::type;
+
+template <typename ValT,
+ typename EnableT = std::enable_if_t<std::is_arithmetic_v<ValT>>>
+attr_holder_t<ValT> buildLiteralAttr(OpBuilder &builder, ValT val) {
+ return attr_holder_t<ValT>::get(buildLiteralType<ValT>(builder), val);
+}
+
+template <typename valueT>
+parsed_inst_t ExpressionParser::parseConstInst(
+ OpBuilder &builder, std::enable_if_t<std::is_arithmetic_v<valueT>> *) {
+ auto parsedConstant = parser.parseLiteral<valueT>();
+ if (failed(parsedConstant))
+ return failure();
+ auto constOp =
+ ConstOp::create(builder, *currentOpLoc,
+ buildLiteralAttr<valueT>(builder, *parsedConstant));
+ return {{constOp.getResult()}};
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::constI32>(OpBuilder &builder) {
+ return parseConstInst<int32_t>(builder);
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::constI64>(OpBuilder &builder) {
+ return parseConstInst<int64_t>(builder);
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::constFP32>(OpBuilder &builder) {
+ return parseConstInst<float>(builder);
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::constFP64>(OpBuilder &builder) {
+ return parseConstInst<double>(builder);
+}
+
+template <typename opcode, typename valueType, unsigned int numOperands>
+inline parsed_inst_t ExpressionParser::buildNumericOp(
+ OpBuilder &builder, std::enable_if_t<std::is_arithmetic_v<valueType>> *) {
+ auto ty = buildLiteralType<valueType>(builder);
+ LDBG() << "*** buildNumericOp: numOperands = " << numOperands
+ << ", type = " << ty << " ***";
+ auto tysToPop = SmallVector<Type, numOperands>();
+ tysToPop.resize(numOperands);
+ std::fill(tysToPop.begin(), tysToPop.end(), ty);
+ auto operands = popOperands(tysToPop);
+ if (failed(operands))
+ return failure();
+ auto op = builder.create<opcode>(*currentOpLoc, *operands).getResult();
+ LDBG() << "Built operation: " << op;
+ return {{op}};
+}
+
+// Convenience macro for generating numerical operations.
+#define BUILD_NUMERIC_OP(OP_NAME, N_ARGS, PREFIX, SUFFIX, TYPE) \
+ template <> \
+ inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \
+ WasmBinaryEncoding::OpCode::PREFIX##SUFFIX>(OpBuilder & builder) { \
+ return buildNumericOp<OP_NAME, TYPE, N_ARGS>(builder); \
+ }
+
+// Macro to define binops that only support integer types.
+#define BUILD_NUMERIC_BINOP_INT(OP_NAME, PREFIX) \
+ BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, I32, int32_t) \
+ BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, I64, int64_t)
+
+// Macro to define binops that only support floating point types.
+#define BUILD_NUMERIC_BINOP_FP(OP_NAME, PREFIX) \
+ BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, F32, float) \
+ BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, F64, double)
+
+// Macro to define binops that support both floating point and integer types.
+#define BUILD_NUMERIC_BINOP_INTFP(OP_NAME, PREFIX) \
+ BUILD_NUMERIC_BINOP_INT(OP_NAME, PREFIX) \
+ BUILD_NUMERIC_BINOP_FP(OP_NAME, PREFIX)
+
+// Macro to implement unary ops that only support integers.
+#define BUILD_NUMERIC_UNARY_OP_INT(OP_NAME, PREFIX) \
+ BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, I32, int32_t) \
+ BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, I64, int64_t)
+
+// Macro to implement unary ops that support integer and floating point types.
+#define BUILD_NUMERIC_UNARY_OP_FP(OP_NAME, PREFIX) \
+ BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, F32, float) \
+ BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, F64, double)
+
+BUILD_NUMERIC_BINOP_FP(CopySignOp, copysign)
+BUILD_NUMERIC_BINOP_FP(DivOp, div)
+BUILD_NUMERIC_BINOP_FP(MaxOp, max)
+BUILD_NUMERIC_BINOP_FP(MinOp, min)
+BUILD_NUMERIC_BINOP_INT(AndOp, and)
+BUILD_NUMERIC_BINOP_INT(DivSIOp, divS)
+BUILD_NUMERIC_BINOP_INT(DivUIOp, divU)
+BUILD_NUMERIC_BINOP_INT(OrOp, or)
+BUILD_NUMERIC_BINOP_INT(RemSIOp, remS)
+BUILD_NUMERIC_BINOP_INT(RemUIOp, remU)
+BUILD_NUMERIC_BINOP_INT(RotlOp, rotl)
+BUILD_NUMERIC_BINOP_INT(RotrOp, rotr)
+BUILD_NUMERIC_BINOP_INT(ShLOp, shl)
+BUILD_NUMERIC_BINOP_INT(ShRSOp, shrS)
+BUILD_NUMERIC_BINOP_INT(ShRUOp, shrU)
+BUILD_NUMERIC_BINOP_INT(XOrOp, xor)
+BUILD_NUMERIC_BINOP_INTFP(AddOp, add)
+BUILD_NUMERIC_BINOP_INTFP(MulOp, mul)
+BUILD_NUMERIC_BINOP_INTFP(SubOp, sub)
+BUILD_NUMERIC_UNARY_OP_FP(AbsOp, abs)
+BUILD_NUMERIC_UNARY_OP_FP(CeilOp, ceil)
+BUILD_NUMERIC_UNARY_OP_FP(FloorOp, floor)
+BUILD_NUMERIC_UNARY_OP_FP(NegOp, neg)
+BUILD_NUMERIC_UNARY_OP_FP(SqrtOp, sqrt)
+BUILD_NUMERIC_UNARY_OP_FP(TruncOp, trunc)
+BUILD_NUMERIC_UNARY_OP_INT(ClzOp, clz)
+BUILD_NUMERIC_UNARY_OP_INT(CtzOp, ctz)
+BUILD_NUMERIC_UNARY_OP_INT(PopCntOp, popcnt)
+
+// Don't need these anymore so let's undef them.
+#undef BUILD_NUMERIC_BINOP_FP
+#undef BUILD_NUMERIC_BINOP_INT
+#undef BUILD_NUMERIC_BINOP_INTFP
+#undef BUILD_NUMERIC_UNARY_OP_FP
+#undef BUILD_NUMERIC_UNARY_OP_INT
+#undef BUILD_NUMERIC_OP
+#undef BUILD_NUMERIC_CAST_OP
+
+class WasmBinaryParser {
+private:
+ struct SectionRegistry {
+ using section_location_t = StringRef;
+
+ std::array<SmallVector<section_location_t>, highestWasmSectionID + 1>
+ registry;
+
+ template <WasmSectionType SecType>
+ std::conditional_t<sectionShouldBeUnique(SecType),
+ std::optional<section_location_t>,
+ ArrayRef<section_location_t>>
+ getContentForSection() const {
+ constexpr auto idx = static_cast<size_t>(SecType);
+ if constexpr (sectionShouldBeUnique(SecType)) {
+ return registry[idx].empty() ? std::nullopt
+ : std::make_optional(registry[idx][0]);
+ } else {
+ return registry[idx];
+ }
+ }
+
+ bool hasSection(WasmSectionType secType) const {
+ return !registry[static_cast<size_t>(secType)].empty();
+ }
+
+ ///
+ /// @returns success if registration valid, failure in case registration
+ /// can't be done (if another section of same type already exist and this
+ /// section type should only be present once)
+ ///
+ LogicalResult registerSection(WasmSectionType secType,
+ section_location_t location, Location loc) {
+ if (sectionShouldBeUnique(secType) && hasSection(secType))
+ return emitError(loc,
+ "trying to add a second instance of unique section");
+
+ registry[static_cast<size_t>(secType)].push_back(location);
+ emitRemark(loc, "Adding section with section ID ")
+ << static_cast<uint8_t>(secType);
+ return success();
+ }
+
+ LogicalResult populateFromBody(ParserHead ph) {
+ while (!ph.end()) {
+ FileLineColLoc sectionLoc = ph.getLocation();
+ FailureOr<WasmSectionType> secType = ph.parseWasmSectionType();
+ if (failed(secType))
+ return failure();
+
+ FailureOr<uint32_t> secSizeParsed = ph.parseLiteral<uint32_t>();
+ if (failed(secSizeParsed))
+ return failure();
+
+ uint32_t secSize = *secSizeParsed;
+ FailureOr<StringRef> sectionContent = ph.consumeNBytes(secSize);
+ if (failed(sectionContent))
+ return failure();
+
+ LogicalResult registration =
+ registerSection(*secType, *sectionContent, sectionLoc);
+
+ if (failed(registration))
+ return failure();
+ }
+ return success();
+ }
+ };
+
+ auto getLocation(int offset = 0) const {
+ return FileLineColLoc::get(srcName, 0, offset);
+ }
+
+ template <WasmSectionType>
+ LogicalResult parseSectionItem(ParserHead &, size_t);
+
+ template <WasmSectionType section>
+ LogicalResult parseSection() {
+ auto secName = std::string{wasmSectionName<section>};
+ auto sectionNameAttr =
+ StringAttr::get(ctx, srcName.strref() + ":" + secName + "-SECTION");
+ unsigned offset = 0;
+ auto getLocation = [sectionNameAttr, &offset]() {
+ return FileLineColLoc::get(sectionNameAttr, 0, offset);
+ };
+ auto secContent = registry.getContentForSection<section>();
+ if (!secContent) {
+ LDBG() << secName << " section is not present in file.";
+ return success();
+ }
+
+ auto secSrc = secContent.value();
+ ParserHead ph{secSrc, sectionNameAttr};
+ FailureOr<uint32_t> nElemsParsed = ph.parseVectorSize();
+ if (failed(nElemsParsed))
+ return failure();
+ uint32_t nElems = *nElemsParsed;
+ LDBG() << "starting to parse " << nElems << " items for section "
+ << secName;
+ for (size_t i = 0; i < nElems; ++i) {
+ if (failed(parseSectionItem<section>(ph, i)))
+ return failure();
+ }
+
+ if (!ph.end())
+ return emitError(getLocation(), "unparsed garbage at end of section ")
+ << secName;
+ return success();
+ }
+
+ /// Handles the registration of a function import
+ LogicalResult visitImport(Location loc, StringRef moduleName,
+ StringRef importName, TypeIdxRecord tid) {
+ using llvm::Twine;
+ if (tid.id >= symbols.moduleFuncTypes.size())
+ return emitError(loc, "invalid type id: ")
+ << tid.id << ". Only " << symbols.moduleFuncTypes.size()
+ << " type registration.";
+ FunctionType type = symbols.moduleFuncTypes[tid.id];
+ std::string symbol = symbols.getNewFuncSymbolName();
+ auto funcOp = FuncImportOp::create(builder, loc, symbol, moduleName,
+ importName, type);
+ symbols.funcSymbols.push_back({{FlatSymbolRefAttr::get(funcOp)}, type});
+ return funcOp.verify();
+ }
+
+ /// Handles the registration of a memory import
+ LogicalResult visitImport(Location loc, StringRef moduleName,
+ StringRef importName, LimitType limitType) {
+ std::string symbol = symbols.getNewMemorySymbolName();
+ auto memOp = MemImportOp::create(builder, loc, symbol, moduleName,
+ importName, limitType);
+ symbols.memSymbols.push_back({FlatSymbolRefAttr::get(memOp)});
+ return memOp.verify();
+ }
+
+ /// Handles the registration of a table import
+ LogicalResult visitImport(Location loc, StringRef moduleName,
+ StringRef importName, TableType tableType) {
+ std::string symbol = symbols.getNewTableSymbolName();
+ auto tableOp = TableImportOp::create(builder, loc, symbol, moduleName,
+ importName, tableType);
+ symbols.tableSymbols.push_back({FlatSymbolRefAttr::get(tableOp)});
+ return tableOp.verify();
+ }
+
+ /// Handles the registration of a global variable import
+ LogicalResult visitImport(Location loc, StringRef moduleName,
+ StringRef importName, GlobalTypeRecord globalType) {
+ std::string symbol = symbols.getNewGlobalSymbolName();
+ auto giOp =
+ GlobalImportOp::create(builder, loc, symbol, moduleName, importName,
+ globalType.type, globalType.isMutable);
+ symbols.globalSymbols.push_back(
+ {{FlatSymbolRefAttr::get(giOp)}, giOp.getType()});
+ return giOp.verify();
+ }
+
+ // Detect occurence of errors
+ LogicalResult peekDiag(Diagnostic &diag) {
+ if (diag.getSeverity() == DiagnosticSeverity::Error)
+ isValid = false;
+ return failure();
+ }
+
+public:
+ WasmBinaryParser(llvm::SourceMgr &sourceMgr, MLIRContext *ctx)
+ : builder{ctx}, ctx{ctx} {
+ ctx->getDiagEngine().registerHandler(
+ [this](Diagnostic &diag) { return peekDiag(diag); });
+ ctx->loadAllAvailableDialects();
+ if (sourceMgr.getNumBuffers() != 1) {
+ emitError(UnknownLoc::get(ctx), "one source file should be provided");
+ return;
+ }
+ uint32_t sourceBufId = sourceMgr.getMainFileID();
+ StringRef source = sourceMgr.getMemoryBuffer(sourceBufId)->getBuffer();
+ srcName = StringAttr::get(
+ ctx, sourceMgr.getMemoryBuffer(sourceBufId)->getBufferIdentifier());
+
+ auto parser = ParserHead{source, srcName};
+ auto const wasmHeader = StringRef{"\0asm", 4};
+ FileLineColLoc magicLoc = parser.getLocation();
+ FailureOr<StringRef> magic = parser.consumeNBytes(wasmHeader.size());
+ if (failed(magic) || magic->compare(wasmHeader)) {
+ emitError(magicLoc, "source file does not contain valid Wasm header.");
+ return;
+ }
+ auto const expectedVersionString = StringRef{"\1\0\0\0", 4};
+ FileLineColLoc versionLoc = parser.getLocation();
+ FailureOr<StringRef> version =
+ parser.consumeNBytes(expectedVersionString.size());
+ if (failed(version))
+ return;
+ if (version->compare(expectedVersionString)) {
+ emitError(versionLoc,
+ "unsupported Wasm version. only version 1 is supported");
+ return;
+ }
+ LogicalResult fillRegistry = registry.populateFromBody(parser.copy());
+ if (failed(fillRegistry))
+ return;
+
+ mOp = ModuleOp::create(builder, getLocation());
+ builder.setInsertionPointToStart(&mOp.getBodyRegion().front());
+ LogicalResult parsingTypes = parseSection<WasmSectionType::TYPE>();
+ if (failed(parsingTypes))
+ return;
+
+ LogicalResult parsingImports = parseSection<WasmSectionType::IMPORT>();
+ if (failed(parsingImports))
+ return;
+
+ firstInternalFuncID = symbols.funcSymbols.size();
+
+ LogicalResult parsingFunctions = parseSection<WasmSectionType::FUNCTION>();
+ if (failed(parsingFunctions))
+ return;
+
+ LogicalResult parsingTables = parseSection<WasmSectionType::TABLE>();
+ if (failed(parsingTables))
+ return;
+
+ LogicalResult parsingMems = parseSection<WasmSectionType::MEMORY>();
+ if (failed(parsingMems))
+ return;
+
+ LogicalResult parsingGlobals = parseSection<WasmSectionType::GLOBAL>();
+ if (failed(parsingGlobals))
+ return;
+
+ LogicalResult parsingCode = parseSection<WasmSectionType::CODE>();
+ if (failed(parsingCode))
+ return;
+
+ LogicalResult parsingExports = parseSection<WasmSectionType::EXPORT>();
+ if (failed(parsingExports))
+ return;
+
+ // Copy over sizes of containers into statistics.
+ LDBG() << "WASM Imports:"
+ << "\n"
+ << " - Num functions: " << symbols.funcSymbols.size() << "\n"
+ << " - Num globals: " << symbols.globalSymbols.size() << "\n"
+ << " - Num memories: " << symbols.memSymbols.size() << "\n"
+ << " - Num tables: " << symbols.tableSymbols.size();
+ }
+
+ ModuleOp getModule() {
+ if (isValid)
+ return mOp;
+ if (mOp)
+ mOp.erase();
+ return ModuleOp{};
+ }
+
+private:
+ mlir::StringAttr srcName;
+ OpBuilder builder;
+ WasmModuleSymbolTables symbols;
+ MLIRContext *ctx;
+ ModuleOp mOp;
+ SectionRegistry registry;
+ size_t firstInternalFuncID{0};
+ bool isValid{true};
+};
+
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::IMPORT>(ParserHead &ph,
+ size_t) {
+ FileLineColLoc importLoc = ph.getLocation();
+ auto moduleName = ph.parseName();
+ if (failed(moduleName))
+ return failure();
+
+ auto importName = ph.parseName();
+ if (failed(importName))
+ return failure();
+
+ FailureOr<ImportDesc> import = ph.parseImportDesc(ctx);
+ if (failed(import))
+ return failure();
+
+ return std::visit(
+ [this, importLoc, &moduleName, &importName](auto import) {
+ return visitImport(importLoc, *moduleName, *importName, import);
+ },
+ *import);
+}
+
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::EXPORT>(ParserHead &ph,
+ size_t) {
+ FileLineColLoc exportLoc = ph.getLocation();
+
+ auto exportName = ph.parseName();
+ if (failed(exportName))
+ return failure();
+
+ FailureOr<std::byte> opcode = ph.consumeByte();
+ if (failed(opcode))
+ return failure();
+
+ FailureOr<uint32_t> idx = ph.parseLiteral<uint32_t>();
+ if (failed(idx))
+ return failure();
+
+ using SymbolRefDesc = std::variant<SmallVector<SymbolRefContainer>,
+ SmallVector<GlobalSymbolRefContainer>,
+ SmallVector<FunctionSymbolRefContainer>>;
+
+ SymbolRefDesc currentSymbolList;
+ std::string symbolType = "";
+ switch (*opcode) {
+ case WasmBinaryEncoding::Export::function:
+ symbolType = "function";
+ currentSymbolList = symbols.funcSymbols;
+ break;
+ case WasmBinaryEncoding::Export::table:
+ symbolType = "table";
+ currentSymbolList = symbols.tableSymbols;
+ break;
+ case WasmBinaryEncoding::Export::memory:
+ symbolType = "memory";
+ currentSymbolList = symbols.memSymbols;
+ break;
+ case WasmBinaryEncoding::Export::global:
+ symbolType = "global";
+ currentSymbolList = symbols.globalSymbols;
+ break;
+ default:
+ return emitError(exportLoc, "invalid value for export type: ")
+ << std::to_integer<unsigned>(*opcode);
+ }
+
+ auto currentSymbol = std::visit(
+ [&](const auto &list) -> FailureOr<FlatSymbolRefAttr> {
+ if (*idx > list.size()) {
+ emitError(
+ exportLoc,
+ llvm::formatv(
+ "trying to export {0} {1} which is undefined in this scope",
+ symbolType, *idx));
+ return failure();
+ }
+ return list[*idx].symbol;
+ },
+ currentSymbolList);
+
+ if (failed(currentSymbol))
+ return failure();
+
+ Operation *op = SymbolTable::lookupSymbolIn(mOp, *currentSymbol);
+ SymbolTable::setSymbolVisibility(op, SymbolTable::Visibility::Public);
+ StringAttr symName = SymbolTable::getSymbolName(op);
+ return SymbolTable{mOp}.rename(symName, *exportName);
+}
+
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::TABLE>(ParserHead &ph,
+ size_t) {
+ FileLineColLoc opLocation = ph.getLocation();
+ FailureOr<TableType> tableType = ph.parseTableType(ctx);
+ if (failed(tableType))
+ return failure();
+ LDBG() << " Parsed table description: " << *tableType;
+ StringAttr symbol = builder.getStringAttr(symbols.getNewTableSymbolName());
+ auto tableOp =
+ TableOp::create(builder, opLocation, symbol.strref(), *tableType);
+ symbols.tableSymbols.push_back({SymbolRefAttr::get(tableOp)});
+ return success();
+}
+
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::FUNCTION>(ParserHead &ph,
+ size_t) {
+ FileLineColLoc opLoc = ph.getLocation();
+ auto typeIdxParsed = ph.parseLiteral<uint32_t>();
+ if (failed(typeIdxParsed))
+ return failure();
+ uint32_t typeIdx = *typeIdxParsed;
+ if (typeIdx >= symbols.moduleFuncTypes.size())
+ return emitError(getLocation(), "invalid type index: ") << typeIdx;
+ std::string symbol = symbols.getNewFuncSymbolName();
+ auto funcOp =
+ FuncOp::create(builder, opLoc, symbol, symbols.moduleFuncTypes[typeIdx]);
+ Block *block = funcOp.addEntryBlock();
+ OpBuilder::InsertionGuard guard{builder};
+ builder.setInsertionPointToEnd(block);
+ ReturnOp::create(builder, opLoc);
+ symbols.funcSymbols.push_back(
+ {{FlatSymbolRefAttr::get(funcOp.getSymNameAttr())},
+ symbols.moduleFuncTypes[typeIdx]});
+ return funcOp.verify();
+}
+
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::TYPE>(ParserHead &ph,
+ size_t) {
+ FailureOr<FunctionType> funcType = ph.parseFunctionType(ctx);
+ if (failed(funcType))
+ return failure();
+ LDBG() << "Parsed function type " << *funcType;
+ symbols.moduleFuncTypes.push_back(*funcType);
+ return success();
+}
+
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::MEMORY>(ParserHead &ph,
+ size_t) {
+ FileLineColLoc opLocation = ph.getLocation();
+ FailureOr<LimitType> memory = ph.parseLimit(ctx);
+ if (failed(memory))
+ return failure();
+
+ LDBG() << " Registering memory " << *memory;
+ std::string symbol = symbols.getNewMemorySymbolName();
+ auto memOp = MemOp::create(builder, opLocation, symbol, *memory);
+ symbols.memSymbols.push_back({SymbolRefAttr::get(memOp)});
+ return success();
+}
+
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::GLOBAL>(ParserHead &ph,
+ size_t) {
+ FileLineColLoc globalLocation = ph.getLocation();
+ auto globalTypeParsed = ph.parseGlobalType(ctx);
+ if (failed(globalTypeParsed))
+ return failure();
+
+ GlobalTypeRecord globalType = *globalTypeParsed;
+ auto symbol = builder.getStringAttr(symbols.getNewGlobalSymbolName());
+ auto globalOp = builder.create<wasmssa::GlobalOp>(
+ globalLocation, symbol, globalType.type, globalType.isMutable);
+ symbols.globalSymbols.push_back(
+ {{FlatSymbolRefAttr::get(globalOp)}, globalOp.getType()});
+ OpBuilder::InsertionGuard guard{builder};
+ Block *block = builder.createBlock(&globalOp.getInitializer());
+ builder.setInsertionPointToStart(block);
+ parsed_inst_t expr = ph.parseExpression(builder, symbols);
+ if (failed(expr))
+ return failure();
+ if (block->empty())
+ return emitError(globalLocation, "global with empty initializer");
+ if (expr->size() != 1 && (*expr)[0].getType() != globalType.type)
+ return emitError(
+ globalLocation,
+ "initializer result type does not match global declaration type");
+ builder.create<ReturnOp>(globalLocation, *expr);
+ return success();
+}
+
+template <>
+LogicalResult WasmBinaryParser::parseSectionItem<WasmSectionType::CODE>(
+ ParserHead &ph, size_t innerFunctionId) {
+ unsigned long funcId = innerFunctionId + firstInternalFuncID;
+ FunctionSymbolRefContainer symRef = symbols.funcSymbols[funcId];
+ auto funcOp =
+ dyn_cast<FuncOp>(SymbolTable::lookupSymbolIn(mOp, symRef.symbol));
+ assert(funcOp);
+ if (failed(ph.parseCodeFor(funcOp, symbols)))
+ return failure();
+ return success();
+}
+} // namespace
+
+namespace mlir::wasm {
+OwningOpRef<ModuleOp> importWebAssemblyToModule(llvm::SourceMgr &source,
+ MLIRContext *context) {
+ WasmBinaryParser wBN{source, context};
+ ModuleOp mOp = wBN.getModule();
+ if (mOp)
+ return {mOp};
+
+ return {nullptr};
+}
+} // namespace mlir::wasm
diff --git a/mlir/lib/Target/Wasm/TranslateRegistration.cpp b/mlir/lib/Target/Wasm/TranslateRegistration.cpp
new file mode 100644
index 0000000..03b9784
--- /dev/null
+++ b/mlir/lib/Target/Wasm/TranslateRegistration.cpp
@@ -0,0 +1,28 @@
+//===- TranslateRegistration.cpp - Register translation -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/OwningOpRef.h"
+#include "mlir/Target/Wasm/WasmImporter.h"
+#include "mlir/Tools/mlir-translate/Translation.h"
+
+using namespace mlir;
+
+namespace mlir {
+void registerFromWasmTranslation() {
+ TranslateToMLIRRegistration registration{
+ "import-wasm", "Translate WASM to MLIR",
+ [](llvm::SourceMgr &sourceMgr,
+ MLIRContext *context) -> OwningOpRef<Operation *> {
+ return wasm::importWebAssemblyToModule(sourceMgr, context);
+ },
+ [](DialectRegistry &registry) {
+ registry.insert<wasmssa::WasmSSADialect>();
+ }};
+}
+} // namespace mlir
diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index 51e702a..c883baa 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -147,8 +147,9 @@ private:
std::string docStr;
{
llvm::raw_string_ostream docOS(docStr);
+ std::string tmpDocStr = doc.str();
raw_indented_ostream(docOS).printReindented(
- StringRef(docStr).rtrim(" \t"));
+ StringRef(tmpDocStr).rtrim(" \t"));
}
return docStr;
}
diff --git a/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp b/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp
index 9950050..6945c09 100644
--- a/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp
+++ b/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp
@@ -21,6 +21,7 @@
#include "llvm/LineEditor/LineEditor.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/Process.h"
#include "llvm/Support/SourceMgr.h"
//===----------------------------------------------------------------------===//
@@ -43,7 +44,7 @@ mlir::mlirQueryMain(int argc, char **argv, MLIRContext &context,
llvm::cl::value_desc("command"), llvm::cl::cat(mlirQueryCategory));
static llvm::cl::opt<std::string> inputFilename(
- llvm::cl::Positional, llvm::cl::desc("<input file>"),
+ llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"),
llvm::cl::cat(mlirQueryCategory));
static llvm::cl::opt<bool> noImplicitModule{
@@ -68,6 +69,14 @@ mlir::mlirQueryMain(int argc, char **argv, MLIRContext &context,
return mlir::success();
}
+ // When reading from stdin and the input is a tty, it is often a user mistake
+ // and the process "appears to be stuck". Print a message to let the user
+ // know!
+ if (inputFilename == "-" &&
+ llvm::sys::Process::FileDescriptorIsDisplayed(fileno(stdin)))
+ llvm::errs() << "(processing input from stdin now, hit ctrl-c/ctrl-d to "
+ "interrupt)\n";
+
// Set up the input file.
std::string errorMessage;
auto file = openInputFile(inputFilename, &errorMessage);
diff --git a/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp b/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp
index e89d392..34459b8 100644
--- a/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp
+++ b/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp
@@ -26,9 +26,9 @@
using namespace mlir;
// Parse and verify the input MLIR file. Returns null on error.
-OwningOpRef<Operation *> loadModule(MLIRContext &context,
- StringRef inputFilename,
- bool insertImplictModule) {
+static OwningOpRef<Operation *> loadModule(MLIRContext &context,
+ StringRef inputFilename,
+ bool insertImplictModule) {
// Set up the input file.
std::string errorMessage;
auto file = openInputFile(inputFilename, &errorMessage);
@@ -65,6 +65,11 @@ LogicalResult mlir::mlirReduceMain(int argc, char **argv,
"Disable implicit addition of a top-level module op during parsing"),
llvm::cl::init(false)};
+ static llvm::cl::opt<bool> allowUnregisteredDialects(
+ "allow-unregistered-dialect",
+ llvm::cl::desc("Allow operation with no registered dialects"),
+ llvm::cl::init(false));
+
llvm::cl::HideUnrelatedOptions(mlirReduceCategory);
llvm::InitLLVM y(argc, argv);
@@ -79,6 +84,8 @@ LogicalResult mlir::mlirReduceMain(int argc, char **argv,
llvm::cl::PrintHelpMessage();
return success();
}
+ if (allowUnregisteredDialects)
+ context.allowUnregisteredDialects();
std::string errorMessage;
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 09e5a02..8eaac30 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -177,11 +177,10 @@ void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
Operation *toOp) {
assert(fromOp->getBlock() == toOp->getBlock());
- assert(
- isa<MemoryEffectOpInterface>(fromOp) &&
- cast<MemoryEffectOpInterface>(fromOp).hasEffect<MemoryEffects::Read>() &&
- isa<MemoryEffectOpInterface>(toOp) &&
- cast<MemoryEffectOpInterface>(toOp).hasEffect<MemoryEffects::Read>());
+ assert(hasEffect<MemoryEffects::Read>(fromOp) &&
+ "expected read effect on fromOp");
+ assert(hasEffect<MemoryEffects::Read>(toOp) &&
+ "expected read effect on toOp");
Operation *nextOp = fromOp->getNextNode();
auto result =
memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp, nullptr));
@@ -245,11 +244,10 @@ LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
// Some simple use case of operation with memory side-effect are dealt with
// here. Operations with no side-effect are done after.
if (!isMemoryEffectFree(op)) {
- auto memEffects = dyn_cast<MemoryEffectOpInterface>(op);
// TODO: Only basic use case for operations with MemoryEffects::Read can be
// eleminated now. More work needs to be done for more complicated patterns
// and other side-effects.
- if (!memEffects || !memEffects.onlyHasEffect<MemoryEffects::Read>())
+ if (!hasSingleEffect<MemoryEffects::Read>(op))
return failure();
// Look for an existing definition for the operation.
diff --git a/mlir/lib/Transforms/InlinerPass.cpp b/mlir/lib/Transforms/InlinerPass.cpp
index 703e517..77a9e6c 100644
--- a/mlir/lib/Transforms/InlinerPass.cpp
+++ b/mlir/lib/Transforms/InlinerPass.cpp
@@ -18,6 +18,7 @@
#include "mlir/Analysis/CallGraph.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Inliner.h"
+#include "llvm/Support/DebugLog.h"
namespace mlir {
#define GEN_PASS_DEF_INLINER
@@ -120,8 +121,8 @@ static bool isProfitableToInline(const Inliner::ResolvedCall &resolvedCall,
return true;
unsigned ratio = countOps(calleeRegion) * 100 / callerOps;
- LLVM_DEBUG(llvm::dbgs() << "Callee / caller operation ratio (max: "
- << inliningThreshold << "%): " << ratio << "%\n");
+ LDBG() << "Callee / caller operation ratio (max: " << inliningThreshold
+ << "%): " << ratio << "%";
return ratio <= inliningThreshold;
}
@@ -138,7 +139,7 @@ void InlinerPass::runOnOperation() {
}
// By default, assume that any inlining is profitable.
- auto profitabilityCb = [=](const Inliner::ResolvedCall &call) {
+ auto profitabilityCb = [this](const Inliner::ResolvedCall &call) {
return isProfitableToInline(call, inliningThreshold);
};
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index cf039c3..d36a3c1 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -19,6 +19,7 @@
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/GenericIteratedDominanceFrontier.h"
namespace mlir {
@@ -632,8 +633,7 @@ MemorySlotPromoter::promoteSlot() {
}
}
- LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: " << slot.ptr
- << "\n");
+ LDBG() << "Promoted memory slot: " << slot.ptr;
if (statistics.promotedAmount)
(*statistics.promotedAmount)++;
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 4ccb83f..0e84b6d 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -258,18 +258,17 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
- LDBG() << "Processing simple op: " << *op;
if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
- LDBG()
- << "Simple op is not memory effect free or has live results, skipping: "
- << *op;
+ LDBG() << "Simple op is not memory effect free or has live results, "
+ "preserving it: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
return;
}
LDBG()
<< "Simple op has all dead results and is memory effect free, scheduling "
"for removal: "
- << *op;
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
cl.operations.push_back(op);
collectNonLiveValues(nonLiveSet, op->getResults(),
BitVector(op->getNumResults(), true));
@@ -345,8 +344,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
// being returned, in order to optimize our IR. So, this demonstrates how we
// can make our optimization strong by even removing a live return value (%0),
// since it forwards only to non-live value(s) (%1#1).
- Operation *lastReturnOp = funcOp.back().getTerminator();
- size_t numReturns = lastReturnOp->getNumOperands();
+ size_t numReturns = funcOp.getNumResults();
BitVector nonLiveRets(numReturns, true);
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
@@ -728,19 +726,31 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
/// Removes dead values collected in RDVFinalCleanupList.
/// To be run once when all dead values have been collected.
static void cleanUpDeadVals(RDVFinalCleanupList &list) {
+ LDBG() << "Starting cleanup of dead values...";
+
// 1. Operations
+ LDBG() << "Cleaning up " << list.operations.size() << " operations";
for (auto &op : list.operations) {
+ LDBG() << "Erasing operation: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
op->dropAllUses();
op->erase();
}
// 2. Values
+ LDBG() << "Cleaning up " << list.values.size() << " values";
for (auto &v : list.values) {
+ LDBG() << "Dropping all uses of value: " << v;
v.dropAllUses();
}
// 3. Functions
+ LDBG() << "Cleaning up " << list.functions.size() << " functions";
for (auto &f : list.functions) {
+ LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName();
+ LDBG() << " Erasing " << f.nonLiveArgs.count() << " non-live arguments";
+ LDBG() << " Erasing " << f.nonLiveRets.count()
+ << " non-live return values";
// Some functions may not allow erasing arguments or results. These calls
// return failure in such cases without modifying the function, so it's okay
// to proceed.
@@ -749,44 +759,67 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
}
// 4. Operands
+ LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
for (OperationToCleanup &o : list.operands) {
- if (o.op->getNumOperands() > 0)
+ if (o.op->getNumOperands() > 0) {
+ LDBG() << "Erasing " << o.nonLive.count()
+ << " non-live operands from operation: "
+ << OpWithFlags(o.op, OpPrintingFlags().skipRegions());
o.op->eraseOperands(o.nonLive);
+ }
}
// 5. Results
+ LDBG() << "Cleaning up " << list.results.size() << " result lists";
for (auto &r : list.results) {
+ LDBG() << "Erasing " << r.nonLive.count()
+ << " non-live results from operation: "
+ << OpWithFlags(r.op, OpPrintingFlags().skipRegions());
dropUsesAndEraseResults(r.op, r.nonLive);
}
// 6. Blocks
+ LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists";
for (auto &b : list.blocks) {
// blocks that are accessed via multiple codepaths processed once
if (b.b->getNumArguments() != b.nonLiveArgs.size())
continue;
+ LDBG() << "Erasing " << b.nonLiveArgs.count()
+ << " non-live arguments from block: " << b.b;
// it iterates backwards because erase invalidates all successor indexes
for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
if (!b.nonLiveArgs[i])
continue;
+ LDBG() << " Erasing block argument " << i << ": " << b.b->getArgument(i);
b.b->getArgument(i).dropAllUses();
b.b->eraseArgument(i);
}
}
// 7. Successor Operands
+ LDBG() << "Cleaning up " << list.successorOperands.size()
+ << " successor operand lists";
for (auto &op : list.successorOperands) {
SuccessorOperands successorOperands =
op.branch.getSuccessorOperands(op.successorIndex);
// blocks that are accessed via multiple codepaths processed once
if (successorOperands.size() != op.nonLiveOperands.size())
continue;
+ LDBG() << "Erasing " << op.nonLiveOperands.count()
+ << " non-live successor operands from successor "
+ << op.successorIndex << " of branch: "
+ << OpWithFlags(op.branch, OpPrintingFlags().skipRegions());
// it iterates backwards because erase invalidates all successor indexes
for (int i = successorOperands.size() - 1; i >= 0; --i) {
if (!op.nonLiveOperands[i])
continue;
+ LDBG() << " Erasing successor operand " << i << ": "
+ << successorOperands[i];
successorOperands.erase(i);
}
}
+
+ LDBG() << "Finished cleanup of dead values";
}
struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp
index 67f536a..859c030 100644
--- a/mlir/lib/Transforms/SROA.cpp
+++ b/mlir/lib/Transforms/SROA.cpp
@@ -12,6 +12,7 @@
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Transforms/Passes.h"
+#include "llvm/Support/DebugLog.h"
namespace mlir {
#define GEN_PASS_DEF_SROA
@@ -180,8 +181,7 @@ static void destructureSlot(
assert(slot.ptr.use_empty() && "after destructuring, the original slot "
"pointer should no longer be used");
- LLVM_DEBUG(llvm::dbgs() << "[sroa] Destructured memory slot: " << slot.ptr
- << "\n");
+ LDBG() << "Destructured memory slot: " << slot.ptr;
if (statistics.destructuredAmount)
(*statistics.destructuredAmount)++;
diff --git a/mlir/lib/Transforms/SymbolDCE.cpp b/mlir/lib/Transforms/SymbolDCE.cpp
index 0a925c4..87885be 100644
--- a/mlir/lib/Transforms/SymbolDCE.cpp
+++ b/mlir/lib/Transforms/SymbolDCE.cpp
@@ -13,8 +13,11 @@
#include "mlir/Transforms/Passes.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/InterleavedRange.h"
namespace mlir {
#define GEN_PASS_DEF_SYMBOLDCE
@@ -87,8 +90,8 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
SymbolTableCollection &symbolTable,
bool symbolTableIsHidden,
DenseSet<Operation *> &liveSymbols) {
- LLVM_DEBUG(llvm::dbgs() << "computeLiveness: " << symbolTableOp->getName()
- << "\n");
+ LDBG() << "computeLiveness: "
+ << OpWithFlags(symbolTableOp, OpPrintingFlags().skipRegions());
// A worklist of live operations to propagate uses from.
SmallVector<Operation *, 16> worklist;
@@ -116,7 +119,8 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
// consideration.
while (!worklist.empty()) {
Operation *op = worklist.pop_back_val();
- LLVM_DEBUG(llvm::dbgs() << "processing: " << op->getName() << "\n");
+ LDBG() << "processing: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
// If this is a symbol table, recursively compute its liveness.
if (op->hasTrait<OpTrait::SymbolTable>()) {
@@ -124,13 +128,14 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
// symbol, or if it is a private symbol.
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate();
- LLVM_DEBUG(llvm::dbgs() << "\tsymbol table: " << op->getName()
- << " is hidden: " << symIsHidden << "\n");
+ LDBG() << "\tsymbol table: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " is hidden: " << symIsHidden;
if (failed(computeLiveness(op, symbolTable, symIsHidden, liveSymbols)))
return failure();
} else {
- LLVM_DEBUG(llvm::dbgs()
- << "\tnon-symbol table: " << op->getName() << "\n");
+ LDBG() << "\tnon-symbol table: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
// If the op is not a symbol table, then, unless op itself is dead which
// would be handled by DCE, we need to check all the regions and blocks
// within the op to find the uses (e.g., consider visibility within op as
@@ -160,20 +165,17 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
}
SmallVector<Operation *, 4> resolvedSymbols;
- LLVM_DEBUG(llvm::dbgs() << "uses of " << op->getName() << "\n");
+ LDBG() << "uses of " << OpWithFlags(op, OpPrintingFlags().skipRegions());
for (const SymbolTable::SymbolUse &use : *uses) {
- LLVM_DEBUG(llvm::dbgs() << "\tuse: " << use.getUser() << "\n");
+ LDBG() << "\tuse: " << use.getUser();
// Lookup the symbols referenced by this use.
resolvedSymbols.clear();
if (failed(symbolTable.lookupSymbolIn(parentOp, use.getSymbolRef(),
resolvedSymbols)))
// Ignore references to unknown symbols.
continue;
- LLVM_DEBUG({
- llvm::dbgs() << "\t\tresolved symbols: ";
- llvm::interleaveComma(resolvedSymbols, llvm::dbgs());
- llvm::dbgs() << "\n";
- });
+ LDBG() << "\t\tresolved symbols: "
+ << llvm::interleaved(resolvedSymbols, ", ");
// Mark each of the resolved symbols as live.
for (Operation *resolvedSymbol : resolvedSymbols)
diff --git a/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp b/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp
index cfd568f..19cf464 100644
--- a/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp
+++ b/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp
@@ -21,7 +21,10 @@
#include "mlir/Transforms/ControlFlowSinkUtils.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "llvm/Support/DebugLog.h"
#include <vector>
#define DEBUG_TYPE "cf-sink"
@@ -84,13 +87,15 @@ bool Sinker::allUsersDominatedBy(Operation *op, Region *region) {
void Sinker::tryToSinkPredecessors(Operation *user, Region *region,
std::vector<Operation *> &stack) {
- LLVM_DEBUG(user->print(llvm::dbgs() << "\nContained op:\n"));
+ LDBG() << "Contained op: "
+ << OpWithFlags(user, OpPrintingFlags().skipRegions());
for (Value value : user->getOperands()) {
Operation *op = value.getDefiningOp();
// Ignore block arguments and ops that are already inside the region.
if (!op || op->getParentRegion() == region)
continue;
- LLVM_DEBUG(op->print(llvm::dbgs() << "\nTry to sink:\n"));
+ LDBG() << "Try to sink:\n"
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
// If the op's users are all in the region and it can be moved, then do so.
if (allUsersDominatedBy(op, region) && shouldMoveIntoRegion(op, region)) {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 0c26b4e..5ba109d 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -182,15 +182,24 @@ private:
/// conversions.)
static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
+/// Return the operation that defines all values in the vector. Return nullptr
+/// if the values are not defined by the same operation.
+static Operation *getCommonDefiningOp(const ValueVector &values) {
+ assert(!values.empty() && "expected non-empty value vector");
+ Operation *op = values.front().getDefiningOp();
+ for (Value v : llvm::drop_begin(values)) {
+ if (v.getDefiningOp() != op)
+ return nullptr;
+ }
+ return op;
+}
+
/// A vector of values is a pure type conversion if all values are defined by
/// the same operation and the operation has the `kPureTypeConversionMarker`
/// attribute.
static bool isPureTypeConversion(const ValueVector &values) {
assert(!values.empty() && "expected non-empty value vector");
- Operation *op = values.front().getDefiningOp();
- for (Value v : llvm::drop_begin(values))
- if (v.getDefiningOp() != op)
- return false;
+ Operation *op = getCommonDefiningOp(values);
return op && op->hasAttr(kPureTypeConversionMarker);
}
@@ -839,9 +848,10 @@ static bool hasRewrite(R &&rewrites, Block *block) {
namespace mlir {
namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
- explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
+ explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
const ConversionConfig &config)
- : context(ctx), config(config) {}
+ : rewriter(rewriter), config(config),
+ notifyingRewriter(rewriter.getContext(), config.listener) {}
//===--------------------------------------------------------------------===//
// State Management
@@ -863,6 +873,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// failure.
template <typename RewriteTy, typename... Args>
void appendRewrite(Args &&...args) {
+ assert(config.allowPatternRollback && "appending rewrites is not allowed");
rewrites.push_back(
std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
}
@@ -877,8 +888,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// is the tag used when describing a value within a diagnostic, e.g.
/// "operand".
LogicalResult remapValues(StringRef valueDiagTag,
- std::optional<Location> inputLoc,
- PatternRewriter &rewriter, ValueRange values,
+ std::optional<Location> inputLoc, ValueRange values,
SmallVector<ValueVector> &remapped);
/// Return "true" if the given operation is ignored, and does not need to be
@@ -889,15 +899,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
bool wasOpReplaced(Operation *op) const;
/// Lookup the most recently mapped values with the desired types in the
- /// mapping.
- ///
- /// Special cases:
- /// - If the desired type range is empty, simply return the most recently
- /// mapped values.
- /// - If there is no mapping to the desired types, also return the most
- /// recently mapped values.
- /// - If there is no mapping for the given values at all, return the given
- /// value.
+ /// mapping, taking into account only replacements. Perform a best-effort
+ /// search for existing materializations with the desired types.
///
/// If `skipPureTypeConversions` is "true", materializations that are pure
/// type conversions are not considered.
@@ -915,8 +918,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Convert the types of block arguments within the given region.
FailureOr<Block *>
- convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
- const TypeConverter &converter,
+ convertRegionTypes(Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion);
/// Apply the given signature conversion on the given block. The new block
@@ -926,8 +928,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// translate between the origin argument types and those specified in the
/// signature conversion.
Block *applySignatureConversion(
- ConversionPatternRewriter &rewriter, Block *block,
- const TypeConverter *converter,
+ Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion);
/// Replace the results of the given operation with the given values and
@@ -976,7 +977,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
Type originalType, const TypeConverter *converter,
- UnrealizedConversionCastOp *castOp = nullptr,
bool isPureTypeConversion = true);
/// Find a replacement value for the given SSA value in the conversion value
@@ -1058,14 +1058,17 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// State
//===--------------------------------------------------------------------===//
- /// MLIR context.
- MLIRContext *context;
+ /// The rewriter that is used to perform the conversion.
+ ConversionPatternRewriter &rewriter;
// Mapping between replaced values that differ in type. This happens when
// replacing a value with one of a different type.
ConversionValueMapping mapping;
/// Ordered list of block operations (creations, splits, motions).
+ /// This vector is maintained only if `allowPatternRollback` is set to
+ /// "true". Otherwise, all IR rewrites are materialized immediately and no
+ /// bookkeeping is needed.
SmallVector<std::unique_ptr<IRRewrite>> rewrites;
/// A set of operations that should no longer be considered for legalization.
@@ -1089,6 +1092,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// by the current pattern.
SetVector<Block *> patternInsertedBlocks;
+ /// A list of unresolved materializations that were created by the current
+ /// pattern.
+ DenseSet<UnrealizedConversionCastOp> patternMaterializations;
+
/// A mapping for looking up metadata of unresolved materializations.
DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
unresolvedMaterializations;
@@ -1104,15 +1111,37 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Dialect conversion configuration.
const ConversionConfig &config;
+ /// A set of erased operations. This set is utilized only if
+ /// `allowPatternRollback` is set to "false". Conceptually, this set is
+ /// similar to `replacedOps` (which is maintained when the flag is set to
+ /// "true"). However, erasing from a DenseSet is more efficient than erasing
+ /// from a SetVector.
+ DenseSet<Operation *> erasedOps;
+
+ /// A set of erased blocks. This set is utilized only if
+ /// `allowPatternRollback` is set to "false".
+ DenseSet<Block *> erasedBlocks;
+
+ /// A rewriter that notifies the listener (if any) about all IR
+ /// modifications. This rewriter is utilized only if `allowPatternRollback`
+ /// is set to "false". If the flag is set to "true", the listener is notified
+ /// with a separate mechanism (e.g., in `IRRewrite::commit`).
+ IRRewriter notifyingRewriter;
+
#ifndef NDEBUG
+ /// A set of replaced block arguments. This set is for debugging purposes
+ /// only and it is maintained only if `allowPatternRollback` is set to
+ /// "true".
+ DenseSet<BlockArgument> replacedArgs;
+
/// A set of operations that have pending updates. This tracking isn't
/// strictly necessary, and is thus only active during debug builds for extra
/// verification.
SmallPtrSet<Operation *, 1> pendingRootUpdates;
/// A raw output stream used to prefix the debug log.
- llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + "] ").str(),
- llvm::dbgs(), /*HasPendingNewline=*/false};
+ llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(),
+ llvm::dbgs()};
/// A logger used to emit diagnostics during the conversion process.
llvm::ScopedPrinter logger{os};
@@ -1140,11 +1169,8 @@ void BlockTypeConversionRewrite::rollback() {
getNewBlock()->replaceAllUsesWith(getOrigBlock());
}
-void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
- Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
- if (!repl)
- return;
-
+static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg,
+ Value repl) {
if (isa<BlockArgument>(repl)) {
rewriter.replaceAllUsesWith(arg, repl);
return;
@@ -1161,6 +1187,13 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
});
}
+void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
+ Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
+ if (!repl)
+ return;
+ performReplaceBlockArg(rewriter, arg, repl);
+}
+
void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
@@ -1223,16 +1256,17 @@ void UnresolvedMaterializationRewrite::rollback() {
}
void ConversionPatternRewriterImpl::applyRewrites() {
- // Commit all rewrites.
- IRRewriter rewriter(context, config.listener);
+ // Commit all rewrites. Use a new rewriter, so the modifications are not
+ // tracked for rollback purposes etc.
+ IRRewriter irRewriter(rewriter.getContext(), config.listener);
// Note: New rewrites may be added during the "commit" phase and the
// `rewrites` vector may reallocate.
for (size_t i = 0; i < rewrites.size(); ++i)
- rewrites[i]->commit(rewriter);
+ rewrites[i]->commit(irRewriter);
// Clean up all rewrites.
SingleEraseRewriter eraseRewriter(
- context, /*opErasedCallback=*/[&](Operation *op) {
+ rewriter.getContext(), /*opErasedCallback=*/[&](Operation *op) {
if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
unresolvedMaterializations.erase(castOp);
});
@@ -1246,6 +1280,30 @@ void ConversionPatternRewriterImpl::applyRewrites() {
ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
+ // Helper function that looks up a single value.
+ auto lookup = [&](const ValueVector &values) -> ValueVector {
+ assert(!values.empty() && "expected non-empty value vector");
+
+ // If the pattern rollback is enabled, use the mapping to look up the
+ // values.
+ if (config.allowPatternRollback)
+ return mapping.lookup(values);
+
+ // Otherwise, look up values by examining the IR. All replacements have
+ // already been materialized in IR.
+ Operation *op = getCommonDefiningOp(values);
+ if (!op)
+ return {};
+ auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
+ if (!castOp)
+ return {};
+ if (!this->unresolvedMaterializations.contains(castOp))
+ return {};
+ if (castOp.getOutputs() != values)
+ return {};
+ return castOp.getInputs();
+ };
+
// Helper function that looks up each value in `values` individually and then
// composes the results. If that fails, it tries to look up the entire vector
// at once.
@@ -1253,7 +1311,7 @@ ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
// If possible, replace each value with (one or multiple) mapped values.
ValueVector next;
for (Value v : values) {
- ValueVector r = mapping.lookup({v});
+ ValueVector r = lookup({v});
if (!r.empty()) {
llvm::append_range(next, r);
} else {
@@ -1273,7 +1331,7 @@ ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
// be stored (and looked up) in the mapping. But for performance reasons,
// we choose to reuse existing IR (when possible) instead of creating it
// multiple times.
- ValueVector r = mapping.lookup(values);
+ ValueVector r = lookup(values);
if (r.empty()) {
// No mapping found: The lookup stops here.
return {};
@@ -1347,21 +1405,13 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state,
void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
StringRef patternName) {
for (auto &rewrite :
- llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) {
- if (!config.allowPatternRollback &&
- !isa<UnresolvedMaterializationRewrite>(rewrite)) {
- // Unresolved materializations can always be rolled back (erased).
- llvm::report_fatal_error("pattern '" + patternName +
- "' rollback of IR modifications requested");
- }
+ llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
rewrite->rollback();
- }
rewrites.resize(numRewritesToKeep);
}
LogicalResult ConversionPatternRewriterImpl::remapValues(
- StringRef valueDiagTag, std::optional<Location> inputLoc,
- PatternRewriter &rewriter, ValueRange values,
+ StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
SmallVector<ValueVector> &remapped) {
remapped.reserve(llvm::size(values));
@@ -1383,7 +1433,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
// If there is no legal conversion, fail to match this pattern.
SmallVector<Type, 1> legalTypes;
- if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
+ if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
diag << "unable to convert type for " << valueDiagTag << " #"
<< it.index() << ", type was " << origType;
@@ -1419,12 +1469,12 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
// Check to see if this operation is ignored or was replaced.
- return replacedOps.count(op) || ignoredOps.count(op);
+ return wasOpReplaced(op) || ignoredOps.count(op);
}
bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
// Check to see if this operation was replaced.
- return replacedOps.count(op);
+ return replacedOps.count(op) || erasedOps.count(op);
}
//===----------------------------------------------------------------------===//
@@ -1432,8 +1482,7 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
//===----------------------------------------------------------------------===//
FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
- ConversionPatternRewriter &rewriter, Region *region,
- const TypeConverter &converter,
+ Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion) {
regionToConverter[region] = &converter;
if (region->empty())
@@ -1448,25 +1497,23 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
if (!conversion)
return failure();
// Convert the block with the computed signature.
- applySignatureConversion(rewriter, &block, &converter, *conversion);
+ applySignatureConversion(&block, &converter, *conversion);
}
// Convert the entry block. If an entry signature conversion was provided,
// use that one. Otherwise, compute the signature with the type converter.
if (entryConversion)
- return applySignatureConversion(rewriter, &region->front(), &converter,
+ return applySignatureConversion(&region->front(), &converter,
*entryConversion);
std::optional<TypeConverter::SignatureConversion> conversion =
converter.convertBlockSignature(&region->front());
if (!conversion)
return failure();
- return applySignatureConversion(rewriter, &region->front(), &converter,
- *conversion);
+ return applySignatureConversion(&region->front(), &converter, *conversion);
}
Block *ConversionPatternRewriterImpl::applySignatureConversion(
- ConversionPatternRewriter &rewriter, Block *block,
- const TypeConverter *converter,
+ Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion) {
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// A block cannot be converted multiple times.
@@ -1508,7 +1555,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// a bit more efficient, so we try to do that when possible.
bool fastPath = !config.listener;
if (fastPath) {
- appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
+ if (config.allowPatternRollback)
+ appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
newBlock->getOperations().splice(newBlock->end(), block->getOperations());
} else {
while (!block->empty())
@@ -1534,7 +1582,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
origArg.getLoc(),
/*valuesToMap=*/{}, /*inputs=*/ValueRange(),
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
- /*castOp=*/nullptr, /*isPureTypeConversion=*/false)
+ /*isPureTypeConversion=*/false)
.front();
replaceUsesOfBlockArgument(origArg, mat, converter);
continue;
@@ -1556,7 +1604,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
replaceUsesOfBlockArgument(origArg, replArgs, converter);
}
- appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
+ if (config.allowPatternRollback)
+ appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
// Erase the old block. (It is just unlinked for now and will be erased during
// cleanup.)
@@ -1575,7 +1624,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
Type originalType, const TypeConverter *converter,
- UnrealizedConversionCastOp *castOp, bool isPureTypeConversion) {
+ bool isPureTypeConversion) {
assert((!originalType || kind == MaterializationKind::Target) &&
"original type is valid only for target materializations");
assert(TypeRange(inputs) != outputTypes &&
@@ -1585,23 +1634,35 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
// tracking the materialization like we do for other operations.
OpBuilder builder(outputTypes.front().getContext());
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
- auto convertOp =
+ UnrealizedConversionCastOp convertOp =
UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
+ if (config.attachDebugMaterializationKind) {
+ StringRef kindStr =
+ kind == MaterializationKind::Source ? "source" : "target";
+ convertOp->setAttr("__kind__", builder.getStringAttr(kindStr));
+ }
if (isPureTypeConversion)
convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
- if (!valuesToMap.empty())
- mapping.map(valuesToMap, convertOp.getResults());
- if (castOp)
- *castOp = convertOp;
+
+ // Register the materialization.
unresolvedMaterializations[convertOp] =
UnresolvedMaterializationInfo(converter, kind, originalType);
- appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
- std::move(valuesToMap));
+ if (config.allowPatternRollback) {
+ if (!valuesToMap.empty())
+ mapping.map(valuesToMap, convertOp.getResults());
+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
+ std::move(valuesToMap));
+ } else {
+ patternMaterializations.insert(convertOp);
+ }
return convertOp.getResults();
}
Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
Value value, const TypeConverter *converter) {
+ assert(config.allowPatternRollback &&
+ "this code path is valid only in rollback mode");
+
// Try to find a replacement value with the same type in the conversion value
// mapping. This includes cached materializations. We try to reuse those
// instead of generating duplicate IR.
@@ -1663,26 +1724,119 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
logger.getOStream() << " (was detached)";
logger.getOStream() << "\n";
});
- assert(!wasOpReplaced(op->getParentOp()) &&
+
+ // In rollback mode, it is easier to misuse the API, so perform extra error
+ // checking.
+ assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) &&
"attempting to insert into a block within a replaced/erased op");
+ // In "no rollback" mode, the listener is always notified immediately.
+ if (!config.allowPatternRollback && config.listener)
+ config.listener->notifyOperationInserted(op, previous);
+
if (wasDetached) {
- // If the op was detached, it is most likely a newly created op.
- // TODO: If the same op is inserted multiple times from a detached state,
- // the rollback mechanism may erase the same op multiple times. This is a
- // bug in the rollback-based dialect conversion driver.
- appendRewrite<CreateOperationRewrite>(op);
+ // If the op was detached, it is most likely a newly created op. Add it the
+ // set of newly created ops, so that it will be legalized. If this op is
+ // not a newly created op, it will be legalized a second time, which is
+ // inefficient but harmless.
patternNewOps.insert(op);
+
+ if (config.allowPatternRollback) {
+ // TODO: If the same op is inserted multiple times from a detached
+ // state, the rollback mechanism may erase the same op multiple times.
+ // This is a bug in the rollback-based dialect conversion driver.
+ appendRewrite<CreateOperationRewrite>(op);
+ } else {
+ // In "no rollback" mode, there is an extra data structure for tracking
+ // erased operations that must be kept up to date.
+ erasedOps.erase(op);
+ }
return;
}
// The op was moved from one place to another.
- appendRewrite<MoveOperationRewrite>(op, previous);
+ if (config.allowPatternRollback)
+ appendRewrite<MoveOperationRewrite>(op, previous);
+}
+
+/// Given that `fromRange` is about to be replaced with `toRange`, compute
+/// replacement values with the types of `fromRange`.
+static SmallVector<Value>
+getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange,
+ const SmallVector<SmallVector<Value>> &toRange,
+ const TypeConverter *converter) {
+ assert(!impl.config.allowPatternRollback &&
+ "this code path is valid only in 'no rollback' mode");
+ SmallVector<Value> repls;
+ for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
+ if (from.use_empty()) {
+ // The replaced value is dead. No replacement value is needed.
+ repls.push_back(Value());
+ continue;
+ }
+
+ if (to.empty()) {
+ // The replaced value is dropped. Materialize a replacement value "out of
+ // thin air".
+ Value srcMat = impl.buildUnresolvedMaterialization(
+ MaterializationKind::Source, computeInsertPoint(from), from.getLoc(),
+ /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
+ /*outputTypes=*/from.getType(), /*originalType=*/Type(),
+ converter)[0];
+ repls.push_back(srcMat);
+ continue;
+ }
+
+ if (TypeRange(ValueRange(to)) == TypeRange(from.getType())) {
+ // The replacement value already has the correct type. Use it directly.
+ repls.push_back(to[0]);
+ continue;
+ }
+
+ // The replacement value has the wrong type. Build a source materialization
+ // to the original type.
+ // TODO: This is a bit inefficient. We should try to reuse existing
+ // materializations if possible. This would require an extension of the
+ // `lookupOrDefault` API.
+ Value srcMat = impl.buildUnresolvedMaterialization(
+ MaterializationKind::Source, computeInsertPoint(to), from.getLoc(),
+ /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(),
+ /*originalType=*/Type(), converter)[0];
+ repls.push_back(srcMat);
+ }
+
+ return repls;
}
void ConversionPatternRewriterImpl::replaceOp(
Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
- assert(newValues.size() == op->getNumResults());
+ assert(newValues.size() == op->getNumResults() &&
+ "incorrect number of replacement values");
+
+ if (!config.allowPatternRollback) {
+ // Pattern rollback is not allowed: materialize all IR changes immediately.
+ SmallVector<Value> repls = getReplacementValues(
+ *this, op->getResults(), newValues, currentTypeConverter);
+ // Update internal data structures, so that there are no dangling pointers
+ // to erased IR.
+ op->walk([&](Operation *op) {
+ erasedOps.insert(op);
+ ignoredOps.remove(op);
+ if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
+ unresolvedMaterializations.erase(castOp);
+ patternMaterializations.erase(castOp);
+ }
+ // The original op will be erased, so remove it from the set of
+ // unlegalized ops.
+ if (config.unlegalizedOps)
+ config.unlegalizedOps->erase(op);
+ });
+ op->walk([&](Block *block) { erasedBlocks.insert(block); });
+ // Replace the op with the replacement values and notify the listener.
+ notifyingRewriter.replaceOp(op, repls);
+ return;
+ }
+
assert(!ignoredOps.contains(op) && "operation was already replaced");
// Check if replaced op is an unresolved materialization, i.e., an
@@ -1704,8 +1858,7 @@ void ConversionPatternRewriterImpl::replaceOp(
MaterializationKind::Source, computeInsertPoint(result),
result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
/*outputTypes=*/result.getType(), /*originalType=*/Type(),
- currentTypeConverter, /*castOp=*/nullptr,
- /*isPureTypeConversion=*/false);
+ currentTypeConverter, /*isPureTypeConversion=*/false);
continue;
}
@@ -1722,11 +1875,59 @@ void ConversionPatternRewriterImpl::replaceOp(
void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
BlockArgument from, ValueRange to, const TypeConverter *converter) {
+ if (!config.allowPatternRollback) {
+ SmallVector<Value> toConv = llvm::to_vector(to);
+ SmallVector<Value> repls =
+ getReplacementValues(*this, from, {toConv}, converter);
+ IRRewriter r(from.getContext());
+ Value repl = repls.front();
+ if (!repl)
+ return;
+
+ performReplaceBlockArg(r, from, repl);
+ return;
+ }
+
+#ifndef NDEBUG
+ // Make sure that a block argument is not replaced multiple times. In
+ // rollback mode, `replaceUsesOfBlockArgument` replaces not only all current
+ // uses of the given block argument, but also all future uses that may be
+ // introduced by future pattern applications. Therefore, it does not make
+ // sense to call `replaceUsesOfBlockArgument` multiple times with the same
+ // block argument. Doing so would overwrite the mapping and mess with the
+ // internal state of the dialect conversion driver.
+ assert(!replacedArgs.contains(from) &&
+ "attempting to replace a block argument that was already replaced");
+ replacedArgs.insert(from);
+#endif // NDEBUG
+
appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
mapping.map(from, to);
}
void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
+ if (!config.allowPatternRollback) {
+ // Pattern rollback is not allowed: materialize all IR changes immediately.
+ // Update internal data structures, so that there are no dangling pointers
+ // to erased IR.
+ block->walk([&](Operation *op) {
+ erasedOps.insert(op);
+ ignoredOps.remove(op);
+ if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
+ unresolvedMaterializations.erase(castOp);
+ patternMaterializations.erase(castOp);
+ }
+ // The original op will be erased, so remove it from the set of
+ // unlegalized ops.
+ if (config.unlegalizedOps)
+ config.unlegalizedOps->erase(op);
+ });
+ block->walk([&](Block *block) { erasedBlocks.insert(block); });
+ // Erase the block and notify the listener.
+ notifyingRewriter.eraseBlock(block);
+ return;
+ }
+
assert(!wasOpReplaced(block->getParentOp()) &&
"attempting to erase a block within a replaced/erased op");
appendRewrite<EraseBlockRewrite>(block);
@@ -1760,23 +1961,37 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
logger.getOStream() << " (was detached)";
logger.getOStream() << "\n";
});
- assert(!wasOpReplaced(newParentOp) &&
+
+ // In rollback mode, it is easier to misuse the API, so perform extra error
+ // checking.
+ assert(!(config.allowPatternRollback && wasOpReplaced(newParentOp)) &&
"attempting to insert into a region within a replaced/erased op");
(void)newParentOp;
+ // In "no rollback" mode, the listener is always notified immediately.
+ if (!config.allowPatternRollback && config.listener)
+ config.listener->notifyBlockInserted(block, previous, previousIt);
+
patternInsertedBlocks.insert(block);
if (wasDetached) {
// If the block was detached, it is most likely a newly created block.
- // TODO: If the same block is inserted multiple times from a detached state,
- // the rollback mechanism may erase the same block multiple times. This is a
- // bug in the rollback-based dialect conversion driver.
- appendRewrite<CreateBlockRewrite>(block);
+ if (config.allowPatternRollback) {
+ // TODO: If the same block is inserted multiple times from a detached
+ // state, the rollback mechanism may erase the same block multiple times.
+ // This is a bug in the rollback-based dialect conversion driver.
+ appendRewrite<CreateBlockRewrite>(block);
+ } else {
+ // In "no rollback" mode, there is an extra data structure for tracking
+ // erased blocks that must be kept up to date.
+ erasedBlocks.erase(block);
+ }
return;
}
// The block was moved from one place to another.
- appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
+ if (config.allowPatternRollback)
+ appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
}
void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source,
@@ -1803,7 +2018,7 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
ConversionPatternRewriter::ConversionPatternRewriter(
MLIRContext *ctx, const ConversionConfig &config)
: PatternRewriter(ctx),
- impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
+ impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
setListener(impl.get());
}
@@ -1880,7 +2095,7 @@ Block *ConversionPatternRewriter::applySignatureConversion(
assert(!impl->wasOpReplaced(block->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
- return impl->applySignatureConversion(*this, block, converter, conversion);
+ return impl->applySignatureConversion(block, converter, conversion);
}
FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
@@ -1889,7 +2104,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
assert(!impl->wasOpReplaced(region->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
- return impl->convertRegionTypes(*this, region, converter, entryConversion);
+ return impl->convertRegionTypes(region, converter, entryConversion);
}
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
@@ -1908,7 +2123,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
Value ConversionPatternRewriter::getRemappedValue(Value key) {
SmallVector<ValueVector> remappedValues;
- if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
+ if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, key,
remappedValues)))
return nullptr;
assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
@@ -1921,7 +2136,7 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
if (keys.empty())
return success();
SmallVector<ValueVector> remapped;
- if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
+ if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, keys,
remapped)))
return failure();
for (const auto &values : remapped) {
@@ -1956,7 +2171,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
// a bit more efficient, so we try to do that when possible.
bool fastPath = !getConfig().listener;
- if (fastPath)
+ if (fastPath && impl->config.allowPatternRollback)
impl->inlineBlockBefore(source, dest, before);
// Replace all uses of block arguments.
@@ -1982,6 +2197,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
}
void ConversionPatternRewriter::startOpModification(Operation *op) {
+ if (!impl->config.allowPatternRollback) {
+ // Pattern rollback is not allowed: no extra bookkeeping is needed.
+ PatternRewriter::startOpModification(op);
+ return;
+ }
assert(!impl->wasOpReplaced(op) &&
"attempting to modify a replaced/erased op");
#ifndef NDEBUG
@@ -1991,20 +2211,29 @@ void ConversionPatternRewriter::startOpModification(Operation *op) {
}
void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
- assert(!impl->wasOpReplaced(op) &&
- "attempting to modify a replaced/erased op");
- PatternRewriter::finalizeOpModification(op);
impl->patternModifiedOps.insert(op);
+ if (!impl->config.allowPatternRollback) {
+ PatternRewriter::finalizeOpModification(op);
+ if (getConfig().listener)
+ getConfig().listener->notifyOperationModified(op);
+ return;
+ }
// There is nothing to do here, we only need to track the operation at the
// start of the update.
#ifndef NDEBUG
+ assert(!impl->wasOpReplaced(op) &&
+ "attempting to modify a replaced/erased op");
assert(impl->pendingRootUpdates.erase(op) &&
"operation did not have a pending in-place update");
#endif
}
void ConversionPatternRewriter::cancelOpModification(Operation *op) {
+ if (!impl->config.allowPatternRollback) {
+ PatternRewriter::cancelOpModification(op);
+ return;
+ }
#ifndef NDEBUG
assert(impl->pendingRootUpdates.erase(op) &&
"operation did not have a pending in-place update");
@@ -2029,17 +2258,17 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
// ConversionPattern
//===----------------------------------------------------------------------===//
-SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands(
+FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
ArrayRef<ValueRange> operands) const {
SmallVector<Value> oneToOneOperands;
oneToOneOperands.reserve(operands.size());
for (ValueRange operand : operands) {
if (operand.size() != 1)
- llvm::report_fatal_error("pattern '" + getDebugName() +
- "' does not support 1:N conversion");
+ return failure();
+
oneToOneOperands.push_back(operand.front());
}
- return oneToOneOperands;
+ return std::move(oneToOneOperands);
}
LogicalResult
@@ -2054,7 +2283,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
// Remap the operands of the operation.
SmallVector<ValueVector> remapped;
- if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
+ if (failed(rewriterImpl.remapValues("operand", op->getLoc(),
op->getOperands(), remapped))) {
return failure();
}
@@ -2076,7 +2305,8 @@ class OperationLegalizer {
public:
using LegalizationAction = ConversionTarget::LegalizationAction;
- OperationLegalizer(const ConversionTarget &targetInfo,
+ OperationLegalizer(ConversionPatternRewriter &rewriter,
+ const ConversionTarget &targetInfo,
const FrozenRewritePatternSet &patterns);
/// Returns true if the given operation is known to be illegal on the target.
@@ -2084,29 +2314,25 @@ public:
/// Attempt to legalize the given operation. Returns success if the operation
/// was legalized, failure otherwise.
- LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
+ LogicalResult legalize(Operation *op);
/// Returns the conversion target in use by the legalizer.
const ConversionTarget &getTarget() { return target; }
private:
/// Attempt to legalize the given operation by folding it.
- LogicalResult legalizeWithFold(Operation *op,
- ConversionPatternRewriter &rewriter);
+ LogicalResult legalizeWithFold(Operation *op);
/// Attempt to legalize the given operation by applying a pattern. Returns
/// success if the operation was legalized, failure otherwise.
- LogicalResult legalizeWithPattern(Operation *op,
- ConversionPatternRewriter &rewriter);
+ LogicalResult legalizeWithPattern(Operation *op);
/// Return true if the given pattern may be applied to the given operation,
/// false otherwise.
- bool canApplyPattern(Operation *op, const Pattern &pattern,
- ConversionPatternRewriter &rewriter);
+ bool canApplyPattern(Operation *op, const Pattern &pattern);
/// Legalize the resultant IR after successfully applying the given pattern.
LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
- ConversionPatternRewriter &rewriter,
const RewriterState &curState,
const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
@@ -2115,18 +2341,12 @@ private:
/// Legalizes the actions registered during the execution of a pattern.
LogicalResult
legalizePatternBlockRewrites(Operation *op,
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
const SetVector<Block *> &insertedBlocks,
const SetVector<Operation *> &newOps);
LogicalResult
- legalizePatternCreatedOperations(ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
- const SetVector<Operation *> &newOps);
+ legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
LogicalResult
- legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
- const SetVector<Operation *> &modifiedOps);
+ legalizePatternRootUpdates(const SetVector<Operation *> &modifiedOps);
//===--------------------------------------------------------------------===//
// Cost Model
@@ -2169,6 +2389,9 @@ private:
/// The current set of patterns that have been applied.
SmallPtrSet<const Pattern *, 8> appliedPatterns;
+ /// The rewriter to use when converting operations.
+ ConversionPatternRewriter &rewriter;
+
/// The legalization information provided by the target.
const ConversionTarget &target;
@@ -2177,9 +2400,10 @@ private:
};
} // namespace
-OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
+OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
+ const ConversionTarget &targetInfo,
const FrozenRewritePatternSet &patterns)
- : target(targetInfo), applicator(patterns) {
+ : rewriter(rewriter), target(targetInfo), applicator(patterns) {
// The set of patterns that can be applied to illegal operations to transform
// them into legal ones.
DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
@@ -2193,9 +2417,7 @@ bool OperationLegalizer::isIllegal(Operation *op) const {
return target.isIllegal(op);
}
-LogicalResult
-OperationLegalizer::legalize(Operation *op,
- ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalize(Operation *op) {
#ifndef NDEBUG
const char *logLineComment =
"//===-------------------------------------------===//\n";
@@ -2257,19 +2479,21 @@ OperationLegalizer::legalize(Operation *op,
return success();
}
- // If the operation isn't legal, try to fold it in-place.
- // TODO: Should we always try to do this, even if the op is
- // already legal?
- if (succeeded(legalizeWithFold(op, rewriter))) {
- LLVM_DEBUG({
- logSuccess(logger, "operation was folded");
- logger.startLine() << logLineComment;
- });
- return success();
+ // If the operation is not legal, try to fold it in-place if the folding mode
+ // is 'BeforePatterns'. 'Never' will skip this.
+ const ConversionConfig &config = rewriter.getConfig();
+ if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
+ if (succeeded(legalizeWithFold(op))) {
+ LLVM_DEBUG({
+ logSuccess(logger, "operation was folded");
+ logger.startLine() << logLineComment;
+ });
+ return success();
+ }
}
// Otherwise, we need to apply a legalization pattern to this operation.
- if (succeeded(legalizeWithPattern(op, rewriter))) {
+ if (succeeded(legalizeWithPattern(op))) {
LLVM_DEBUG({
logSuccess(logger, "");
logger.startLine() << logLineComment;
@@ -2277,6 +2501,18 @@ OperationLegalizer::legalize(Operation *op,
return success();
}
+ // If the operation can't be legalized via patterns, try to fold it in-place
+ // if the folding mode is 'AfterPatterns'.
+ if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
+ if (succeeded(legalizeWithFold(op))) {
+ LLVM_DEBUG({
+ logSuccess(logger, "operation was folded");
+ logger.startLine() << logLineComment;
+ });
+ return success();
+ }
+ }
+
LLVM_DEBUG({
logFailure(logger, "no matched legalization pattern");
logger.startLine() << logLineComment;
@@ -2293,9 +2529,7 @@ static T moveAndReset(T &obj) {
return result;
}
-LogicalResult
-OperationLegalizer::legalizeWithFold(Operation *op,
- ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
auto &rewriterImpl = rewriter.getImpl();
LLVM_DEBUG({
rewriterImpl.logger.startLine() << "* Fold {\n";
@@ -2329,14 +2563,14 @@ OperationLegalizer::legalizeWithFold(Operation *op,
// An empty list of replacement values indicates that the fold was in-place.
// As the operation changed, a new legalization needs to be attempted.
if (replacementValues.empty())
- return legalize(op, rewriter);
+ return legalize(op);
// Insert a replacement for 'op' with the folded replacement values.
rewriter.replaceOp(op, replacementValues);
// Recursively legalize any new constant operations.
for (Operation *newOp : newOps) {
- if (failed(legalize(newOp, rewriter))) {
+ if (failed(legalize(newOp))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger,
"failed to legalize generated constant '{0}'",
newOp->getName()));
@@ -2381,9 +2615,7 @@ reportNewIrLegalizationFatalError(const Pattern &pattern,
llvm::join(insertedBlockNames, ", ") + "}");
}
-LogicalResult
-OperationLegalizer::legalizeWithPattern(Operation *op,
- ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
auto &rewriterImpl = rewriter.getImpl();
const ConversionConfig &config = rewriter.getConfig();
@@ -2415,7 +2647,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
// Functor that returns if the given pattern may be applied.
auto canApply = [&](const Pattern &pattern) {
- bool canApply = canApplyPattern(op, pattern, rewriter);
+ bool canApply = canApplyPattern(op, pattern);
if (canApply && config.listener)
config.listener->notifyPatternBegin(pattern, op);
return canApply;
@@ -2425,17 +2657,23 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
RewriterState curState = rewriterImpl.getCurrentState();
auto onFailure = [&](const Pattern &pattern) {
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
-#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (!rewriterImpl.config.allowPatternRollback) {
- // Returning "failure" after modifying IR is not allowed.
+ // Erase all unresolved materializations.
+ for (auto op : rewriterImpl.patternMaterializations) {
+ rewriterImpl.unresolvedMaterializations.erase(op);
+ op.erase();
+ }
+ rewriterImpl.patternMaterializations.clear();
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ // Expensive pattern check that can detect API violations.
if (checkOp) {
OperationFingerPrint fingerPrintAfterPattern(checkOp);
if (fingerPrintAfterPattern != *topLevelFingerPrint)
llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
"' returned failure but IR did change");
}
- }
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ }
rewriterImpl.patternNewOps.clear();
rewriterImpl.patternModifiedOps.clear();
rewriterImpl.patternInsertedBlocks.clear();
@@ -2459,12 +2697,22 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
// successfully applied.
auto onSuccess = [&](const Pattern &pattern) {
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
+ if (!rewriterImpl.config.allowPatternRollback) {
+ // Eagerly erase unused materializations.
+ for (auto op : rewriterImpl.patternMaterializations) {
+ if (op->use_empty()) {
+ rewriterImpl.unresolvedMaterializations.erase(op);
+ op.erase();
+ }
+ }
+ rewriterImpl.patternMaterializations.clear();
+ }
SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
SetVector<Operation *> modifiedOps =
moveAndReset(rewriterImpl.patternModifiedOps);
SetVector<Block *> insertedBlocks =
moveAndReset(rewriterImpl.patternInsertedBlocks);
- auto result = legalizePatternResult(op, pattern, rewriter, curState, newOps,
+ auto result = legalizePatternResult(op, pattern, curState, newOps,
modifiedOps, insertedBlocks);
appliedPatterns.erase(&pattern);
if (failed(result)) {
@@ -2483,8 +2731,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
onSuccess);
}
-bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
- ConversionPatternRewriter &rewriter) {
+bool OperationLegalizer::canApplyPattern(Operation *op,
+ const Pattern &pattern) {
LLVM_DEBUG({
auto &os = rewriter.getImpl().logger;
os.getOStream() << "\n";
@@ -2506,11 +2754,11 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
}
LogicalResult OperationLegalizer::legalizePatternResult(
- Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
- const RewriterState &curState, const SetVector<Operation *> &newOps,
+ Operation *op, const Pattern &pattern, const RewriterState &curState,
+ const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
const SetVector<Block *> &insertedBlocks) {
- auto &impl = rewriter.getImpl();
+ [[maybe_unused]] auto &impl = rewriter.getImpl();
assert(impl.pendingRootUpdates.empty() && "dangling root updates");
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
@@ -2528,10 +2776,9 @@ LogicalResult OperationLegalizer::legalizePatternResult(
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Legalize each of the actions registered during application.
- if (failed(legalizePatternBlockRewrites(op, rewriter, impl, insertedBlocks,
- newOps)) ||
- failed(legalizePatternRootUpdates(rewriter, impl, modifiedOps)) ||
- failed(legalizePatternCreatedOperations(rewriter, impl, newOps))) {
+ if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
+ failed(legalizePatternRootUpdates(modifiedOps)) ||
+ failed(legalizePatternCreatedOperations(newOps))) {
return failure();
}
@@ -2540,15 +2787,17 @@ LogicalResult OperationLegalizer::legalizePatternResult(
}
LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
- Operation *op, ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
- const SetVector<Block *> &insertedBlocks,
+ Operation *op, const SetVector<Block *> &insertedBlocks,
const SetVector<Operation *> &newOps) {
+ ConversionPatternRewriterImpl &impl = rewriter.getImpl();
SmallPtrSet<Operation *, 16> alreadyLegalized;
// If the pattern moved or created any blocks, make sure the types of block
// arguments get legalized.
for (Block *block : insertedBlocks) {
+ if (impl.erasedBlocks.contains(block))
+ continue;
+
// Only check blocks outside of the current operation.
Operation *parentOp = block->getParentOp();
if (!parentOp || parentOp == op || block->getNumArguments() == 0)
@@ -2564,7 +2813,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
"block"));
return failure();
}
- impl.applySignatureConversion(rewriter, block, converter, *conversion);
+ impl.applySignatureConversion(block, converter, *conversion);
continue;
}
@@ -2573,7 +2822,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
// operation, and blocks in regions created by this pattern will already be
// legalized later on.
if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
- if (failed(legalize(parentOp, rewriter))) {
+ if (failed(legalize(parentOp))) {
LLVM_DEBUG(logFailure(
impl.logger, "operation '{0}'({1}) became illegal after rewrite",
parentOp->getName(), parentOp));
@@ -2585,11 +2834,10 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
}
LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
- ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
const SetVector<Operation *> &newOps) {
for (Operation *op : newOps) {
- if (failed(legalize(op, rewriter))) {
- LLVM_DEBUG(logFailure(impl.logger,
+ if (failed(legalize(op))) {
+ LLVM_DEBUG(logFailure(rewriter.getImpl().logger,
"failed to legalize generated operation '{0}'({1})",
op->getName(), op));
return failure();
@@ -2599,13 +2847,13 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
}
LogicalResult OperationLegalizer::legalizePatternRootUpdates(
- ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
const SetVector<Operation *> &modifiedOps) {
for (Operation *op : modifiedOps) {
- if (failed(legalize(op, rewriter))) {
- LLVM_DEBUG(logFailure(
- impl.logger, "failed to legalize operation updated in-place '{0}'",
- op->getName()));
+ if (failed(legalize(op))) {
+ LLVM_DEBUG(
+ logFailure(rewriter.getImpl().logger,
+ "failed to legalize operation updated in-place '{0}'",
+ op->getName()));
return failure();
}
}
@@ -2825,21 +3073,22 @@ namespace mlir {
// rewrite patterns. The conversion behaves differently depending on the
// conversion mode.
struct OperationConverter {
- explicit OperationConverter(const ConversionTarget &target,
+ explicit OperationConverter(MLIRContext *ctx, const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
const ConversionConfig &config,
OpConversionMode mode)
- : config(config), opLegalizer(target, patterns), mode(mode) {}
+ : rewriter(ctx, config), opLegalizer(rewriter, target, patterns),
+ mode(mode) {}
/// Converts the given operations to the conversion target.
LogicalResult convertOperations(ArrayRef<Operation *> ops);
private:
/// Converts an operation with the given rewriter.
- LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
+ LogicalResult convert(Operation *op);
- /// Dialect conversion configuration.
- ConversionConfig config;
+ /// The rewriter to use when converting operations.
+ ConversionPatternRewriter rewriter;
/// The legalizer to use when converting operations.
OperationLegalizer opLegalizer;
@@ -2849,10 +3098,11 @@ private:
};
} // namespace mlir
-LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
- Operation *op) {
+LogicalResult OperationConverter::convert(Operation *op) {
+ const ConversionConfig &config = rewriter.getConfig();
+
// Legalize the given operation.
- if (failed(opLegalizer.legalize(op, rewriter))) {
+ if (failed(opLegalizer.legalize(op))) {
// Handle the case of a failed conversion for each of the different modes.
// Full conversions expect all operations to be converted.
if (mode == OpConversionMode::Full)
@@ -2928,7 +3178,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
}
LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
- assert(!ops.empty() && "expected at least one operation");
const ConversionTarget &target = opLegalizer.getTarget();
// Compute the set of operations and blocks to convert.
@@ -2947,11 +3196,10 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
}
// Convert each operation and discard rewrites on failure.
- ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
for (auto *op : toConvert) {
- if (failed(convert(rewriter, op))) {
+ if (failed(convert(op))) {
// Dialect conversion failed.
if (rewriterImpl.config.allowPatternRollback) {
// Rollback is allowed: restore the original IR.
@@ -2986,13 +3234,16 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
castOp->removeAttr(kPureTypeConversionMarker);
// Try to legalize all unresolved materializations.
- if (config.buildMaterializations) {
- IRRewriter rewriter(rewriterImpl.context, config.listener);
+ if (rewriter.getConfig().buildMaterializations) {
+ // Use a new rewriter, so the modifications are not tracked for rollback
+ // purposes etc.
+ IRRewriter irRewriter(rewriterImpl.rewriter.getContext(),
+ rewriter.getConfig().listener);
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
auto it = materializations.find(castOp);
assert(it != materializations.end() && "inconsistent state");
- if (failed(
- legalizeUnresolvedMaterialization(rewriter, castOp, it->second)))
+ if (failed(legalizeUnresolvedMaterialization(irRewriter, castOp,
+ it->second)))
return failure();
}
}
@@ -3159,6 +3410,27 @@ LogicalResult TypeConverter::convertType(Type t,
return failure();
}
+LogicalResult TypeConverter::convertType(Value v,
+ SmallVectorImpl<Type> &results) const {
+ assert(v && "expected non-null value");
+
+ // If this type converter does not have context-aware type conversions, call
+ // the type-based overload, which has caching.
+ if (!hasContextAwareTypeConversions)
+ return convertType(v.getType(), results);
+
+ // Walk the added converters in reverse order to apply the most recently
+ // registered first.
+ for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
+ if (std::optional<LogicalResult> result = converter(v, results)) {
+ if (!succeeded(*result))
+ return failure();
+ return success();
+ }
+ }
+ return failure();
+}
+
Type TypeConverter::convertType(Type t) const {
// Use the multi-type result version to convert the type.
SmallVector<Type, 1> results;
@@ -3169,6 +3441,16 @@ Type TypeConverter::convertType(Type t) const {
return results.size() == 1 ? results.front() : nullptr;
}
+Type TypeConverter::convertType(Value v) const {
+ // Use the multi-type result version to convert the type.
+ SmallVector<Type, 1> results;
+ if (failed(convertType(v, results)))
+ return nullptr;
+
+ // Check to ensure that only one type was produced.
+ return results.size() == 1 ? results.front() : nullptr;
+}
+
LogicalResult
TypeConverter::convertTypes(TypeRange types,
SmallVectorImpl<Type> &results) const {
@@ -3178,21 +3460,38 @@ TypeConverter::convertTypes(TypeRange types,
return success();
}
+LogicalResult
+TypeConverter::convertTypes(ValueRange values,
+ SmallVectorImpl<Type> &results) const {
+ for (Value value : values)
+ if (failed(convertType(value, results)))
+ return failure();
+ return success();
+}
+
bool TypeConverter::isLegal(Type type) const {
return convertType(type) == type;
}
+
+bool TypeConverter::isLegal(Value value) const {
+ return convertType(value) == value.getType();
+}
+
bool TypeConverter::isLegal(Operation *op) const {
- return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
+ return isLegal(op->getOperands()) && isLegal(op->getResults());
}
bool TypeConverter::isLegal(Region *region) const {
- return llvm::all_of(*region, [this](Block &block) {
- return isLegal(block.getArgumentTypes());
- });
+ return llvm::all_of(
+ *region, [this](Block &block) { return isLegal(block.getArguments()); });
}
bool TypeConverter::isSignatureLegal(FunctionType ty) const {
- return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
+ if (!isLegal(ty.getInputs()))
+ return false;
+ if (!isLegal(ty.getResults()))
+ return false;
+ return true;
}
LogicalResult
@@ -3220,6 +3519,31 @@ TypeConverter::convertSignatureArgs(TypeRange types,
return failure();
return success();
}
+LogicalResult
+TypeConverter::convertSignatureArg(unsigned inputNo, Value value,
+ SignatureConversion &result) const {
+ // Try to convert the given input type.
+ SmallVector<Type, 1> convertedTypes;
+ if (failed(convertType(value, convertedTypes)))
+ return failure();
+
+ // If this argument is being dropped, there is nothing left to do.
+ if (convertedTypes.empty())
+ return success();
+
+ // Otherwise, add the new inputs.
+ result.addInputs(inputNo, convertedTypes);
+ return success();
+}
+LogicalResult
+TypeConverter::convertSignatureArgs(ValueRange values,
+ SignatureConversion &result,
+ unsigned origInputOffset) const {
+ for (unsigned i = 0, e = values.size(); i != e; ++i)
+ if (failed(convertSignatureArg(origInputOffset + i, values[i], result)))
+ return failure();
+ return success();
+}
Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
Location loc, Type resultType,
@@ -3263,7 +3587,7 @@ SmallVector<Value> TypeConverter::materializeTargetConversion(
std::optional<TypeConverter::SignatureConversion>
TypeConverter::convertBlockSignature(Block *block) const {
SignatureConversion conversion(block->getNumArguments());
- if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
+ if (failed(convertSignatureArgs(block->getArguments(), conversion)))
return std::nullopt;
return conversion;
}
@@ -3388,7 +3712,7 @@ mlir::convertOpResultTypes(Operation *op, ValueRange operands,
newOp.addOperands(operands);
SmallVector<Type> newResultTypes;
- if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
+ if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
newOp.addTypes(newResultTypes);
@@ -3661,7 +3985,8 @@ static LogicalResult applyConversion(ArrayRef<Operation *> ops,
SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
ctx->executeAction<ApplyConversionAction>(
[&] {
- OperationConverter opConverter(target, patterns, config, mode);
+ OperationConverter opConverter(ops.front()->getContext(), target,
+ patterns, config, mode);
status = opConverter.convertOperations(ops);
},
irUnits);
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 607b86c..0324588 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -15,6 +15,8 @@
#include "mlir/Config/mlir-config.h"
#include "mlir/IR/Action.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Rewrite/PatternApplicator.h"
@@ -23,7 +25,7 @@
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/ScopeExit.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ScopedPrinter.h"
#include "llvm/Support/raw_ostream.h"
@@ -178,9 +180,8 @@ static Operation *getDumpRootOp(Operation *op) {
return op;
}
static void logSuccessfulFolding(Operation *op) {
- llvm::dbgs() << "// *** IR Dump After Successful Folding ***\n";
- op->dump();
- llvm::dbgs() << "\n\n";
+ LDBG() << "// *** IR Dump After Successful Folding ***\n"
+ << OpWithFlags(op, OpPrintingFlags().elideLargeElementsAttrs());
}
#endif // NDEBUG
@@ -394,8 +395,12 @@ private:
function_ref<void(Diagnostic &)> reasonCallback) override;
#ifndef NDEBUG
+ /// A raw output stream used to prefix the debug log.
+
+ llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(),
+ llvm::dbgs()};
/// A logger used to emit information during the application process.
- llvm::ScopedPrinter logger{llvm::dbgs()};
+ llvm::ScopedPrinter logger{os};
#endif
/// The low-level pattern applicator.
@@ -871,7 +876,18 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
ctx->executeAction<GreedyPatternRewriteIteration>(
[&] {
- continueRewrites = processWorklist();
+ continueRewrites = false;
+
+ // Erase unreachable blocks
+ // Operations like:
+ // %add = arith.addi %add, %add : i64
+ // are legal in unreachable code. Unfortunately many patterns would be
+ // unsafe to apply on such IR and can lead to crashes or infinite
+ // loops.
+ continueRewrites |=
+ succeeded(eraseUnreachableBlocks(rewriter, region));
+
+ continueRewrites |= processWorklist();
// After applying patterns, make sure that the CFG of each of the
// regions is kept up to date.
@@ -917,10 +933,9 @@ mlir::applyPatternsGreedily(Region &region,
RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
region);
LogicalResult converged = std::move(driver).simplify(changed);
- LLVM_DEBUG(if (failed(converged)) {
- llvm::dbgs() << "The pattern rewrite did not converge after scanning "
- << config.getMaxIterations() << " times\n";
- });
+ if (failed(converged))
+ LDBG() << "The pattern rewrite did not converge after scanning "
+ << config.getMaxIterations() << " times";
return converged;
}
@@ -1052,9 +1067,8 @@ LogicalResult mlir::applyOpPatternsGreedily(
LogicalResult converged = std::move(driver).simplify(ops, changed);
if (allErased)
*allErased = surviving.empty();
- LLVM_DEBUG(if (failed(converged)) {
- llvm::dbgs() << "The pattern rewrite did not converge after "
- << config.getMaxNumRewrites() << " rewrites";
- });
+ if (failed(converged))
+ LDBG() << "The pattern rewrite did not converge after "
+ << config.getMaxNumRewrites() << " rewrites";
return converged;
}
diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp
index eeb4052..73107cf 100644
--- a/mlir/lib/Transforms/Utils/InliningUtils.cpp
+++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp
@@ -13,10 +13,12 @@
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
@@ -182,13 +184,16 @@ static bool isLegalToInline(InlinerInterface &interface, Region *src,
IRMapping &valueMapping) {
for (auto &block : *src) {
for (auto &op : block) {
+ // UnrealizedConversionCastOp is inlineable but cannot implement the
+ // inliner interface due to layering constraints.
+ if (isa<UnrealizedConversionCastOp>(op))
+ continue;
+
// Check this operation.
if (!interface.isLegalToInline(&op, insertRegion,
shouldCloneInlinedRegion, valueMapping)) {
- LLVM_DEBUG({
- llvm::dbgs() << "* Illegal to inline because of op: ";
- op.dump();
- });
+ LDBG() << "* Illegal to inline because of op: "
+ << OpWithFlags(&op, OpPrintingFlags().skipRegions());
return false;
}
// Check any nested regions.
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index cb3f2c5..111f58e 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -13,11 +13,13 @@
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include <queue>
#define DEBUG_TYPE "licm"
@@ -64,8 +66,7 @@ size_t mlir::moveLoopInvariantCode(
size_t numMoved = 0;
for (Region *region : regions) {
- LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
- << *region->getParentOp() << "\n");
+ LDBG() << "Original loop:\n" << *region->getParentOp();
std::queue<Operation *> worklist;
// Add top-level operations in the loop body to the worklist.
@@ -83,12 +84,13 @@ size_t mlir::moveLoopInvariantCode(
if (op->getParentRegion() != region)
continue;
- LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
+ LDBG() << "Checking op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
if (!shouldMoveOutOfRegion(op, region) ||
!canBeHoisted(op, definedOutside))
continue;
- LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
+ LDBG() << "Moving loop-invariant op: " << *op;
moveOutOfRegion(op, region);
++numMoved;
@@ -322,7 +324,7 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
LoopLikeOpInterface loopLike,
BlockArgument iterArg) {
assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
- auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
+ BlockArgument *it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
MatchingSubsets subsets;
if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index a1d975d..31ae1d1 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -23,12 +23,15 @@
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
#include <deque>
#include <iterator>
using namespace mlir;
+#define DEBUG_TYPE "region-utils"
+
void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement,
Region &region) {
for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
@@ -182,19 +185,34 @@ SmallVector<Value> mlir::makeRegionIsolatedFromAbove(
// TODO: We could likely merge this with the DCE algorithm below.
LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter,
MutableArrayRef<Region> regions) {
+ LDBG() << "Starting eraseUnreachableBlocks with " << regions.size()
+ << " regions";
+
// Set of blocks found to be reachable within a given region.
llvm::df_iterator_default_set<Block *, 16> reachable;
// If any blocks were found to be dead.
- bool erasedDeadBlocks = false;
+ int erasedDeadBlocks = 0;
SmallVector<Region *, 1> worklist;
worklist.reserve(regions.size());
for (Region &region : regions)
worklist.push_back(&region);
+
+ LDBG(2) << "Initial worklist size: " << worklist.size();
+
while (!worklist.empty()) {
Region *region = worklist.pop_back_val();
- if (region->empty())
+ if (region->empty()) {
+ LDBG(2) << "Skipping empty region";
continue;
+ }
+
+ LDBG(2) << "Processing region with " << region->getBlocks().size()
+ << " blocks";
+ if (region->getParentOp())
+ LDBG(2) << " -> for operation: "
+ << OpWithFlags(region->getParentOp(),
+ OpPrintingFlags().skipRegions());
// If this is a single block region, just collect the nested regions.
if (region->hasOneBlock()) {
@@ -209,13 +227,17 @@ LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter,
for (Block *block : depth_first_ext(&region->front(), reachable))
(void)block /* Mark all reachable blocks */;
+ LDBG(2) << "Found " << reachable.size() << " reachable blocks out of "
+ << region->getBlocks().size() << " total blocks";
+
// Collect all of the dead blocks and push the live regions onto the
// worklist.
for (Block &block : llvm::make_early_inc_range(*region)) {
if (!reachable.count(&block)) {
+ LDBG() << "Erasing unreachable block: " << &block;
block.dropAllDefinedValueUses();
rewriter.eraseBlock(&block);
- erasedDeadBlocks = true;
+ ++erasedDeadBlocks;
continue;
}
@@ -226,7 +248,10 @@ LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter,
}
}
- return success(erasedDeadBlocks);
+ LDBG() << "Finished eraseUnreachableBlocks, erased " << erasedDeadBlocks
+ << " dead blocks";
+
+ return success(erasedDeadBlocks > 0);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index ee5c642..1382550 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -13,18 +13,40 @@
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Rewrite/PatternApplicator.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"
#define DEBUG_TYPE "walk-rewriter"
namespace mlir {
+// Find all reachable blocks in the region and add them to the visitedBlocks
+// set.
+static void findReachableBlocks(Region &region,
+ DenseSet<Block *> &reachableBlocks) {
+ Block *entryBlock = &region.front();
+ reachableBlocks.insert(entryBlock);
+ // Traverse the CFG and add all reachable blocks to the blockList.
+ SmallVector<Block *> worklist({entryBlock});
+ while (!worklist.empty()) {
+ Block *block = worklist.pop_back_val();
+ Operation *terminator = &block->back();
+ for (Block *successor : terminator->getSuccessors()) {
+ if (reachableBlocks.contains(successor))
+ continue;
+ worklist.push_back(successor);
+ reachableBlocks.insert(successor);
+ }
+ }
+}
+
namespace {
struct WalkAndApplyPatternsAction final
: tracing::ActionImpl<WalkAndApplyPatternsAction> {
@@ -88,20 +110,104 @@ void walkAndApplyPatterns(Operation *op,
PatternApplicator applicator(patterns);
applicator.applyDefaultCostModel();
+ // Iterator on all reachable operations in the region.
+ // Also keep track if we visited the nested regions of the current op
+ // already to drive the post-order traversal.
+ struct RegionReachableOpIterator {
+ RegionReachableOpIterator(Region *region) : region(region) {
+ regionIt = region->begin();
+ if (regionIt != region->end())
+ blockIt = regionIt->begin();
+ if (!llvm::hasSingleElement(*region))
+ findReachableBlocks(*region, reachableBlocks);
+ }
+ // Advance the iterator to the next reachable operation.
+ void advance() {
+ assert(regionIt != region->end());
+ hasVisitedRegions = false;
+ if (blockIt == regionIt->end()) {
+ ++regionIt;
+ while (regionIt != region->end() &&
+ !reachableBlocks.contains(&*regionIt))
+ ++regionIt;
+ if (regionIt != region->end())
+ blockIt = regionIt->begin();
+ return;
+ }
+ ++blockIt;
+ if (blockIt != regionIt->end()) {
+ LDBG() << "Incrementing block iterator, next op: "
+ << OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions());
+ }
+ }
+ // The region we're iterating over.
+ Region *region;
+ // The Block currently being iterated over.
+ Region::iterator regionIt;
+ // The Operation currently being iterated over.
+ Block::iterator blockIt;
+ // The set of blocks that are reachable in the current region.
+ DenseSet<Block *> reachableBlocks;
+ // Whether we've visited the nested regions of the current op already.
+ bool hasVisitedRegions = false;
+ };
+
+ // Worklist of regions to visit to drive the post-order traversal.
+ SmallVector<RegionReachableOpIterator> worklist;
+
+ LDBG() << "Starting walk-based pattern rewrite driver";
ctx->executeAction<WalkAndApplyPatternsAction>(
[&] {
+ // Perform a post-order traversal of the regions, visiting each
+ // reachable operation.
for (Region &region : op->getRegions()) {
- region.walk([&](Operation *visitedOp) {
- LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
- llvm::dbgs(), OpPrintingFlags().skipRegions());
- llvm::dbgs() << "\n";);
+ assert(worklist.empty());
+ if (region.empty())
+ continue;
+
+ // Prime the worklist with the entry block of this region.
+ worklist.push_back({&region});
+ while (!worklist.empty()) {
+ RegionReachableOpIterator &it = worklist.back();
+ if (it.regionIt == it.region->end()) {
+ // We're done with this region.
+ worklist.pop_back();
+ continue;
+ }
+ if (it.blockIt == it.regionIt->end()) {
+ // We're done with this block.
+ it.advance();
+ continue;
+ }
+ Operation *op = &*it.blockIt;
+ // If we haven't visited the nested regions of this op yet,
+ // enqueue them.
+ if (!it.hasVisitedRegions) {
+ it.hasVisitedRegions = true;
+ for (Region &nestedRegion : llvm::reverse(op->getRegions())) {
+ if (nestedRegion.empty())
+ continue;
+ worklist.push_back({&nestedRegion});
+ }
+ }
+ // If we're not at the back of the worklist, we've enqueued some
+ // nested region for processing. We'll come back to this op later
+ // (post-order)
+ if (&it != &worklist.back())
+ continue;
+
+ // Preemptively increment the iterator, in case the current op
+ // would be erased.
+ it.advance();
+
+ LDBG() << "Visiting op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- erasedListener.visitedOp = visitedOp;
+ erasedListener.visitedOp = op;
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
- LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
- }
- });
+ if (succeeded(applicator.matchAndRewrite(op, rewriter)))
+ LDBG() << "\tOp matched and rewritten";
+ }
}
},
{op});