aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp8
-rw-r--r--mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp130
-rw-r--r--mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp83
-rw-r--r--mlir/lib/Analysis/DataFlowFramework.cpp20
-rw-r--r--mlir/lib/Analysis/Presburger/PresburgerRelation.cpp1
-rw-r--r--mlir/lib/AsmParser/DialectSymbolParser.cpp7
-rw-r--r--mlir/lib/AsmParser/Lexer.cpp27
-rw-r--r--mlir/lib/AsmParser/Lexer.h3
-rw-r--r--mlir/lib/AsmParser/TypeParser.cpp1
-rw-r--r--mlir/lib/Bindings/Python/IRAttributes.cpp13
-rw-r--r--mlir/lib/CAPI/RegisterEverything/CMakeLists.txt11
-rw-r--r--mlir/lib/CMakeLists.txt34
-rw-r--r--mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp12
-rw-r--r--mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp101
-rw-r--r--mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp16
-rw-r--r--mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp38
-rw-r--r--mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp1
-rw-r--r--mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp1
-rw-r--r--mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp5
-rw-r--r--mlir/lib/Conversion/CMakeLists.txt2
-rw-r--r--mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp41
-rw-r--r--mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp1
-rw-r--r--mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp11
-rw-r--r--mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp1
-rw-r--r--mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp2
-rw-r--r--mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp1
-rw-r--r--mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp1
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp4
-rw-r--r--mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp24
-rw-r--r--mlir/lib/Conversion/LLVMCommon/Pattern.cpp15
-rw-r--r--mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp2
-rw-r--r--mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp13
-rw-r--r--mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp3
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp78
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp36
-rw-r--r--mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp15
-rw-r--r--mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp53
-rw-r--r--mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp9
-rw-r--r--mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp2
-rw-r--r--mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp2
-rw-r--r--mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp3
-rw-r--r--mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp2
-rw-r--r--mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp12
-rw-r--r--mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp1
-rw-r--r--mlir/lib/Conversion/ShardToMPI/CMakeLists.txt (renamed from mlir/lib/Conversion/MeshToMPI/CMakeLists.txt)8
-rw-r--r--mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp (renamed from mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp)146
-rw-r--r--mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp1
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp134
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp158
-rw-r--r--mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp1
-rw-r--r--mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp75
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp71
-rw-r--r--mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp5
-rw-r--r--mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp64
-rw-r--r--mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp1
-rw-r--r--mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp40
-rw-r--r--mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt4
-rw-r--r--mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp29
-rw-r--r--mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp97
-rw-r--r--mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp55
-rw-r--r--mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp10
-rw-r--r--mlir/lib/Dialect/AMX/IR/AMXDialect.cpp13
-rw-r--r--mlir/lib/Dialect/Affine/IR/AffineOps.cpp8
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp2
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp33
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp24
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp2
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp10
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp6
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp55
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp53
-rw-r--r--mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp138
-rw-r--r--mlir/lib/Dialect/Affine/Utils/Utils.cpp132
-rw-r--r--mlir/lib/Dialect/Arith/IR/ArithDialect.cpp2
-rw-r--r--mlir/lib/Dialect/Arith/IR/ArithOps.cpp2
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp4
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp4
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp4
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp217
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp267
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp10
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp28
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp38
-rw-r--r--mlir/lib/Dialect/Arith/Utils/Utils.cpp98
-rw-r--r--mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp24
-rw-r--r--mlir/lib/Dialect/ArmSME/IR/Utils.cpp18
-rw-r--r--mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp13
-rw-r--r--mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp4
-rw-r--r--mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp179
-rw-r--r--mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp28
-rw-r--r--mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp24
-rw-r--r--mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp79
-rw-r--r--mlir/lib/Dialect/Async/IR/Async.cpp2
-rw-r--r--mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp253
-rw-r--r--mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp16
-rw-r--r--mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp108
-rw-r--r--mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp14
-rw-r--r--mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp28
-rw-r--r--mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp53
-rw-r--r--mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp4
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp27
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp8
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp5
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp12
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp8
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp4
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp9
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp280
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp94
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp16
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp39
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp2
-rw-r--r--mlir/lib/Dialect/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/EmitC/IR/EmitC.cpp14
-rw-r--r--mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp4
-rw-r--r--mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp14
-rw-r--r--mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp2
-rw-r--r--mlir/lib/Dialect/Func/Extensions/CMakeLists.txt8
-rw-r--r--mlir/lib/Dialect/Func/Extensions/ShardingExtensions.cpp (renamed from mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp)8
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp36
-rw-r--r--mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp43
-rw-r--r--mlir/lib/Dialect/GPU/TransformOps/Utils.cpp29
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp10
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp7
-rw-r--r--mlir/lib/Dialect/LLVMIR/CMakeLists.txt3
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp28
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp6
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp1
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/VCIXDialect.cpp5
-rw-r--r--mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp32
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp6
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp20
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp425
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp59
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp8
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp64
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp4
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp6
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt4
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp171
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp79
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp86
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp17
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp26
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp10
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp181
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp96
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp6
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp4
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp14
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp5
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp45
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp9
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp12
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Loops.cpp11
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp17
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp58
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp63
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Padding.cpp23
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp50
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp4
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp (renamed from mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp)187
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Split.cpp4
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp50
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp8
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp33
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp102
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp137
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp13
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp44
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp490
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp264
-rw-r--r--mlir/lib/Dialect/Linalg/Utils/Utils.cpp30
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp8
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp38
-rw-r--r--mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp7
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp11
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp8
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp62
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp17
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp34
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp39
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp34
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp59
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp91
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp19
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp17
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp4
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp6
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp2
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp159
-rw-r--r--mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp66
-rw-r--r--mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp3
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp17
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp3
-rw-r--r--mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp11
-rw-r--r--mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp84
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp143
-rw-r--r--mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp4
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp44
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp17
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp2
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp158
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp18
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp6
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp59
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp20
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp73
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp20
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/WrapInZeroTripCheck.cpp14
-rw-r--r--mlir/lib/Dialect/SCF/Utils/Utils.cpp177
-rw-r--r--mlir/lib/Dialect/SMT/IR/SMTDialect.cpp4
-rw-r--r--mlir/lib/Dialect/SMT/IR/SMTOps.cpp4
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp86
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp118
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp12
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp54
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp15
-rw-r--r--mlir/lib/Dialect/Shape/IR/Shape.cpp58
-rw-r--r--mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp9
-rw-r--r--mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp6
-rw-r--r--mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp8
-rw-r--r--mlir/lib/Dialect/Shard/CMakeLists.txt (renamed from mlir/lib/Dialect/Mesh/CMakeLists.txt)0
-rw-r--r--mlir/lib/Dialect/Shard/IR/CMakeLists.txt (renamed from mlir/lib/Dialect/Mesh/IR/CMakeLists.txt)8
-rw-r--r--mlir/lib/Dialect/Shard/IR/ShardOps.cpp (renamed from mlir/lib/Dialect/Mesh/IR/MeshOps.cpp)529
-rw-r--r--mlir/lib/Dialect/Shard/Interfaces/CMakeLists.txt (renamed from mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt)4
-rw-r--r--mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp (renamed from mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp)270
-rw-r--r--mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt (renamed from mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt)10
-rw-r--r--mlir/lib/Dialect/Shard/Transforms/Partition.cpp (renamed from mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp)413
-rw-r--r--mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp (renamed from mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp)81
-rw-r--r--mlir/lib/Dialect/Shard/Transforms/Simplifications.cpp (renamed from mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp)64
-rw-r--r--mlir/lib/Dialect/Shard/Transforms/Transforms.cpp (renamed from mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp)82
-rw-r--r--mlir/lib/Dialect/Shard/Transforms/TransformsDetail.h (renamed from mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h)10
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp14
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp10
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp4
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp16
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp10
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp3
-rw-r--r--mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp2
-rw-r--r--mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt8
-rw-r--r--mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp (renamed from mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp)38
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp9
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp5
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp7
-rw-r--r--mlir/lib/Dialect/Tosa/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp26
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp15
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp126
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp10
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp7
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp14
-rw-r--r--mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp104
-rw-r--r--mlir/lib/Dialect/Vector/IR/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp141
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp24
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp7
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp10
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp39
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp6
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp37
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp28
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp35
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp130
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp35
-rw-r--r--mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp15
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp103
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp9
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp18
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp10
-rw-r--r--mlir/lib/IR/AffineExpr.cpp2
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp7
-rw-r--r--mlir/lib/IR/Diagnostics.cpp21
-rw-r--r--mlir/lib/IR/Location.cpp4
-rw-r--r--mlir/lib/IR/PDL/PDLPatternMatch.cpp3
-rw-r--r--mlir/lib/IR/PatternLoggingListener.cpp28
-rw-r--r--mlir/lib/IR/PatternMatch.cpp12
-rw-r--r--mlir/lib/IR/SymbolTable.cpp1
-rw-r--r--mlir/lib/IR/Value.cpp2
-rw-r--r--mlir/lib/Parser/Parser.cpp30
-rw-r--r--mlir/lib/Pass/Pass.cpp2
-rw-r--r--mlir/lib/Pass/PassRegistry.cpp1
-rw-r--r--mlir/lib/Query/Matcher/MatchersInternal.cpp1
-rw-r--r--mlir/lib/RegisterAllDialects.cpp207
-rw-r--r--mlir/lib/RegisterAllExtensions.cpp115
-rw-r--r--mlir/lib/RegisterAllPasses.cpp99
-rw-r--r--mlir/lib/Support/ToolUtilities.cpp39
-rw-r--r--mlir/lib/Support/TypeID.cpp3
-rw-r--r--mlir/lib/TableGen/Successor.cpp1
-rw-r--r--mlir/lib/TableGen/Type.cpp1
-rw-r--r--mlir/lib/Target/Cpp/TranslateRegistration.cpp2
-rw-r--r--mlir/lib/Target/Cpp/TranslateToCpp.cpp8
-rw-r--r--mlir/lib/Target/LLVM/CMakeLists.txt3
-rw-r--r--mlir/lib/Target/LLVMIR/CMakeLists.txt1
-rw-r--r--mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp1
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt1
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp50
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp1
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp3
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp59
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt21
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp103
-rw-r--r--mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp10
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleImport.cpp105
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleTranslation.cpp61
-rw-r--r--mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp43
-rw-r--r--mlir/lib/Target/SPIRV/Deserialization/Deserializer.h1
-rw-r--r--mlir/lib/Target/SPIRV/Serialization/Serializer.cpp72
-rw-r--r--mlir/lib/Tools/PDLL/AST/NodePrinter.cpp1
-rw-r--r--mlir/lib/Tools/PDLL/ODS/Operation.cpp2
-rw-r--r--mlir/lib/Tools/lsp-server-support/Protocol.cpp10
-rw-r--r--mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp2
-rw-r--r--mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp1
-rw-r--r--mlir/lib/Tools/mlir-lsp-server/Protocol.cpp7
-rw-r--r--mlir/lib/Tools/mlir-opt/MlirOptMain.cpp55
-rw-r--r--mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp1
-rw-r--r--mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp10
-rw-r--r--mlir/lib/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Transforms/CSE.cpp1
-rw-r--r--mlir/lib/Transforms/Canonicalizer.cpp1
-rw-r--r--mlir/lib/Transforms/OpStats.cpp2
-rw-r--r--mlir/lib/Transforms/RemoveDeadValues.cpp65
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp227
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp1
-rw-r--r--mlir/lib/Transforms/Utils/Inliner.cpp33
329 files changed, 7393 insertions, 6080 deletions
diff --git a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
index 51fa773..fb5649e 100644
--- a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
@@ -16,6 +16,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include <cassert>
#define DEBUG_TYPE "constant-propagation"
@@ -46,7 +47,7 @@ void ConstantValue::print(raw_ostream &os) const {
LogicalResult SparseConstantPropagation::visitOperation(
Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
ArrayRef<Lattice<ConstantValue> *> results) {
- LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n");
+ LDBG() << "SCP: Visiting operation: " << *op;
// Don't try to simulate the results of a region operation as we can't
// guarantee that folding will be out-of-place. We don't allow in-place
@@ -98,12 +99,11 @@ LogicalResult SparseConstantPropagation::visitOperation(
// Merge in the result of the fold, either a constant or a value.
OpFoldResult foldResult = std::get<1>(it);
if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(foldResult)) {
- LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
+ LDBG() << "Folded to constant: " << attr;
propagateIfChanged(lattice,
lattice->join(ConstantValue(attr, op->getDialect())));
} else {
- LLVM_DEBUG(llvm::dbgs()
- << "Folded to value: " << cast<Value>(foldResult) << "\n");
+ LDBG() << "Folded to value: " << cast<Value>(foldResult);
AbstractSparseForwardDataFlowAnalysis::join(
lattice, *getLatticeElement(cast<Value>(foldResult)));
}
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index 1abdfcb..10874fd 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -23,12 +23,11 @@
#include "mlir/Support/LLVM.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include <cassert>
#include <optional>
#define DEBUG_TYPE "dead-code-analysis"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
using namespace mlir;
using namespace mlir::dataflow;
@@ -127,7 +126,8 @@ DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver)
}
LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
- LDBG("Initializing DeadCodeAnalysis for top-level op: " << top->getName());
+ LDBG() << "Initializing DeadCodeAnalysis for top-level op: "
+ << top->getName();
// Mark the top-level blocks as executable.
for (Region &region : top->getRegions()) {
if (region.empty())
@@ -135,7 +135,7 @@ LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
auto *state =
getOrCreate<Executable>(getProgramPointBefore(&region.front()));
propagateIfChanged(state, state->setToLive());
- LDBG("Marked entry block live for region in op: " << top->getName());
+ LDBG() << "Marked entry block live for region in op: " << top->getName();
}
// Mark as overdefined the predecessors of symbol callables with potentially
@@ -146,18 +146,18 @@ LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
}
void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
- LDBG("[init] Entering initializeSymbolCallables for top-level op: "
- << top->getName());
+ LDBG() << "[init] Entering initializeSymbolCallables for top-level op: "
+ << top->getName();
analysisScope = top;
auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
- LDBG("[init] Processing symbol table op: " << symTable->getName());
+ LDBG() << "[init] Processing symbol table op: " << symTable->getName();
Region &symbolTableRegion = symTable->getRegion(0);
Block *symbolTableBlock = &symbolTableRegion.front();
bool foundSymbolCallable = false;
for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
- LDBG("[init] Found CallableOpInterface: "
- << callable.getOperation()->getName());
+ LDBG() << "[init] Found CallableOpInterface: "
+ << callable.getOperation()->getName();
Region *callableRegion = callable.getCallableRegion();
if (!callableRegion)
continue;
@@ -171,8 +171,8 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
auto *state =
getOrCreate<PredecessorState>(getProgramPointAfter(callable));
propagateIfChanged(state, state->setHasUnknownPredecessors());
- LDBG("[init] Marked callable as having unknown predecessors: "
- << callable.getOperation()->getName());
+ LDBG() << "[init] Marked callable as having unknown predecessors: "
+ << callable.getOperation()->getName();
}
foundSymbolCallable = true;
}
@@ -187,15 +187,15 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
if (!uses) {
// If we couldn't gather the symbol uses, conservatively assume that
// we can't track information for any nested symbols.
- LDBG("[init] Could not gather symbol uses, conservatively marking "
- "all nested callables as having unknown predecessors");
+ LDBG() << "[init] Could not gather symbol uses, conservatively marking "
+ "all nested callables as having unknown predecessors";
return top->walk([&](CallableOpInterface callable) {
auto *state =
getOrCreate<PredecessorState>(getProgramPointAfter(callable));
propagateIfChanged(state, state->setHasUnknownPredecessors());
- LDBG("[init] Marked nested callable as "
- "having unknown predecessors: "
- << callable.getOperation()->getName());
+ LDBG() << "[init] Marked nested callable as "
+ "having unknown predecessors: "
+ << callable.getOperation()->getName();
});
}
@@ -209,15 +209,15 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
continue;
auto *state = getOrCreate<PredecessorState>(getProgramPointAfter(symbol));
propagateIfChanged(state, state->setHasUnknownPredecessors());
- LDBG("[init] Found non-call use for symbol, "
- "marked as having unknown predecessors: "
- << symbol->getName());
+ LDBG() << "[init] Found non-call use for symbol, "
+ "marked as having unknown predecessors: "
+ << symbol->getName();
}
};
SymbolTable::walkSymbolTables(top, /*allSymUsesVisible=*/!top->getBlock(),
walkFn);
- LDBG("[init] Finished initializeSymbolCallables for top-level op: "
- << top->getName());
+ LDBG() << "[init] Finished initializeSymbolCallables for top-level op: "
+ << top->getName();
}
/// Returns true if the operation is a returning terminator in region
@@ -229,14 +229,14 @@ static bool isRegionOrCallableReturn(Operation *op) {
}
LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
- LDBG("[init] Entering initializeRecursively for op: " << op->getName()
- << " at " << op);
+ LDBG() << "[init] Entering initializeRecursively for op: " << op->getName()
+ << " at " << op;
// Initialize the analysis by visiting every op with control-flow semantics.
if (op->getNumRegions() || op->getNumSuccessors() ||
isRegionOrCallableReturn(op) || isa<CallOpInterface>(op)) {
- LDBG("[init] Visiting op with control-flow semantics: " << *op);
- // When the liveness of the parent block changes, make sure to re-invoke the
- // analysis on the op.
+ LDBG() << "[init] Visiting op with control-flow semantics: " << *op;
+ // When the liveness of the parent block changes, make sure to
+ // re-invoke the analysis on the op.
if (op->getBlock())
getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))
->blockContentSubscribe(this);
@@ -246,21 +246,21 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
}
// Recurse on nested operations.
for (Region &region : op->getRegions()) {
- LDBG("[init] Recursing into region of op: " << op->getName());
+ LDBG() << "[init] Recursing into region of op: " << op->getName();
for (Operation &nestedOp : region.getOps()) {
- LDBG("[init] Recursing into nested op: " << nestedOp.getName() << " at "
- << &nestedOp);
+ LDBG() << "[init] Recursing into nested op: " << nestedOp.getName()
+ << " at " << &nestedOp;
if (failed(initializeRecursively(&nestedOp)))
return failure();
}
}
- LDBG("[init] Finished initializeRecursively for op: " << op->getName()
- << " at " << op);
+ LDBG() << "[init] Finished initializeRecursively for op: " << op->getName()
+ << " at " << op;
return success();
}
void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
- LDBG("Marking edge live from block " << from << " to block " << to);
+ LDBG() << "Marking edge live from block " << from << " to block " << to;
auto *state = getOrCreate<Executable>(getProgramPointBefore(to));
propagateIfChanged(state, state->setToLive());
auto *edgeState =
@@ -269,35 +269,35 @@ void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
}
void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
- LDBG("Marking entry blocks live for op: " << op->getName());
+ LDBG() << "Marking entry blocks live for op: " << op->getName();
for (Region &region : op->getRegions()) {
if (region.empty())
continue;
auto *state =
getOrCreate<Executable>(getProgramPointBefore(&region.front()));
propagateIfChanged(state, state->setToLive());
- LDBG("Marked entry block live for region in op: " << op->getName());
+ LDBG() << "Marked entry block live for region in op: " << op->getName();
}
}
LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
- LDBG("Visiting program point: " << point << " " << *point);
+ LDBG() << "Visiting program point: " << point << " " << *point;
if (point->isBlockStart())
return success();
Operation *op = point->getPrevOp();
- LDBG("Visiting operation: " << *op);
+ LDBG() << "Visiting operation: " << *op;
// If the parent block is not executable, there is nothing to do.
if (op->getBlock() != nullptr &&
!getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))
->isLive()) {
- LDBG("Parent block not live, skipping op: " << *op);
+ LDBG() << "Parent block not live, skipping op: " << *op;
return success();
}
// We have a live call op. Add this as a live predecessor of the callee.
if (auto call = dyn_cast<CallOpInterface>(op)) {
- LDBG("Visiting call operation: " << *op);
+ LDBG() << "Visiting call operation: " << *op;
visitCallOperation(call);
}
@@ -305,12 +305,12 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
if (op->getNumRegions()) {
// Check if we can reason about the region control-flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
- LDBG("Visiting region branch operation: " << *op);
+ LDBG() << "Visiting region branch operation: " << *op;
visitRegionBranchOperation(branch);
// Check if this is a callable operation.
} else if (auto callable = dyn_cast<CallableOpInterface>(op)) {
- LDBG("Visiting callable operation: " << *op);
+ LDBG() << "Visiting callable operation: " << *op;
const auto *callsites = getOrCreateFor<PredecessorState>(
getProgramPointAfter(op), getProgramPointAfter(callable));
@@ -322,19 +322,19 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
// Otherwise, conservatively mark all entry blocks as executable.
} else {
- LDBG("Marking all entry blocks live for op: " << *op);
+ LDBG() << "Marking all entry blocks live for op: " << *op;
markEntryBlocksLive(op);
}
}
if (isRegionOrCallableReturn(op)) {
if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
- LDBG("Visiting region terminator: " << *op);
+ LDBG() << "Visiting region terminator: " << *op;
// Visit the exiting terminator of a region.
visitRegionTerminator(op, branch);
} else if (auto callable =
dyn_cast<CallableOpInterface>(op->getParentOp())) {
- LDBG("Visiting callable terminator: " << *op);
+ LDBG() << "Visiting callable terminator: " << *op;
// Visit the exiting terminator of a callable.
visitCallableTerminator(op, callable);
}
@@ -343,12 +343,12 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
if (op->getNumSuccessors()) {
// Check if we can reason about the control-flow.
if (auto branch = dyn_cast<BranchOpInterface>(op)) {
- LDBG("Visiting branch operation: " << *op);
+ LDBG() << "Visiting branch operation: " << *op;
visitBranchOperation(branch);
// Otherwise, conservatively mark all successors as exectuable.
} else {
- LDBG("Marking all successors live for op: " << *op);
+ LDBG() << "Marking all successors live for op: " << *op;
for (Block *successor : op->getSuccessors())
markEdgeLive(op->getBlock(), successor);
}
@@ -358,7 +358,7 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
}
void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
- LDBG("visitCallOperation: " << call.getOperation()->getName());
+ LDBG() << "visitCallOperation: " << call.getOperation()->getName();
Operation *callableOp = call.resolveCallableInTable(&symbolTable);
// A call to a externally-defined callable has unknown predecessors.
@@ -381,15 +381,15 @@ void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
auto *callsites =
getOrCreate<PredecessorState>(getProgramPointAfter(callableOp));
propagateIfChanged(callsites, callsites->join(call));
- LDBG("Added callsite as predecessor for callable: "
- << callableOp->getName());
+ LDBG() << "Added callsite as predecessor for callable: "
+ << callableOp->getName();
} else {
// Mark this call op's predecessors as overdefined.
auto *predecessors =
getOrCreate<PredecessorState>(getProgramPointAfter(call));
propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors());
- LDBG("Marked call op's predecessors as unknown for: "
- << call.getOperation()->getName());
+ LDBG() << "Marked call op's predecessors as unknown for: "
+ << call.getOperation()->getName();
}
}
@@ -421,7 +421,7 @@ DeadCodeAnalysis::getOperandValues(Operation *op) {
}
void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
- LDBG("visitBranchOperation: " << branch.getOperation()->getName());
+ LDBG() << "visitBranchOperation: " << branch.getOperation()->getName();
// Try to deduce a single successor for the branch.
std::optional<SmallVector<Attribute>> operands = getOperandValues(branch);
if (!operands)
@@ -429,18 +429,18 @@ void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
if (Block *successor = branch.getSuccessorForOperands(*operands)) {
markEdgeLive(branch->getBlock(), successor);
- LDBG("Branch has single successor: " << successor);
+ LDBG() << "Branch has single successor: " << successor;
} else {
// Otherwise, mark all successors as executable and outgoing edges.
for (Block *successor : branch->getSuccessors())
markEdgeLive(branch->getBlock(), successor);
- LDBG("Branch has multiple/all successors live");
+ LDBG() << "Branch has multiple/all successors live";
}
}
void DeadCodeAnalysis::visitRegionBranchOperation(
RegionBranchOpInterface branch) {
- LDBG("visitRegionBranchOperation: " << branch.getOperation()->getName());
+ LDBG() << "visitRegionBranchOperation: " << branch.getOperation()->getName();
// Try to deduce which regions are executable.
std::optional<SmallVector<Attribute>> operands = getOperandValues(branch);
if (!operands)
@@ -457,19 +457,19 @@ void DeadCodeAnalysis::visitRegionBranchOperation(
// Mark the entry block as executable.
auto *state = getOrCreate<Executable>(point);
propagateIfChanged(state, state->setToLive());
- LDBG("Marked region successor live: " << point);
+ LDBG() << "Marked region successor live: " << point;
// Add the parent op as a predecessor.
auto *predecessors = getOrCreate<PredecessorState>(point);
propagateIfChanged(
predecessors,
predecessors->join(branch, successor.getSuccessorInputs()));
- LDBG("Added region branch as predecessor for successor: " << point);
+ LDBG() << "Added region branch as predecessor for successor: " << point;
}
}
void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
RegionBranchOpInterface branch) {
- LDBG("visitRegionTerminator: " << *op);
+ LDBG() << "visitRegionTerminator: " << *op;
std::optional<SmallVector<Attribute>> operands = getOperandValues(op);
if (!operands)
return;
@@ -488,7 +488,7 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
auto *state =
getOrCreate<Executable>(getProgramPointBefore(&region->front()));
propagateIfChanged(state, state->setToLive());
- LDBG("Marked region entry block live for region: " << region);
+ LDBG() << "Marked region entry block live for region: " << region;
predecessors = getOrCreate<PredecessorState>(
getProgramPointBefore(&region->front()));
} else {
@@ -498,14 +498,14 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
}
propagateIfChanged(predecessors,
predecessors->join(op, successor.getSuccessorInputs()));
- LDBG("Added region terminator as predecessor for successor: "
- << (successor.getSuccessor() ? "region entry" : "parent op"));
+ LDBG() << "Added region terminator as predecessor for successor: "
+ << (successor.getSuccessor() ? "region entry" : "parent op");
}
}
void DeadCodeAnalysis::visitCallableTerminator(Operation *op,
CallableOpInterface callable) {
- LDBG("visitCallableTerminator: " << *op);
+ LDBG() << "visitCallableTerminator: " << *op;
// Add as predecessors to all callsites this return op.
auto *callsites = getOrCreateFor<PredecessorState>(
getProgramPointAfter(op), getProgramPointAfter(callable));
@@ -516,15 +516,15 @@ void DeadCodeAnalysis::visitCallableTerminator(Operation *op,
getOrCreate<PredecessorState>(getProgramPointAfter(predecessor));
if (canResolve) {
propagateIfChanged(predecessors, predecessors->join(op));
- LDBG("Added callable terminator as predecessor for callsite: "
- << predecessor->getName());
+ LDBG() << "Added callable terminator as predecessor for callsite: "
+ << predecessor->getName();
} else {
// If the terminator is not a return-like, then conservatively assume we
// can't resolve the predecessor.
propagateIfChanged(predecessors,
predecessors->setHasUnknownPredecessors());
- LDBG("Could not resolve callable terminator for callsite: "
- << predecessor->getName());
+ LDBG() << "Could not resolve callable terminator for callsite: "
+ << predecessor->getName();
}
}
}
diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
index 6a12fe3..509f520 100644
--- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
@@ -10,7 +10,7 @@
#include <cassert>
#include <mlir/Analysis/DataFlow/LivenessAnalysis.h>
-#include <llvm/Support/Debug.h>
+#include <llvm/Support/DebugLog.h>
#include <mlir/Analysis/DataFlow/SparseAnalysis.h>
#include <mlir/Analysis/DataFlow/Utils.h>
#include <mlir/Analysis/DataFlowFramework.h>
@@ -21,8 +21,6 @@
#include <mlir/Support/LLVM.h>
#define DEBUG_TYPE "liveness-analysis"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
using namespace mlir;
using namespace mlir::dataflow;
@@ -81,16 +79,15 @@ ChangeResult Liveness::meet(const AbstractSparseLattice &other) {
LogicalResult
LivenessAnalysis::visitOperation(Operation *op, ArrayRef<Liveness *> operands,
ArrayRef<const Liveness *> results) {
- LLVM_DEBUG(DBGS() << "[visitOperation] Enter: ";
- op->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
- llvm::dbgs() << "\n");
+ LDBG() << "[visitOperation] Enter: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
// This marks values of type (1.a) and (4) liveness as "live".
if (!isMemoryEffectFree(op) || op->hasTrait<OpTrait::ReturnLike>()) {
- LDBG("[visitOperation] Operation has memory effects or is "
- "return-like, marking operands live");
+ LDBG() << "[visitOperation] Operation has memory effects or is "
+ "return-like, marking operands live";
for (auto *operand : operands) {
- LDBG(" [visitOperation] Marking operand live: "
- << operand << " (" << operand->isLive << ")");
+ LDBG() << " [visitOperation] Marking operand live: " << operand << " ("
+ << operand->isLive << ")";
propagateIfChanged(operand, operand->markLive());
}
}
@@ -99,28 +96,28 @@ LivenessAnalysis::visitOperation(Operation *op, ArrayRef<Liveness *> operands,
bool foundLiveResult = false;
for (const Liveness *r : results) {
if (r->isLive && !foundLiveResult) {
- LDBG("[visitOperation] Found live result, "
- "meeting all operands with result: "
- << r);
+ LDBG() << "[visitOperation] Found live result, "
+ "meeting all operands with result: "
+ << r;
// It is assumed that each operand is used to compute each result of an
// op. Thus, if at least one result is live, each operand is live.
for (Liveness *operand : operands) {
- LDBG(" [visitOperation] Meeting operand: " << operand
- << " with result: " << r);
+ LDBG() << " [visitOperation] Meeting operand: " << operand
+ << " with result: " << r;
meet(operand, *r);
}
foundLiveResult = true;
}
- LDBG("[visitOperation] Adding dependency for result: " << r << " after op: "
- << *op);
+ LDBG() << "[visitOperation] Adding dependency for result: " << r
+ << " after op: " << *op;
addDependency(const_cast<Liveness *>(r), getProgramPointAfter(op));
}
return success();
}
void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
- LDBG("Visiting branch operand: " << operand.get()
- << " in op: " << *operand.getOwner());
+ LDBG() << "Visiting branch operand: " << operand.get()
+ << " in op: " << *operand.getOwner();
// We know (at the moment) and assume (for the future) that `operand` is a
// non-forwarded branch operand of a `RegionBranchOpInterface`,
// `BranchOpInterface`, `RegionBranchTerminatorOpInterface` or return-like op.
@@ -152,9 +149,9 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
for (Value result : op->getResults()) {
if (getLatticeElement(result)->isLive) {
mayLive = true;
- LDBG("[visitBranchOperand] Non-forwarded branch "
- "operand may be live due to live result: "
- << result);
+ LDBG() << "[visitBranchOperand] Non-forwarded branch "
+ "operand may be live due to live result: "
+ << result;
break;
}
}
@@ -174,8 +171,8 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
// Therefore, we conservatively consider the non-forwarded operand of the
// branch operation may live.
mayLive = true;
- LDBG("[visitBranchOperand] Non-forwarded branch operand may "
- "be live due to branch op interface");
+ LDBG() << "[visitBranchOperand] Non-forwarded branch operand may "
+ "be live due to branch op interface";
} else {
Operation *parentOp = op->getParentOp();
assert(isa<RegionBranchOpInterface>(parentOp) &&
@@ -191,9 +188,9 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
for (Value result : parentOp->getResults()) {
if (getLatticeElement(result)->isLive) {
mayLive = true;
- LDBG("[visitBranchOperand] Non-forwarded branch "
- "operand may be live due to parent live result: "
- << result);
+ LDBG() << "[visitBranchOperand] Non-forwarded branch "
+ "operand may be live due to parent live result: "
+ << result;
break;
}
}
@@ -214,9 +211,9 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
for (Operation &nestedOp : *block) {
if (!isMemoryEffectFree(&nestedOp)) {
mayLive = true;
- LDBG("Non-forwarded branch operand may be "
- "live due to memory effect in block: "
- << block);
+ LDBG() << "Non-forwarded branch operand may be "
+ "live due to memory effect in block: "
+ << block;
break;
}
}
@@ -224,7 +221,7 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
if (mayLive) {
Liveness *operandLiveness = getLatticeElement(operand.get());
- LDBG("Marking branch operand live: " << operand.get());
+ LDBG() << "Marking branch operand live: " << operand.get();
propagateIfChanged(operandLiveness, operandLiveness->markLive());
}
@@ -236,7 +233,7 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
SmallVector<const Liveness *, 4> resultsLiveness;
for (const Value result : op->getResults())
resultsLiveness.push_back(getLatticeElement(result));
- LDBG("Visiting operation for non-forwarded branch operand: " << *op);
+ LDBG() << "Visiting operation for non-forwarded branch operand: " << *op;
(void)visitOperation(op, operandLiveness, resultsLiveness);
// We also visit the parent op with the parent's results and this operand if
@@ -249,14 +246,14 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
SmallVector<const Liveness *, 4> parentResultsLiveness;
for (const Value parentResult : parentOp->getResults())
parentResultsLiveness.push_back(getLatticeElement(parentResult));
- LDBG("Visiting parent operation for non-forwarded branch operand: "
- << *parentOp);
+ LDBG() << "Visiting parent operation for non-forwarded branch operand: "
+ << *parentOp;
(void)visitOperation(parentOp, operandLiveness, parentResultsLiveness);
}
void LivenessAnalysis::visitCallOperand(OpOperand &operand) {
- LDBG("Visiting call operand: " << operand.get()
- << " in op: " << *operand.getOwner());
+ LDBG() << "Visiting call operand: " << operand.get()
+ << " in op: " << *operand.getOwner();
// We know (at the moment) and assume (for the future) that `operand` is a
// non-forwarded call operand of an op implementing `CallOpInterface`.
assert(isa<CallOpInterface>(operand.getOwner()) &&
@@ -269,18 +266,18 @@ void LivenessAnalysis::visitCallOperand(OpOperand &operand) {
// This marks values of type (1.c) liveness as "live". A non-forwarded
// call operand is live.
Liveness *operandLiveness = getLatticeElement(operand.get());
- LDBG("Marking call operand live: " << operand.get());
+ LDBG() << "Marking call operand live: " << operand.get();
propagateIfChanged(operandLiveness, operandLiveness->markLive());
}
void LivenessAnalysis::setToExitState(Liveness *lattice) {
- LDBG("setToExitState for lattice: " << lattice);
+ LDBG() << "setToExitState for lattice: " << lattice;
if (lattice->isLive) {
- LDBG("Lattice already live, nothing to do");
+ LDBG() << "Lattice already live, nothing to do";
return;
}
// This marks values of type (2) liveness as "live".
- LDBG("Marking lattice live due to exit state");
+ LDBG() << "Marking lattice live due to exit state";
(void)lattice->markLive();
propagateIfChanged(lattice, ChangeResult::Change);
}
@@ -290,14 +287,14 @@ void LivenessAnalysis::setToExitState(Liveness *lattice) {
//===----------------------------------------------------------------------===//
RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) {
- LDBG("Constructing RunLivenessAnalysis for op: " << op->getName());
+ LDBG() << "Constructing RunLivenessAnalysis for op: " << op->getName();
SymbolTableCollection symbolTable;
loadBaselineAnalyses(solver);
solver.load<LivenessAnalysis>(symbolTable);
- LDBG("Initializing and running solver");
+ LDBG() << "Initializing and running solver";
(void)solver.initializeAndRun(op);
- LDBG("Dumping liveness state for op");
+ LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName();
}
const Liveness *RunLivenessAnalysis::getLiveness(Value val) {
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index 176d53e..16f7033 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -14,7 +14,7 @@
#include "llvm/ADT/iterator.h"
#include "llvm/Config/abi-breaking.h"
#include "llvm/Support/Casting.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "dataflow"
@@ -44,9 +44,8 @@ void AnalysisState::addDependency(ProgramPoint *dependent,
(void)inserted;
DATAFLOW_DEBUG({
if (inserted) {
- llvm::dbgs() << "Creating dependency between " << debugName << " of "
- << anchor << "\nand " << debugName << " on " << dependent
- << "\n";
+ LDBG() << "Creating dependency between " << debugName << " of " << anchor
+ << "\nand " << debugName << " on " << dependent;
}
});
}
@@ -116,8 +115,7 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
// Initialize the analyses.
for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
- DATAFLOW_DEBUG(llvm::dbgs()
- << "Priming analysis: " << analysis.debugName << "\n");
+ DATAFLOW_DEBUG(LDBG() << "Priming analysis: " << analysis.debugName);
if (failed(analysis.initialize(top)))
return failure();
}
@@ -129,8 +127,8 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
auto [point, analysis] = worklist.front();
worklist.pop();
- DATAFLOW_DEBUG(llvm::dbgs() << "Invoking '" << analysis->debugName
- << "' on: " << point << "\n");
+ DATAFLOW_DEBUG(LDBG() << "Invoking '" << analysis->debugName
+ << "' on: " << point);
if (failed(analysis->visit(point)))
return failure();
}
@@ -143,9 +141,9 @@ void DataFlowSolver::propagateIfChanged(AnalysisState *state,
assert(isRunning &&
"DataFlowSolver is not running, should not use propagateIfChanged");
if (changed == ChangeResult::Change) {
- DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName
- << " of " << state->anchor << "\n"
- << "Value: " << *state << "\n");
+ DATAFLOW_DEBUG(LDBG() << "Propagating update to " << state->debugName
+ << " of " << state->anchor << "\n"
+ << "Value: " << *state);
state->onUpdate(this);
}
}
diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
index 239ffe6..ea7dfdc 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
@@ -15,7 +15,6 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <functional>
diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index 9f4a87a..8b14e71 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -89,6 +89,7 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body,
nestedPunctuation.pop_back();
return success();
};
+ const char *curBufferEnd = state.lex.getBufferEnd();
do {
// Handle code completions, which may appear in the middle of the symbol
// body.
@@ -98,6 +99,12 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body,
break;
}
+ if (curBufferEnd == curPtr) {
+ if (!nestedPunctuation.empty())
+ return emitPunctError();
+ return emitError("unexpected nul or EOF in pretty dialect name");
+ }
+
char c = *curPtr++;
switch (c) {
case '\0':
diff --git a/mlir/lib/AsmParser/Lexer.cpp b/mlir/lib/AsmParser/Lexer.cpp
index 751bd63..8f53529 100644
--- a/mlir/lib/AsmParser/Lexer.cpp
+++ b/mlir/lib/AsmParser/Lexer.cpp
@@ -37,6 +37,18 @@ Lexer::Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context,
AsmParserCodeCompleteContext *codeCompleteContext)
: sourceMgr(sourceMgr), context(context), codeCompleteLoc(nullptr) {
auto bufferID = sourceMgr.getMainFileID();
+
+ // Check to see if the main buffer contains the last buffer, and if so the
+ // last buffer should be used as main file for parsing.
+ if (sourceMgr.getNumBuffers() > 1) {
+ unsigned lastFileID = sourceMgr.getNumBuffers();
+ const llvm::MemoryBuffer *main = sourceMgr.getMemoryBuffer(bufferID);
+ const llvm::MemoryBuffer *last = sourceMgr.getMemoryBuffer(lastFileID);
+ if (main->getBufferStart() <= last->getBufferStart() &&
+ main->getBufferEnd() >= last->getBufferEnd()) {
+ bufferID = lastFileID;
+ }
+ }
curBuffer = sourceMgr.getMemoryBuffer(bufferID)->getBuffer();
curPtr = curBuffer.begin();
@@ -71,6 +83,7 @@ Token Lexer::emitError(const char *loc, const Twine &message) {
}
Token Lexer::lexToken() {
+ const char *curBufferEnd = curBuffer.end();
while (true) {
const char *tokStart = curPtr;
@@ -78,6 +91,9 @@ Token Lexer::lexToken() {
if (tokStart == codeCompleteLoc)
return formToken(Token::code_complete, tokStart);
+ if (tokStart == curBufferEnd)
+ return formToken(Token::eof, tokStart);
+
// Lex the next token.
switch (*curPtr++) {
default:
@@ -102,7 +118,7 @@ Token Lexer::lexToken() {
case 0:
// This may either be a nul character in the source file or may be the EOF
// marker that llvm::MemoryBuffer guarantees will be there.
- if (curPtr - 1 == curBuffer.end())
+ if (curPtr - 1 == curBufferEnd)
return formToken(Token::eof, tokStart);
continue;
@@ -259,7 +275,11 @@ void Lexer::skipComment() {
assert(*curPtr == '/');
++curPtr;
+ const char *curBufferEnd = curBuffer.end();
while (true) {
+ if (curPtr == curBufferEnd)
+ return;
+
switch (*curPtr++) {
case '\n':
case '\r':
@@ -267,7 +287,7 @@ void Lexer::skipComment() {
return;
case 0:
// If this is the end of the buffer, end the comment.
- if (curPtr - 1 == curBuffer.end()) {
+ if (curPtr - 1 == curBufferEnd) {
--curPtr;
return;
}
@@ -405,6 +425,7 @@ Token Lexer::lexPrefixedIdentifier(const char *tokStart) {
Token Lexer::lexString(const char *tokStart) {
assert(curPtr[-1] == '"');
+ const char *curBufferEnd = curBuffer.end();
while (true) {
// Check to see if there is a code completion location within the string. In
// these cases we generate a completion location and place the currently
@@ -419,7 +440,7 @@ Token Lexer::lexString(const char *tokStart) {
case 0:
// If this is a random nul character in the middle of a string, just
// include it. If it is the end of file, then it is an error.
- if (curPtr - 1 != curBuffer.end())
+ if (curPtr - 1 != curBufferEnd)
continue;
[[fallthrough]];
case '\n':
diff --git a/mlir/lib/AsmParser/Lexer.h b/mlir/lib/AsmParser/Lexer.h
index 4085a9b..670444e 100644
--- a/mlir/lib/AsmParser/Lexer.h
+++ b/mlir/lib/AsmParser/Lexer.h
@@ -40,6 +40,9 @@ public:
/// Returns the start of the buffer.
const char *getBufferBegin() { return curBuffer.data(); }
+ /// Returns the end of the buffer.
+ const char *getBufferEnd() { return curBuffer.end(); }
+
/// Return the code completion location of the lexer, or nullptr if there is
/// none.
const char *getCodeCompleteLoc() const { return codeCompleteLoc; }
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 21bb0ec..a461ebe 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -19,7 +19,6 @@
#include "mlir/IR/TensorEncoding.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/STLExtras.h"
#include <cassert>
#include <cstdint>
#include <limits>
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 8f79caf..db84ee1 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -16,8 +16,8 @@
#include "NanobindUtils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/raw_ostream.h"
@@ -1428,6 +1428,12 @@ public:
}
};
+// Check if the python version is less than 3.13. Py_IsFinalizing is a part
+// of stable ABI since 3.13 and before it was available as _Py_IsFinalizing.
+#if PY_VERSION_HEX < 0x030d0000
+#define Py_IsFinalizing _Py_IsFinalizing
+#endif
+
class PyDenseResourceElementsAttribute
: public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
public:
@@ -1474,8 +1480,9 @@ public:
// The userData is a Py_buffer* that the deleter owns.
auto deleter = [](void *userData, const void *data, size_t size,
size_t align) {
- if (!Py_IsInitialized())
- Py_Initialize();
+ if (Py_IsFinalizing())
+ return;
+ assert(Py_IsInitialized() && "expected interpreter to be initialized");
Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
nb::gil_scoped_acquire gil;
PyBuffer_Release(ownedView);
diff --git a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt
index 8b9a395..ccda668 100644
--- a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt
+++ b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt
@@ -1,19 +1,16 @@
# Dialect registration.
-get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS)
-get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
-get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
add_mlir_upstream_c_api_library(MLIRCAPIRegisterEverything
RegisterEverything.cpp
LINK_LIBS PUBLIC
- ${dialect_libs}
${translation_libs}
- ${conversion_libs}
- ${extension_libs}
MLIRBuiltinToLLVMIRTranslation
MLIRCAPIIR
- MLIRLLVMToLLVMIRTranslation
MLIRCAPITransforms
+ MLIRLLVMToLLVMIRTranslation
+ MLIRRegisterAllDialects
+ MLIRRegisterAllExtensions
+ MLIRRegisterAllPasses
)
diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt
index d25c84a..191b5ab6 100644
--- a/mlir/lib/CMakeLists.txt
+++ b/mlir/lib/CMakeLists.txt
@@ -20,3 +20,37 @@ add_subdirectory(Target)
add_subdirectory(Tools)
add_subdirectory(Transforms)
add_subdirectory(ExecutionEngine)
+
+get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
+get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
+get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
+
+add_mlir_library(MLIRRegisterAllDialects
+ RegisterAllDialects.cpp
+
+ PARTIAL_SOURCES_INTENDED
+
+ LINK_LIBS PUBLIC
+ ${dialect_libs}
+ )
+
+add_mlir_library(MLIRRegisterAllPasses
+ RegisterAllPasses.cpp
+
+ PARTIAL_SOURCES_INTENDED
+
+ LINK_LIBS PUBLIC
+ ${dialect_libs} # Some passes are part of the dialect libs
+ ${conversion_libs}
+ )
+
+add_mlir_library(MLIRRegisterAllExtensions
+ RegisterAllExtensions.cpp
+
+ PARTIAL_SOURCES_INTENDED
+
+ LINK_LIBS PUBLIC
+ ${dialect_libs}
+ ${conversion_libs}
+ ${extension_libs}
+ )
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index bc0d9bf..64720bf 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -232,8 +232,8 @@ struct FatRawBufferCastLowering
Value result = MemRefDescriptor::poison(
rewriter, loc,
getTypeConverter()->convertType(op.getResult().getType()));
- result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
- kAllocatedPtrPosInMemRefDescriptor);
+ SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor};
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos);
result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
kAlignedPtrPosInMemRefDescriptor);
result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
@@ -481,16 +481,16 @@ struct MemoryCounterWaitOpLowering
if (chipset.majorVersion >= 12) {
Location loc = op.getLoc();
if (std::optional<int> ds = adaptor.getDs())
- rewriter.create<ROCDL::WaitDscntOp>(loc, *ds);
+ ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
if (std::optional<int> load = adaptor.getLoad())
- rewriter.create<ROCDL::WaitLoadcntOp>(loc, *load);
+ ROCDL::WaitLoadcntOp::create(rewriter, loc, *load);
if (std::optional<int> store = adaptor.getStore())
- rewriter.create<ROCDL::WaitStorecntOp>(loc, *store);
+ ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
if (std::optional<int> exp = adaptor.getExp())
- rewriter.create<ROCDL::WaitExpcntOp>(loc, *exp);
+ ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
rewriter.eraseOp(op);
return success();
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 8c68b57..8230591 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -449,7 +449,7 @@ LogicalResult
ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
- constexpr int64_t opWidth = 2;
+ constexpr int64_t opOutWidth = 2;
Value in = op.getIn();
Value scale = op.getScale();
@@ -460,6 +460,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
Type scaleType = getElementTypeOrSelf(scale);
Type outType = getElementTypeOrSelf(out);
+ int64_t opInWidth = 32 / inType.getIntOrFloatBitWidth();
+
VectorType outVecType = dyn_cast<VectorType>(out.getType());
VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
@@ -473,7 +475,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
else if (scaleType.getIntOrFloatBitWidth() > 32)
scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
- VectorType extScaleResultType = VectorType::get(opWidth, outType);
+ VectorType extScaleResultType = VectorType::get(opOutWidth, outType);
if (!outVecType) {
Value inCast = vector::BroadcastOp::create(rewriter, loc,
@@ -487,10 +489,11 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
VectorType inVecType = cast<VectorType>(in.getType());
Value origScale = getOriginalVectorValue(op.getScale());
+ VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
ArrayRef<int64_t> inShape = inVecType.getShape();
SmallVector<int64_t> originalScaleShape;
- if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
+ if (origScaleVecType)
llvm::append_range(originalScaleShape, origScaleVecType.getShape());
originalScaleShape.insert(originalScaleShape.end(),
@@ -524,19 +527,26 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
Value blockResult =
rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
- for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
+ for (int64_t i = 0, inSliceWidth = std::min(opInWidth, blockSize - i);
i < blockSize;
- i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
- Value slice = vector::ExtractStridedSliceOp::create(
- rewriter, loc, block1D, i, sliceWidth, 1);
- // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
- Value scaleExt = amdgpu::ScaledExtPackedOp::create(
- rewriter, loc, extScaleResultType, slice, uniformScale, 0);
- if (sliceWidth != opWidth)
- scaleExt = vector::ExtractStridedSliceOp::create(
- rewriter, loc, scaleExt, 0, sliceWidth, 1);
- blockResult = vector::InsertStridedSliceOp::create(
- rewriter, loc, scaleExt, blockResult, i, 1);
+ i += inSliceWidth, inSliceWidth = std::min(opInWidth, blockSize - i)) {
+ Value inSlice = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, block1D, i, inSliceWidth, 1);
+ for (int64_t j = 0,
+ outSliceWidth = std::min(opOutWidth, inSliceWidth - j);
+ j < inSliceWidth; j += outSliceWidth,
+ outSliceWidth = std::min(opOutWidth, inSliceWidth - j)) {
+ // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
+ Value scaleExt = amdgpu::ScaledExtPackedOp::create(
+ rewriter, loc, extScaleResultType, inSlice, uniformScale,
+ j / opOutWidth);
+ if (outSliceWidth < opOutWidth) {
+ scaleExt = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, scaleExt, 0, outSliceWidth, 1);
+ }
+ blockResult = vector::InsertStridedSliceOp::create(
+ rewriter, loc, scaleExt, blockResult, i + j, 1);
+ }
}
VectorType resultType = VectorType::get(ratio, outType);
@@ -555,7 +565,7 @@ LogicalResult
ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
- constexpr int64_t opWidth = 2;
+ constexpr int64_t opInWidth = 2;
Value in = op.getIn();
Value scale = op.getScale();
@@ -568,7 +578,6 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
VectorType outVecType = dyn_cast<VectorType>(out.getType());
VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
-
if (outVecType && outVecType.isScalable())
return failure();
@@ -581,8 +590,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
Value zero = arith::ConstantOp::create(rewriter, loc, outType,
rewriter.getFloatAttr(outType, 0.0));
- unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth();
- VectorType truncScaleResultType = VectorType::get(numPackedElem, outType);
+ int64_t opOutWidth = 32 / outType.getIntOrFloatBitWidth();
+ VectorType truncScaleResultType = VectorType::get(opOutWidth, outType);
if (!outVecType) {
Type inVecType = VectorType::get(1, inType);
@@ -598,16 +607,16 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
VectorType inVecType = cast<VectorType>(in.getType());
Value origScale = getOriginalVectorValue(op.getScale());
+ VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
ArrayRef<int64_t> inShape = inVecType.getShape();
- SmallVector<int64_t> originalScaleShape;
- if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
- llvm::append_range(originalScaleShape, origScaleVecType.getShape());
+ SmallVector<int64_t> scaleShape;
+ if (origScaleVecType)
+ llvm::append_range(scaleShape, origScaleVecType.getShape());
- originalScaleShape.insert(originalScaleShape.end(),
- inShape.size() - originalScaleShape.size(), 1);
+ scaleShape.insert(scaleShape.end(), inShape.size() - scaleShape.size(), 1);
- auto maybeRatio = computeShapeRatio(inShape, originalScaleShape);
+ auto maybeRatio = computeShapeRatio(inShape, scaleShape);
assert(maybeRatio &&
"failed to derive block size from broadcast or splat operation");
@@ -633,20 +642,36 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
Value blockResult =
rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
- for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
- i < blockSize;
- i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
- Value slice = vector::ExtractStridedSliceOp::create(
- rewriter, loc, block1D, i, sliceWidth, 1);
- // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
- Value scaleTrunc = amdgpu::PackedScaledTruncOp::create(
- rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
- /*existing=*/nullptr);
- int64_t packedWidth =
- cast<VectorType>(scaleTrunc.getType()).getNumElements();
- if (packedWidth != opWidth)
+ for (int64_t i = 0, outSliceWidth = std::min(opOutWidth, blockSize - i);
+ i < blockSize; i += outSliceWidth,
+ outSliceWidth = std::min(opOutWidth, blockSize - i)) {
+ Value scaleTrunc;
+ // Case where <= 2 elements are being truncated.
+ if (outSliceWidth <= opInWidth) {
+ Value slice = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, block1D, i, outSliceWidth, 1);
+ // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
+ scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+ rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
+ /*existing=*/nullptr);
+ } else {
+ scaleTrunc = vector::BroadcastOp::create(rewriter, loc,
+ truncScaleResultType, zero);
+ for (int64_t j = 0,
+ inSliceWidth = std::min(opInWidth, outSliceWidth - j);
+ j < outSliceWidth; j += opInWidth,
+ inSliceWidth = std::min(opInWidth, outSliceWidth - j)) {
+ Value slice = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, block1D, i + j, inSliceWidth, 1);
+ scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+ rewriter, loc, truncScaleResultType, slice, uniformScale,
+ j / opInWidth, scaleTrunc);
+ }
+ }
+ if (outSliceWidth != opOutWidth) {
scaleTrunc = vector::ExtractStridedSliceOp::create(
- rewriter, loc, scaleTrunc, 0, sliceWidth, 1);
+ rewriter, loc, scaleTrunc, 0, outSliceWidth, 1);
+ }
blockResult = vector::InsertStridedSliceOp::create(
rewriter, loc, scaleTrunc, blockResult, i, 1);
}
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 59b3fe2..515fe5c 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -402,8 +402,8 @@ public:
Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
// Actual cast (may change bitwidth)
- auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
- castDestType, actualOp);
+ auto cast =
+ emitc::CastOp::create(rewriter, op.getLoc(), castDestType, actualOp);
// Cast to the expected output type
auto result = adaptValueType(cast, rewriter, opReturnType);
@@ -507,8 +507,8 @@ public:
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
- Value arithmeticResult = rewriter.template create<EmitCOp>(
- op.getLoc(), arithmeticType, lhs, rhs);
+ Value arithmeticResult =
+ EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
Value result = adaptValueType(arithmeticResult, rewriter, type);
@@ -547,8 +547,8 @@ public:
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
- Value arithmeticResult = rewriter.template create<EmitCOp>(
- op.getLoc(), arithmeticType, lhs, rhs);
+ Value arithmeticResult =
+ EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
Value result = adaptValueType(arithmeticResult, rewriter, type);
@@ -748,8 +748,8 @@ public:
}
Value fpCastOperand = adaptor.getIn();
if (actualOperandType != operandType) {
- fpCastOperand = rewriter.template create<emitc::CastOp>(
- castOp.getLoc(), actualOperandType, fpCastOperand);
+ fpCastOperand = emitc::CastOp::create(rewriter, castOp.getLoc(),
+ actualOperandType, fpCastOperand);
}
rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index d43e681..265293b 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -99,6 +99,17 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
return builder.getF32FloatAttr(dstVal.convertToFloat());
}
+// Get in IntegerAttr from FloatAttr while preserving the bits.
+// Useful for converting float constants to integer constants while preserving
+// the bits.
+static IntegerAttr
+getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
+ ConversionPatternRewriter &rewriter) {
+ APFloat floatVal = floatAttr.getValue();
+ APInt intVal = floatVal.bitcastToAPInt();
+ return rewriter.getIntegerAttr(dstType, intVal);
+}
+
/// Returns true if the given `type` is a boolean scalar or vector type.
static bool isBoolScalarOrVector(Type type) {
assert(type && "Not a valid type");
@@ -296,8 +307,18 @@ struct ConstantCompositeOpPattern final
SmallVector<Attribute, 8> elements;
if (isa<FloatType>(srcElemType)) {
for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
- FloatAttr dstAttr =
- convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
+ Attribute dstAttr = nullptr;
+ // Handle 8-bit float conversion to 8-bit integer.
+ auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+ if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
+ srcElemType.getIntOrFloatBitWidth() == 8 &&
+ isa<IntegerType>(dstElemType)) {
+ dstAttr =
+ getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
+ } else {
+ dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
+ rewriter);
+ }
if (!dstAttr)
return failure();
elements.push_back(dstAttr);
@@ -361,11 +382,19 @@ struct ConstantScalarOpPattern final
// Floating-point types.
if (isa<FloatType>(srcType)) {
auto srcAttr = cast<FloatAttr>(cstAttr);
- auto dstAttr = srcAttr;
+ Attribute dstAttr = srcAttr;
// Floating-point types not supported in the target environment are all
// converted to float type.
- if (srcType != dstType) {
+ auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+ if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
+ srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
+ dstType.getIntOrFloatBitWidth() == 8) {
+ // If the source is an 8-bit float, convert it to a 8-bit integer.
+ dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
+ if (!dstAttr)
+ return failure();
+ } else if (srcType != dstType) {
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
if (!dstAttr)
return failure();
@@ -1352,6 +1381,7 @@ struct ConvertArithToSPIRVPass
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 1510b0b..e34b368 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -12,7 +12,6 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 79e1683..29e6552 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -18,7 +18,6 @@
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index 30a7170..3edcbb8 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -68,9 +68,8 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
scf::YieldOp::create(rewriter, loc, acc);
};
- auto size = rewriter
- .create<scf::ForOp>(loc, zero, rank, one, ValueRange(one),
- loopBody)
+ auto size = scf::ForOp::create(rewriter, loc, zero, rank, one,
+ ValueRange(one), loopBody)
.getResult(0);
MemRefType memrefType = MemRefType::get({ShapedType::kDynamic},
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index f84375b..785cb82 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -43,7 +43,7 @@ add_subdirectory(MathToSPIRV)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
-add_subdirectory(MeshToMPI)
+add_subdirectory(ShardToMPI)
add_subdirectory(MPIToLLVM)
add_subdirectory(NVGPUToNVVM)
add_subdirectory(NVVMToLLVM)
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 6f0fc29..35ad99c 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -64,10 +64,46 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
patterns.getContext(), "__ocml_cabs_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
patterns.getContext(), "__ocml_cabs_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>(
+ patterns.getContext(), "__ocml_carg_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>(
+ patterns.getContext(), "__ocml_carg_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>(
+ patterns.getContext(), "__ocml_conj_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>(
+ patterns.getContext(), "__ocml_conj_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>(
+ patterns.getContext(), "__ocml_ccos_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>(
+ patterns.getContext(), "__ocml_ccos_f64");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>(
patterns.getContext(), "__ocml_cexp_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>(
patterns.getContext(), "__ocml_cexp_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float32Type>>(
+ patterns.getContext(), "__ocml_clog_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>(
+ patterns.getContext(), "__ocml_clog_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>(
+ patterns.getContext(), "__ocml_cpow_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>(
+ patterns.getContext(), "__ocml_cpow_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>(
+ patterns.getContext(), "__ocml_csin_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>(
+ patterns.getContext(), "__ocml_csin_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float32Type>>(
+ patterns.getContext(), "__ocml_csqrt_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float64Type>>(
+ patterns.getContext(), "__ocml_csqrt_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float32Type>>(
+ patterns.getContext(), "__ocml_ctan_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float64Type>>(
+ patterns.getContext(), "__ocml_ctan_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float32Type>>(
+ patterns.getContext(), "__ocml_ctanh_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float64Type>>(
+ patterns.getContext(), "__ocml_ctanh_f64");
}
namespace {
@@ -86,7 +122,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<func::FuncDialect>();
- target.addIllegalOp<complex::AbsOp, complex::ExpOp>();
+ target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp,
+ complex::CosOp, complex::ExpOp, complex::LogOp,
+ complex::PowOp, complex::SinOp, complex::SqrtOp,
+ complex::TanOp, complex::TanhOp>();
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index eeff8a9..5ad514d 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -12,7 +12,6 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include <type_traits>
diff --git a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
index c8311eb..5ac838c 100644
--- a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
@@ -144,12 +144,11 @@ ControlFlowToSCFTransformation::createUnreachableTerminator(Location loc,
return emitError(loc, "Cannot create unreachable terminator for '")
<< parentOp->getName() << "'";
- return builder
- .create<func::ReturnOp>(
- loc, llvm::map_to_vector(funcOp.getResultTypes(),
- [&](Type type) {
- return getUndefValue(loc, builder, type);
- }))
+ return func::ReturnOp::create(
+ builder, loc,
+ llvm::map_to_vector(
+ funcOp.getResultTypes(),
+ [&](Type type) { return getUndefValue(loc, builder, type); }))
.getOperation();
}
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
index 03f4bf4..56b6181 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
@@ -43,6 +43,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
// TODO: We should also take care of block argument type conversion.
diff --git a/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp b/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp
index c9b1dc1..ee6d7d5 100644
--- a/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp
+++ b/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp
@@ -9,8 +9,6 @@
#include "mlir/Conversion/ConvertToEmitC/ConvertToEmitCPass.h"
#include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h"
-#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
-#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
index 252245d..c70b5f0 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
@@ -9,7 +9,6 @@
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
-#include "llvm/ADT/DenseSet.h"
using namespace mlir;
diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
index 8ed9f65..c0439a4 100644
--- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
@@ -42,6 +42,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 63eb6c58..3cfbd89 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -579,8 +579,8 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
auto function = [&] {
if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
return function;
- return OpBuilder::atBlockEnd(module.getBody())
- .create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
+ auto builder = OpBuilder::atBlockEnd(module.getBody());
+ return LLVM::LLVMFuncOp::create(builder, loc, functionName, functionType);
}();
return LLVM::CallOp::create(builder, loc, function, arguments);
}
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index a19194e..1817861 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -507,25 +507,27 @@ LogicalResult GPURotateConversion::matchAndRewrite(
getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
unsigned subgroupSize =
targetEnv.getAttr().getResourceLimits().getSubgroupSize();
- IntegerAttr widthAttr;
- if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) ||
- widthAttr.getValue().getZExtValue() > subgroupSize)
+ unsigned width = rotateOp.getWidth();
+ if (width > subgroupSize)
return rewriter.notifyMatchFailure(
- rotateOp,
- "rotate width is not a constant or larger than target subgroup size");
+ rotateOp, "rotate width is larger than target subgroup size");
Location loc = rotateOp.getLoc();
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
+ Value offsetVal =
+ arith::ConstantOp::create(rewriter, loc, adaptor.getOffsetAttr());
+ Value widthVal =
+ arith::ConstantOp::create(rewriter, loc, adaptor.getWidthAttr());
Value rotateResult = spirv::GroupNonUniformRotateKHROp::create(
- rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset(),
- adaptor.getWidth());
+ rewriter, loc, scope, adaptor.getValue(), offsetVal, widthVal);
Value validVal;
- if (widthAttr.getValue().getZExtValue() == subgroupSize) {
+ if (width == subgroupSize) {
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter);
} else {
+ IntegerAttr widthAttr = adaptor.getWidthAttr();
Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
- laneId, adaptor.getWidth());
+ laneId, widthVal);
}
rewriter.replaceOp(rotateOp, {rotateResult, validVal});
@@ -559,8 +561,8 @@ static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc,
builder, loc, builder.getI32Type(),
builder.getIntegerAttr(builder.getI32Type(), *clusterSize));
- return builder
- .create<NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue)
+ return NonUniformOp::create(builder, loc, type, scope, groupOp, arg,
+ clusterSizeValue)
.getResult();
}
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index ecd5b63..2568044 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -272,14 +272,13 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
// Allocate memory, copy, and free the source if necessary.
Value memory =
- toDynamic
- ? builder
- .create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize)
- .getResult()
- : LLVM::AllocaOp::create(builder, loc, getPtrType(),
- IntegerType::get(getContext(), 8),
- allocationSize,
- /*alignment=*/0);
+ toDynamic ? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
+ allocationSize)
+ .getResult()
+ : LLVM::AllocaOp::create(builder, loc, getPtrType(),
+ IntegerType::get(getContext(), 8),
+ allocationSize,
+ /*alignment=*/0);
Value source = desc.memRefDescPtr(builder, loc);
LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false);
if (!toDynamic)
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 5b68eb8..e5496e5 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -35,7 +35,7 @@ static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc,
if (!(ret = moduleOp.lookupSymbol<Op>(name))) {
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
- ret = rewriter.template create<Op>(loc, std::forward<Args>(args)...);
+ ret = Op::create(rewriter, loc, std::forward<Args>(args)...);
}
return ret;
}
diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
index 08a4566..cde2340 100644
--- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
+++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
@@ -17,13 +17,12 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOFUNCS
@@ -33,7 +32,6 @@ namespace mlir {
using namespace mlir;
#define DEBUG_TYPE "math-to-funcs"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
namespace {
// Pattern to convert vector operations to scalar operations.
@@ -654,10 +652,8 @@ FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
/// }
static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
if (!isa<IntegerType>(elementType)) {
- LLVM_DEBUG({
- DBGS() << "non-integer element type for CtlzFunc; type was: ";
- elementType.print(llvm::dbgs());
- });
+ LDBG() << "non-integer element type for CtlzFunc; type was: "
+ << elementType;
llvm_unreachable("non-integer element type");
}
int64_t bitWidth = elementType.getIntOrFloatBitWidth();
@@ -699,7 +695,8 @@ static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
scf::IfOp ifOp =
scf::IfOp::create(builder, elementType, inputEqZero,
/*addThenBlock=*/true, /*addElseBlock=*/true);
- ifOp.getThenBodyBuilder().create<scf::YieldOp>(loc, bitWidthValue);
+ auto thenBuilder = ifOp.getThenBodyBuilder();
+ scf::YieldOp::create(thenBuilder, loc, bitWidthValue);
auto elseBuilder =
ImplicitLocOpBuilder::atBlockEnd(loc, &ifOp.getElseRegion().front());
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 93d8b49..df219f3 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -21,7 +22,6 @@
#include "../GPUCommon/GPUOpsLowering.h"
#include "../GPUCommon/OpToFuncCallLowering.h"
-#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOROCDL
@@ -31,7 +31,6 @@ namespace mlir {
using namespace mlir;
#define DEBUG_TYPE "math-to-rocdl"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
template <typename OpTy>
static void populateOpPatterns(const LLVMTypeConverter &converter,
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index e882845..6bd0e2d 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -19,10 +19,18 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
+#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
+#include <cstdint>
using namespace mlir;
+static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) {
+ return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() &&
+ memRefType.getRank() != 0 &&
+ !llvm::is_contained(memRefType.getShape(), 0);
+}
+
namespace {
/// Implement the interface to convert MemRef to EmitC.
struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
@@ -89,6 +97,68 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
return resultTy;
}
+struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = allocOp.getLoc();
+ MemRefType memrefType = allocOp.getType();
+ if (!isMemRefTypeLegalForEmitC(memrefType)) {
+ return rewriter.notifyMatchFailure(
+ loc, "incompatible memref type for EmitC conversion");
+ }
+
+ Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
+ Type elementType = memrefType.getElementType();
+ IndexType indexType = rewriter.getIndexType();
+ emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>(
+ loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{},
+ ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)}));
+
+ int64_t numElements = 1;
+ for (int64_t dimSize : memrefType.getShape()) {
+ numElements *= dimSize;
+ }
+ Value numElementsValue = rewriter.create<emitc::ConstantOp>(
+ loc, indexType, rewriter.getIndexAttr(numElements));
+
+ Value totalSizeBytes = rewriter.create<emitc::MulOp>(
+ loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue);
+
+ emitc::CallOpaqueOp allocCall;
+ StringAttr allocFunctionName;
+ Value alignmentValue;
+ SmallVector<Value, 2> argsVec;
+ if (allocOp.getAlignment()) {
+ allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName);
+ alignmentValue = rewriter.create<emitc::ConstantOp>(
+ loc, sizeTType,
+ rewriter.getIntegerAttr(indexType,
+ allocOp.getAlignment().value_or(0)));
+ argsVec.push_back(alignmentValue);
+ } else {
+ allocFunctionName = rewriter.getStringAttr(mallocFunctionName);
+ }
+
+ argsVec.push_back(totalSizeBytes);
+ ValueRange args(argsVec);
+
+ allocCall = rewriter.create<emitc::CallOpaqueOp>(
+ loc,
+ emitc::PointerType::get(
+ emitc::OpaqueType::get(rewriter.getContext(), "void")),
+ allocFunctionName, args);
+
+ emitc::PointerType targetPointerType = emitc::PointerType::get(elementType);
+ emitc::CastOp castOp = rewriter.create<emitc::CastOp>(
+ loc, targetPointerType, allocCall.getResult(0));
+
+ rewriter.replaceOp(allocOp, castOp);
+ return success();
+ }
+};
+
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
using OpConversionPattern::OpConversionPattern;
@@ -223,9 +293,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
typeConverter.addConversion(
[&](MemRefType memRefType) -> std::optional<Type> {
- if (!memRefType.hasStaticShape() ||
- !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0 ||
- llvm::is_contained(memRefType.getShape(), 0)) {
+ if (!isMemRefTypeLegalForEmitC(memRefType)) {
return {};
}
Type convertedElementType =
@@ -252,6 +320,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
void mlir::populateMemRefToEmitCConversionPatterns(
RewritePatternSet &patterns, const TypeConverter &converter) {
- patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
- ConvertStore>(converter, patterns.getContext());
+ patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
+ ConvertLoad, ConvertStore>(converter, patterns.getContext());
}
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index cf25c09..e78dd76 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -15,6 +15,7 @@
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -28,9 +29,11 @@ using namespace mlir;
namespace {
struct ConvertMemRefToEmitCPass
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
+ using Base::Base;
void runOnOperation() override {
TypeConverter converter;
-
+ ConvertMemRefToEmitCOptions options;
+ options.lowerToCpp = this->lowerToCpp;
// Fallback for other types.
converter.addConversion([](Type type) -> std::optional<Type> {
if (!emitc::isSupportedEmitCType(type))
@@ -50,6 +53,37 @@ struct ConvertMemRefToEmitCPass
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
+
+ mlir::ModuleOp module = getOperation();
+ module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
+ if (callOp.getCallee() != alignedAllocFunctionName &&
+ callOp.getCallee() != mallocFunctionName) {
+ return mlir::WalkResult::advance();
+ }
+
+ for (auto &op : *module.getBody()) {
+ emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
+ if (!includeOp) {
+ continue;
+ }
+ if (includeOp.getIsStandardInclude() &&
+ ((options.lowerToCpp &&
+ includeOp.getInclude() == cppStandardLibraryHeader) ||
+ (!options.lowerToCpp &&
+ includeOp.getInclude() == cStandardLibraryHeader))) {
+ return mlir::WalkResult::interrupt();
+ }
+ }
+
+ mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
+ StringAttr includeAttr =
+ builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader
+ : cStandardLibraryHeader);
+ builder.create<mlir::emitc::IncludeOp>(
+ module.getLoc(), includeAttr,
+ /*is_standard_include=*/builder.getUnitAttr());
+ return mlir::WalkResult::interrupt();
+ });
}
};
} // namespace
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 53a1912..dc2035b 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -24,11 +24,12 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Pass/Pass.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/MathExtras.h"
+
#include <optional>
#define DEBUG_TYPE "memref-to-llvm"
-#define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] "
namespace mlir {
#define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
@@ -575,8 +576,8 @@ private:
Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy,
getTypeConverter()->getIndexType(),
offsetPtr, idxPlusOne);
- return rewriter
- .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
+ return LLVM::LoadOp::create(rewriter, loc,
+ getTypeConverter()->getIndexType(), sizePtr)
.getResult();
}
@@ -1848,8 +1849,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
return LLVM::AtomicBinOp::xchg;
case arith::AtomicRMWKind::maximumf:
// TODO: remove this by end of 2025.
- LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw maximumf changed "
- "from fmax to fmaximum, expect more NaNs");
+ LDBG() << "the lowering of memref.atomicrmw maximumf changed "
+ "from fmax to fmaximum, expect more NaNs";
return LLVM::AtomicBinOp::fmaximum;
case arith::AtomicRMWKind::maxnumf:
return LLVM::AtomicBinOp::fmax;
@@ -1859,8 +1860,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
return LLVM::AtomicBinOp::umax;
case arith::AtomicRMWKind::minimumf:
// TODO: remove this by end of 2025.
- LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw minimum changed "
- "from fmin to fminimum, expect more NaNs");
+ LDBG() << "the lowering of memref.atomicrmw minimum changed "
+ "from fmin to fminimum, expect more NaNs";
return LLVM::AtomicBinOp::fminimum;
case arith::AtomicRMWKind::minnumf:
return LLVM::AtomicBinOp::fmin;
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 905287e1..2549a9c 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -21,19 +21,17 @@
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
#define DEBUG_TYPE "nvgpu-to-nvvm"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define DBGSE() (llvm::dbgs())
namespace mlir {
#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
@@ -1106,13 +1104,13 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
// // [0,14) start_address
dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
- LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
- << "leading_off:" << leadDimVal << "\t"
- << "stride_off :" << strideDimVal << "\t"
- << "base_offset:" << offsetVal << "\t"
- << "layout_type:" << swizzle << " ("
- << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
- << ")\n start_addr : " << baseAddr << "\n");
+ LDBG() << "Generating warpgroup.descriptor: "
+ << "leading_off:" << leadDimVal << "\t"
+ << "stride_off :" << strideDimVal << "\t"
+ << "base_offset:" << offsetVal << "\t"
+ << "layout_type:" << swizzle << " ("
+ << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
+ << ")\n start_addr : " << baseAddr;
rewriter.replaceOp(op, dsc);
return success();
@@ -1282,8 +1280,8 @@ struct NVGPUWarpgroupMmaOpLowering
} else {
llvm_unreachable("msg: not supported K shape");
}
- LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
- << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
+ LDBG() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
+ << ", n = " << wgmmaN << ", k = " << wgmmaK << "]";
}
/// Generates WGMMATypesAttr from MLIR Type
@@ -1367,9 +1365,9 @@ struct NVGPUWarpgroupMmaOpLowering
int tileShapeA = matrixTypeA.getDimSize(1);
int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
incrementVal = incrementVal >> exclude4LSB;
- LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
- << "] [wgmma descriptors] Descriptor A + "
- << incrementVal << " | \t ");
+ LDBG() << "\t\t[m: " << i << " n: " << j << " k: " << k
+ << "] [wgmma descriptors] Descriptor A + " << incrementVal
+ << " | \t ";
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(b, incrementVal));
@@ -1392,7 +1390,7 @@ struct NVGPUWarpgroupMmaOpLowering
int byte = elemB.getIntOrFloatBitWidth() / 8;
int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
incrementVal = incrementVal >> exclude4LSB;
- LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
+ LDBG() << "Descriptor B + " << incrementVal;
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(b, incrementVal));
@@ -1401,15 +1399,14 @@ struct NVGPUWarpgroupMmaOpLowering
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
/// descriptors and arranges them based on induction variables: i, j, and k.
Value generateWgmma(int i, int j, int k, Value matrixC) {
- LLVM_DEBUG(DBGS() << "\t wgmma."
- << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
- << "(A[" << (iterationM * wgmmaM) << ":"
- << (iterationM * wgmmaM) + wgmmaM << "]["
- << (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "] * "
- << " B[" << (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
- << wgmmaN << "])\n");
+ LDBG() << "\t wgmma."
+ << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK << "(A["
+ << (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM
+ << "][" << (iterationK * wgmmaK) << ":"
+ << (iterationK * wgmmaK + wgmmaK) << "] * "
+ << " B[" << (iterationK * wgmmaK) << ":"
+ << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN
+ << "])";
Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
@@ -1468,9 +1465,9 @@ struct NVGPUWarpgroupMmaOpLowering
totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
- LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
- << "] += A[" << totalM << "][" << totalK << "] * B["
- << totalK << "][" << totalN << "] ---===\n");
+ LDBG() << "===--- GEMM D[" << totalM << "][" << totalN << "] += A["
+ << totalM << "][" << totalK << "] * B[" << totalK << "][" << totalN
+ << "] ---===";
// Find the shape for one wgmma instruction
findWgmmaShape(
diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index 662ee9e..91788f9 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -25,11 +25,10 @@
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "nvvm-to-llvm"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
namespace mlir {
#define GEN_PASS_DEF_CONVERTNVVMTOLLVMPASS
@@ -52,17 +51,17 @@ struct PtxLowering
LogicalResult matchAndRewrite(BasicPtxBuilderInterface op,
PatternRewriter &rewriter) const override {
if (op.hasIntrinsic()) {
- LLVM_DEBUG(DBGS() << "Ptx Builder does not lower \n\t" << op << "\n");
+ LDBG() << "Ptx Builder does not lower \n\t" << op;
return failure();
}
SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
- LLVM_DEBUG(DBGS() << op.getPtx() << "\n");
+ LDBG() << op.getPtx();
PtxBuilder generator(op, rewriter);
op.getAsmValues(rewriter, asmValues);
for (auto &[asmValue, modifier] : asmValues) {
- LLVM_DEBUG(DBGSNL() << asmValue << "\t Modifier : " << &modifier);
+ LDBG() << asmValue << "\t Modifier : " << &modifier;
generator.insertValue(asmValue, modifier);
}
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
index 3e434ea..5bd1d49 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
@@ -49,7 +49,7 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
predList.emplace_back(pos, builder.getIsNotNull());
- if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) {
+ if (auto attr = val.getDefiningOp<pdl::AttributeOp>()) {
// If the attribute has a type or value, add a constraint.
if (Value type = attr.getValueType())
getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp b/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp
index e1a9fa59..2d9c661f 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp
@@ -14,9 +14,7 @@
#include "RootOrdering.h"
#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SmallVector.h"
-#include <queue>
#include <utility>
using namespace mlir;
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 240491a..807be7e 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -582,6 +582,7 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
// block. This should be reconsidered if we allow break/continue in SCF.
rewriter.setInsertionPointToEnd(before);
auto condOp = cast<ConditionOp>(before->getTerminator());
+ SmallVector<Value> args = llvm::to_vector(condOp.getArgs());
rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
after, condOp.getArgs(),
continuation, ValueRange());
@@ -593,7 +594,7 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
// Replace the op with values "yielded" from the "before" region, which are
// visible by dominance.
- rewriter.replaceOp(whileOp, condOp.getArgs());
+ rewriter.replaceOp(whileOp, args);
return success();
}
diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
index f191f35..badd2f6 100644
--- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
+++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
@@ -25,9 +25,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/Support/Debug.h"
#include <optional>
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index aae3271..9b61540 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1493,11 +1493,11 @@ public:
Value extended;
if (op2TypeWidth < dstTypeWidth) {
if (isUnsignedIntegerOrVector(op2Type)) {
- extended = rewriter.template create<LLVM::ZExtOp>(
- loc, dstType, adaptor.getOperand2());
+ extended =
+ LLVM::ZExtOp::create(rewriter, loc, dstType, adaptor.getOperand2());
} else {
- extended = rewriter.template create<LLVM::SExtOp>(
- loc, dstType, adaptor.getOperand2());
+ extended =
+ LLVM::SExtOp::create(rewriter, loc, dstType, adaptor.getOperand2());
}
} else if (op2TypeWidth == dstTypeWidth) {
extended = adaptor.getOperand2();
@@ -1505,8 +1505,8 @@ public:
return failure();
}
- Value result = rewriter.template create<LLVMOp>(
- loc, dstType, adaptor.getOperand1(), extended);
+ Value result =
+ LLVMOp::create(rewriter, loc, dstType, adaptor.getOperand1(), extended);
rewriter.replaceOp(op, result);
return success();
}
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 7025c5a..0ff9fb3 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -14,7 +14,6 @@
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt
index 15560aa..564f36f 100644
--- a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
+++ b/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt
@@ -1,8 +1,8 @@
-add_mlir_conversion_library(MLIRMeshToMPI
- MeshToMPI.cpp
+add_mlir_conversion_library(MLIRShardToMPI
+ ShardToMPI.cpp
ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MeshToMPI
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ShardToMPI
DEPENDS
MLIRConversionPassIncGen
@@ -17,7 +17,7 @@ add_mlir_conversion_library(MLIRMeshToMPI
MLIRLinalgTransforms
MLIRMemRefDialect
MLIRPass
- MLIRMeshDialect
+ MLIRShardDialect
MLIRMPIDialect
MLIRTransforms
)
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index 63b1fda..fa9e544 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -1,4 +1,4 @@
-//===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===//
+//===- ShardToMPI.cpp - Shard to MPI dialect conversion -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,11 +6,11 @@
//
//===----------------------------------------------------------------------===//
//
-// This file implements a translation of Mesh communication ops tp MPI ops.
+// This file implements a translation of Shard communication ops to MPI ops.
//
//===----------------------------------------------------------------------===//
-#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
+#include "mlir/Conversion/ShardToMPI/ShardToMPI.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -20,11 +20,11 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
-#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
+#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
+#include "mlir/Dialect/Shard/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
@@ -35,16 +35,15 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#define DEBUG_TYPE "mesh-to-mpi"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define DEBUG_TYPE "shard-to-mpi"
namespace mlir {
-#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS
+#define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
-using namespace mesh;
+using namespace shard;
namespace {
/// Converts a vector of OpFoldResults (ints) into vector of Values of the
@@ -177,9 +176,8 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
auto type = RankedTensorType::get({nSplits, 2}, i64);
Value resHaloSizes =
haloSizes.empty()
- ? rewriter
- .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0},
- i64)
+ ? tensor::EmptyOp::create(rewriter, loc,
+ std::array<int64_t, 2>{0, 0}, i64)
.getResult()
: tensor::FromElementsOp::create(rewriter, loc, type, haloSizes)
.getResult();
@@ -188,18 +186,18 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
// maxSplitSize+1}. Store the offsets in the tensor but set trailing
// elements for smaller split-groups to -1. Computing the max size of the
// split groups needs using collectiveProcessGroupSize (which needs the
- // MeshOp)
+ // GridOp)
Value resOffsets;
if (adaptor.getStaticShardedDimsOffsets().empty()) {
resOffsets = tensor::EmptyOp::create(rewriter, loc,
std::array<int64_t, 2>{0, 0}, i64);
} else {
SymbolTableCollection symbolTableCollection;
- auto meshOp = getMesh(op, symbolTableCollection);
+ auto gridOp = getGrid(op, symbolTableCollection);
int64_t maxSplitSize = 0;
for (auto axes : splitAxes) {
int64_t splitSize =
- collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+ collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
assert(splitSize != ShapedType::kDynamic);
maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
}
@@ -218,7 +216,7 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
int64_t curr = 0;
for (auto [i, axes] : llvm::enumerate(splitAxes)) {
int64_t splitSize =
- collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+ collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
++splitSize; // add one for the total size
ArrayRef<Value> values(&offsets[curr], splitSize);
@@ -264,20 +262,20 @@ struct ConvertProcessMultiIndexOp
SymbolTableCollection symbolTableCollection;
Location loc = op.getLoc();
- auto meshOp = getMesh(op, symbolTableCollection);
- // For now we only support static mesh shapes
- if (ShapedType::isDynamicShape(meshOp.getShape()))
+ auto gridOp = getGrid(op, symbolTableCollection);
+ // For now we only support static grid shapes
+ if (ShapedType::isDynamicShape(gridOp.getShape()))
return failure();
SmallVector<Value> dims;
llvm::transform(
- meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
+ gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
});
- Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), meshOp);
+ Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), gridOp);
auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
- // optionally extract subset of mesh axes
+ // optionally extract subset of grid axes
auto axes = adaptor.getAxes();
if (!axes.empty()) {
SmallVector<Value> subIndex;
@@ -306,13 +304,11 @@ public:
auto ctx = op.getContext();
Value commWorld =
mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx));
- auto rank =
- rewriter
- .create<mpi::CommRankOp>(
- loc,
- TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
- commWorld)
- .getRank();
+ auto rank = mpi::CommRankOp::create(
+ rewriter, loc,
+ TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
+ commWorld)
+ .getRank();
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
rank);
return success();
@@ -338,12 +334,12 @@ struct ConvertNeighborsLinearIndicesOp
Location loc = op.getLoc();
SymbolTableCollection symbolTableCollection;
- auto meshOp = getMesh(op, symbolTableCollection);
+ auto gridOp = getGrid(op, symbolTableCollection);
auto mIdx = adaptor.getDevice();
auto orgIdx = mIdx[axes[0]];
SmallVector<Value> dims;
llvm::transform(
- meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
+ gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
});
Value dimSz = dims[axes[0]];
@@ -394,14 +390,14 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
if (!sharding) {
return op->emitError()
- << "Expected SharingOp as defining op for sharding"
+ << "Expected ShardingOp as defining op for sharding"
<< " but found " << adaptor.getSharding()[0].getDefiningOp();
}
// Compute the sharded shape by applying the sharding to the input shape.
// If shardedDimsOffsets is not defined in the sharding, the shard shape is
// computed by dividing the dimension size by the number of shards in that
- // dimension (which is given by the size of the mesh axes provided in
+ // dimension (which is given by the size of the grid axes provided in
// split-axes). Odd elements get distributed to trailing shards. If a
// shardedDimsOffsets is provided, the shard shape is computed by
// subtracting the offset of the current shard from the offset of the next
@@ -431,11 +427,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
SmallVector<Value> multiIdx =
getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
- // Get the MeshOp, the mesh shape is needed to compute the sharded shape.
+ // Get the GridOp, the grid shape is needed to compute the sharded shape.
SymbolTableCollection symbolTableCollection;
- auto meshOp = getMesh(sharding, symbolTableCollection);
- // For now we only support static mesh shapes
- if (ShapedType::isDynamicShape(meshOp.getShape()))
+ auto gridOp = getGrid(sharding, symbolTableCollection);
+ // For now we only support static grid shapes
+ if (ShapedType::isDynamicShape(gridOp.getShape()))
return failure();
auto splitAxes = sharding.getSplitAxes().getAxes();
@@ -455,7 +451,7 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
tmp);
}
- // With static mesh shape the sizes of the split axes are known.
+ // With static grid shape the sizes of the split axes are known.
// Hence the start/pos for each split axes in shardDimsOffsets can be
// computed statically.
int64_t pos = 0;
@@ -475,10 +471,10 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
// Create a value from the static position in shardDimsOffsets.
Value posVal = arith::ConstantOp::create(rewriter, loc,
rewriter.getIndexAttr(pos));
- // Get the index of the local shard in the mesh axis.
+ // Get the index of the local shard in the grid axis.
Value idx = multiIdx[axes[0]];
auto numShards =
- collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+ collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
if (shardedDimsOffs) {
// If sharded dims offsets are provided, use them to compute the
// sharded shape.
@@ -556,13 +552,13 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SymbolTableCollection symbolTableCollection;
- auto mesh = adaptor.getMesh();
- mlir::mesh::MeshOp meshOp = getMesh(op, symbolTableCollection);
- if (!meshOp)
- return op->emitError() << "No mesh found for AllReduceOp";
- if (ShapedType::isDynamicShape(meshOp.getShape()))
+ auto grid = adaptor.getGrid();
+ mlir::shard::GridOp gridOp = getGrid(op, symbolTableCollection);
+ if (!gridOp)
+ return op->emitError() << "No grid found for AllReduceOp";
+ if (ShapedType::isDynamicShape(gridOp.getShape()))
return op->emitError()
- << "Dynamic mesh shape not supported in AllReduceOp";
+ << "Dynamic grid shape not supported in AllReduceOp";
ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
Value input = adaptor.getInput();
@@ -592,27 +588,27 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
linalg::CopyOp::create(iBuilder, input, buffer);
// Get an MPI_Comm_split for the AllReduce operation.
- // The color is the linear index of the process in the mesh along the
- // non-reduced axes. The key is the linear index of the process in the mesh
+ // The color is the linear index of the process in the grid along the
+ // non-reduced axes. The key is the linear index of the process in the grid
// along the reduced axes.
- SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
+ SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
iBuilder.getIndexType());
SmallVector<Value> myMultiIndex =
- ProcessMultiIndexOp::create(iBuilder, indexResultTypes, mesh)
+ ProcessMultiIndexOp::create(iBuilder, indexResultTypes, grid)
.getResult();
Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
SmallVector<Value> multiKey(myMultiIndex.size(), zero);
- auto redAxes = adaptor.getMeshAxes();
+ auto redAxes = adaptor.getGridAxes();
for (auto axis : redAxes) {
multiKey[axis] = myMultiIndex[axis];
myMultiIndex[axis] = zero;
}
Value color =
- createProcessLinearIndex(mesh, myMultiIndex, redAxes, iBuilder);
+ createProcessLinearIndex(grid, myMultiIndex, redAxes, iBuilder);
color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color);
- Value key = createProcessLinearIndex(mesh, multiKey, redAxes, iBuilder);
+ Value key = createProcessLinearIndex(grid, multiKey, redAxes, iBuilder);
key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key);
// Finally split the communicator
@@ -698,15 +694,14 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
}
auto rank = cast<ShapedType>(array.getType()).getRank();
auto opSplitAxes = adaptor.getSplitAxes().getAxes();
- auto mesh = adaptor.getMesh();
- auto meshOp = getMesh(op, symbolTableCollection);
+ auto grid = adaptor.getGrid();
+ auto gridOp = getGrid(op, symbolTableCollection);
// subviews need Index values
for (auto &sz : haloSizes) {
if (auto value = dyn_cast<Value>(sz))
- sz =
- rewriter
- .create<arith::IndexCastOp>(loc, rewriter.getIndexType(), value)
- .getResult();
+ sz = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(),
+ value)
+ .getResult();
}
// most of the offset/size/stride data is the same for all dims
@@ -745,10 +740,10 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
- SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
+ SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
rewriter.getIndexType());
auto myMultiIndex =
- ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, mesh)
+ ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, grid)
.getResult();
// traverse all split axes from high to low dim
for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
@@ -758,9 +753,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
// Get the linearized ids of the neighbors (down and up) for the
// given split
- auto tmp = rewriter
- .create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex,
- splitAxes)
+ auto tmp = NeighborsLinearIndicesOp::create(rewriter, loc, grid,
+ myMultiIndex, splitAxes)
.getResults();
// MPI operates on i32...
Value neighbourIDs[2] = {
@@ -791,7 +785,7 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
: haloSizes[currHaloDim * 2];
// Check if we need to send and/or receive
- // Processes on the mesh borders have only one neighbor
+ // Processes on the grid borders have only one neighbor
auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
auto hasFrom = arith::CmpIOp::create(
@@ -869,8 +863,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
}
};
-struct ConvertMeshToMPIPass
- : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
+struct ConvertShardToMPIPass
+ : public impl::ConvertShardToMPIPassBase<ConvertShardToMPIPass> {
using Base::Base;
/// Run the dialect converter on the module.
@@ -879,12 +873,12 @@ struct ConvertMeshToMPIPass
RewritePatternSet patterns(ctxt);
ConversionTarget target(getContext());
- // Define a type converter to convert mesh::ShardingType,
+ // Define a type converter to convert shard::ShardingType,
// mostly for use in return operations.
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
- // convert mesh::ShardingType to a tuple of RankedTensorTypes
+ // convert shard::ShardingType to a tuple of RankedTensorTypes
typeConverter.addConversion(
[](ShardingType type,
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
@@ -920,10 +914,10 @@ struct ConvertMeshToMPIPass
return results;
});
- // No mesh dialect should left after conversion...
- target.addIllegalDialect<mesh::MeshDialect>();
- // ...except the global MeshOp. MeshShapeOp which will get folded later.
- target.addLegalOp<mesh::MeshOp, mesh::MeshShapeOp>();
+ // No shard dialect should left after conversion...
+ target.addIllegalDialect<shard::ShardDialect>();
+ // ...except the global GridOp. GridShapeOp which will get folded later.
+ target.addLegalOp<shard::GridOp, shard::GridShapeOp>();
// Allow all the stuff that our patterns will convert to
target.addLegalDialect<
BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect,
@@ -951,7 +945,7 @@ struct ConvertMeshToMPIPass
// Folding patterns cannot be mixed with conversion patterns -> extra pass.
patterns.clear();
SymbolTableCollection symbolTableCollection;
- mlir::mesh::populateFoldingPatterns(patterns, symbolTableCollection);
+ mlir::shard::populateFoldingPatterns(patterns, symbolTableCollection);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
index f07386e..8cd650e 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
@@ -41,6 +41,7 @@ class ConvertTensorToSPIRVPass
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index ec55091..0e3de06 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -22,7 +22,6 @@
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
@@ -570,10 +569,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// to UIToFP.
if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
auto unrealizedCast =
- rewriter
- .create<UnrealizedConversionCastOp>(
- loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
- args[0])
+ UnrealizedConversionCastOp::create(
+ rewriter, loc,
+ rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), args[0])
.getResult(0);
return arith::UIToFPOp::create(rewriter, loc, resultTypes[0],
unrealizedCast);
@@ -869,14 +867,13 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
// Emit 'linalg.generic' op
auto resultTensor =
- opBuilder
- .create<linalg::GenericOp>(
- loc, outputTensor.getType(), operand, outputTensor, affineMaps,
- getNParallelLoopsAttrs(rank),
- [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
- // Emit 'linalg.yield' op
- linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
- })
+ linalg::GenericOp::create(
+ opBuilder, loc, outputTensor.getType(), operand, outputTensor,
+ affineMaps, getNParallelLoopsAttrs(rank),
+ [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
+ // Emit 'linalg.yield' op
+ linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
+ })
.getResult(0);
// Cast to original operand type if necessary
@@ -1156,11 +1153,9 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
inputs.push_back(input);
// First fill the output buffer with the init value.
- auto emptyTensor =
- rewriter
- .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
- dynDims)
- .getResult();
+ auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, reduceShape,
+ resultTy.getElementType(), dynDims)
+ .getResult();
auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
if (!fillValueAttr)
@@ -1168,10 +1163,10 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
op, "No initial value found for reduction operation");
auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
- auto filledTensor = rewriter
- .create<linalg::FillOp>(loc, ValueRange{fillValue},
- ValueRange{emptyTensor})
- .result();
+ auto filledTensor =
+ linalg::FillOp::create(rewriter, loc, ValueRange{fillValue},
+ ValueRange{emptyTensor})
+ .result();
outputs.push_back(filledTensor);
bool isNanIgnoreMode = false;
@@ -1187,14 +1182,12 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
auto trueAttr = rewriter.getBoolAttr(true);
auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr);
auto emptyBoolTensor =
- rewriter
- .create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(),
- dynDims)
+ tensor::EmptyOp::create(rewriter, loc, reduceShape,
+ trueValue.getType(), dynDims)
.getResult();
auto allResultsNaNTensor =
- rewriter
- .create<linalg::FillOp>(loc, ValueRange{trueValue},
- ValueRange{emptyBoolTensor})
+ linalg::FillOp::create(rewriter, loc, ValueRange{trueValue},
+ ValueRange{emptyBoolTensor})
.result();
// Note that because the linalg::ReduceOp has two variadic arguments
// (inputs and outputs) and it has the SameVariadicOperandSize trait we
@@ -1262,22 +1255,19 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
auto emptyNanTensor =
- rewriter
- .create<tensor::EmptyOp>(loc, reduceShape,
- resultTy.getElementType(), dynDims)
+ tensor::EmptyOp::create(rewriter, loc, reduceShape,
+ resultTy.getElementType(), dynDims)
.getResult();
auto nanFilledTensor =
- rewriter
- .create<linalg::FillOp>(loc, ValueRange{nanValue},
- ValueRange{emptyNanTensor})
+ linalg::FillOp::create(rewriter, loc, ValueRange{nanValue},
+ ValueRange{emptyNanTensor})
.result();
// Create an empty tensor, non need to fill this since it will be
// overwritten by the select.
auto finalEmptyTensor =
- rewriter
- .create<tensor::EmptyOp>(loc, reduceShape,
- resultTy.getElementType(), dynDims)
+ tensor::EmptyOp::create(rewriter, loc, reduceShape,
+ resultTy.getElementType(), dynDims)
.getResult();
// Do a selection between the tensors akin to:
@@ -1504,12 +1494,11 @@ public:
Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
if (valueTy.isUnsignedInteger()) {
- value = nestedBuilder
- .create<UnrealizedConversionCastOp>(
- nestedLoc,
- nestedBuilder.getIntegerType(
- valueTy.getIntOrFloatBitWidth()),
- value)
+ value = UnrealizedConversionCastOp::create(
+ nestedBuilder, nestedLoc,
+ nestedBuilder.getIntegerType(
+ valueTy.getIntOrFloatBitWidth()),
+ value)
.getResult(0);
}
if (valueTy.getIntOrFloatBitWidth() < 32) {
@@ -1558,9 +1547,8 @@ public:
}
if (outIntType.isUnsignedInteger()) {
- value = nestedBuilder
- .create<UnrealizedConversionCastOp>(nestedLoc,
- outIntType, value)
+ value = UnrealizedConversionCastOp::create(nestedBuilder, nestedLoc,
+ outIntType, value)
.getResult(0);
}
linalg::YieldOp::create(nestedBuilder, loc, value);
@@ -2096,10 +2084,9 @@ public:
Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis);
// First fill the output buffer with the init value.
- auto emptyTensor = rewriter
- .create<tensor::EmptyOp>(loc, inputTy.getShape(),
- inputTy.getElementType(),
- ArrayRef<Value>({dynDims}))
+ auto emptyTensor = tensor::EmptyOp::create(
+ rewriter, loc, inputTy.getShape(),
+ inputTy.getElementType(), ArrayRef<Value>({dynDims}))
.getResult();
SmallVector<AffineMap, 2> affineMaps = {
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
@@ -2242,23 +2229,22 @@ public:
}
// First fill the output buffer for the index.
- auto emptyTensorIdx = rewriter
- .create<tensor::EmptyOp>(loc, resultTy.getShape(),
- outElementTy, dynDims)
- .getResult();
+ auto emptyTensorIdx =
+ tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
+ outElementTy, dynDims)
+ .getResult();
auto fillValueIdx = arith::ConstantOp::create(
rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0));
auto filledTensorIdx =
- rewriter
- .create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
- ValueRange{emptyTensorIdx})
+ linalg::FillOp::create(rewriter, loc, ValueRange{fillValueIdx},
+ ValueRange{emptyTensorIdx})
.result();
// Second fill the output buffer for the running max.
- auto emptyTensorMax = rewriter
- .create<tensor::EmptyOp>(loc, resultTy.getShape(),
- inElementTy, dynDims)
- .getResult();
+ auto emptyTensorMax =
+ tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), inElementTy,
+ dynDims)
+ .getResult();
auto fillValueMaxAttr =
createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
@@ -2269,9 +2255,8 @@ public:
auto fillValueMax =
arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr);
auto filledTensorMax =
- rewriter
- .create<linalg::FillOp>(loc, ValueRange{fillValueMax},
- ValueRange{emptyTensorMax})
+ linalg::FillOp::create(rewriter, loc, ValueRange{fillValueMax},
+ ValueRange{emptyTensorMax})
.result();
// We need to reduce along the arg-max axis, with parallel operations along
@@ -2372,9 +2357,8 @@ public:
auto loc = op.getLoc();
auto emptyTensor =
- rewriter
- .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
- dynamicDims)
+ tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
+ resultElementTy, dynamicDims)
.getResult();
SmallVector<AffineMap, 2> affineMaps = {
@@ -2449,10 +2433,10 @@ public:
}
}
- auto emptyTensor = rewriter
- .create<tensor::EmptyOp>(loc, resultTy.getShape(),
- resultElementTy, dynDims)
- .getResult();
+ auto emptyTensor =
+ tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
+ resultElementTy, dynDims)
+ .getResult();
SmallVector<AffineMap, 2> affineMaps = {
rewriter.getMultiDimIdentityMap(resultTy.getRank()),
@@ -2586,10 +2570,10 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes);
auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
- auto filledTensor = rewriter
- .create<linalg::FillOp>(loc, ValueRange{fillValue},
- ValueRange{emptyTensor})
- .result();
+ auto filledTensor =
+ linalg::FillOp::create(rewriter, loc, ValueRange{fillValue},
+ ValueRange{emptyTensor})
+ .result();
return filledTensor;
}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 3a20524..da1fb20 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -64,19 +64,20 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias,
Value conv, Value result,
ArrayRef<AffineMap> indexingMaps) {
ShapedType resultTy = cast<ShapedType>(conv.getType());
- return rewriter
- .create<linalg::GenericOp>(
- loc, resultTy, ValueRange({bias, conv}), result, indexingMaps,
- getNParallelLoopsAttrs(resultTy.getRank()),
- [](OpBuilder &builder, Location loc, ValueRange args) {
- Value biasVal = args[0];
- Type resType = args[1].getType();
- if (resType != biasVal.getType()) {
- biasVal = arith::ExtSIOp::create(builder, loc, resType, biasVal);
- }
- Value added = arith::AddIOp::create(builder, loc, biasVal, args[1]);
- linalg::YieldOp::create(builder, loc, added);
- })
+ return linalg::GenericOp::create(
+ rewriter, loc, resultTy, ValueRange({bias, conv}), result,
+ indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
+ [](OpBuilder &builder, Location loc, ValueRange args) {
+ Value biasVal = args[0];
+ Type resType = args[1].getType();
+ if (resType != biasVal.getType()) {
+ biasVal =
+ arith::ExtSIOp::create(builder, loc, resType, biasVal);
+ }
+ Value added =
+ arith::AddIOp::create(builder, loc, biasVal, args[1]);
+ linalg::YieldOp::create(builder, loc, added);
+ })
.getResult(0);
}
@@ -124,23 +125,23 @@ static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter,
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
// Build the broadcast-like operation as a linalg.generic.
- return rewriter
- .create<linalg::GenericOp>(
- loc, resultTy, ValueRange({source}), result, indexingMaps,
- getNParallelLoopsAttrs(resultTy.getRank()),
- [&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
- Value biasVal = args[0];
- Type resType = args[1].getType();
- if (resType != biasVal.getType()) {
- biasVal =
- resultTy.getElementType().isFloat()
- ? arith::ExtFOp::create(builder, loc, resType, biasVal)
- .getResult()
- : arith::ExtSIOp::create(builder, loc, resType, biasVal)
- .getResult();
- }
- linalg::YieldOp::create(builder, loc, biasVal);
- })
+ return linalg::GenericOp::create(
+ rewriter, loc, resultTy, ValueRange({source}), result,
+ indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
+ [&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
+ Value biasVal = args[0];
+ Type resType = args[1].getType();
+ if (resType != biasVal.getType()) {
+ biasVal =
+ resultTy.getElementType().isFloat()
+ ? arith::ExtFOp::create(builder, loc, resType, biasVal)
+ .getResult()
+ : arith::ExtSIOp::create(builder, loc, resType,
+ biasVal)
+ .getResult();
+ }
+ linalg::YieldOp::create(builder, loc, biasVal);
+ })
.getResult(0);
}
@@ -397,21 +398,19 @@ public:
auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp);
- Value conv =
- rewriter
- .create<LinalgConvQOp>(
- loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
- ValueRange{broadcastBias}, strideAttr, dilationAttr)
- ->getResult(0);
+ Value conv = LinalgConvQOp::create(
+ rewriter, loc, resultTy,
+ ValueRange{input, weight, iZpVal, kZpVal},
+ ValueRange{broadcastBias}, strideAttr, dilationAttr)
+ ->getResult(0);
rewriter.replaceOp(op, conv);
return success();
}
- Value conv = rewriter
- .create<LinalgConvOp>(
- loc, accTy, ValueRange{input, weight},
- ValueRange{broadcastBias}, strideAttr, dilationAttr)
+ Value conv = LinalgConvOp::create(
+ rewriter, loc, accTy, ValueRange{input, weight},
+ ValueRange{broadcastBias}, strideAttr, dilationAttr)
->getResult(0);
// We may need to truncate back to the result type if the accumulator was
@@ -529,9 +528,8 @@ public:
Value emptyTensor = tensor::EmptyOp::create(
rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims);
Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr);
- Value zeroTensor = rewriter
- .create<linalg::FillOp>(loc, ValueRange{zero},
- ValueRange{emptyTensor})
+ Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero},
+ ValueRange{emptyTensor})
.result();
Value biasEmptyTensor = tensor::EmptyOp::create(
@@ -544,10 +542,9 @@ public:
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
if (hasNullZps) {
- Value conv = rewriter
- .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
- loc, linalgConvTy, ValueRange{input, weight},
- ValueRange{zeroTensor}, strideAttr, dilationAttr)
+ Value conv = linalg::DepthwiseConv2DNhwcHwcmOp::create(
+ rewriter, loc, linalgConvTy, ValueRange{input, weight},
+ ValueRange{zeroTensor}, strideAttr, dilationAttr)
.getResult(0);
// We may need to truncate back to the result type if the accumulator was
@@ -565,22 +562,20 @@ public:
rewriter, loc, resultTy, conv, reassociationMap);
Value result =
- rewriter
- .create<linalg::GenericOp>(
- loc, resultTy, ValueRange({bias, convReshape}),
- biasEmptyTensor, indexingMaps,
- getNParallelLoopsAttrs(resultRank),
- [&](OpBuilder &nestedBuilder, Location nestedLoc,
- ValueRange args) {
- Value added;
- if (llvm::isa<FloatType>(inputETy))
- added = arith::AddFOp::create(nestedBuilder, loc, args[0],
- args[1]);
- else
- added = arith::AddIOp::create(nestedBuilder, loc, args[0],
- args[1]);
- linalg::YieldOp::create(nestedBuilder, nestedLoc, added);
- })
+ linalg::GenericOp::create(
+ rewriter, loc, resultTy, ValueRange({bias, convReshape}),
+ biasEmptyTensor, indexingMaps, getNParallelLoopsAttrs(resultRank),
+ [&](OpBuilder &nestedBuilder, Location nestedLoc,
+ ValueRange args) {
+ Value added;
+ if (llvm::isa<FloatType>(inputETy))
+ added = arith::AddFOp::create(nestedBuilder, loc, args[0],
+ args[1]);
+ else
+ added = arith::AddIOp::create(nestedBuilder, loc, args[0],
+ args[1]);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc, added);
+ })
.getResult(0);
rewriter.replaceOp(op, result);
} else {
@@ -588,12 +583,11 @@ public:
IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp);
- Value conv =
- rewriter
- .create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
- loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
- ValueRange{zeroTensor}, strideAttr, dilationAttr)
- .getResult(0);
+ Value conv = linalg::DepthwiseConv2DNhwcHwcmQOp::create(
+ rewriter, loc, linalgConvTy,
+ ValueRange{input, weight, iZpVal, kZpVal},
+ ValueRange{zeroTensor}, strideAttr, dilationAttr)
+ .getResult(0);
SmallVector<ReassociationExprs, 4> reassociationMap;
createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
Value convReshape = tensor::CollapseShapeOp::create(
@@ -639,9 +633,8 @@ public:
auto emptyTensor =
tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(),
outputTy.getElementType(), filteredDims);
- Value zeroTensor = rewriter
- .create<linalg::FillOp>(loc, ValueRange{zero},
- ValueRange{emptyTensor})
+ Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero},
+ ValueRange{emptyTensor})
.result();
FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
@@ -910,20 +903,18 @@ public:
rewriter, loc, accTy.getShape(), accETy, dynamicDims);
Value filledEmptyTensor =
- rewriter
- .create<linalg::FillOp>(loc, ValueRange{initialValue},
- ValueRange{poolEmptyTensor})
+ linalg::FillOp::create(rewriter, loc, ValueRange{initialValue},
+ ValueRange{poolEmptyTensor})
.result();
Value fakeWindowDims =
tensor::EmptyOp::create(rewriter, loc, kernel, accETy);
// Sum across the pooled region.
- Value poolingOp = rewriter
- .create<linalg::PoolingNhwcSumOp>(
- loc, ArrayRef<Type>{accTy},
- ValueRange{paddedInput, fakeWindowDims},
- filledEmptyTensor, strideAttr, dilationAttr)
+ Value poolingOp = linalg::PoolingNhwcSumOp::create(
+ rewriter, loc, ArrayRef<Type>{accTy},
+ ValueRange{paddedInput, fakeWindowDims},
+ filledEmptyTensor, strideAttr, dilationAttr)
.getResult(0);
// Normalize the summed value by the number of elements grouped in each
@@ -1050,10 +1041,9 @@ public:
Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
auto scaled =
- rewriter
- .create<tosa::ApplyScaleOp>(
- loc, rewriter.getI32Type(), poolVal, multiplier, shift,
- rewriter.getStringAttr("SINGLE_ROUND"))
+ tosa::ApplyScaleOp::create(
+ rewriter, loc, rewriter.getI32Type(), poolVal, multiplier,
+ shift, rewriter.getStringAttr("SINGLE_ROUND"))
.getResult();
// If we have quantization information we need to apply output
diff --git a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
index b83f5ec9..f8efb34 100644
--- a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
+++ b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
@@ -13,7 +13,6 @@
#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 77aab85..1d1904f 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -31,10 +31,9 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "vector-to-gpu"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
namespace mlir {
#define GEN_PASS_DEF_CONVERTVECTORTOGPU
@@ -366,7 +365,7 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
// by all operations.
if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
if (!supportsMMaMatrixType(op, useNvGpu)) {
- LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
+ LDBG() << "cannot convert op: " << *op;
return true;
}
return false;
@@ -482,14 +481,12 @@ struct CombineTransferReadOpTranspose final
permutationMap.compose(transferReadOp.getPermutationMap());
auto loc = op.getLoc();
- Value result =
- rewriter
- .create<vector::TransferReadOp>(
- loc, resultType, transferReadOp.getBase(),
- transferReadOp.getIndices(), AffineMapAttr::get(newMap),
- transferReadOp.getPadding(), transferReadOp.getMask(),
- transferReadOp.getInBoundsAttr())
- .getResult();
+ Value result = vector::TransferReadOp::create(
+ rewriter, loc, resultType, transferReadOp.getBase(),
+ transferReadOp.getIndices(), AffineMapAttr::get(newMap),
+ transferReadOp.getPadding(), transferReadOp.getMask(),
+ transferReadOp.getInBoundsAttr())
+ .getResult();
// Fuse through the integer extend op.
if (extOp) {
@@ -550,7 +547,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
std::optional<int64_t> stride =
getStaticallyKnownRowStride(op.getShapedType());
if (!stride.has_value()) {
- LLVM_DEBUG(DBGS() << "no stride\n");
+ LDBG() << "no stride";
return rewriter.notifyMatchFailure(op, "no stride");
}
@@ -585,7 +582,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
isTranspose ? rewriter.getUnitAttr() : UnitAttr());
valueMapping[mappingResult] = load;
- LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
+ LDBG() << "transfer read to: " << load;
return success();
}
@@ -599,13 +596,13 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
std::optional<int64_t> stride =
getStaticallyKnownRowStride(op.getShapedType());
if (!stride.has_value()) {
- LLVM_DEBUG(DBGS() << "no stride\n");
+ LDBG() << "no stride";
return rewriter.notifyMatchFailure(op, "no stride");
}
auto it = valueMapping.find(op.getVector());
if (it == valueMapping.end()) {
- LLVM_DEBUG(DBGS() << "no mapping\n");
+ LDBG() << "no mapping";
return rewriter.notifyMatchFailure(op, "no mapping");
}
@@ -615,9 +612,9 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
(void)store;
- LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
+ LDBG() << "transfer write to: " << store;
- LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
+ LDBG() << "erase: " << op;
rewriter.eraseOp(op);
return success();
}
@@ -643,21 +640,21 @@ convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
if (failed(warpMatrixInfo)) {
- LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
+ LDBG() << "no warpMatrixInfo";
return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
}
FailureOr<nvgpu::FragmentElementInfo> regInfo =
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(regInfo)) {
- LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
+ LDBG() << "not mma sync reg info";
return rewriter.notifyMatchFailure(op, "not mma sync reg info");
}
VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
if (!dense) {
- LLVM_DEBUG(DBGS() << "not a splat\n");
+ LDBG() << "not a splat";
return rewriter.notifyMatchFailure(op, "not a splat");
}
@@ -679,8 +676,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
mlir::AffineMap map = op.getPermutationMap();
if (map.getNumResults() != 2) {
- LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` "
- "is not a 2d operand\n");
+ LDBG() << "Failed because the result of `vector.transfer_read` "
+ "is not a 2d operand";
return failure();
}
@@ -693,8 +690,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
auto exprN = dyn_cast<AffineDimExpr>(dN);
if (!exprM || !exprN) {
- LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim "
- "expressions, then transpose cannot be determined.\n");
+ LDBG() << "Failed because expressions are not affine dim "
+ "expressions, then transpose cannot be determined.";
return failure();
}
@@ -711,20 +708,20 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
if (failed(warpMatrixInfo)) {
- LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
+ LDBG() << "no warpMatrixInfo";
return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
}
FailureOr<nvgpu::FragmentElementInfo> regInfo =
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(regInfo)) {
- LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
+ LDBG() << "not mma sync reg info";
return rewriter.notifyMatchFailure(op, "not mma sync reg info");
}
FailureOr<bool> transpose = isTransposed(op);
if (failed(transpose)) {
- LLVM_DEBUG(DBGS() << "failed to determine the transpose\n");
+ LDBG() << "failed to determine the transpose";
return rewriter.notifyMatchFailure(
op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
}
@@ -733,10 +730,8 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
if (failed(params)) {
- LLVM_DEBUG(
- DBGS()
- << "failed to convert vector.transfer_read to ldmatrix. "
- << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
+ LDBG() << "failed to convert vector.transfer_read to ldmatrix. "
+ << "Op should likely not be converted to a nvgpu.ldmatrix call.";
return rewriter.notifyMatchFailure(
op, "failed to convert vector.transfer_read to ldmatrix; this op "
"likely should not be converted to a nvgpu.ldmatrix call.");
@@ -747,7 +742,7 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
FailureOr<AffineMap> offsets =
nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
if (failed(offsets)) {
- LLVM_DEBUG(DBGS() << "no offsets\n");
+ LDBG() << "no offsets";
return rewriter.notifyMatchFailure(op, "no offsets");
}
@@ -936,7 +931,7 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
vector::StoreOp::create(rewriter, loc, el, op.getBase(), newIndices);
}
- LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
+ LDBG() << "erase: " << op;
rewriter.eraseOp(op);
return success();
}
@@ -1134,9 +1129,9 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
loop.getNumResults())))
rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
- LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
- LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
- LLVM_DEBUG(DBGS() << "erase: " << loop);
+ LDBG() << "newLoop now: " << newLoop;
+ LDBG() << "stripped scf.for: " << loop;
+ LDBG() << "erase: " << loop;
rewriter.eraseOp(loop);
return newLoop;
@@ -1152,7 +1147,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
auto it = valueMapping.find(operand.value());
if (it == valueMapping.end()) {
- LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
+ LDBG() << "no value mapping for: " << operand.value();
continue;
}
argMapping.push_back(std::make_pair(
@@ -1170,7 +1165,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
}
- LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
+ LDBG() << "scf.for to: " << newForOp;
return success();
}
@@ -1193,7 +1188,7 @@ convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
}
scf::YieldOp::create(rewriter, op.getLoc(), yieldOperands);
- LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
+ LDBG() << "erase: " << op;
rewriter.eraseOp(op);
return success();
}
@@ -1246,7 +1241,7 @@ LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter,
auto globalRes = LogicalResult::success();
for (Operation *op : ops) {
- LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
+ LDBG() << "Process op: " << *op;
// Apparently callers do not want to early exit on failure here.
auto res = LogicalResult::success();
if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 9cd491c..17a79e3 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -29,7 +29,9 @@
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/APFloat.h"
+#include "llvm/IR/LLVMContext.h"
#include "llvm/Support/Casting.h"
+
#include <optional>
using namespace mlir;
@@ -1068,39 +1070,6 @@ public:
}
};
-class VectorExtractElementOpConversion
- : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
-public:
- using ConvertOpToLLVMPattern<
- vector::ExtractElementOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto vectorType = extractEltOp.getSourceVectorType();
- auto llvmType = typeConverter->convertType(vectorType.getElementType());
-
- // Bail if result type cannot be lowered.
- if (!llvmType)
- return failure();
-
- if (vectorType.getRank() == 0) {
- Location loc = extractEltOp.getLoc();
- auto idxType = rewriter.getIndexType();
- auto zero = LLVM::ConstantOp::create(rewriter, loc,
- typeConverter->convertType(idxType),
- rewriter.getIntegerAttr(idxType, 0));
- rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
- extractEltOp, llvmType, adaptor.getVector(), zero);
- return success();
- }
-
- rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
- extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
- return success();
- }
-};
-
class VectorExtractOpConversion
: public ConvertOpToLLVMPattern<vector::ExtractOp> {
public:
@@ -1204,39 +1173,6 @@ public:
}
};
-class VectorInsertElementOpConversion
- : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
-public:
- using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto vectorType = insertEltOp.getDestVectorType();
- auto llvmType = typeConverter->convertType(vectorType);
-
- // Bail if result type cannot be lowered.
- if (!llvmType)
- return failure();
-
- if (vectorType.getRank() == 0) {
- Location loc = insertEltOp.getLoc();
- auto idxType = rewriter.getIndexType();
- auto zero = LLVM::ConstantOp::create(rewriter, loc,
- typeConverter->convertType(idxType),
- rewriter.getIntegerAttr(idxType, 0));
- rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
- return success();
- }
-
- rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
- adaptor.getPosition());
- return success();
- }
-};
-
class VectorInsertOpConversion
: public ConvertOpToLLVMPattern<vector::InsertOp> {
public:
@@ -2242,8 +2178,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorGatherOpConversion, VectorScatterOpConversion>(
converter, useVectorAlignment);
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
- VectorExtractElementOpConversion, VectorExtractOpConversion,
- VectorFMAOp1DConversion, VectorInsertElementOpConversion,
+ VectorExtractOpConversion, VectorFMAOp1DConversion,
VectorInsertOpConversion, VectorPrintOpConversion,
VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 4c1047a..508f4e2 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -24,7 +24,6 @@
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -691,7 +690,7 @@ struct PrepareTransferWriteConversion
/// %lastIndex = arith.subi %length, %c1 : index
/// vector.print punctuation <open>
/// scf.for %i = %c0 to %length step %c1 {
-/// %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
+/// %el = vector.extract %v[%i] : i32 from vector<[4]xi32>
/// vector.print %el : i32 punctuation <no_punctuation>
/// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
/// scf.if %notLastIndex {
@@ -1644,7 +1643,7 @@ struct Strategy1d<TransferWriteOp> {
/// Is rewritten to approximately the following pseudo-IR:
/// ```
/// for i = 0 to 9 {
-/// %t = vector.extractelement %vec[i] : vector<9xf32>
+/// %t = vector.extract %vec[i] : f32 from vector<9xf32>
/// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
/// }
/// ```
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 986eae3..a4be7d4 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -335,63 +335,6 @@ struct VectorInsertOpConvert final
}
};
-struct VectorExtractElementOpConvert final
- : public OpConversionPattern<vector::ExtractElementOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Type resultType = getTypeConverter()->convertType(extractOp.getType());
- if (!resultType)
- return failure();
-
- if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
- rewriter.replaceOp(extractOp, adaptor.getVector());
- return success();
- }
-
- APInt cstPos;
- if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
- rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
- extractOp, resultType, adaptor.getVector(),
- rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())}));
- else
- rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
- extractOp, resultType, adaptor.getVector(), adaptor.getPosition());
- return success();
- }
-};
-
-struct VectorInsertElementOpConvert final
- : public OpConversionPattern<vector::InsertElementOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Type vectorType = getTypeConverter()->convertType(insertOp.getType());
- if (!vectorType)
- return failure();
-
- if (isa<spirv::ScalarType>(vectorType)) {
- rewriter.replaceOp(insertOp, adaptor.getSource());
- return success();
- }
-
- APInt cstPos;
- if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
- rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
- insertOp, adaptor.getSource(), adaptor.getDest(),
- cstPos.getSExtValue());
- else
- rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
- insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
- adaptor.getPosition());
- return success();
- }
-};
-
struct VectorInsertStridedSliceOpConvert final
: public OpConversionPattern<vector::InsertStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
@@ -1107,12 +1050,11 @@ struct VectorToElementOpConvert final
void mlir::populateVectorToSPIRVPatterns(
const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<
- VectorBitcastConvert, VectorBroadcastConvert,
- VectorExtractElementOpConvert, VectorExtractOpConvert,
+ VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
- VectorToElementOpConvert, VectorInsertElementOpConvert,
- VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
+ VectorToElementOpConvert, VectorInsertOpConvert,
+ VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index 2411af0..4dfcb2b 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -10,7 +10,6 @@
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 18e8270..9a0a230 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -510,6 +510,10 @@ LogicalResult DPPOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// GatherToLDSOp
+//===----------------------------------------------------------------------===//
+
LogicalResult GatherToLDSOp::verify() {
MemRefType srcType = cast<MemRefType>(getSrc().getType());
MemRefType dstType = cast<MemRefType>(getDst().getType());
@@ -546,6 +550,42 @@ LogicalResult GatherToLDSOp::verify() {
return success();
}
+namespace {
+/// If the source/target of a GatherToLDSOp is a CastOp that only removes static
+/// information or changes layout, the cast can be skipped.
+struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
+ PatternRewriter &rewriter) const override {
+ bool modified = false;
+ auto foldCast = [&](OpOperand &operand) {
+ if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
+ if (memref::CastOp::canFoldIntoConsumerOp(castOp)) {
+ rewriter.modifyOpInPlace(gatherOp,
+ [&] { operand.assign(castOp.getSource()); });
+ modified = true;
+ }
+ }
+ };
+
+ foldCast(gatherOp.getSrcMutable());
+ foldCast(gatherOp.getDstMutable());
+
+ return success(modified);
+ }
+};
+} // namespace
+
+void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<FoldGatherToLDSOfCast>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// TransposeLoadOp
+//===----------------------------------------------------------------------===//
+
LogicalResult TransposeLoadOp::verify() {
MemRefType srcType = cast<MemRefType>(getSrc().getType());
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index 17bbe54..729e3da 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -1,7 +1,8 @@
add_mlir_dialect_library(MLIRAMDGPUTransforms
EmulateAtomics.cpp
- ResolveStridedMetadata.cpp
+ FoldMemRefsOps.cpp
MaskedloadToLoad.cpp
+ ResolveStridedMetadata.cpp
ADDITIONAL_HEADER_DIRS
{$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
@@ -12,6 +13,7 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms
LINK_LIBS PUBLIC
MLIRAMDGPUDialect
MLIRAMDGPUUtils
+ MLIRAffineUtils
MLIRArithDialect
MLIRMemRefDialect
MLIRSCFDialect
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
index 37e0d2a..6d1f64e 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
@@ -99,8 +99,8 @@ static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc,
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
Type allBitsType = rewriter.getIntegerType(bitwidth);
auto allBitsVecType = VectorType::get({1}, allBitsType);
- Value bitcast = rewriter.create<vector::BitCastOp>(loc, allBitsVecType, val);
- Value scalar = rewriter.create<vector::ExtractOp>(loc, bitcast, 0);
+ Value bitcast = vector::BitCastOp::create(rewriter, loc, allBitsVecType, val);
+ Value scalar = vector::ExtractOp::create(rewriter, loc, bitcast, 0);
return scalar;
}
@@ -118,27 +118,27 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
SmallVector<NamedAttribute> loadAttrs;
patchOperandSegmentSizes(origAttrs, loadAttrs, DataArgAction::Drop);
- Value initialLoad =
- rewriter.create<RawBufferLoadOp>(loc, dataType, invariantArgs, loadAttrs);
+ Value initialLoad = RawBufferLoadOp::create(rewriter, loc, dataType,
+ invariantArgs, loadAttrs);
Block *currentBlock = rewriter.getInsertionBlock();
Block *afterAtomic =
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
Block *loopBlock = rewriter.createBlock(afterAtomic, {dataType}, {loc});
rewriter.setInsertionPointToEnd(currentBlock);
- rewriter.create<cf::BranchOp>(loc, loopBlock, initialLoad);
+ cf::BranchOp::create(rewriter, loc, loopBlock, initialLoad);
rewriter.setInsertionPointToEnd(loopBlock);
Value prevLoad = loopBlock->getArgument(0);
- Value operated = rewriter.create<ArithOp>(loc, data, prevLoad);
+ Value operated = ArithOp::create(rewriter, loc, data, prevLoad);
dataType = operated.getType();
SmallVector<NamedAttribute> cmpswapAttrs;
patchOperandSegmentSizes(origAttrs, cmpswapAttrs, DataArgAction::Duplicate);
SmallVector<Value> cmpswapArgs = {operated, prevLoad};
cmpswapArgs.append(invariantArgs.begin(), invariantArgs.end());
- Value atomicRes = rewriter.create<RawBufferAtomicCmpswapOp>(
- loc, dataType, cmpswapArgs, cmpswapAttrs);
+ Value atomicRes = RawBufferAtomicCmpswapOp::create(rewriter, loc, dataType,
+ cmpswapArgs, cmpswapAttrs);
// We care about exact bitwise equality here, so do some bitcasts.
// These will fold away during lowering to the ROCDL dialect, where
@@ -150,14 +150,15 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
if (auto floatDataTy = dyn_cast<FloatType>(dataType)) {
Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth());
prevLoadForCompare =
- rewriter.create<arith::BitcastOp>(loc, equivInt, prevLoad);
+ arith::BitcastOp::create(rewriter, loc, equivInt, prevLoad);
atomicResForCompare =
- rewriter.create<arith::BitcastOp>(loc, equivInt, atomicRes);
+ arith::BitcastOp::create(rewriter, loc, equivInt, atomicRes);
}
- Value canLeave = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, atomicResForCompare, prevLoadForCompare);
- rewriter.create<cf::CondBranchOp>(loc, canLeave, afterAtomic, ValueRange{},
- loopBlock, atomicRes);
+ Value canLeave =
+ arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
+ atomicResForCompare, prevLoadForCompare);
+ cf::CondBranchOp::create(rewriter, loc, canLeave, afterAtomic, ValueRange{},
+ loopBlock, atomicRes);
rewriter.eraseOp(atomicOp);
return success();
}
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
new file mode 100644
index 0000000..a3fdc7e
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
@@ -0,0 +1,97 @@
+//===- FoldSubviewOps.cpp - AMDGPU fold subview ops -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir::amdgpu {
+#define GEN_PASS_DEF_AMDGPUFOLDMEMREFOPSPASS
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
+
+struct AmdgpuFoldMemRefOpsPass final
+ : amdgpu::impl::AmdgpuFoldMemRefOpsPassBase<AmdgpuFoldMemRefOpsPass> {
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateAmdgpuFoldMemRefOpsPatterns(patterns);
+ walkAndApplyPatterns(getOperation(), std::move(patterns));
+ }
+};
+
+struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(GatherToLDSOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+
+ Value memrefSource;
+ SmallVector<Value> sourceIndices;
+ auto foldResult =
+ llvm::TypeSwitch<Operation *, LogicalResult>(
+ op.getSrc().getDefiningOp())
+ .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
+ // If the source is a SubViewOp, we can directly rewrite the
+ // GatherToLDSOp.
+ mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+ rewriter, loc, subviewOp.getMixedOffsets(),
+ subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
+ op.getSrcIndices(), sourceIndices);
+ memrefSource = subviewOp.getSource();
+ return success();
+ })
+ .Case<memref::ExpandShapeOp>(
+ [&](memref::ExpandShapeOp expandShapeOp) {
+ if (failed(mlir::memref::resolveSourceIndicesExpandShape(
+ loc, rewriter, expandShapeOp, op.getSrcIndices(),
+ sourceIndices, false))) {
+ return failure();
+ }
+ memrefSource = expandShapeOp.getViewSource();
+ return success();
+ })
+ .Case<memref::CollapseShapeOp>(
+ [&](memref::CollapseShapeOp collapseShapeOp) {
+ if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
+ loc, rewriter, collapseShapeOp, op.getSrcIndices(),
+ sourceIndices))) {
+ return failure();
+ }
+ memrefSource = collapseShapeOp.getViewSource();
+ return success();
+ })
+ .Default([&](Operation *op) {
+ // If the source is not a SubViewOp, ExpandShapeOp, or
+ // CollapseShapeOp, we cannot fold the GatherToLDSOp.
+ return rewriter.notifyMatchFailure(
+ op,
+ "source producer is not one of SubViewOp, ExpandShapeOp, or "
+ "CollapseShapeOp");
+ });
+
+ if (failed(foldResult)) {
+ return failure();
+ }
+
+ rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
+ op.getDst(), op.getDstIndices(),
+ op.getTransferType());
+
+ return success();
+ }
+};
+
+void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit) {
+ patterns.add<FoldMemRefOpsIntoGatherToLDSOp>(patterns.getContext(), benefit);
+}
+} // namespace mlir::amdgpu
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
index af8634c..f15c63c 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
@@ -54,11 +54,11 @@ static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
vector::MaskedLoadOp maskedOp,
bool passthru) {
VectorType vectorType = maskedOp.getVectorType();
- Value load = builder.create<vector::LoadOp>(
- loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
+ Value load = vector::LoadOp::create(
+ builder, loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
if (passthru)
- load = builder.create<arith::SelectOp>(loc, vectorType, maskedOp.getMask(),
- load, maskedOp.getPassThru());
+ load = arith::SelectOp::create(builder, loc, vectorType, maskedOp.getMask(),
+ load, maskedOp.getPassThru());
return load;
}
@@ -108,7 +108,7 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
SmallVector<OpFoldResult> indices = maskedOp.getIndices();
auto stridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
+ memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
SmallVector<OpFoldResult> strides =
stridedMetadata.getConstifiedMixedStrides();
SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
@@ -122,47 +122,47 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
// delta = bufferSize - linearizedOffset
Value vectorSizeOffset =
- rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+ arith::ConstantIndexOp::create(rewriter, loc, vectorSize);
Value linearIndex =
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
Value totalSize = getValueOrCreateConstantIndexOp(
rewriter, loc, linearizedInfo.linearizedSize);
- Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
+ Value delta = arith::SubIOp::create(rewriter, loc, totalSize, linearIndex);
// 1) check if delta < vectorSize
- Value isOutofBounds = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
+ Value isOutofBounds = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
// 2) check if (detla % elements_per_word != 0)
- Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
- loc, llvm::divideCeil(32, elementBitWidth));
- Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ne,
- rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord),
- rewriter.create<arith::ConstantIndexOp>(loc, 0));
+ Value elementsPerWord = arith::ConstantIndexOp::create(
+ rewriter, loc, llvm::divideCeil(32, elementBitWidth));
+ Value isNotWordAligned = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::ne,
+ arith::RemUIOp::create(rewriter, loc, delta, elementsPerWord),
+ arith::ConstantIndexOp::create(rewriter, loc, 0));
// We take the fallback of maskedload default lowering only it is both
// out-of-bounds and not word aligned. The fallback ensures correct results
// when loading at the boundary of the buffer since buffer load returns
// inconsistent zeros for the whole word when boundary is crossed.
Value ifCondition =
- rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
+ arith::AndIOp::create(rewriter, loc, isOutofBounds, isNotWordAligned);
auto thenBuilder = [&](OpBuilder &builder, Location loc) {
Operation *read = builder.clone(*maskedOp.getOperation());
read->setAttr(kMaskedloadNeedsMask, builder.getUnitAttr());
Value readResult = read->getResult(0);
- builder.create<scf::YieldOp>(loc, readResult);
+ scf::YieldOp::create(builder, loc, readResult);
};
auto elseBuilder = [&](OpBuilder &builder, Location loc) {
Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp,
/*passthru=*/true);
- rewriter.create<scf::YieldOp>(loc, res);
+ scf::YieldOp::create(rewriter, loc, res);
};
auto ifOp =
- rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
+ scf::IfOp::create(rewriter, loc, ifCondition, thenBuilder, elseBuilder);
rewriter.replaceOp(maskedOp, ifOp);
@@ -185,13 +185,13 @@ struct FullMaskedLoadToConditionalLoad
auto trueBuilder = [&](OpBuilder &builder, Location loc) {
Value res = createVectorLoadForMaskedLoad(builder, loc, loadOp,
/*passthru=*/false);
- rewriter.create<scf::YieldOp>(loc, res);
+ scf::YieldOp::create(rewriter, loc, res);
};
auto falseBuilder = [&](OpBuilder &builder, Location loc) {
- rewriter.create<scf::YieldOp>(loc, loadOp.getPassThru());
+ scf::YieldOp::create(rewriter, loc, loadOp.getPassThru());
};
- auto ifOp = rewriter.create<scf::IfOp>(loadOp.getLoc(), cond, trueBuilder,
- falseBuilder);
+ auto ifOp = scf::IfOp::create(rewriter, loadOp.getLoc(), cond, trueBuilder,
+ falseBuilder);
rewriter.replaceOp(loadOp, ifOp);
return success();
}
@@ -210,11 +210,12 @@ struct FullMaskedStoreToConditionalStore
Value cond = maybeCond.value();
auto trueBuilder = [&](OpBuilder &builder, Location loc) {
- rewriter.create<vector::StoreOp>(loc, storeOp.getValueToStore(),
- storeOp.getBase(), storeOp.getIndices());
- rewriter.create<scf::YieldOp>(loc);
+ vector::StoreOp::create(rewriter, loc, storeOp.getValueToStore(),
+ storeOp.getBase(), storeOp.getIndices());
+ scf::YieldOp::create(rewriter, loc);
};
- auto ifOp = rewriter.create<scf::IfOp>(storeOp.getLoc(), cond, trueBuilder);
+ auto ifOp =
+ scf::IfOp::create(rewriter, storeOp.getLoc(), cond, trueBuilder);
rewriter.replaceOp(storeOp, ifOp);
return success();
}
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp
index 195f59d..f8bab82 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp
@@ -37,8 +37,8 @@ struct ExtractStridedMetadataOnFatRawBufferCastFolder final
return rewriter.notifyMatchFailure(metadataOp,
"not a fat raw buffer cast");
Location loc = castOp.getLoc();
- auto sourceMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
- loc, castOp.getSource());
+ auto sourceMetadata = memref::ExtractStridedMetadataOp::create(
+ rewriter, loc, castOp.getSource());
SmallVector<Value> results;
if (metadataOp.getBaseBuffer().use_empty()) {
results.push_back(nullptr);
@@ -48,13 +48,13 @@ struct ExtractStridedMetadataOnFatRawBufferCastFolder final
if (baseBufferType == castOp.getResult().getType()) {
results.push_back(castOp.getResult());
} else {
- results.push_back(rewriter.create<memref::ReinterpretCastOp>(
- loc, baseBufferType, castOp.getResult(), /*offset=*/0,
+ results.push_back(memref::ReinterpretCastOp::create(
+ rewriter, loc, baseBufferType, castOp.getResult(), /*offset=*/0,
/*sizes=*/ArrayRef<int64_t>{}, /*strides=*/ArrayRef<int64_t>{}));
}
}
if (castOp.getResetOffset())
- results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
+ results.push_back(arith::ConstantIndexOp::create(rewriter, loc, 0));
else
results.push_back(sourceMetadata.getOffset());
llvm::append_range(results, sourceMetadata.getSizes());
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index 12b375b..6f3110c 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -76,8 +76,8 @@ static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
return SmallVector<Value>{
- rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
- rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr)};
+ LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, mattr),
+ LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)};
}
/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
@@ -95,15 +95,14 @@ static Value getStride(Location loc, MemRefType mType, Value base,
// Dynamic stride needs code to compute the stride at runtime.
MemRefDescriptor memrefDescriptor(base);
auto attr = rewriter.getI64IntegerAttr(bytes);
- Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
- return rewriter
- .create<LLVM::MulOp>(loc, llvmInt64Type, scale,
- memrefDescriptor.stride(rewriter, loc, preLast))
+ Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
+ return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale,
+ memrefDescriptor.stride(rewriter, loc, preLast))
.getResult();
}
// Use direct constant for static stride.
auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
- return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr)
+ return LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr)
.getResult();
}
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 8d7053c..22608a1 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -26,7 +26,7 @@
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
#include <numeric>
@@ -40,7 +40,6 @@ using llvm::divideFloorSigned;
using llvm::mod;
#define DEBUG_TYPE "affine-ops"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
#include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
@@ -1062,12 +1061,9 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
AffineMap *map,
ValueRange dims,
ValueRange syms) {
+ LDBG() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`";
AffineMap affineMinMap = minOp.getAffineMap();
- LLVM_DEBUG({
- DBGS() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`\n";
- });
-
// Check the value is positive.
for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) {
// Compare each expression in the minimum against 0.
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index f18cec5..df39544 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -202,7 +202,7 @@ void AffineDataCopyGeneration::runOnBlock(Block *block,
void AffineDataCopyGeneration::runOnOperation() {
func::FuncOp f = getOperation();
OpBuilder topBuilder(f.getBody());
- zeroIndex = topBuilder.create<arith::ConstantIndexOp>(f.getLoc(), 0);
+ zeroIndex = arith::ConstantIndexOp::create(topBuilder, f.getLoc(), 0);
// Nests that are copy-in's or copy-out's; the root AffineForOps of those
// nests are stored herein.
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index 5430bdc..c0d174a 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -58,8 +58,9 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
// Note: basis elements and their products are, definitionally,
// non-negative, so `nuw` is justified.
if (dynamicPart)
- dynamicPart = rewriter.create<arith::MulIOp>(
- loc, dynamicPart, dynamicBasis[dynamicIndex - 1], ovflags);
+ dynamicPart =
+ arith::MulIOp::create(rewriter, loc, dynamicPart,
+ dynamicBasis[dynamicIndex - 1], ovflags);
else
dynamicPart = dynamicBasis[dynamicIndex - 1];
--dynamicIndex;
@@ -74,7 +75,7 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart);
if (dynamicPart)
stride =
- rewriter.create<arith::MulIOp>(loc, dynamicPart, stride, ovflags);
+ arith::MulIOp::create(rewriter, loc, dynamicPart, stride, ovflags);
result.push_back(stride);
}
}
@@ -106,20 +107,20 @@ affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
Value initialPart =
- rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
+ arith::FloorDivSIOp::create(rewriter, loc, linearIdx, strides.front());
results.push_back(initialPart);
auto emitModTerm = [&](Value stride) -> Value {
- Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
- Value remainderNegative = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, remainder, zero);
+ Value remainder = arith::RemSIOp::create(rewriter, loc, linearIdx, stride);
+ Value remainderNegative = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::slt, remainder, zero);
// If the correction is relevant, this term is <= stride, which is known
// to be positive in `index`. Otherwise, while 2 * stride might overflow,
// this branch won't be taken, so the risk of `poison` is fine.
- Value corrected = rewriter.create<arith::AddIOp>(
- loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
- Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
- corrected, remainder);
+ Value corrected = arith::AddIOp::create(rewriter, loc, remainder, stride,
+ arith::IntegerOverflowFlags::nsw);
+ Value mod = arith::SelectOp::create(rewriter, loc, remainderNegative,
+ corrected, remainder);
return mod;
};
@@ -131,7 +132,7 @@ affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
// We know both inputs are positive, so floorDiv == div.
// This could potentially be a divui, but it's not clear if that would
// cause issues.
- Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
+ Value divided = arith::DivSIOp::create(rewriter, loc, modulus, nextStride);
results.push_back(divided);
}
@@ -167,8 +168,8 @@ LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
// our hands on an `OpOperand&` for the loop invariant counting function.
for (auto [stride, idxOp] :
llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
- Value scaledIdx = rewriter.create<arith::MulIOp>(
- loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
+ Value scaledIdx = arith::MulIOp::create(rewriter, loc, idxOp.get(), stride,
+ arith::IntegerOverflowFlags::nsw);
int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
scaledValues.emplace_back(scaledIdx, numHoistableLoops);
}
@@ -184,8 +185,8 @@ LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
Value result = scaledValues.front().first;
for (auto [scaledValue, numHoistableLoops] : llvm::drop_begin(scaledValues)) {
std::ignore = numHoistableLoops;
- result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
- arith::IntegerOverflowFlags::nsw);
+ result = arith::AddIOp::create(rewriter, loc, result, scaledValue,
+ arith::IntegerOverflowFlags::nsw);
}
rewriter.replaceOp(op, result);
return success();
diff --git a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
index 4fd0cf9..6265f46 100644
--- a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
@@ -15,13 +15,13 @@
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/InterleavedRange.h"
using namespace mlir;
using namespace mlir::affine;
#define DEBUG_TYPE "decompose-affine-ops"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
/// Count the number of loops surrounding `operand` such that operand could be
/// hoisted above.
@@ -88,8 +88,8 @@ static AffineApplyOp createSubApply(RewriterBase &rewriter,
auto rhsMap = AffineMap::get(m.getNumDims(), m.getNumSymbols(), expr, ctx);
SmallVector<Value> rhsOperands = originalOp->getOperands();
canonicalizeMapAndOperands(&rhsMap, &rhsOperands);
- return rewriter.create<AffineApplyOp>(originalOp.getLoc(), rhsMap,
- rhsOperands);
+ return AffineApplyOp::create(rewriter, originalOp.getLoc(), rhsMap,
+ rhsOperands);
}
FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
@@ -115,7 +115,7 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
return rewriter.notifyMatchFailure(
op, "only add or mul binary expr can be reassociated");
- LLVM_DEBUG(DBGS() << "Start decomposeIntoFinerGrainedOps: " << op << "\n");
+ LDBG() << "Start decomposeIntoFinerGrainedOps: " << op;
// 2. Iteratively extract the RHS subexpressions while the top-level binary
// expr kind remains the same.
@@ -125,11 +125,11 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
auto currentBinExpr = dyn_cast<AffineBinaryOpExpr>(remainingExp);
if (!currentBinExpr || currentBinExpr.getKind() != binExpr.getKind()) {
subExpressions.push_back(remainingExp);
- LLVM_DEBUG(DBGS() << "--terminal: " << subExpressions.back() << "\n");
+ LDBG() << "--terminal: " << subExpressions.back();
break;
}
subExpressions.push_back(currentBinExpr.getRHS());
- LLVM_DEBUG(DBGS() << "--subExpr: " << subExpressions.back() << "\n");
+ LDBG() << "--subExpr: " << subExpressions.back();
remainingExp = currentBinExpr.getLHS();
}
@@ -146,9 +146,7 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
llvm::stable_sort(subExpressions, [&](AffineExpr e1, AffineExpr e2) {
return getMaxSymbol(e1) < getMaxSymbol(e2);
});
- LLVM_DEBUG(
- llvm::interleaveComma(subExpressions, DBGS() << "--sorted subexprs: ");
- llvm::dbgs() << "\n");
+ LDBG() << "--sorted subexprs: " << llvm::interleaved(subExpressions);
// 4. Merge sorted subExpressions iteratively, thus achieving reassociation.
auto s0 = getAffineSymbolExpr(0, ctx);
@@ -160,9 +158,9 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
auto current = createSubApply(rewriter, op, subExpressions[0]);
for (int64_t i = 1, e = subExpressions.size(); i < e; ++i) {
Value tmp = createSubApply(rewriter, op, subExpressions[i]);
- current = rewriter.create<AffineApplyOp>(op.getLoc(), binMap,
- ValueRange{current, tmp});
- LLVM_DEBUG(DBGS() << "--reassociate into: " << current << "\n");
+ current = AffineApplyOp::create(rewriter, op.getLoc(), binMap,
+ ValueRange{current, tmp});
+ LDBG() << "--reassociate into: " << current;
}
// 5. Replace original op.
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 1d5a665..6c9adff 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -424,7 +424,7 @@ static Value createPrivateMemRef(AffineForOp forOp,
// consumer loop nests to reduce their live range. Currently they are added
// at the beginning of the block, because loop nests can be reordered
// during the fusion pass.
- Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType);
+ Value newMemRef = memref::AllocOp::create(top, forOp.getLoc(), newMemRefType);
// Build an AffineMap to remap access functions based on lower bound offsets.
SmallVector<AffineExpr, 4> remapExprs;
diff --git a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
index 05a352f..c942c02 100644
--- a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
@@ -100,16 +100,16 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
}
// Create and place the alloc right before the 'affine.for' operation.
- Value newMemRef = bOuter.create<memref::AllocOp>(
- forOp.getLoc(), newMemRefType, allocOperands);
+ Value newMemRef = memref::AllocOp::create(bOuter, forOp.getLoc(),
+ newMemRefType, allocOperands);
// Create 'iv mod 2' value to index the leading dimension.
auto d0 = bInner.getAffineDimExpr(0);
int64_t step = forOp.getStepAsInt();
auto modTwoMap =
AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0.floorDiv(step) % 2);
- auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap,
- forOp.getInductionVar());
+ auto ivModTwoOp = AffineApplyOp::create(bInner, forOp.getLoc(), modTwoMap,
+ forOp.getInductionVar());
// replaceAllMemRefUsesWith will succeed unless the forOp body has
// non-dereferencing uses of the memref (dealloc's are fine though).
@@ -130,7 +130,7 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
}
// Insert the dealloc op right after the for loop.
bOuter.setInsertionPointAfter(forOp);
- bOuter.create<memref::DeallocOp>(forOp.getLoc(), newMemRef);
+ memref::DeallocOp::create(bOuter, forOp.getLoc(), newMemRef);
return true;
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
index 1a266b7..9537d3e 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
@@ -51,10 +51,10 @@ OpFoldResult affine::materializeComputedBound(
"expected dynamic dim");
if (isa<RankedTensorType>(value.getType())) {
// A tensor dimension is used: generate a tensor.dim.
- operands.push_back(b.create<tensor::DimOp>(loc, value, *dim));
+ operands.push_back(tensor::DimOp::create(b, loc, value, *dim));
} else if (isa<MemRefType>(value.getType())) {
// A memref dimension is used: generate a memref.dim.
- operands.push_back(b.create<memref::DimOp>(loc, value, *dim));
+ operands.push_back(memref::DimOp::create(b, loc, value, *dim));
} else {
llvm_unreachable("cannot generate DimOp for unsupported shaped type");
}
@@ -76,7 +76,7 @@ OpFoldResult affine::materializeComputedBound(
operands[expr.getPosition() + boundMap.getNumDims()]);
// General case: build affine.apply op.
return static_cast<OpFoldResult>(
- b.create<affine::AffineApplyOp>(loc, boundMap, operands).getResult());
+ affine::AffineApplyOp::create(b, loc, boundMap, operands).getResult());
}
FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
index 8493b60..2521512 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
@@ -19,11 +19,10 @@
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/IntEqClasses.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/InterleavedRange.h"
#define DEBUG_TYPE "affine-min-max"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
using namespace mlir;
using namespace mlir::affine;
@@ -39,7 +38,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
ValueRange operands = affineOp.getOperands();
static constexpr bool isMin = std::is_same_v<AffineOp, AffineMinOp>;
- LLVM_DEBUG({ DBGS() << "analyzing value: `" << affineOp << "`\n"; });
+ LDBG() << "analyzing value: `" << affineOp;
// Create a `Variable` list with values corresponding to each of the results
// in the affine affineMap.
@@ -48,12 +47,9 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
[&](unsigned i) {
return Variable(affineMap.getSliceMap(i, 1), operands);
});
- LLVM_DEBUG({
- DBGS() << "- constructed variables are: "
- << llvm::interleaved_array(llvm::map_range(
- variables, [](const Variable &v) { return v.getMap(); }))
- << "`\n";
- });
+ LDBG() << "- constructed variables are: "
+ << llvm::interleaved_array(llvm::map_range(
+ variables, [](const Variable &v) { return v.getMap(); }));
// Get the comparison operation.
ComparisonOperator cmpOp =
@@ -72,10 +68,8 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
// Initialize the bound.
Variable *bound = &v;
- LLVM_DEBUG({
- DBGS() << "- inspecting variable: #" << i << ", with map: `" << v.getMap()
- << "`\n";
- });
+ LDBG() << "- inspecting variable: #" << i << ", with map: `" << v.getMap()
+ << "`\n";
// Check against the other variables.
for (size_t j = i + 1; j < variables.size(); ++j) {
@@ -87,10 +81,8 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
// Get the bound of the equivalence class or itself.
Variable *nv = bounds.lookup_or(jEqClass, &variables[j]);
- LLVM_DEBUG({
- DBGS() << "- comparing with variable: #" << jEqClass
- << ", with map: " << nv->getMap() << "\n";
- });
+ LDBG() << "- comparing with variable: #" << jEqClass
+ << ", with map: " << nv->getMap();
// Compare the variables.
FailureOr<bool> cmpResult =
@@ -98,18 +90,14 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
// The variables cannot be compared.
if (failed(cmpResult)) {
- LLVM_DEBUG({
- DBGS() << "-- classes: #" << i << ", #" << jEqClass
- << " cannot be merged\n";
- });
+ LDBG() << "-- classes: #" << i << ", #" << jEqClass
+ << " cannot be merged";
continue;
}
// Join the equivalent classes and update the bound if necessary.
- LLVM_DEBUG({
- DBGS() << "-- merging classes: #" << i << ", #" << jEqClass
- << ", is cmp(lhs, rhs): " << *cmpResult << "`\n";
- });
+ LDBG() << "-- merging classes: #" << i << ", #" << jEqClass
+ << ", is cmp(lhs, rhs): " << *cmpResult << "`";
if (*cmpResult) {
boundedClasses.join(eqClass, jEqClass);
} else {
@@ -124,8 +112,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
// Return if there's no simplification.
if (bounds.size() >= affineMap.getNumResults()) {
- LLVM_DEBUG(
- { DBGS() << "- the affine operation couldn't get simplified\n"; });
+ LDBG() << "- the affine operation couldn't get simplified";
return false;
}
@@ -135,13 +122,11 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
for (auto [k, bound] : bounds)
results.push_back(bound->getMap().getResult(0));
- LLVM_DEBUG({
- DBGS() << "- starting from map: " << affineMap << "\n";
- DBGS() << "- creating new map with: \n";
- DBGS() << "--- dims: " << affineMap.getNumDims() << "\n";
- DBGS() << "--- syms: " << affineMap.getNumSymbols() << "\n";
- DBGS() << "--- res: " << llvm::interleaved_array(results) << "\n";
- });
+ LDBG() << "- starting from map: " << affineMap;
+ LDBG() << "- creating new map with:";
+ LDBG() << "--- dims: " << affineMap.getNumDims();
+ LDBG() << "--- syms: " << affineMap.getNumSymbols();
+ LDBG() << "--- res: " << llvm::interleaved_array(results);
affineMap =
AffineMap::get(0, affineMap.getNumSymbols() + affineMap.getNumDims(),
@@ -149,7 +134,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
// Update the affine op.
rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); });
- LLVM_DEBUG({ DBGS() << "- simplified affine op: `" << affineOp << "`\n"; });
+ LDBG() << "- simplified affine op: `" << affineOp << "`";
return true;
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index 7fae260..50a0f3d 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -905,8 +905,8 @@ static void computeMemoryOpIndices(Operation *op, AffineMap map,
for (auto resultExpr : map.getResults()) {
auto singleResMap =
AffineMap::get(map.getNumDims(), map.getNumSymbols(), resultExpr);
- auto afOp = state.builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
- mapOperands);
+ auto afOp = AffineApplyOp::create(state.builder, op->getLoc(), singleResMap,
+ mapOperands);
results.push_back(afOp);
}
}
@@ -961,7 +961,7 @@ static arith::ConstantOp vectorizeConstant(arith::ConstantOp constOp,
auto vecForOp = cast<AffineForOp>(parentOp);
state.builder.setInsertionPointToStart(vecForOp.getBody());
auto newConstOp =
- state.builder.create<arith::ConstantOp>(constOp.getLoc(), vecAttr);
+ arith::ConstantOp::create(state.builder, constOp.getLoc(), vecAttr);
// Register vector replacement for future uses in the scope.
state.registerOpVectorReplacement(constOp, newConstOp);
@@ -986,8 +986,8 @@ static Operation *vectorizeAffineApplyOp(AffineApplyOp applyOp,
}
}
- auto newApplyOp = state.builder.create<AffineApplyOp>(
- applyOp.getLoc(), applyOp.getAffineMap(), updatedOperands);
+ auto newApplyOp = AffineApplyOp::create(
+ state.builder, applyOp.getLoc(), applyOp.getAffineMap(), updatedOperands);
// Register the new affine.apply result.
state.registerValueScalarReplacement(applyOp.getResult(),
@@ -1010,7 +1010,7 @@ static arith::ConstantOp createInitialVector(arith::AtomicRMWKind reductionKind,
auto vecTy = getVectorType(scalarTy, state.strategy);
auto vecAttr = DenseElementsAttr::get(vecTy, valueAttr);
auto newConstOp =
- state.builder.create<arith::ConstantOp>(oldOperand.getLoc(), vecAttr);
+ arith::ConstantOp::create(state.builder, oldOperand.getLoc(), vecAttr);
return newConstOp;
}
@@ -1062,11 +1062,11 @@ static Value createMask(AffineForOp vecForOp, VectorizationState &state) {
AffineMap ubMap = vecForOp.getUpperBoundMap();
Value ub;
if (ubMap.getNumResults() == 1)
- ub = state.builder.create<AffineApplyOp>(loc, vecForOp.getUpperBoundMap(),
- vecForOp.getUpperBoundOperands());
+ ub = AffineApplyOp::create(state.builder, loc, vecForOp.getUpperBoundMap(),
+ vecForOp.getUpperBoundOperands());
else
- ub = state.builder.create<AffineMinOp>(loc, vecForOp.getUpperBoundMap(),
- vecForOp.getUpperBoundOperands());
+ ub = AffineMinOp::create(state.builder, loc, vecForOp.getUpperBoundMap(),
+ vecForOp.getUpperBoundOperands());
// Then we compute the number of (original) iterations left in the loop.
AffineExpr subExpr =
state.builder.getAffineDimExpr(0) - state.builder.getAffineDimExpr(1);
@@ -1080,7 +1080,7 @@ static Value createMask(AffineForOp vecForOp, VectorizationState &state) {
Type maskTy = VectorType::get(state.strategy->vectorSizes,
state.builder.getIntegerType(1));
Value mask =
- state.builder.create<vector::CreateMaskOp>(loc, maskTy, itersLeft);
+ vector::CreateMaskOp::create(state.builder, loc, maskTy, itersLeft);
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ creating a mask:\n"
<< itersLeft << "\n"
@@ -1123,8 +1123,8 @@ static Operation *vectorizeUniform(Value uniformVal,
state.builder.setInsertionPointAfterValue(uniformScalarRepl);
auto vectorTy = getVectorType(uniformVal.getType(), state.strategy);
- auto bcastOp = state.builder.create<BroadcastOp>(uniformVal.getLoc(),
- vectorTy, uniformScalarRepl);
+ auto bcastOp = BroadcastOp::create(state.builder, uniformVal.getLoc(),
+ vectorTy, uniformScalarRepl);
state.registerValueVectorReplacement(uniformVal, bcastOp);
return bcastOp;
}
@@ -1256,8 +1256,8 @@ static Operation *vectorizeAffineLoad(AffineLoadOp loadOp,
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
LLVM_DEBUG(permutationMap.print(dbgs()));
- auto transfer = state.builder.create<vector::TransferReadOp>(
- loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices,
+ auto transfer = vector::TransferReadOp::create(
+ state.builder, loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices,
/*padding=*/std::nullopt, permutationMap);
// Register replacement for future uses in the scope.
@@ -1303,9 +1303,9 @@ static Operation *vectorizeAffineStore(AffineStoreOp storeOp,
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
LLVM_DEBUG(permutationMap.print(dbgs()));
- auto transfer = state.builder.create<vector::TransferWriteOp>(
- storeOp.getLoc(), vectorValue, storeOp.getMemRef(), indices,
- permutationMap);
+ auto transfer = vector::TransferWriteOp::create(
+ state.builder, storeOp.getLoc(), vectorValue, storeOp.getMemRef(),
+ indices, permutationMap);
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << transfer);
// Register replacement for future uses in the scope.
@@ -1322,7 +1322,7 @@ static bool isNeutralElementConst(arith::AtomicRMWKind reductionKind,
return false;
Attribute valueAttr = getIdentityValueAttr(reductionKind, scalarTy,
state.builder, value.getLoc());
- if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(value.getDefiningOp()))
+ if (auto constOp = value.getDefiningOp<arith::ConstantOp>())
return constOp.getValue() == valueAttr;
return false;
}
@@ -1387,10 +1387,10 @@ static Operation *vectorizeAffineForOp(AffineForOp forOp,
}
}
- auto vecForOp = state.builder.create<AffineForOp>(
- forOp.getLoc(), forOp.getLowerBoundOperands(), forOp.getLowerBoundMap(),
- forOp.getUpperBoundOperands(), forOp.getUpperBoundMap(), newStep,
- vecIterOperands,
+ auto vecForOp = AffineForOp::create(
+ state.builder, forOp.getLoc(), forOp.getLowerBoundOperands(),
+ forOp.getLowerBoundMap(), forOp.getUpperBoundOperands(),
+ forOp.getUpperBoundMap(), newStep, vecIterOperands,
/*bodyBuilder=*/[](OpBuilder &, Location, Value, ValueRange) {
// Make sure we don't create a default terminator in the loop body as
// the proper terminator will be added during vectorization.
@@ -1512,8 +1512,8 @@ static Operation *vectorizeAffineYieldOp(AffineYieldOp yieldOp,
// IterOperands are neutral element vectors.
Value neutralVal = cast<AffineForOp>(newParentOp).getInits()[i];
state.builder.setInsertionPoint(combinerOps.back());
- Value maskedReducedVal = state.builder.create<arith::SelectOp>(
- reducedVal.getLoc(), mask, reducedVal, neutralVal);
+ Value maskedReducedVal = arith::SelectOp::create(
+ state.builder, reducedVal.getLoc(), mask, reducedVal, neutralVal);
LLVM_DEBUG(
dbgs() << "\n[early-vect]+++++ masking an input to a binary op that"
"produces value for a yield Op: "
@@ -1865,7 +1865,6 @@ verifyLoopNesting(const std::vector<SmallVector<AffineForOp, 2>> &loops) {
return success();
}
-
/// External utility to vectorize affine loops in 'loops' using the n-D
/// vectorization factors in 'vectorSizes'. By default, each vectorization
/// factor is applied inner-to-outer to the loops of each loop nest.
@@ -1927,4 +1926,4 @@ LogicalResult mlir::affine::vectorizeAffineLoopNest(
if (failed(verifyLoopNesting(loops)))
return failure();
return vectorizeLoopNest(loops, strategy);
-}
+} \ No newline at end of file
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 21f69ad..2de057d 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -54,8 +54,8 @@ getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor,
OpBuilder b(forOp);
auto lbMap = forOp.getLowerBoundMap();
- auto lb = b.create<AffineApplyOp>(forOp.getLoc(), lbMap,
- forOp.getLowerBoundOperands());
+ auto lb = AffineApplyOp::create(b, forOp.getLoc(), lbMap,
+ forOp.getLowerBoundOperands());
// For each upper bound expr, get the range.
// Eg: affine.for %i = lb to min (ub1, ub2),
@@ -71,7 +71,7 @@ getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor,
auto bumpMap = AffineMap::get(tripCountMap.getNumDims(),
tripCountMap.getNumSymbols(), bumpExprs[i]);
bumpValues[i] =
- b.create<AffineApplyOp>(forOp.getLoc(), bumpMap, tripCountOperands);
+ AffineApplyOp::create(b, forOp.getLoc(), bumpMap, tripCountOperands);
}
SmallVector<AffineExpr, 4> newUbExprs(tripCountMap.getNumResults());
@@ -134,8 +134,8 @@ LogicalResult mlir::affine::promoteIfSingleIteration(AffineForOp forOp) {
builder.setInsertionPointToStart(&func.getFunctionBody().front());
else
builder.setInsertionPoint(forOp);
- auto constOp = builder.create<arith::ConstantIndexOp>(
- forOp.getLoc(), forOp.getConstantLowerBound());
+ auto constOp = arith::ConstantIndexOp::create(
+ builder, forOp.getLoc(), forOp.getConstantLowerBound());
iv.replaceAllUsesWith(constOp);
} else {
auto lbOperands = forOp.getLowerBoundOperands();
@@ -146,7 +146,7 @@ LogicalResult mlir::affine::promoteIfSingleIteration(AffineForOp forOp) {
iv.replaceAllUsesWith(lbOperands[0]);
} else {
auto affineApplyOp =
- builder.create<AffineApplyOp>(forOp.getLoc(), lbMap, lbOperands);
+ AffineApplyOp::create(builder, forOp.getLoc(), lbMap, lbOperands);
iv.replaceAllUsesWith(affineApplyOp);
}
}
@@ -181,8 +181,8 @@ static AffineForOp generateShiftedLoop(
assert(ubMap.getNumInputs() == ubOperands.size());
auto loopChunk =
- b.create<AffineForOp>(srcForOp.getLoc(), lbOperands, lbMap, ubOperands,
- ubMap, srcForOp.getStepAsInt());
+ AffineForOp::create(b, srcForOp.getLoc(), lbOperands, lbMap, ubOperands,
+ ubMap, srcForOp.getStepAsInt());
auto loopChunkIV = loopChunk.getInductionVar();
auto srcIV = srcForOp.getInductionVar();
@@ -197,8 +197,8 @@ static AffineForOp generateShiftedLoop(
// Generate the remapping if the shift is not zero: remappedIV = newIV -
// shift.
if (!srcIV.use_empty() && shift != 0) {
- auto ivRemap = bodyBuilder.create<AffineApplyOp>(
- srcForOp.getLoc(),
+ auto ivRemap = AffineApplyOp::create(
+ bodyBuilder, srcForOp.getLoc(),
bodyBuilder.getSingleDimShiftAffineMap(
-static_cast<int64_t>(srcForOp.getStepAsInt() * shift)),
loopChunkIV);
@@ -433,7 +433,7 @@ static void constructTiledLoopNest(MutableArrayRef<AffineForOp> origLoops,
for (unsigned i = 0; i < width; i++) {
OpBuilder b(topLoop);
// Loop bounds will be set later.
- AffineForOp pointLoop = b.create<AffineForOp>(loc, 0, 0);
+ AffineForOp pointLoop = AffineForOp::create(b, loc, 0, 0);
pointLoop.getBody()->getOperations().splice(
pointLoop.getBody()->begin(), topLoop->getBlock()->getOperations(),
topLoop);
@@ -447,7 +447,7 @@ static void constructTiledLoopNest(MutableArrayRef<AffineForOp> origLoops,
for (unsigned i = width; i < 2 * width; i++) {
OpBuilder b(topLoop);
// Loop bounds will be set later.
- AffineForOp tileSpaceLoop = b.create<AffineForOp>(loc, 0, 0);
+ AffineForOp tileSpaceLoop = AffineForOp::create(b, loc, 0, 0);
tileSpaceLoop.getBody()->getOperations().splice(
tileSpaceLoop.getBody()->begin(), topLoop->getBlock()->getOperations(),
topLoop);
@@ -1048,7 +1048,7 @@ LogicalResult mlir::affine::loopUnrollByFactor(
// iv' = iv + i * step
auto d0 = b.getAffineDimExpr(0);
auto bumpMap = AffineMap::get(1, 0, d0 + i * step);
- return b.create<AffineApplyOp>(forOp.getLoc(), bumpMap, iv);
+ return AffineApplyOp::create(b, forOp.getLoc(), bumpMap, iv);
},
/*annotateFn=*/annotateFn,
/*iterArgs=*/iterArgs, /*yieldedValues=*/yieldedValues);
@@ -1212,7 +1212,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
auto d0 = builder.getAffineDimExpr(0);
auto bumpMap = AffineMap::get(1, 0, d0 + i * step);
auto ivUnroll =
- builder.create<AffineApplyOp>(forOp.getLoc(), bumpMap, forOpIV);
+ AffineApplyOp::create(builder, forOp.getLoc(), bumpMap, forOpIV);
operandMaps[i - 1].map(forOpIV, ivUnroll);
}
// Clone the sub-block being unroll-jammed.
@@ -1541,8 +1541,8 @@ stripmineSink(AffineForOp forOp, uint64_t factor,
for (auto t : targets) {
// Insert newForOp before the terminator of `t`.
auto b = OpBuilder::atBlockTerminator(t.getBody());
- auto newForOp = b.create<AffineForOp>(t.getLoc(), lbOperands, lbMap,
- ubOperands, ubMap, originalStep);
+ auto newForOp = AffineForOp::create(b, t.getLoc(), lbOperands, lbMap,
+ ubOperands, ubMap, originalStep);
auto begin = t.getBody()->begin();
// Skip terminator and `newForOp` which is just before the terminator.
auto nOps = t.getBody()->getOperations().size() - 2;
@@ -1616,9 +1616,9 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef<AffineForOp> loops) {
// 1. Store the upper bound of the outermost loop in a variable.
Value prev;
if (!llvm::hasSingleElement(origUbMap.getResults()))
- prev = builder.create<AffineMinOp>(loc, origUbMap, ubOperands);
+ prev = AffineMinOp::create(builder, loc, origUbMap, ubOperands);
else
- prev = builder.create<AffineApplyOp>(loc, origUbMap, ubOperands);
+ prev = AffineApplyOp::create(builder, loc, origUbMap, ubOperands);
upperBoundSymbols.push_back(prev);
// 2. Emit code computing the upper bound of the coalesced loop as product of
@@ -1630,16 +1630,16 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef<AffineForOp> loops) {
Value upperBound;
// If upper bound map has more than one result, take their minimum.
if (!llvm::hasSingleElement(origUbMap.getResults()))
- upperBound = builder.create<AffineMinOp>(loc, origUbMap, ubOperands);
+ upperBound = AffineMinOp::create(builder, loc, origUbMap, ubOperands);
else
- upperBound = builder.create<AffineApplyOp>(loc, origUbMap, ubOperands);
+ upperBound = AffineApplyOp::create(builder, loc, origUbMap, ubOperands);
upperBoundSymbols.push_back(upperBound);
SmallVector<Value, 4> operands;
operands.push_back(prev);
operands.push_back(upperBound);
// Maintain running product of loop upper bounds.
- prev = builder.create<AffineApplyOp>(
- loc,
+ prev = AffineApplyOp::create(
+ builder, loc,
AffineMap::get(/*dimCount=*/1,
/*symbolCount=*/1,
builder.getAffineDimExpr(0) *
@@ -1668,13 +1668,12 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef<AffineForOp> loops) {
SmallVector<Value, 4> operands;
operands.push_back(previous);
operands.push_back(upperBoundSymbols[idx]);
- previous = builder.create<AffineApplyOp>(
- loc,
- AffineMap::get(
- /*dimCount=*/1, /*symbolCount=*/1,
- builder.getAffineDimExpr(0).floorDiv(
- builder.getAffineSymbolExpr(0))),
- operands);
+ previous = AffineApplyOp::create(builder, loc,
+ AffineMap::get(
+ /*dimCount=*/1, /*symbolCount=*/1,
+ builder.getAffineDimExpr(0).floorDiv(
+ builder.getAffineSymbolExpr(0))),
+ operands);
}
// Modified value of the induction variables of the nested loops after
// coalescing.
@@ -1685,8 +1684,8 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef<AffineForOp> loops) {
SmallVector<Value, 4> applyOperands;
applyOperands.push_back(previous);
applyOperands.push_back(upperBoundSymbols[idx - 1]);
- inductionVariable = builder.create<AffineApplyOp>(
- loc,
+ inductionVariable = AffineApplyOp::create(
+ builder, loc,
AffineMap::get(
/*dimCount=*/1, /*symbolCount=*/1,
builder.getAffineDimExpr(0) % builder.getAffineSymbolExpr(0)),
@@ -1723,21 +1722,21 @@ void mlir::affine::mapLoopToProcessorIds(scf::ForOp forOp,
Value linearIndex = processorId.front();
for (unsigned i = 1, e = processorId.size(); i < e; ++i) {
- auto mulApplyOp = b.create<AffineApplyOp>(
- loc, mulMap, ValueRange{linearIndex, numProcessors[i]});
- linearIndex = b.create<AffineApplyOp>(
- loc, addMap, ValueRange{mulApplyOp, processorId[i]});
+ auto mulApplyOp = AffineApplyOp::create(
+ b, loc, mulMap, ValueRange{linearIndex, numProcessors[i]});
+ linearIndex = AffineApplyOp::create(b, loc, addMap,
+ ValueRange{mulApplyOp, processorId[i]});
}
- auto mulApplyOp = b.create<AffineApplyOp>(
- loc, mulMap, ValueRange{linearIndex, forOp.getStep()});
- Value lb = b.create<AffineApplyOp>(
- loc, addMap, ValueRange{mulApplyOp, forOp.getLowerBound()});
+ auto mulApplyOp = AffineApplyOp::create(
+ b, loc, mulMap, ValueRange{linearIndex, forOp.getStep()});
+ Value lb = AffineApplyOp::create(
+ b, loc, addMap, ValueRange{mulApplyOp, forOp.getLowerBound()});
forOp.setLowerBound(lb);
Value step = forOp.getStep();
for (auto numProcs : numProcessors)
- step = b.create<AffineApplyOp>(loc, mulMap, ValueRange{numProcs, step});
+ step = AffineApplyOp::create(b, loc, mulMap, ValueRange{numProcs, step});
forOp.setStep(step);
}
@@ -1874,7 +1873,7 @@ generatePointWiseCopy(Location loc, Value memref, Value fastMemRef,
auto fastBufOffsetMap =
AffineMap::get(lbOperands.size(), 0, fastBufOffsets[d]);
- auto offset = b.create<AffineApplyOp>(loc, fastBufOffsetMap, lbOperands);
+ auto offset = AffineApplyOp::create(b, loc, fastBufOffsetMap, lbOperands);
// Construct the subscript for the fast memref being copied into/from:
// x - offset_x.
@@ -1901,16 +1900,16 @@ generatePointWiseCopy(Location loc, Value memref, Value fastMemRef,
if (!isCopyOut) {
// Copy in.
- auto load = b.create<AffineLoadOp>(loc, memref, memIndices);
- b.create<AffineStoreOp>(loc, load, fastMemRef, fastBufMap,
- fastBufMapOperands);
+ auto load = AffineLoadOp::create(b, loc, memref, memIndices);
+ AffineStoreOp::create(b, loc, load, fastMemRef, fastBufMap,
+ fastBufMapOperands);
return copyNestRoot;
}
// Copy out.
auto load =
- b.create<AffineLoadOp>(loc, fastMemRef, fastBufMap, fastBufMapOperands);
- b.create<AffineStoreOp>(loc, load, memref, memIndices);
+ AffineLoadOp::create(b, loc, fastMemRef, fastBufMap, fastBufMapOperands);
+ AffineStoreOp::create(b, loc, load, memref, memIndices);
return copyNestRoot;
}
@@ -1945,7 +1944,7 @@ static LogicalResult generateCopy(
auto f = begin->getParentOfType<FunctionOpInterface>();
OpBuilder topBuilder(f.getFunctionBody());
- Value zeroIndex = topBuilder.create<arith::ConstantIndexOp>(f.getLoc(), 0);
+ Value zeroIndex = arith::ConstantIndexOp::create(topBuilder, f.getLoc(), 0);
*sizeInBytes = 0;
@@ -2056,7 +2055,7 @@ static LogicalResult generateCopy(
memIndices.push_back(zeroIndex);
} else {
memIndices.push_back(
- top.create<arith::ConstantIndexOp>(loc, indexVal).getResult());
+ arith::ConstantIndexOp::create(top, loc, indexVal).getResult());
}
} else {
// The coordinate for the start location is just the lower bound along the
@@ -2070,7 +2069,8 @@ static LogicalResult generateCopy(
lbs[d] = lbs[d].replaceDimsAndSymbols(
/*dimReplacements=*/{}, symReplacements, lbs[d].getNumSymbols(),
/*numResultSyms=*/0);
- memIndices.push_back(b.create<AffineApplyOp>(loc, lbs[d], regionSymbols));
+ memIndices.push_back(
+ AffineApplyOp::create(b, loc, lbs[d], regionSymbols));
}
// The fast buffer is copied into at location zero; addressing is relative.
bufIndices.push_back(zeroIndex);
@@ -2094,7 +2094,7 @@ static LogicalResult generateCopy(
// Create the fast memory space buffer just before the 'affine.for'
// operation.
fastMemRef =
- prologue.create<memref::AllocOp>(loc, fastMemRefType).getResult();
+ memref::AllocOp::create(prologue, loc, fastMemRefType).getResult();
// Record it.
fastBufferMap[memref] = fastMemRef;
// fastMemRefType is a constant shaped memref.
@@ -2111,7 +2111,7 @@ static LogicalResult generateCopy(
fastMemRef = fastBufferMap[memref];
}
- auto numElementsSSA = top.create<arith::ConstantIndexOp>(loc, *numElements);
+ auto numElementsSSA = arith::ConstantIndexOp::create(top, loc, *numElements);
Value dmaStride;
Value numEltPerDmaStride;
@@ -2128,9 +2128,9 @@ static LogicalResult generateCopy(
if (!dmaStrideInfos.empty()) {
dmaStride =
- top.create<arith::ConstantIndexOp>(loc, dmaStrideInfos[0].stride);
- numEltPerDmaStride = top.create<arith::ConstantIndexOp>(
- loc, dmaStrideInfos[0].numEltPerStride);
+ arith::ConstantIndexOp::create(top, loc, dmaStrideInfos[0].stride);
+ numEltPerDmaStride = arith::ConstantIndexOp::create(
+ top, loc, dmaStrideInfos[0].numEltPerStride);
}
}
@@ -2160,21 +2160,21 @@ static LogicalResult generateCopy(
// Create a tag (single element 1-d memref) for the DMA.
auto tagMemRefType = MemRefType::get({1}, top.getIntegerType(32), {},
copyOptions.tagMemorySpace);
- auto tagMemRef = prologue.create<memref::AllocOp>(loc, tagMemRefType);
+ auto tagMemRef = memref::AllocOp::create(prologue, loc, tagMemRefType);
SmallVector<Value, 4> tagIndices({zeroIndex});
auto tagAffineMap = b.getMultiDimIdentityMap(tagIndices.size());
fullyComposeAffineMapAndOperands(&tagAffineMap, &tagIndices);
if (!region.isWrite()) {
// DMA non-blocking read from original buffer to fast buffer.
- b.create<AffineDmaStartOp>(loc, memref, memAffineMap, memIndices,
- fastMemRef, bufAffineMap, bufIndices,
- tagMemRef, tagAffineMap, tagIndices,
- numElementsSSA, dmaStride, numEltPerDmaStride);
+ AffineDmaStartOp::create(b, loc, memref, memAffineMap, memIndices,
+ fastMemRef, bufAffineMap, bufIndices, tagMemRef,
+ tagAffineMap, tagIndices, numElementsSSA,
+ dmaStride, numEltPerDmaStride);
} else {
// DMA non-blocking write from fast buffer to the original memref.
- auto op = b.create<AffineDmaStartOp>(
- loc, fastMemRef, bufAffineMap, bufIndices, memref, memAffineMap,
+ auto op = AffineDmaStartOp::create(
+ b, loc, fastMemRef, bufAffineMap, bufIndices, memref, memAffineMap,
memIndices, tagMemRef, tagAffineMap, tagIndices, numElementsSSA,
dmaStride, numEltPerDmaStride);
// Since new ops may be appended at 'end' (for outgoing DMAs), adjust the
@@ -2184,11 +2184,11 @@ static LogicalResult generateCopy(
}
// Matching DMA wait to block on completion; tag always has a 0 index.
- b.create<AffineDmaWaitOp>(loc, tagMemRef, tagAffineMap, zeroIndex,
- numElementsSSA);
+ AffineDmaWaitOp::create(b, loc, tagMemRef, tagAffineMap, zeroIndex,
+ numElementsSSA);
// Generate dealloc for the tag.
- auto tagDeallocOp = epilogue.create<memref::DeallocOp>(loc, tagMemRef);
+ auto tagDeallocOp = memref::DeallocOp::create(epilogue, loc, tagMemRef);
if (*nEnd == end && isCopyOutAtEndOfBlock)
// Since new ops are being appended (for outgoing DMAs), adjust the end to
// mark end of range of the original.
@@ -2197,7 +2197,7 @@ static LogicalResult generateCopy(
// Generate dealloc for the buffer.
if (!existingBuf) {
- auto bufDeallocOp = epilogue.create<memref::DeallocOp>(loc, fastMemRef);
+ auto bufDeallocOp = memref::DeallocOp::create(epilogue, loc, fastMemRef);
// When generating pointwise copies, `nEnd' has to be set to deallocOp on
// the fast buffer (since it marks the new end insertion point).
if (!copyOptions.generateDma && *nEnd == end && isCopyOutAtEndOfBlock)
@@ -2567,8 +2567,8 @@ AffineForOp mlir::affine::createCanonicalizedAffineForOp(
canonicalizeMapAndOperands(&ubMap, &upperOperands);
ubMap = removeDuplicateExprs(ubMap);
- return b.create<AffineForOp>(loc, lowerOperands, lbMap, upperOperands, ubMap,
- step);
+ return AffineForOp::create(b, loc, lowerOperands, lbMap, upperOperands, ubMap,
+ step);
}
/// Creates an AffineIfOp that encodes the conditional to choose between
@@ -2651,8 +2651,8 @@ static AffineIfOp createSeparationCondition(MutableArrayRef<AffineForOp> loops,
SmallVector<Value, 4> setOperands;
cst.getValues(0, cst.getNumDimAndSymbolVars(), &setOperands);
canonicalizeSetAndOperands(&ifCondSet, &setOperands);
- return b.create<AffineIfOp>(loops[0].getLoc(), ifCondSet, setOperands,
- /*withElseRegion=*/true);
+ return AffineIfOp::create(b, loops[0].getLoc(), ifCondSet, setOperands,
+ /*withElseRegion=*/true);
}
/// Create the full tile loop nest (along with its body).
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 7bb158e..845be20 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -56,7 +56,7 @@ public:
auto rhs = visit(expr.getRHS());
if (!lhs || !rhs)
return nullptr;
- auto op = builder.create<OpTy>(loc, lhs, rhs, overflowFlags);
+ auto op = OpTy::create(builder, loc, lhs, rhs, overflowFlags);
return op.getResult();
}
@@ -90,14 +90,14 @@ public:
auto rhs = visit(expr.getRHS());
assert(lhs && rhs && "unexpected affine expr lowering failure");
- Value remainder = builder.create<arith::RemSIOp>(loc, lhs, rhs);
- Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
- Value isRemainderNegative = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, remainder, zeroCst);
+ Value remainder = arith::RemSIOp::create(builder, loc, lhs, rhs);
+ Value zeroCst = arith::ConstantIndexOp::create(builder, loc, 0);
+ Value isRemainderNegative = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::slt, remainder, zeroCst);
Value correctedRemainder =
- builder.create<arith::AddIOp>(loc, remainder, rhs);
- Value result = builder.create<arith::SelectOp>(
- loc, isRemainderNegative, correctedRemainder, remainder);
+ arith::AddIOp::create(builder, loc, remainder, rhs);
+ Value result = arith::SelectOp::create(builder, loc, isRemainderNegative,
+ correctedRemainder, remainder);
return result;
}
@@ -129,18 +129,19 @@ public:
auto rhs = visit(expr.getRHS());
assert(lhs && rhs && "unexpected affine expr lowering failure");
- Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
- Value noneCst = builder.create<arith::ConstantIndexOp>(loc, -1);
- Value negative = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, lhs, zeroCst);
- Value negatedDecremented = builder.create<arith::SubIOp>(loc, noneCst, lhs);
- Value dividend =
- builder.create<arith::SelectOp>(loc, negative, negatedDecremented, lhs);
- Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
+ Value zeroCst = arith::ConstantIndexOp::create(builder, loc, 0);
+ Value noneCst = arith::ConstantIndexOp::create(builder, loc, -1);
+ Value negative = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::slt, lhs, zeroCst);
+ Value negatedDecremented =
+ arith::SubIOp::create(builder, loc, noneCst, lhs);
+ Value dividend = arith::SelectOp::create(builder, loc, negative,
+ negatedDecremented, lhs);
+ Value quotient = arith::DivSIOp::create(builder, loc, dividend, rhs);
Value correctedQuotient =
- builder.create<arith::SubIOp>(loc, noneCst, quotient);
- Value result = builder.create<arith::SelectOp>(loc, negative,
- correctedQuotient, quotient);
+ arith::SubIOp::create(builder, loc, noneCst, quotient);
+ Value result = arith::SelectOp::create(builder, loc, negative,
+ correctedQuotient, quotient);
return result;
}
@@ -168,26 +169,26 @@ public:
auto rhs = visit(expr.getRHS());
assert(lhs && rhs && "unexpected affine expr lowering failure");
- Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
- Value oneCst = builder.create<arith::ConstantIndexOp>(loc, 1);
- Value nonPositive = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sle, lhs, zeroCst);
- Value negated = builder.create<arith::SubIOp>(loc, zeroCst, lhs);
- Value decremented = builder.create<arith::SubIOp>(loc, lhs, oneCst);
- Value dividend =
- builder.create<arith::SelectOp>(loc, nonPositive, negated, decremented);
- Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
+ Value zeroCst = arith::ConstantIndexOp::create(builder, loc, 0);
+ Value oneCst = arith::ConstantIndexOp::create(builder, loc, 1);
+ Value nonPositive = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::sle, lhs, zeroCst);
+ Value negated = arith::SubIOp::create(builder, loc, zeroCst, lhs);
+ Value decremented = arith::SubIOp::create(builder, loc, lhs, oneCst);
+ Value dividend = arith::SelectOp::create(builder, loc, nonPositive, negated,
+ decremented);
+ Value quotient = arith::DivSIOp::create(builder, loc, dividend, rhs);
Value negatedQuotient =
- builder.create<arith::SubIOp>(loc, zeroCst, quotient);
+ arith::SubIOp::create(builder, loc, zeroCst, quotient);
Value incrementedQuotient =
- builder.create<arith::AddIOp>(loc, quotient, oneCst);
- Value result = builder.create<arith::SelectOp>(
- loc, nonPositive, negatedQuotient, incrementedQuotient);
+ arith::AddIOp::create(builder, loc, quotient, oneCst);
+ Value result = arith::SelectOp::create(
+ builder, loc, nonPositive, negatedQuotient, incrementedQuotient);
return result;
}
Value visitConstantExpr(AffineConstantExpr expr) {
- auto op = builder.create<arith::ConstantIndexOp>(loc, expr.getValue());
+ auto op = arith::ConstantIndexOp::create(builder, loc, expr.getValue());
return op.getResult();
}
@@ -297,9 +298,9 @@ static AffineIfOp hoistAffineIfOp(AffineIfOp ifOp, Operation *hoistOverOp) {
// block.
IRMapping operandMap;
OpBuilder b(hoistOverOp);
- auto hoistedIfOp = b.create<AffineIfOp>(ifOp.getLoc(), ifOp.getIntegerSet(),
- ifOp.getOperands(),
- /*elseBlock=*/true);
+ auto hoistedIfOp = AffineIfOp::create(b, ifOp.getLoc(), ifOp.getIntegerSet(),
+ ifOp.getOperands(),
+ /*elseBlock=*/true);
// Create a clone of hoistOverOp to use for the else branch of the hoisted
// conditional. The else block may get optimized away if empty.
@@ -368,8 +369,8 @@ mlir::affine::affineParallelize(AffineForOp forOp,
parallelReductions, [](const LoopReduction &red) { return red.value; }));
auto reductionKinds = llvm::to_vector<4>(llvm::map_range(
parallelReductions, [](const LoopReduction &red) { return red.kind; }));
- AffineParallelOp newPloop = outsideBuilder.create<AffineParallelOp>(
- loc, ValueRange(reducedValues).getTypes(), reductionKinds,
+ AffineParallelOp newPloop = AffineParallelOp::create(
+ outsideBuilder, loc, ValueRange(reducedValues).getTypes(), reductionKinds,
llvm::ArrayRef(lowerBoundMap), lowerBoundOperands,
llvm::ArrayRef(upperBoundMap), upperBoundOperands,
llvm::ArrayRef(forOp.getStepAsInt()));
@@ -540,7 +541,8 @@ void mlir::affine::normalizeAffineParallel(AffineParallelOp op) {
SmallVector<Value, 8> applyOperands{dimOperands};
applyOperands.push_back(iv);
applyOperands.append(symbolOperands.begin(), symbolOperands.end());
- auto apply = builder.create<AffineApplyOp>(op.getLoc(), map, applyOperands);
+ auto apply =
+ AffineApplyOp::create(builder, op.getLoc(), map, applyOperands);
iv.replaceAllUsesExcept(apply, apply);
}
@@ -621,8 +623,9 @@ LogicalResult mlir::affine::normalizeAffineFor(AffineForOp op,
AffineValueMap newIvToOldIvMap;
AffineValueMap::difference(lbMap, scaleIvValueMap, &newIvToOldIvMap);
(void)newIvToOldIvMap.canonicalize();
- auto newIV = opBuilder.create<AffineApplyOp>(
- loc, newIvToOldIvMap.getAffineMap(), newIvToOldIvMap.getOperands());
+ auto newIV =
+ AffineApplyOp::create(opBuilder, loc, newIvToOldIvMap.getAffineMap(),
+ newIvToOldIvMap.getOperands());
op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV);
return success();
}
@@ -1186,8 +1189,8 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
for (auto resultExpr : oldMap.getResults()) {
auto singleResMap = AffineMap::get(oldMap.getNumDims(),
oldMap.getNumSymbols(), resultExpr);
- auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
- oldMapOperands);
+ auto afOp = AffineApplyOp::create(builder, op->getLoc(), singleResMap,
+ oldMapOperands);
oldMemRefOperands.push_back(afOp);
affineApplyOps.push_back(afOp);
}
@@ -1213,8 +1216,8 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
for (auto resultExpr : indexRemap.getResults()) {
auto singleResMap = AffineMap::get(
indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
- auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
- remapOperands);
+ auto afOp = AffineApplyOp::create(builder, op->getLoc(), singleResMap,
+ remapOperands);
remapOutputs.push_back(afOp);
affineApplyOps.push_back(afOp);
}
@@ -1263,8 +1266,8 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
// AffineMapAccessInterface, we need to apply the values of `newMapOperands`
// to the `newMap` to get the correct indices.
for (unsigned i = 0; i < newMemRefRank; i++) {
- state.operands.push_back(builder.create<AffineApplyOp>(
- op->getLoc(),
+ state.operands.push_back(AffineApplyOp::create(
+ builder, op->getLoc(),
AffineMap::get(newMap.getNumDims(), newMap.getNumSymbols(),
newMap.getResult(i)),
newMapOperands));
@@ -1449,8 +1452,8 @@ void mlir::affine::createAffineComputationSlice(
for (auto resultExpr : composedMap.getResults()) {
auto singleResMap = AffineMap::get(composedMap.getNumDims(),
composedMap.getNumSymbols(), resultExpr);
- sliceOps->push_back(builder.create<AffineApplyOp>(
- opInst->getLoc(), singleResMap, composedOpOperands));
+ sliceOps->push_back(AffineApplyOp::create(
+ builder, opInst->getLoc(), singleResMap, composedOpOperands));
}
// Construct the new operands that include the results from the composed
@@ -1680,7 +1683,7 @@ static void createNewDynamicSizes(MemRefType oldMemRefType,
// Create ConstantOp for static dimension.
auto constantAttr = b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]);
inAffineApply.emplace_back(
- b.create<arith::ConstantOp>(allocOp.getLoc(), constantAttr));
+ arith::ConstantOp::create(b, allocOp.getLoc(), constantAttr));
}
}
@@ -1704,7 +1707,7 @@ static void createNewDynamicSizes(MemRefType oldMemRefType,
AffineMap newMap =
AffineMap::get(map.getNumInputs(), map.getNumSymbols(), newMapOutput);
Value affineApp =
- b.create<AffineApplyOp>(allocOp.getLoc(), newMap, inAffineApply);
+ AffineApplyOp::create(b, allocOp.getLoc(), newMap, inAffineApply);
newDynamicSizes.emplace_back(affineApp);
}
newDimIdx++;
@@ -1739,12 +1742,11 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp allocOp) {
createNewDynamicSizes(oldMemRefType, newMemRefType, layoutMap, allocOp, b,
newDynamicSizes);
// Add the new dynamic sizes in new AllocOp.
- newAlloc =
- b.create<AllocLikeOp>(allocOp.getLoc(), newMemRefType, newDynamicSizes,
- allocOp.getAlignmentAttr());
+ newAlloc = AllocLikeOp::create(b, allocOp.getLoc(), newMemRefType,
+ newDynamicSizes, allocOp.getAlignmentAttr());
} else {
- newAlloc = b.create<AllocLikeOp>(allocOp.getLoc(), newMemRefType,
- allocOp.getAlignmentAttr());
+ newAlloc = AllocLikeOp::create(b, allocOp.getLoc(), newMemRefType,
+ allocOp.getAlignmentAttr());
}
// Replace all uses of the old memref.
if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
@@ -1802,10 +1804,10 @@ mlir::affine::normalizeMemRef(memref::ReinterpretCastOp reinterpretCastOp) {
for (unsigned i = 0, e = memrefType.getRank(); i < e; i++) {
if (memrefType.isDynamicDim(i))
mapOperands[i] =
- b.create<arith::SubIOp>(loc, oldSizes[0].getType(), oldSizes[idx++],
- b.create<arith::ConstantIndexOp>(loc, 1));
+ arith::SubIOp::create(b, loc, oldSizes[0].getType(), oldSizes[idx++],
+ arith::ConstantIndexOp::create(b, loc, 1));
else
- mapOperands[i] = b.create<arith::ConstantIndexOp>(loc, oldShape[i] - 1);
+ mapOperands[i] = arith::ConstantIndexOp::create(b, loc, oldShape[i] - 1);
}
for (unsigned i = 0, e = oldStrides.size(); i < e; i++)
mapOperands[memrefType.getRank() + i] = oldStrides[i];
@@ -1815,20 +1817,20 @@ mlir::affine::normalizeMemRef(memref::ReinterpretCastOp reinterpretCastOp) {
for (unsigned i = 0; i < newRank; i++) {
if (!newMemRefType.isDynamicDim(i))
continue;
- newSizes.push_back(b.create<AffineApplyOp>(
- loc,
+ newSizes.push_back(AffineApplyOp::create(
+ b, loc,
AffineMap::get(oldLayoutMap.getNumDims(), oldLayoutMap.getNumSymbols(),
oldLayoutMap.getResult(i)),
mapOperands));
}
for (unsigned i = 0, e = newSizes.size(); i < e; i++) {
newSizes[i] =
- b.create<arith::AddIOp>(loc, newSizes[i].getType(), newSizes[i],
- b.create<arith::ConstantIndexOp>(loc, 1));
+ arith::AddIOp::create(b, loc, newSizes[i].getType(), newSizes[i],
+ arith::ConstantIndexOp::create(b, loc, 1));
}
// Create the new reinterpret_cast op.
- auto newReinterpretCast = b.create<memref::ReinterpretCastOp>(
- loc, newMemRefType, reinterpretCastOp.getSource(),
+ auto newReinterpretCast = memref::ReinterpretCastOp::create(
+ b, loc, newMemRefType, reinterpretCastOp.getSource(),
/*offsets=*/ValueRange(), newSizes,
/*strides=*/ValueRange(),
/*static_offsets=*/newStaticOffsets,
diff --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
index ebcb951..e7cbee6 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
@@ -64,7 +64,7 @@ Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
- return builder.create<ub::PoisonOp>(loc, type, poison);
+ return ub::PoisonOp::create(builder, loc, type, poison);
return ConstantOp::materialize(builder, value, type, loc);
}
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 910334b..488c3c3 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2498,7 +2498,7 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
matchPattern(adaptor.getFalseValue(), m_Zero()))
return condition;
- if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
+ if (auto cmp = condition.getDefiningOp<arith::CmpIOp>()) {
auto pred = cmp.getPredicate();
if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
auto cmpLhs = cmp.getLhs();
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp
index f2e7732..9199dcc 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp
@@ -67,8 +67,8 @@ struct SelectOpInterface
return state.getMemrefWithUniqueOwnership(builder, value,
value.getParentBlock());
- Value ownership = builder.create<arith::SelectOp>(
- op->getLoc(), selectOp.getCondition(),
+ Value ownership = arith::SelectOp::create(
+ builder, op->getLoc(), selectOp.getCondition(),
state.getOwnership(selectOp.getTrueValue(), block).getIndicator(),
state.getOwnership(selectOp.getFalseValue(), block).getIndicator());
return {selectOp.getResult(), ownership};
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index afee162..b073a31 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -170,10 +170,10 @@ struct SelectOpInterface
return failure();
if (trueBuffer.getType() != *targetType)
trueBuffer =
- rewriter.create<memref::CastOp>(loc, *targetType, trueBuffer);
+ memref::CastOp::create(rewriter, loc, *targetType, trueBuffer);
if (falseBuffer.getType() != *targetType)
falseBuffer =
- rewriter.create<memref::CastOp>(loc, *targetType, falseBuffer);
+ memref::CastOp::create(rewriter, loc, *targetType, falseBuffer);
}
replaceOpWithNewBufferizedOp<arith::SelectOp>(
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index f96bda6..93682a9 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -27,7 +27,7 @@ add_mlir_dialect_library(MLIRArithTransforms
MLIRInferIntRangeInterface
MLIRIR
MLIRMemRefDialect
- MLIRMeshDialect
+ MLIRShardDialect
MLIRPass
MLIRShardingInterface
MLIRTensorDialect
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 55b757c..7626d35 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -75,7 +75,7 @@ LogicalResult EmulateFloatPattern::matchAndRewrite(
for (auto [res, oldType, newType] : llvm::zip_equal(
MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
if (oldType != newType) {
- auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res);
+ auto truncFOp = arith::TruncFOp::create(rewriter, loc, oldType, res);
truncFOp.setFastmath(arith::FastMathFlags::contract);
res = truncFOp.getResult();
}
@@ -98,7 +98,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
});
converter.addTargetMaterialization(
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
- auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+ auto extFOp = arith::ExtFOp::create(b, loc, target, input);
extFOp.setFastmath(arith::FastMathFlags::contract);
return extFOp;
});
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index d5d1559..efe6ad2 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -72,7 +72,7 @@ static Value extractLastDimSlice(ConversionPatternRewriter &rewriter,
// Scalarize the result in case of 1D vectors.
if (shape.size() == 1)
- return rewriter.create<vector::ExtractOp>(loc, input, lastOffset);
+ return vector::ExtractOp::create(rewriter, loc, input, lastOffset);
SmallVector<int64_t> offsets(shape.size(), 0);
offsets.back() = lastOffset;
@@ -80,8 +80,8 @@ static Value extractLastDimSlice(ConversionPatternRewriter &rewriter,
sizes.back() = 1;
SmallVector<int64_t> strides(shape.size(), 1);
- return rewriter.create<vector::ExtractStridedSliceOp>(loc, input, offsets,
- sizes, strides);
+ return vector::ExtractStridedSliceOp::create(rewriter, loc, input, offsets,
+ sizes, strides);
}
/// Extracts two vector slices from the `input` whose type is `vector<...x2T>`,
@@ -107,7 +107,7 @@ static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter,
assert(shape.back() == 1 && "Expected the last vector dim to be x1");
auto newVecTy = VectorType::get(shape.drop_back(), vecTy.getElementType());
- return rewriter.create<vector::ShapeCastOp>(loc, newVecTy, input);
+ return vector::ShapeCastOp::create(rewriter, loc, newVecTy, input);
}
/// Performs a vector shape cast to append an x1 dimension. If the
@@ -122,7 +122,7 @@ static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc,
auto newShape = llvm::to_vector(vecTy.getShape());
newShape.push_back(1);
auto newTy = VectorType::get(newShape, vecTy.getElementType());
- return rewriter.create<vector::ShapeCastOp>(loc, newTy, input);
+ return vector::ShapeCastOp::create(rewriter, loc, newTy, input);
}
/// Inserts the `source` vector slice into the `dest` vector at offset
@@ -136,13 +136,13 @@ static Value insertLastDimSlice(ConversionPatternRewriter &rewriter,
// Handle scalar source.
if (isa<IntegerType>(source.getType()))
- return rewriter.create<vector::InsertOp>(loc, source, dest, lastOffset);
+ return vector::InsertOp::create(rewriter, loc, source, dest, lastOffset);
SmallVector<int64_t> offsets(shape.size(), 0);
offsets.back() = lastOffset;
SmallVector<int64_t> strides(shape.size(), 1);
- return rewriter.create<vector::InsertStridedSliceOp>(loc, source, dest,
- offsets, strides);
+ return vector::InsertStridedSliceOp::create(rewriter, loc, source, dest,
+ offsets, strides);
}
/// Constructs a new vector of type `resultType` by creating a series of
@@ -254,12 +254,12 @@ struct ConvertAddI final : OpConversionPattern<arith::AddIOp> {
extractLastDimHalves(rewriter, loc, adaptor.getRhs());
auto lowSum =
- rewriter.create<arith::AddUIExtendedOp>(loc, lhsElem0, rhsElem0);
+ arith::AddUIExtendedOp::create(rewriter, loc, lhsElem0, rhsElem0);
Value overflowVal =
- rewriter.create<arith::ExtUIOp>(loc, newElemTy, lowSum.getOverflow());
+ arith::ExtUIOp::create(rewriter, loc, newElemTy, lowSum.getOverflow());
- Value high0 = rewriter.create<arith::AddIOp>(loc, overflowVal, lhsElem1);
- Value high = rewriter.create<arith::AddIOp>(loc, high0, rhsElem1);
+ Value high0 = arith::AddIOp::create(rewriter, loc, overflowVal, lhsElem1);
+ Value high = arith::AddIOp::create(rewriter, loc, high0, rhsElem1);
Value resultVec =
constructResultVector(rewriter, loc, newTy, {lowSum.getSum(), high});
@@ -293,8 +293,8 @@ struct ConvertBitwiseBinary final : OpConversionPattern<BinaryOp> {
auto [rhsElem0, rhsElem1] =
extractLastDimHalves(rewriter, loc, adaptor.getRhs());
- Value resElem0 = rewriter.create<BinaryOp>(loc, lhsElem0, rhsElem0);
- Value resElem1 = rewriter.create<BinaryOp>(loc, lhsElem1, rhsElem1);
+ Value resElem0 = BinaryOp::create(rewriter, loc, lhsElem0, rhsElem0);
+ Value resElem1 = BinaryOp::create(rewriter, loc, lhsElem1, rhsElem1);
Value resultVec =
constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
rewriter.replaceOp(op, resultVec);
@@ -346,26 +346,26 @@ struct ConvertCmpI final : OpConversionPattern<arith::CmpIOp> {
extractLastDimHalves(rewriter, loc, adaptor.getRhs());
Value lowCmp =
- rewriter.create<arith::CmpIOp>(loc, lowPred, lhsElem0, rhsElem0);
+ arith::CmpIOp::create(rewriter, loc, lowPred, lhsElem0, rhsElem0);
Value highCmp =
- rewriter.create<arith::CmpIOp>(loc, highPred, lhsElem1, rhsElem1);
+ arith::CmpIOp::create(rewriter, loc, highPred, lhsElem1, rhsElem1);
Value cmpResult{};
switch (highPred) {
case arith::CmpIPredicate::eq: {
- cmpResult = rewriter.create<arith::AndIOp>(loc, lowCmp, highCmp);
+ cmpResult = arith::AndIOp::create(rewriter, loc, lowCmp, highCmp);
break;
}
case arith::CmpIPredicate::ne: {
- cmpResult = rewriter.create<arith::OrIOp>(loc, lowCmp, highCmp);
+ cmpResult = arith::OrIOp::create(rewriter, loc, lowCmp, highCmp);
break;
}
default: {
// Handle inequality checks.
- Value highEq = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1);
+ Value highEq = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1);
cmpResult =
- rewriter.create<arith::SelectOp>(loc, highEq, lowCmp, highCmp);
+ arith::SelectOp::create(rewriter, loc, highEq, lowCmp, highCmp);
break;
}
}
@@ -401,14 +401,14 @@ struct ConvertMulI final : OpConversionPattern<arith::MulIOp> {
// Multiplying two i2N integers produces (at most) an i4N result, but
// because the calculation of top i2N is not necessary, we omit it.
auto mulLowLow =
- rewriter.create<arith::MulUIExtendedOp>(loc, lhsElem0, rhsElem0);
- Value mulLowHi = rewriter.create<arith::MulIOp>(loc, lhsElem0, rhsElem1);
- Value mulHiLow = rewriter.create<arith::MulIOp>(loc, lhsElem1, rhsElem0);
+ arith::MulUIExtendedOp::create(rewriter, loc, lhsElem0, rhsElem0);
+ Value mulLowHi = arith::MulIOp::create(rewriter, loc, lhsElem0, rhsElem1);
+ Value mulHiLow = arith::MulIOp::create(rewriter, loc, lhsElem1, rhsElem0);
Value resLow = mulLowLow.getLow();
Value resHi =
- rewriter.create<arith::AddIOp>(loc, mulLowLow.getHigh(), mulLowHi);
- resHi = rewriter.create<arith::AddIOp>(loc, resHi, mulHiLow);
+ arith::AddIOp::create(rewriter, loc, mulLowLow.getHigh(), mulLowHi);
+ resHi = arith::AddIOp::create(rewriter, loc, resHi, mulHiLow);
Value resultVec =
constructResultVector(rewriter, loc, newTy, {resLow, resHi});
@@ -443,10 +443,10 @@ struct ConvertExtSI final : OpConversionPattern<arith::ExtSIOp> {
loc, newResultComponentTy, newOperand);
Value operandZeroCst =
createScalarOrSplatConstant(rewriter, loc, newResultComponentTy, 0);
- Value signBit = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, extended, operandZeroCst);
+ Value signBit = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::slt, extended, operandZeroCst);
Value signValue =
- rewriter.create<arith::ExtSIOp>(loc, newResultComponentTy, signBit);
+ arith::ExtSIOp::create(rewriter, loc, newResultComponentTy, signBit);
Value resultVec =
constructResultVector(rewriter, loc, newTy, {extended, signValue});
@@ -508,7 +508,7 @@ struct ConvertMaxMin final : OpConversionPattern<SourceOp> {
// Rewrite Max*I/Min*I as compare and select over original operands. Let
// the CmpI and Select emulation patterns handle the final legalization.
Value cmp =
- rewriter.create<arith::CmpIOp>(loc, CmpPred, op.getLhs(), op.getRhs());
+ arith::CmpIOp::create(rewriter, loc, CmpPred, op.getLhs(), op.getRhs());
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, op.getLhs(),
op.getRhs());
return success();
@@ -587,7 +587,7 @@ struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> {
// Sign or zero-extend the result. Let the matching conversion pattern
// legalize the extension op.
Value underlyingVal =
- rewriter.create<CastOp>(loc, narrowTy, adaptor.getIn());
+ CastOp::create(rewriter, loc, narrowTy, adaptor.getIn());
rewriter.replaceOpWithNewOp<ExtensionOp>(op, resultType, underlyingVal);
return success();
}
@@ -616,9 +616,9 @@ struct ConvertSelect final : OpConversionPattern<arith::SelectOp> {
Value cond = appendX1Dim(rewriter, loc, adaptor.getCondition());
Value resElem0 =
- rewriter.create<arith::SelectOp>(loc, cond, trueElem0, falseElem0);
+ arith::SelectOp::create(rewriter, loc, cond, trueElem0, falseElem0);
Value resElem1 =
- rewriter.create<arith::SelectOp>(loc, cond, trueElem1, falseElem1);
+ arith::SelectOp::create(rewriter, loc, cond, trueElem1, falseElem1);
Value resultVec =
constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
rewriter.replaceOp(op, resultVec);
@@ -680,33 +680,33 @@ struct ConvertShLI final : OpConversionPattern<arith::ShLIOp> {
Value elemBitWidth =
createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth);
- Value illegalElemShift = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
+ Value illegalElemShift = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
Value shiftedElem0 =
- rewriter.create<arith::ShLIOp>(loc, lhsElem0, rhsElem0);
- Value resElem0 = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
- zeroCst, shiftedElem0);
+ arith::ShLIOp::create(rewriter, loc, lhsElem0, rhsElem0);
+ Value resElem0 = arith::SelectOp::create(rewriter, loc, illegalElemShift,
+ zeroCst, shiftedElem0);
- Value cappedShiftAmount = rewriter.create<arith::SelectOp>(
- loc, illegalElemShift, elemBitWidth, rhsElem0);
+ Value cappedShiftAmount = arith::SelectOp::create(
+ rewriter, loc, illegalElemShift, elemBitWidth, rhsElem0);
Value rightShiftAmount =
- rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
+ arith::SubIOp::create(rewriter, loc, elemBitWidth, cappedShiftAmount);
Value shiftedRight =
- rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rightShiftAmount);
+ arith::ShRUIOp::create(rewriter, loc, lhsElem0, rightShiftAmount);
Value overshotShiftAmount =
- rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
+ arith::SubIOp::create(rewriter, loc, rhsElem0, elemBitWidth);
Value shiftedLeft =
- rewriter.create<arith::ShLIOp>(loc, lhsElem0, overshotShiftAmount);
+ arith::ShLIOp::create(rewriter, loc, lhsElem0, overshotShiftAmount);
Value shiftedElem1 =
- rewriter.create<arith::ShLIOp>(loc, lhsElem1, rhsElem0);
- Value resElem1High = rewriter.create<arith::SelectOp>(
- loc, illegalElemShift, zeroCst, shiftedElem1);
- Value resElem1Low = rewriter.create<arith::SelectOp>(
- loc, illegalElemShift, shiftedLeft, shiftedRight);
+ arith::ShLIOp::create(rewriter, loc, lhsElem1, rhsElem0);
+ Value resElem1High = arith::SelectOp::create(
+ rewriter, loc, illegalElemShift, zeroCst, shiftedElem1);
+ Value resElem1Low = arith::SelectOp::create(rewriter, loc, illegalElemShift,
+ shiftedLeft, shiftedRight);
Value resElem1 =
- rewriter.create<arith::OrIOp>(loc, resElem1Low, resElem1High);
+ arith::OrIOp::create(rewriter, loc, resElem1Low, resElem1High);
Value resultVec =
constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
@@ -769,33 +769,33 @@ struct ConvertShRUI final : OpConversionPattern<arith::ShRUIOp> {
Value elemBitWidth =
createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth);
- Value illegalElemShift = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
+ Value illegalElemShift = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
Value shiftedElem0 =
- rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rhsElem0);
- Value resElem0Low = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
- zeroCst, shiftedElem0);
+ arith::ShRUIOp::create(rewriter, loc, lhsElem0, rhsElem0);
+ Value resElem0Low = arith::SelectOp::create(rewriter, loc, illegalElemShift,
+ zeroCst, shiftedElem0);
Value shiftedElem1 =
- rewriter.create<arith::ShRUIOp>(loc, lhsElem1, rhsElem0);
- Value resElem1 = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
- zeroCst, shiftedElem1);
+ arith::ShRUIOp::create(rewriter, loc, lhsElem1, rhsElem0);
+ Value resElem1 = arith::SelectOp::create(rewriter, loc, illegalElemShift,
+ zeroCst, shiftedElem1);
- Value cappedShiftAmount = rewriter.create<arith::SelectOp>(
- loc, illegalElemShift, elemBitWidth, rhsElem0);
+ Value cappedShiftAmount = arith::SelectOp::create(
+ rewriter, loc, illegalElemShift, elemBitWidth, rhsElem0);
Value leftShiftAmount =
- rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
+ arith::SubIOp::create(rewriter, loc, elemBitWidth, cappedShiftAmount);
Value shiftedLeft =
- rewriter.create<arith::ShLIOp>(loc, lhsElem1, leftShiftAmount);
+ arith::ShLIOp::create(rewriter, loc, lhsElem1, leftShiftAmount);
Value overshotShiftAmount =
- rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
+ arith::SubIOp::create(rewriter, loc, rhsElem0, elemBitWidth);
Value shiftedRight =
- rewriter.create<arith::ShRUIOp>(loc, lhsElem1, overshotShiftAmount);
+ arith::ShRUIOp::create(rewriter, loc, lhsElem1, overshotShiftAmount);
- Value resElem0High = rewriter.create<arith::SelectOp>(
- loc, illegalElemShift, shiftedRight, shiftedLeft);
+ Value resElem0High = arith::SelectOp::create(
+ rewriter, loc, illegalElemShift, shiftedRight, shiftedLeft);
Value resElem0 =
- rewriter.create<arith::OrIOp>(loc, resElem0Low, resElem0High);
+ arith::OrIOp::create(rewriter, loc, resElem0Low, resElem0High);
Value resultVec =
constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
@@ -832,33 +832,33 @@ struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
// Perform as many ops over the narrow integer type as possible and let the
// other emulation patterns convert the rest.
Value elemZero = createScalarOrSplatConstant(rewriter, loc, narrowTy, 0);
- Value signBit = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, lhsElem1, elemZero);
+ Value signBit = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::slt, lhsElem1, elemZero);
signBit = dropTrailingX1Dim(rewriter, loc, signBit);
// Create a bit pattern of either all ones or all zeros. Then shift it left
// to calculate the sign extension bits created by shifting the original
// sign bit right.
- Value allSign = rewriter.create<arith::ExtSIOp>(loc, oldTy, signBit);
+ Value allSign = arith::ExtSIOp::create(rewriter, loc, oldTy, signBit);
Value maxShift =
createScalarOrSplatConstant(rewriter, loc, narrowTy, origBitwidth);
Value numNonSignExtBits =
- rewriter.create<arith::SubIOp>(loc, maxShift, rhsElem0);
+ arith::SubIOp::create(rewriter, loc, maxShift, rhsElem0);
numNonSignExtBits = dropTrailingX1Dim(rewriter, loc, numNonSignExtBits);
numNonSignExtBits =
- rewriter.create<arith::ExtUIOp>(loc, oldTy, numNonSignExtBits);
+ arith::ExtUIOp::create(rewriter, loc, oldTy, numNonSignExtBits);
Value signBits =
- rewriter.create<arith::ShLIOp>(loc, allSign, numNonSignExtBits);
+ arith::ShLIOp::create(rewriter, loc, allSign, numNonSignExtBits);
// Use original arguments to create the right shift.
Value shrui =
- rewriter.create<arith::ShRUIOp>(loc, op.getLhs(), op.getRhs());
- Value shrsi = rewriter.create<arith::OrIOp>(loc, shrui, signBits);
+ arith::ShRUIOp::create(rewriter, loc, op.getLhs(), op.getRhs());
+ Value shrsi = arith::OrIOp::create(rewriter, loc, shrui, signBits);
// Handle shifting by zero. This is necessary when the `signBits` shift is
// invalid.
- Value isNoop = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
- rhsElem0, elemZero);
+ Value isNoop = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::eq, rhsElem0, elemZero);
isNoop = dropTrailingX1Dim(rewriter, loc, isNoop);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNoop, op.getLhs(),
shrsi);
@@ -892,14 +892,14 @@ struct ConvertSubI final : OpConversionPattern<arith::SubIOp> {
// Emulates LHS - RHS by [LHS0 - RHS0, LHS1 - RHS1 - CARRY] where
// CARRY is 1 or 0.
- Value low = rewriter.create<arith::SubIOp>(loc, lhsElem0, rhsElem0);
+ Value low = arith::SubIOp::create(rewriter, loc, lhsElem0, rhsElem0);
// We have a carry if lhsElem0 < rhsElem0.
- Value carry0 = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, lhsElem0, rhsElem0);
- Value carryVal = rewriter.create<arith::ExtUIOp>(loc, newElemTy, carry0);
+ Value carry0 = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::ult, lhsElem0, rhsElem0);
+ Value carryVal = arith::ExtUIOp::create(rewriter, loc, newElemTy, carry0);
- Value high0 = rewriter.create<arith::SubIOp>(loc, lhsElem1, carryVal);
- Value high = rewriter.create<arith::SubIOp>(loc, high0, rhsElem1);
+ Value high0 = arith::SubIOp::create(rewriter, loc, lhsElem1, carryVal);
+ Value high = arith::SubIOp::create(rewriter, loc, high0, rhsElem1);
Value resultVec = constructResultVector(rewriter, loc, newTy, {low, high});
rewriter.replaceOp(op, resultVec);
@@ -933,13 +933,13 @@ struct ConvertSIToFP final : OpConversionPattern<arith::SIToFPOp> {
// result or not based on that sign bit. We implement negation by
// subtracting from zero. Note that this relies on the the other conversion
// patterns to legalize created ops and narrow the bit widths.
- Value isNeg = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
- in, zeroCst);
- Value neg = rewriter.create<arith::SubIOp>(loc, zeroCst, in);
- Value abs = rewriter.create<arith::SelectOp>(loc, isNeg, neg, in);
+ Value isNeg = arith::CmpIOp::create(rewriter, loc,
+ arith::CmpIPredicate::slt, in, zeroCst);
+ Value neg = arith::SubIOp::create(rewriter, loc, zeroCst, in);
+ Value abs = arith::SelectOp::create(rewriter, loc, isNeg, neg, in);
- Value absResult = rewriter.create<arith::UIToFPOp>(loc, op.getType(), abs);
- Value negResult = rewriter.create<arith::NegFOp>(loc, absResult);
+ Value absResult = arith::UIToFPOp::create(rewriter, loc, op.getType(), abs);
+ Value negResult = arith::NegFOp::create(rewriter, loc, absResult);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, negResult,
absResult);
return success();
@@ -985,13 +985,13 @@ struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> {
//
// Note 2: We do not strictly need the `hi == 0`, case, but it makes
// constant folding easier.
- Value hiEqZero = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, hiInt, zeroCst);
+ Value hiEqZero = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::eq, hiInt, zeroCst);
Type resultTy = op.getType();
Type resultElemTy = getElementTypeOrSelf(resultTy);
- Value lowFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, lowInt);
- Value hiFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, hiInt);
+ Value lowFp = arith::UIToFPOp::create(rewriter, loc, resultTy, lowInt);
+ Value hiFp = arith::UIToFPOp::create(rewriter, loc, resultTy, hiInt);
int64_t pow2Int = int64_t(1) << newBitWidth;
TypedAttr pow2Attr =
@@ -999,10 +999,11 @@ struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> {
if (auto vecTy = dyn_cast<VectorType>(resultTy))
pow2Attr = SplatElementsAttr::get(vecTy, pow2Attr);
- Value pow2Val = rewriter.create<arith::ConstantOp>(loc, resultTy, pow2Attr);
+ Value pow2Val =
+ arith::ConstantOp::create(rewriter, loc, resultTy, pow2Attr);
- Value hiVal = rewriter.create<arith::MulFOp>(loc, hiFp, pow2Val);
- Value result = rewriter.create<arith::AddFOp>(loc, lowFp, hiVal);
+ Value hiVal = arith::MulFOp::create(rewriter, loc, hiFp, pow2Val);
+ Value result = arith::AddFOp::create(rewriter, loc, lowFp, hiVal);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, hiEqZero, lowFp, result);
return success();
@@ -1037,22 +1038,22 @@ struct ConvertFPToSI final : OpConversionPattern<arith::FPToSIOp> {
// result is UB.
TypedAttr zeroAttr = rewriter.getZeroAttr(fpTy);
- Value zeroCst = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
+ Value zeroCst = arith::ConstantOp::create(rewriter, loc, zeroAttr);
Value zeroCstInt = createScalarOrSplatConstant(rewriter, loc, intTy, 0);
// Get the absolute value. One could have used math.absf here, but that
// introduces an extra dependency.
- Value isNeg = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
- inFp, zeroCst);
- Value negInFp = rewriter.create<arith::NegFOp>(loc, inFp);
+ Value isNeg = arith::CmpFOp::create(
+ rewriter, loc, arith::CmpFPredicate::OLT, inFp, zeroCst);
+ Value negInFp = arith::NegFOp::create(rewriter, loc, inFp);
- Value absVal = rewriter.create<arith::SelectOp>(loc, isNeg, negInFp, inFp);
+ Value absVal = arith::SelectOp::create(rewriter, loc, isNeg, negInFp, inFp);
// Defer the absolute value to fptoui.
- Value res = rewriter.create<arith::FPToUIOp>(loc, intTy, absVal);
+ Value res = arith::FPToUIOp::create(rewriter, loc, intTy, absVal);
// Negate the value if < 0 .
- Value neg = rewriter.create<arith::SubIOp>(loc, zeroCstInt, res);
+ Value neg = arith::SubIOp::create(rewriter, loc, zeroCstInt, res);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, neg, res);
return success();
@@ -1109,17 +1110,17 @@ struct ConvertFPToUI final : OpConversionPattern<arith::FPToUIOp> {
if (auto vecType = dyn_cast<VectorType>(fpTy))
powBitwidthAttr = SplatElementsAttr::get(vecType, powBitwidthAttr);
Value powBitwidthFloatCst =
- rewriter.create<arith::ConstantOp>(loc, powBitwidthAttr);
+ arith::ConstantOp::create(rewriter, loc, powBitwidthAttr);
Value fpDivPowBitwidth =
- rewriter.create<arith::DivFOp>(loc, inFp, powBitwidthFloatCst);
+ arith::DivFOp::create(rewriter, loc, inFp, powBitwidthFloatCst);
Value resHigh =
- rewriter.create<arith::FPToUIOp>(loc, newHalfType, fpDivPowBitwidth);
+ arith::FPToUIOp::create(rewriter, loc, newHalfType, fpDivPowBitwidth);
// Calculate fp - resHigh * 2^N by getting the remainder of the division
Value remainder =
- rewriter.create<arith::RemFOp>(loc, inFp, powBitwidthFloatCst);
+ arith::RemFOp::create(rewriter, loc, inFp, powBitwidthFloatCst);
Value resLow =
- rewriter.create<arith::FPToUIOp>(loc, newHalfType, remainder);
+ arith::FPToUIOp::create(rewriter, loc, newHalfType, remainder);
Value high = appendX1Dim(rewriter, loc, resHigh);
Value low = appendX1Dim(rewriter, loc, resLow);
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index e842f44..f8fa35c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -28,10 +28,10 @@ static Value createConst(Location loc, Type type, int value,
PatternRewriter &rewriter) {
auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value);
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
- return rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(shapedTy, attr));
+ return arith::ConstantOp::create(rewriter, loc,
+ DenseElementsAttr::get(shapedTy, attr));
}
- return rewriter.create<arith::ConstantOp>(loc, attr);
+ return arith::ConstantOp::create(rewriter, loc, attr);
}
/// Create a float constant.
@@ -39,11 +39,11 @@ static Value createFloatConst(Location loc, Type type, APFloat value,
PatternRewriter &rewriter) {
auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
- return rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(shapedTy, attr));
+ return arith::ConstantOp::create(rewriter, loc,
+ DenseElementsAttr::get(shapedTy, attr));
}
- return rewriter.create<arith::ConstantOp>(loc, attr);
+ return arith::ConstantOp::create(rewriter, loc, attr);
}
/// Creates shapedType using shape from cloneFrom and base type from cloneTo
@@ -67,11 +67,11 @@ struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
Value b = op.getRhs();
Value zero = createConst(loc, a.getType(), 0, rewriter);
Value compare =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
+ arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, a, zero);
Value one = createConst(loc, a.getType(), 1, rewriter);
- Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
- Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
- Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
+ Value minusOne = arith::SubIOp::create(rewriter, loc, a, one);
+ Value quotient = arith::DivUIOp::create(rewriter, loc, minusOne, b);
+ Value plusOne = arith::AddIOp::create(rewriter, loc, quotient, one);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne);
return success();
}
@@ -96,22 +96,22 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
Value zero = createConst(loc, type, 0, rewriter);
Value one = createConst(loc, type, 1, rewriter);
- Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
- Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
- Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ne, a, product);
+ Value quotient = arith::DivSIOp::create(rewriter, loc, a, b);
+ Value product = arith::MulIOp::create(rewriter, loc, quotient, b);
+ Value notEqualDivisor = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::ne, a, product);
- Value aNeg =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
- Value bNeg =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
+ Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
+ a, zero);
+ Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
+ b, zero);
- Value signEqual = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, aNeg, bNeg);
+ Value signEqual = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::eq, aNeg, bNeg);
Value cond =
- rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signEqual);
+ arith::AndIOp::create(rewriter, loc, notEqualDivisor, signEqual);
- Value quotientPlusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
+ Value quotientPlusOne = arith::AddIOp::create(rewriter, loc, quotient, one);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientPlusOne,
quotient);
@@ -135,25 +135,25 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
Value a = op.getLhs();
Value b = op.getRhs();
- Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
- Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
- Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ne, a, product);
+ Value quotient = arith::DivSIOp::create(rewriter, loc, a, b);
+ Value product = arith::MulIOp::create(rewriter, loc, quotient, b);
+ Value notEqualDivisor = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::ne, a, product);
Value zero = createConst(loc, type, 0, rewriter);
- Value aNeg =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
- Value bNeg =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
+ Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
+ a, zero);
+ Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
+ b, zero);
- Value signOpposite = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ne, aNeg, bNeg);
+ Value signOpposite = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::ne, aNeg, bNeg);
Value cond =
- rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signOpposite);
+ arith::AndIOp::create(rewriter, loc, notEqualDivisor, signOpposite);
Value minusOne = createConst(loc, type, -1, rewriter);
Value quotientMinusOne =
- rewriter.create<arith::AddIOp>(loc, quotient, minusOne);
+ arith::AddIOp::create(rewriter, loc, quotient, minusOne);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
quotient);
@@ -171,7 +171,7 @@ public:
Value lhs = op.getLhs();
Value rhs = op.getRhs();
- Value cmp = rewriter.create<arith::CmpIOp>(op.getLoc(), pred, lhs, rhs);
+ Value cmp = arith::CmpIOp::create(rewriter, op.getLoc(), pred, lhs, rhs);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
return success();
}
@@ -192,12 +192,12 @@ public:
static_assert(pred == arith::CmpFPredicate::UGT ||
pred == arith::CmpFPredicate::ULT,
"pred must be either UGT or ULT");
- Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
- Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
+ Value cmp = arith::CmpFOp::create(rewriter, loc, pred, lhs, rhs);
+ Value select = arith::SelectOp::create(rewriter, loc, cmp, lhs, rhs);
// Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'.
- Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
- rhs, rhs);
+ Value isNaN = arith::CmpFOp::create(rewriter, loc,
+ arith::CmpFPredicate::UNO, rhs, rhs);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
return success();
}
@@ -218,12 +218,12 @@ public:
static_assert(pred == arith::CmpFPredicate::UGT ||
pred == arith::CmpFPredicate::ULT,
"pred must be either UGT or ULT");
- Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
- Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
+ Value cmp = arith::CmpFOp::create(rewriter, loc, pred, lhs, rhs);
+ Value select = arith::SelectOp::create(rewriter, loc, cmp, lhs, rhs);
// Handle the case where lhs is NaN: 'isNaN(lhs) ? rhs : select'.
- Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
- lhs, lhs);
+ Value isNaN = arith::CmpFOp::create(rewriter, loc,
+ arith::CmpFPredicate::UNO, lhs, lhs);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
return success();
}
@@ -247,12 +247,12 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
- Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
- Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
+ Value bitcast = arith::BitcastOp::create(b, i16Ty, operand);
+ Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast);
Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
- Value shl = b.create<arith::ShLIOp>(exti, c16);
- Value result = b.create<arith::BitcastOp>(resultTy, shl);
+ Value shl = arith::ShLIOp::create(b, exti, c16);
+ Value result = arith::BitcastOp::create(b, resultTy, shl);
rewriter.replaceOp(op, result);
return success();
@@ -296,7 +296,7 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
// exponent bits, that simple truncation is the desired outcome for
// infinities.
Value isNan =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::UNE, operand, operand);
// Constant used to make the rounding bias.
Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
// Constant used to generate a quiet NaN.
@@ -305,30 +305,30 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
// Reinterpret the input f32 value as bits.
- Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
+ Value bitcast = arith::BitcastOp::create(b, i32Ty, operand);
// Read bit 16 as a value in {0,1}.
Value bit16 =
- b.create<arith::AndIOp>(b.create<arith::ShRUIOp>(bitcast, c16), c1);
+ arith::AndIOp::create(b, arith::ShRUIOp::create(b, bitcast, c16), c1);
// Determine the rounding bias to add as either 0x7fff or 0x8000 depending
// on bit 16, implementing the tie-breaking "to nearest even".
- Value roundingBias = b.create<arith::AddIOp>(bit16, c7FFF);
+ Value roundingBias = arith::AddIOp::create(b, bit16, c7FFF);
// Add the rounding bias. Generally we want this to be added to the
// mantissa, but nothing prevents this to from carrying into the exponent
// bits, which would feel like a bug, but this is the magic trick here:
// when that happens, the mantissa gets reset to zero and the exponent
// gets incremented by the carry... which is actually exactly what we
// want.
- Value biased = b.create<arith::AddIOp>(bitcast, roundingBias);
+ Value biased = arith::AddIOp::create(b, bitcast, roundingBias);
// Now that the rounding-bias has been added, truncating the low bits
// yields the correctly rounded result.
- Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
+ Value biasedAndShifted = arith::ShRUIOp::create(b, biased, c16);
Value normalCaseResultI16 =
- b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
+ arith::TruncIOp::create(b, i16Ty, biasedAndShifted);
// Select either the above-computed result, or a quiet NaN constant
// if the input was NaN.
Value select =
- b.create<arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16);
- Value result = b.create<arith::BitcastOp>(resultTy, select);
+ arith::SelectOp::create(b, isNan, c7FC0I16, normalCaseResultI16);
+ Value result = arith::BitcastOp::create(b, resultTy, select);
rewriter.replaceOp(op, result);
return success();
}
@@ -381,7 +381,7 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
- Value i4Bits = b.create<arith::BitcastOp>(i4Ty, operand);
+ Value i4Bits = arith::BitcastOp::create(b, i4Ty, operand);
Value c0x0 = createConst(loc, i4Ty, 0x0, rewriter);
Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter);
@@ -390,38 +390,39 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
// Set last Exponent bit and Mantissa.
Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter);
- Value bits1To24 = b.create<arith::ShLIOp>(i4Bits, c0x2);
+ Value bits1To24 = arith::ShLIOp::create(b, i4Bits, c0x2);
Value isHalf =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x1);
- bits1To24 = b.create<arith::SelectOp>(isHalf, c0x0, bits1To24);
- bits1To24 = b.create<arith::ExtUIOp>(i32Ty, bits1To24);
- bits1To24 = b.create<arith::ShLIOp>(bits1To24, c0x00000014);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x1);
+ bits1To24 = arith::SelectOp::create(b, isHalf, c0x0, bits1To24);
+ bits1To24 = arith::ExtUIOp::create(b, i32Ty, bits1To24);
+ bits1To24 = arith::ShLIOp::create(b, bits1To24, c0x00000014);
// Set first 7 bits of Exponent.
Value zeroExpBits = createConst(loc, i32Ty, 0x00000000, rewriter);
Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter);
Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter);
Value useLargerExp =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x4);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x4);
Value bits25To31 =
- b.create<arith::SelectOp>(useLargerExp, highExpBits, lowExpBits);
+ arith::SelectOp::create(b, useLargerExp, highExpBits, lowExpBits);
Value zeroExp =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x0);
- bits25To31 = b.create<arith::SelectOp>(zeroExp, zeroExpBits, bits25To31);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x0);
+ bits25To31 = arith::SelectOp::create(b, zeroExp, zeroExpBits, bits25To31);
// Set sign.
Value c0x80000000 = createConst(loc, i32Ty, 0x80000000, rewriter);
Value c0x8 = createConst(loc, i4Ty, 0x8, rewriter);
Value negative =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x8);
- Value bit32 = b.create<arith::SelectOp>(negative, c0x80000000, zeroExpBits);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x8);
+ Value bit32 =
+ arith::SelectOp::create(b, negative, c0x80000000, zeroExpBits);
// Add segments together.
- Value bits1To31 = b.create<arith::AddIOp>(bits1To24, bits25To31);
- Value bits1To32 = b.create<arith::AddIOp>(bits1To31, bit32);
- Value result = b.create<arith::BitcastOp>(f32Ty, bits1To32);
+ Value bits1To31 = arith::AddIOp::create(b, bits1To24, bits25To31);
+ Value bits1To32 = arith::AddIOp::create(b, bits1To31, bit32);
+ Value result = arith::BitcastOp::create(b, f32Ty, bits1To32);
if (!isa<Float32Type>(resultETy))
- result = b.create<arith::TruncFOp>(resultTy, result);
+ result = arith::TruncFOp::create(b, resultTy, result);
rewriter.replaceOp(op, result);
return success();
@@ -447,25 +448,25 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
- Value bitcast = b.create<arith::BitcastOp>(i8Ty, operand);
+ Value bitcast = arith::BitcastOp::create(b, i8Ty, operand);
// create constants for NaNs
Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
- Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
- Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
+ Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast);
+ Value f32Bits = arith::ShLIOp::create(b, exti, cF32MantissaWidth);
Value isNan =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::eq, bitcast, cF8NaN);
// select for NaNs
- f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
- Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+ f32Bits = arith::SelectOp::create(b, isNan, cF32NaN, f32Bits);
+ Value result = arith::BitcastOp::create(b, f32Ty, f32Bits);
if (resultETy.getIntOrFloatBitWidth() < 32) {
- result = b.create<arith::TruncFOp>(resultTy, result, nullptr,
- op.getFastmathAttr());
+ result = arith::TruncFOp::create(b, resultTy, result, nullptr,
+ op.getFastmathAttr());
} else if (resultETy.getIntOrFloatBitWidth() > 32) {
- result = b.create<arith::ExtFOp>(resultTy, result, op.getFastmathAttr());
+ result = arith::ExtFOp::create(b, resultTy, result, op.getFastmathAttr());
}
rewriter.replaceOp(op, result);
return success();
@@ -520,7 +521,7 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
if (!isa<Float4E2M1FNType>(resultETy))
return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
if (!isa<Float32Type>(operandETy))
- operand = b.create<arith::ExtFOp>(f32Ty, operand);
+ operand = arith::ExtFOp::create(b, f32Ty, operand);
Value c0x1 = createConst(loc, i4Ty, 1, rewriter);
Value c0x3 = createConst(loc, i4Ty, 3, rewriter);
@@ -532,65 +533,65 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
// Step 0: Clamp to bounds.
Value cHigherBound = createFloatConst(loc, f32Ty, APFloat(6.0f), rewriter);
Value cLowerBound = createFloatConst(loc, f32Ty, APFloat(-6.0f), rewriter);
- Value operandClamped = b.create<arith::MinNumFOp>(cHigherBound, operand);
- operandClamped = b.create<arith::MaxNumFOp>(cLowerBound, operandClamped);
- Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
+ Value operandClamped = arith::MinNumFOp::create(b, cHigherBound, operand);
+ operandClamped = arith::MaxNumFOp::create(b, cLowerBound, operandClamped);
+ Value f32Bits = arith::BitcastOp::create(b, i32Ty, operandClamped);
// Step 1: Set sign bit.
Value cF32ExpManWidth = createConst(loc, i32Ty, 31, rewriter); // 23
- Value f32Sign = b.create<arith::ShRUIOp>(f32Bits, cF32ExpManWidth);
- Value f4Sign = b.create<arith::TruncIOp>(i4Ty, f32Sign);
- Value f4Bits = b.create<arith::ShLIOp>(f4Sign, c0x3);
+ Value f32Sign = arith::ShRUIOp::create(b, f32Bits, cF32ExpManWidth);
+ Value f4Sign = arith::TruncIOp::create(b, i4Ty, f32Sign);
+ Value f4Bits = arith::ShLIOp::create(b, f4Sign, c0x3);
// Step 2: Convert exponent by adjusting bias.
Value biasAdjustment = createConst(loc, i32Ty, 0x7e, rewriter);
Value cF4MantissaWidth = c0x1; // 1
Value cF32MantissaWidth = createConst(loc, i32Ty, 23, rewriter); // 23
- Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
+ Value f32SignExp = arith::ShRUIOp::create(b, f32Bits, cF32MantissaWidth);
Value biasAdjustedSignExp =
- b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
- Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
- f4Exp = b.create<arith::ShLIOp>(f4Exp, cF4MantissaWidth);
- f4Bits = b.create<arith::AddIOp>(f4Bits, f4Exp);
+ arith::SubIOp::create(b, f32SignExp, biasAdjustment);
+ Value f4Exp = arith::TruncIOp::create(b, i4Ty, biasAdjustedSignExp);
+ f4Exp = arith::ShLIOp::create(b, f4Exp, cF4MantissaWidth);
+ f4Bits = arith::AddIOp::create(b, f4Bits, f4Exp);
// Step 3: Set mantissa to first bit.
Value cF32FirstBitMask = createConst(loc, i32Ty, 0x400000, rewriter);
- Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
- man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
- Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
- f4Bits = b.create<arith::AddIOp>(f4Bits, f4Man);
+ Value man1Bit = arith::AndIOp::create(b, f32Bits, cF32FirstBitMask);
+ man1Bit = arith::ShRUIOp::create(b, man1Bit, c0x00000016);
+ Value f4Man = arith::TruncIOp::create(b, i4Ty, man1Bit);
+ f4Bits = arith::AddIOp::create(b, f4Bits, f4Man);
// Step 4: Special consideration for conversion to 0.5.
Value cF32MantissaMask = createConst(loc, i32Ty, 0x7fffff, rewriter);
- Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
+ Value f8Exp = arith::TruncIOp::create(b, i8Ty, biasAdjustedSignExp);
Value isSubnormal =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::sle, f8Exp, c0x00);
Value isNegOneExp =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
- Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
- Value isNonZeroMan = b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt,
- man23Bits, zeroExpBits);
- Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::eq, f8Exp, c0xff);
+ Value man23Bits = arith::AndIOp::create(b, f32Bits, cF32MantissaMask);
+ Value isNonZeroMan = arith::CmpIOp::create(b, arith::CmpIPredicate::ugt,
+ man23Bits, zeroExpBits);
+ Value roundToHalf = arith::AndIOp::create(b, isNegOneExp, isNonZeroMan);
Value isZeroExp =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::eq, f8Exp, c0x00);
Value subnormalF4Bits = createConst(loc, i4Ty, 0xf, rewriter);
Value halfF4Bits = createConst(loc, i4Ty, 0x0, rewriter);
Value subResult =
- b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
- subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
- f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
+ arith::SelectOp::create(b, isSubnormal, subnormalF4Bits, f4Bits);
+ subResult = arith::SelectOp::create(b, roundToHalf, halfF4Bits, subResult);
+ f4Bits = arith::SelectOp::create(b, isZeroExp, f4Bits, subResult);
// Step 5: Round up if necessary.
Value cF32Last22BitMask = createConst(loc, i32Ty, 0x3fffff, rewriter);
Value cRound = createConst(loc, i32Ty, 0x200000, rewriter); // 010 0000...
- Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
+ Value man22Bits = arith::AndIOp::create(b, f32Bits, cF32Last22BitMask);
Value shouldRound =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
- shouldRound = b.create<arith::OrIOp>(shouldRound, isSubnormal);
- Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
- f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::uge, man22Bits, cRound);
+ shouldRound = arith::OrIOp::create(b, shouldRound, isSubnormal);
+ Value roundedF4Bits = arith::AddIOp::create(b, f4Bits, c0x1);
+ f4Bits = arith::SelectOp::create(b, shouldRound, roundedF4Bits, f4Bits);
- Value result = b.create<arith::BitcastOp>(resultTy, f4Bits);
+ Value result = arith::BitcastOp::create(b, resultTy, f4Bits);
rewriter.replaceOp(op, result);
return success();
}
@@ -625,16 +626,16 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
if (operandETy.getIntOrFloatBitWidth() < 32) {
- operand = b.create<arith::ExtFOp>(f32Ty, operand, op.getFastmathAttr());
+ operand = arith::ExtFOp::create(b, f32Ty, operand, op.getFastmathAttr());
} else if (operandETy.getIntOrFloatBitWidth() > 32) {
- operand = b.create<arith::TruncFOp>(
- f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
+ operand = arith::TruncFOp::create(
+ b, f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
}
- Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
+ Value f32Bits = arith::BitcastOp::create(b, i32Ty, operand);
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
- Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
- Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
- Value result = b.create<arith::BitcastOp>(resultTy, exp8Bits);
+ Value f32SignExp = arith::ShRUIOp::create(b, f32Bits, cF32MantissaWidth);
+ Value exp8Bits = arith::TruncIOp::create(b, i8Ty, f32SignExp);
+ Value result = arith::BitcastOp::create(b, resultTy, exp8Bits);
rewriter.replaceOp(op, result);
return success();
}
@@ -653,8 +654,8 @@ struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
if (scaleETy.getIntOrFloatBitWidth() >= 16) {
scaleETy = b.getF8E8M0Type();
scaleTy = cloneToShapedType(scaleTy, scaleETy);
- scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
- op.getFastmathAttr());
+ scaleOperand = arith::TruncFOp::create(b, scaleTy, scaleOperand, nullptr,
+ op.getFastmathAttr());
}
// Catch scale types like f8E5M2.
if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
@@ -666,11 +667,11 @@ struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
// extf on scale will essentially create floating point number
// of type resulTy that is 2^scale and will also propagate NaNs
Value scaleExt =
- b.create<arith::ExtFOp>(resultTy, scaleOperand, op.getFastmathAttr());
+ arith::ExtFOp::create(b, resultTy, scaleOperand, op.getFastmathAttr());
Value inputExt =
- b.create<arith::ExtFOp>(resultTy, inputOperand, op.getFastmathAttr());
+ arith::ExtFOp::create(b, resultTy, inputOperand, op.getFastmathAttr());
Value result =
- b.create<arith::MulFOp>(inputExt, scaleExt, op.getFastmathAttr());
+ arith::MulFOp::create(b, inputExt, scaleExt, op.getFastmathAttr());
rewriter.replaceOp(op, result);
return success();
}
@@ -695,8 +696,8 @@ struct ScalingTruncFOpConverter
if (scaleETy.getIntOrFloatBitWidth() >= 16) {
scaleETy = b.getF8E8M0Type();
scaleTy = cloneToShapedType(scaleTy, scaleETy);
- scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
- op.getFastmathAttr());
+ scaleOperand = arith::TruncFOp::create(b, scaleTy, scaleOperand, nullptr,
+ op.getFastmathAttr());
}
if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
return rewriter.notifyMatchFailure(
@@ -708,11 +709,11 @@ struct ScalingTruncFOpConverter
// this will create a floating point number of type
// inputTy that is 2^scale and will also propagate NaNs
scaleOperand =
- b.create<arith::ExtFOp>(inputTy, scaleOperand, op.getFastmathAttr());
- Value result = b.create<arith::DivFOp>(inputOperand, scaleOperand,
- op.getFastmathAttr());
- Value resultCast = b.create<arith::TruncFOp>(
- resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
+ arith::ExtFOp::create(b, inputTy, scaleOperand, op.getFastmathAttr());
+ Value result = arith::DivFOp::create(b, inputOperand, scaleOperand,
+ op.getFastmathAttr());
+ Value resultCast = arith::TruncFOp::create(
+ b, resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
rewriter.replaceOp(op, resultCast);
return success();
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index f2f9388..777ff0e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -305,18 +305,18 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
if (castKind == CastKind::Signed)
- return builder.create<arith::IndexCastOp>(loc, dstType, src);
- return builder.create<arith::IndexCastUIOp>(loc, dstType, src);
+ return arith::IndexCastOp::create(builder, loc, dstType, src);
+ return arith::IndexCastUIOp::create(builder, loc, dstType, src);
}
auto srcInt = cast<IntegerType>(srcElemType);
auto dstInt = cast<IntegerType>(dstElemType);
if (dstInt.getWidth() < srcInt.getWidth())
- return builder.create<arith::TruncIOp>(loc, dstType, src);
+ return arith::TruncIOp::create(builder, loc, dstType, src);
if (castKind == CastKind::Signed)
- return builder.create<arith::ExtSIOp>(loc, dstType, src);
- return builder.create<arith::ExtUIOp>(loc, dstType, src);
+ return arith::ExtSIOp::create(builder, loc, dstType, src);
+ return arith::ExtUIOp::create(builder, loc, dstType, src);
}
struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
index 5fb7953..4bdd1e6 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
@@ -23,8 +23,8 @@ static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map,
std::function<Value(AffineExpr)> buildExpr = [&](AffineExpr e) -> Value {
switch (e.getKind()) {
case AffineExprKind::Constant:
- return b.create<ConstantIndexOp>(loc,
- cast<AffineConstantExpr>(e).getValue());
+ return ConstantIndexOp::create(b, loc,
+ cast<AffineConstantExpr>(e).getValue());
case AffineExprKind::DimId:
return operands[cast<AffineDimExpr>(e).getPosition()];
case AffineExprKind::SymbolId:
@@ -32,28 +32,28 @@ static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map,
map.getNumDims()];
case AffineExprKind::Add: {
auto binaryExpr = cast<AffineBinaryOpExpr>(e);
- return b.create<AddIOp>(loc, buildExpr(binaryExpr.getLHS()),
- buildExpr(binaryExpr.getRHS()));
+ return AddIOp::create(b, loc, buildExpr(binaryExpr.getLHS()),
+ buildExpr(binaryExpr.getRHS()));
}
case AffineExprKind::Mul: {
auto binaryExpr = cast<AffineBinaryOpExpr>(e);
- return b.create<MulIOp>(loc, buildExpr(binaryExpr.getLHS()),
- buildExpr(binaryExpr.getRHS()));
+ return MulIOp::create(b, loc, buildExpr(binaryExpr.getLHS()),
+ buildExpr(binaryExpr.getRHS()));
}
case AffineExprKind::FloorDiv: {
auto binaryExpr = cast<AffineBinaryOpExpr>(e);
- return b.create<DivSIOp>(loc, buildExpr(binaryExpr.getLHS()),
- buildExpr(binaryExpr.getRHS()));
+ return DivSIOp::create(b, loc, buildExpr(binaryExpr.getLHS()),
+ buildExpr(binaryExpr.getRHS()));
}
case AffineExprKind::CeilDiv: {
auto binaryExpr = cast<AffineBinaryOpExpr>(e);
- return b.create<CeilDivSIOp>(loc, buildExpr(binaryExpr.getLHS()),
- buildExpr(binaryExpr.getRHS()));
+ return CeilDivSIOp::create(b, loc, buildExpr(binaryExpr.getLHS()),
+ buildExpr(binaryExpr.getRHS()));
}
case AffineExprKind::Mod: {
auto binaryExpr = cast<AffineBinaryOpExpr>(e);
- return b.create<RemSIOp>(loc, buildExpr(binaryExpr.getLHS()),
- buildExpr(binaryExpr.getRHS()));
+ return RemSIOp::create(b, loc, buildExpr(binaryExpr.getLHS()),
+ buildExpr(binaryExpr.getRHS()));
}
}
llvm_unreachable("unsupported AffineExpr kind");
@@ -89,10 +89,10 @@ FailureOr<OpFoldResult> mlir::arith::reifyValueBound(
"expected dynamic dim");
if (isa<RankedTensorType>(value.getType())) {
// A tensor dimension is used: generate a tensor.dim.
- operands.push_back(b.create<tensor::DimOp>(loc, value, *dim));
+ operands.push_back(tensor::DimOp::create(b, loc, value, *dim));
} else if (isa<MemRefType>(value.getType())) {
// A memref dimension is used: generate a memref.dim.
- operands.push_back(b.create<memref::DimOp>(loc, value, *dim));
+ operands.push_back(memref::DimOp::create(b, loc, value, *dim));
} else {
llvm_unreachable("cannot generate DimOp for unsupported shaped type");
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
index 3478adc..3e34246 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
@@ -6,22 +6,22 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
#include "mlir/IR/DialectRegistry.h"
using namespace mlir;
using namespace mlir::arith;
-using namespace mlir::mesh;
+using namespace mlir::shard;
namespace {
// Sharding of arith.constant
// RankedTensor constants can be sharded like any other tensor.
// %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
-// %sharding = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
+// %sharding = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding
// Scalar constants are always replicated and need no sharding annotation.
struct ConstantShardingInterface
@@ -48,8 +48,8 @@ struct ConstantShardingInterface
// Otherwise mirror result sharding if it is a tensor constant.
// Otherwise return replication option.
FailureOr<ShardingOption>
- getShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings) const {
+ getShardingOption(Operation *op, ArrayRef<Sharding> operandShardings,
+ ArrayRef<Sharding> resultShardings) const {
assert(resultShardings.size() == 1 &&
"Expecting exactly one result sharding for arith.constant");
auto resultSharding = resultShardings[0];
@@ -61,17 +61,17 @@ struct ConstantShardingInterface
for (auto [i, axes] : llvm::enumerate(resultSharding.getSplitAxes())) {
axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end());
}
- return ShardingOption(axesArray, resultSharding.getMeshAttr());
+ return ShardingOption(axesArray, resultSharding.getGridAttr());
}
- return ShardingOption({}, resultSharding.getMeshAttr());
+ return ShardingOption({}, resultSharding.getGridAttr());
}
- LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings,
- IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTable,
- OpBuilder &builder) const {
+ LogicalResult partition(Operation *op, ArrayRef<Value> partitiondOperands,
+ ArrayRef<Sharding> operandShardings,
+ ArrayRef<Sharding> resultShardings,
+ IRMapping &partitionMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder) const {
auto cOp = cast<ConstantOp>(op);
if (auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue())) {
if (!value.isSplat() || !resultShardings[0]) {
@@ -80,15 +80,15 @@ struct ConstantShardingInterface
}
auto sharding = resultShardings[0];
auto newType = cast<RankedTensorType>(shardType(
- cOp.getType(), getMesh(op, sharding.getMeshAttr(), symbolTable),
+ cOp.getType(), getGrid(op, sharding.getGridAttr(), symbolTable),
sharding));
auto newValue = value.resizeSplat(newType);
- auto newOp = builder.create<ConstantOp>(op->getLoc(), newType, newValue);
- spmdizationMap.map(op->getResult(0), newOp.getResult());
- spmdizationMap.map(op, newOp.getOperation());
+ auto newOp = ConstantOp::create(builder, op->getLoc(), newType, newValue);
+ partitionMap.map(op->getResult(0), newOp.getResult());
+ partitionMap.map(op, newOp.getOperation());
} else {
// `clone` will populate the mapping of old to new results.
- (void)builder.clone(*op, spmdizationMap);
+ (void)builder.clone(*op, partitionMap);
}
return success();
}
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index bdeeccb..b1fc9aa 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -67,7 +67,7 @@ mlir::inferExpandShapeOutputShape(OpBuilder &b, Location loc,
// dynamism.
Value indexGroupSize = cast<Value>(inputShape[inputIndex]);
Value indexGroupStaticSizesProduct =
- b.create<arith::ConstantIndexOp>(loc, indexGroupStaticSizesProductInt);
+ arith::ConstantIndexOp::create(b, loc, indexGroupStaticSizesProductInt);
Value dynamicDimSize = b.createOrFold<arith::DivSIOp>(
loc, indexGroupSize, indexGroupStaticSizesProduct);
outputShapeValues.push_back(dynamicDimSize);
@@ -104,8 +104,8 @@ Value mlir::getValueOrCreateConstantIntOp(OpBuilder &b, Location loc,
if (auto value = dyn_cast_if_present<Value>(ofr))
return value;
auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
- return b.create<arith::ConstantOp>(
- loc, b.getIntegerAttr(attr.getType(), attr.getValue().getSExtValue()));
+ return arith::ConstantOp::create(
+ b, loc, b.getIntegerAttr(attr.getType(), attr.getValue().getSExtValue()));
}
Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
@@ -113,7 +113,7 @@ Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
if (auto value = dyn_cast_if_present<Value>(ofr))
return value;
auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
- return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
+ return arith::ConstantIndexOp::create(b, loc, attr.getValue().getSExtValue());
}
Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
@@ -124,7 +124,7 @@ Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
bool targetIsIndex = targetType.isIndex();
bool valueIsIndex = value.getType().isIndex();
if (targetIsIndex ^ valueIsIndex)
- return b.create<arith::IndexCastOp>(loc, targetType, value);
+ return arith::IndexCastOp::create(b, loc, targetType, value);
auto targetIntegerType = dyn_cast<IntegerType>(targetType);
auto valueIntegerType = dyn_cast<IntegerType>(value.getType());
@@ -133,8 +133,8 @@ Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
- return b.create<arith::ExtSIOp>(loc, targetIntegerType, value);
- return b.create<arith::TruncIOp>(loc, targetIntegerType, value);
+ return arith::ExtSIOp::create(b, loc, targetIntegerType, value);
+ return arith::TruncIOp::create(b, loc, targetIntegerType, value);
}
static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand,
@@ -142,21 +142,21 @@ static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand,
// If operand is floating point, cast directly to the int type.
if (isa<FloatType>(operand.getType())) {
if (isUnsigned)
- return b.create<arith::FPToUIOp>(toType, operand);
- return b.create<arith::FPToSIOp>(toType, operand);
+ return arith::FPToUIOp::create(b, toType, operand);
+ return arith::FPToSIOp::create(b, toType, operand);
}
// Cast index operands directly to the int type.
if (operand.getType().isIndex())
- return b.create<arith::IndexCastOp>(toType, operand);
+ return arith::IndexCastOp::create(b, toType, operand);
if (auto fromIntType = dyn_cast<IntegerType>(operand.getType())) {
// Either extend or truncate.
if (toType.getWidth() > fromIntType.getWidth()) {
if (isUnsigned)
- return b.create<arith::ExtUIOp>(toType, operand);
- return b.create<arith::ExtSIOp>(toType, operand);
+ return arith::ExtUIOp::create(b, toType, operand);
+ return arith::ExtSIOp::create(b, toType, operand);
}
if (toType.getWidth() < fromIntType.getWidth())
- return b.create<arith::TruncIOp>(toType, operand);
+ return arith::TruncIOp::create(b, toType, operand);
return operand;
}
@@ -169,14 +169,14 @@ static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand,
// Note that it is unclear how to cast from BF16<->FP16.
if (isa<IntegerType>(operand.getType())) {
if (isUnsigned)
- return b.create<arith::UIToFPOp>(toType, operand);
- return b.create<arith::SIToFPOp>(toType, operand);
+ return arith::UIToFPOp::create(b, toType, operand);
+ return arith::SIToFPOp::create(b, toType, operand);
}
if (auto fromFpTy = dyn_cast<FloatType>(operand.getType())) {
if (toType.getWidth() > fromFpTy.getWidth())
- return b.create<arith::ExtFOp>(toType, operand);
+ return arith::ExtFOp::create(b, toType, operand);
if (toType.getWidth() < fromFpTy.getWidth())
- return b.create<arith::TruncFOp>(toType, operand);
+ return arith::TruncFOp::create(b, toType, operand);
return operand;
}
@@ -189,18 +189,18 @@ static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand,
if (auto fromComplexType = dyn_cast<ComplexType>(operand.getType())) {
if (isa<FloatType>(targetType.getElementType()) &&
isa<FloatType>(fromComplexType.getElementType())) {
- Value real = b.create<complex::ReOp>(operand);
- Value imag = b.create<complex::ImOp>(operand);
+ Value real = complex::ReOp::create(b, operand);
+ Value imag = complex::ImOp::create(b, operand);
Type targetETy = targetType.getElementType();
if (targetType.getElementType().getIntOrFloatBitWidth() <
fromComplexType.getElementType().getIntOrFloatBitWidth()) {
- real = b.create<arith::TruncFOp>(targetETy, real);
- imag = b.create<arith::TruncFOp>(targetETy, imag);
+ real = arith::TruncFOp::create(b, targetETy, real);
+ imag = arith::TruncFOp::create(b, targetETy, imag);
} else {
- real = b.create<arith::ExtFOp>(targetETy, real);
- imag = b.create<arith::ExtFOp>(targetETy, imag);
+ real = arith::ExtFOp::create(b, targetETy, real);
+ imag = arith::ExtFOp::create(b, targetETy, imag);
}
- return b.create<complex::CreateOp>(targetType, real, imag);
+ return complex::CreateOp::create(b, targetType, real, imag);
}
}
@@ -209,27 +209,27 @@ static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand,
auto toBitwidth = toFpTy.getIntOrFloatBitWidth();
Value from = operand;
if (from.getType().getIntOrFloatBitWidth() < toBitwidth) {
- from = b.create<arith::ExtFOp>(toFpTy, from);
+ from = arith::ExtFOp::create(b, toFpTy, from);
}
if (from.getType().getIntOrFloatBitWidth() > toBitwidth) {
- from = b.create<arith::TruncFOp>(toFpTy, from);
+ from = arith::TruncFOp::create(b, toFpTy, from);
}
- Value zero = b.create<mlir::arith::ConstantFloatOp>(
- toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
- return b.create<complex::CreateOp>(targetType, from, zero);
+ Value zero = mlir::arith::ConstantFloatOp::create(
+ b, toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
+ return complex::CreateOp::create(b, targetType, from, zero);
}
if (isa<IntegerType>(operand.getType())) {
FloatType toFpTy = cast<FloatType>(targetType.getElementType());
Value from = operand;
if (isUnsigned) {
- from = b.create<arith::UIToFPOp>(toFpTy, from);
+ from = arith::UIToFPOp::create(b, toFpTy, from);
} else {
- from = b.create<arith::SIToFPOp>(toFpTy, from);
+ from = arith::SIToFPOp::create(b, toFpTy, from);
}
- Value zero = b.create<mlir::arith::ConstantFloatOp>(
- toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
- return b.create<complex::CreateOp>(targetType, from, zero);
+ Value zero = mlir::arith::ConstantFloatOp::create(
+ b, toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
+ return complex::CreateOp::create(b, targetType, from, zero);
}
return {};
@@ -277,7 +277,7 @@ Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
attr = SplatElementsAttr::get(vecTy, value);
}
- return builder.create<arith::ConstantOp>(loc, attr);
+ return arith::ConstantOp::create(builder, loc, attr);
}
Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
@@ -309,35 +309,35 @@ Type mlir::getType(OpFoldResult ofr) {
}
Value ArithBuilder::_and(Value lhs, Value rhs) {
- return b.create<arith::AndIOp>(loc, lhs, rhs);
+ return arith::AndIOp::create(b, loc, lhs, rhs);
}
Value ArithBuilder::add(Value lhs, Value rhs) {
if (isa<FloatType>(lhs.getType()))
- return b.create<arith::AddFOp>(loc, lhs, rhs);
- return b.create<arith::AddIOp>(loc, lhs, rhs, ovf);
+ return arith::AddFOp::create(b, loc, lhs, rhs);
+ return arith::AddIOp::create(b, loc, lhs, rhs, ovf);
}
Value ArithBuilder::sub(Value lhs, Value rhs) {
if (isa<FloatType>(lhs.getType()))
- return b.create<arith::SubFOp>(loc, lhs, rhs);
- return b.create<arith::SubIOp>(loc, lhs, rhs, ovf);
+ return arith::SubFOp::create(b, loc, lhs, rhs);
+ return arith::SubIOp::create(b, loc, lhs, rhs, ovf);
}
Value ArithBuilder::mul(Value lhs, Value rhs) {
if (isa<FloatType>(lhs.getType()))
- return b.create<arith::MulFOp>(loc, lhs, rhs);
- return b.create<arith::MulIOp>(loc, lhs, rhs, ovf);
+ return arith::MulFOp::create(b, loc, lhs, rhs);
+ return arith::MulIOp::create(b, loc, lhs, rhs, ovf);
}
Value ArithBuilder::sgt(Value lhs, Value rhs) {
if (isa<FloatType>(lhs.getType()))
- return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs);
- return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs);
+ return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OGT, lhs, rhs);
+ return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sgt, lhs, rhs);
}
Value ArithBuilder::slt(Value lhs, Value rhs) {
if (isa<FloatType>(lhs.getType()))
- return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs);
- return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs);
+ return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OLT, lhs, rhs);
+ return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt, lhs, rhs);
}
Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
- return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
+ return arith::SelectOp::create(b, loc, cmp, lhs, rhs);
}
namespace mlir::arith {
@@ -348,8 +348,8 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values) {
Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
Type resultType) {
- Value one = builder.create<ConstantOp>(loc, resultType,
- builder.getOneAttr(resultType));
+ Value one = ConstantOp::create(builder, loc, resultType,
+ builder.getOneAttr(resultType));
ArithBuilder arithBuilder(builder, loc);
return std::accumulate(
values.begin(), values.end(), one,
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
index 5aadaec..1aa8064 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
@@ -49,7 +49,7 @@ std::optional<Value> getExtOperand(Value v) {
// If the operand is not defined by an explicit extend operation of the
// accepted operation type allow for an implicit sign-extension.
- auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
+ auto extOp = v.getDefiningOp<Op>();
if (!extOp) {
if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
auto eltTy = cast<VectorType>(v.getType()).getElementType();
@@ -145,8 +145,8 @@ protected:
return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
lhs, rhs);
case MMLA::Bfloat:
- return rewriter.create<arm_neon::BfmmlaOp>(loc, acc.getType(), acc, lhs,
- rhs);
+ return arm_neon::BfmmlaOp::create(rewriter, loc, acc.getType(), acc, lhs,
+ rhs);
case MMLA::Nop:
llvm_unreachable("Uninitialized operation type");
}
@@ -226,8 +226,9 @@ public:
// Initial accumulator for the final result. This is the un-tiled result if
// tiling is done.
- Value result = rewriter.create<arith::ConstantOp>(
- loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
+ Value result =
+ arith::ConstantOp::create(rewriter, loc, op.getResultType(),
+ rewriter.getZeroAttr(op.getResultType()));
SmallVector<int64_t, 3> loopOrder = {0, 1};
if (iterationBounds.size() == 3)
@@ -263,8 +264,9 @@ public:
if (dimM == 1) {
auto expandRowVector = [&](Value tiledOperand,
VectorType expandedTypeType) {
- auto emptyOperand = rewriter.create<arith::ConstantOp>(
- loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
+ auto emptyOperand =
+ arith::ConstantOp::create(rewriter, loc, expandedTypeType,
+ rewriter.getZeroAttr(expandedTypeType));
SmallVector<int64_t> offsets(
cast<ShapedType>(emptyOperand.getType()).getRank(), 0);
SmallVector<int64_t> strides(
@@ -280,8 +282,8 @@ public:
// using the instruction for unsigned by signed multiplication with
// reversed operands.
if (swapOperands)
- tiledAcc = rewriter.create<vector::TransposeOp>(
- loc, tiledAcc, ArrayRef<int64_t>({1, 0}));
+ tiledAcc = vector::TransposeOp::create(rewriter, loc, tiledAcc,
+ ArrayRef<int64_t>({1, 0}));
// Collapse tiled operands to 1D vectors required by the ArmNeon ops
auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
@@ -309,8 +311,8 @@ public:
// Because of the reversed operands the result is obtained transposed.
// Transpose it back,
if (swapOperands)
- tiledRes = rewriter.create<vector::TransposeOp>(
- loc, tiledRes, ArrayRef<int64_t>({1, 0}));
+ tiledRes = vector::TransposeOp::create(rewriter, loc, tiledRes,
+ ArrayRef<int64_t>({1, 0}));
// With vecmat, only one row of tiled ACC can be inserted into the final
// result
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index 5f00cef..e5e1312 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -75,21 +75,21 @@ scf::ForOp createLoopOverTileSlices(
PatternRewriter &rewriter, Location loc, Value initTile,
std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody) {
OpBuilder::InsertionGuard g(rewriter);
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
- loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
+ auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ auto minTileSlices = arith::ConstantIndexOp::create(
+ rewriter, loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
auto vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
+ auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
auto numTileSlices =
- rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
- ValueRange{initTile});
+ arith::MulIOp::create(rewriter, loc, minTileSlices, vscale);
+ auto forOp = scf::ForOp::create(rewriter, loc, lowerBound, numTileSlices,
+ step, ValueRange{initTile});
rewriter.setInsertionPointToStart(forOp.getBody());
Value nextTile =
makeLoopBody(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(),
/*currentTile=*/forOp.getRegionIterArg(0));
- rewriter.create<scf::YieldOp>(loc, nextTile);
+ scf::YieldOp::create(rewriter, loc, nextTile);
return forOp;
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
index 23f2c2b..9bf0265 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
@@ -136,7 +136,7 @@ public:
auto loc = op.getLoc();
auto packInputs = [&](Value lhs, Value rhs) {
- return rewriter.create<vector::InterleaveOp>(loc, lhs, rhs);
+ return vector::InterleaveOp::create(rewriter, loc, lhs, rhs);
};
auto lhs = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
@@ -284,7 +284,7 @@ public:
auto loc = op.getLoc();
auto packInputs = [&](Value lhs, Value rhs) {
- return rewriter.create<vector::InterleaveOp>(loc, lhs, rhs);
+ return vector::InterleaveOp::create(rewriter, loc, lhs, rhs);
};
auto lhs0 = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
@@ -456,8 +456,8 @@ struct SwapVectorExtractOfArithExtend
Value extendSource = extendOp->getOperand(0);
// Create new extract from source of extend.
- Value newExtract = rewriter.create<vector::ExtractOp>(
- loc, extendSource, extractOp.getMixedPosition());
+ Value newExtract = vector::ExtractOp::create(rewriter, loc, extendSource,
+ extractOp.getMixedPosition());
// Extend new extract to original result type.
Operation *newExtend =
@@ -503,8 +503,9 @@ struct SwapVectorScalableExtractOfArithExtend
// Create new extract from source of extend.
VectorType extractResultVectorType =
resultType.clone(extendSourceVectorType.getElementType());
- Value newExtract = rewriter.create<vector::ScalableExtractOp>(
- loc, extractResultVectorType, extendSource, extractOp.getPos());
+ Value newExtract = vector::ScalableExtractOp::create(
+ rewriter, loc, extractResultVectorType, extendSource,
+ extractOp.getPos());
// Extend new extract to original result type.
Operation *newExtend =
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index b3c988d..d925c19 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -210,7 +210,7 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
rewriter.setInsertionPointToEnd(source);
- rewriter.create<cf::BranchOp>(loc, dest, args);
+ cf::BranchOp::create(rewriter, loc, dest, args);
};
for (auto condBranch : worklist) {
@@ -253,7 +253,7 @@ void insertCopiesAtBranches(IRRewriter &rewriter,
for (OpOperand &operand : terminator->getOpOperands()) {
if (isValidSMETileVectorType(operand.get().getType())) {
auto copy =
- rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get());
+ CopyTileOp::create(rewriter, terminator->getLoc(), operand.get());
rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); });
}
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 1e8e126..1c0eced 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -82,13 +82,14 @@ SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
Location loc,
ValueRange indices,
ArrayRef<int> scalableOffsets) {
- auto vscale = builder.create<vector::VectorScaleOp>(loc);
+ auto vscale = vector::VectorScaleOp::create(builder, loc);
return llvm::map_to_vector(
llvm::zip_equal(indices, scalableOffsets), [&](auto pair) -> Value {
auto [index, base] = pair;
- auto offset = builder.create<arith::MulIOp>(
- loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale);
- return builder.create<arith::AddIOp>(loc, index, offset);
+ auto offset = arith::MulIOp::create(
+ builder, loc, arith::ConstantIndexOp::create(builder, loc, base),
+ vscale);
+ return arith::AddIOp::create(builder, loc, index, offset);
});
}
@@ -132,8 +133,8 @@ Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
// from the mask operands to get the parameters for this sub-tile.
auto smeTileMaskDims = addConstantScalableOffset(
builder, loc, createMask.getOperands(), {-smeTile.row, -smeTile.col});
- auto smeTileCreateMask = builder.create<vector::CreateMaskOp>(
- loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims);
+ auto smeTileCreateMask = vector::CreateMaskOp::create(
+ builder, loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims);
return smeTileCreateMask.getResult();
}
@@ -190,8 +191,8 @@ struct LegalizeArithConstantOpsByDecomposition
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
- auto tileSplat = rewriter.create<arith::ConstantOp>(
- constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
+ auto tileSplat = arith::ConstantOp::create(
+ rewriter, constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
SmallVector<Value> repl(tileCount, tileSplat);
rewriter.replaceOpWithMultiple(constantOp, {repl});
@@ -237,12 +238,12 @@ struct LegalizeVectorOuterProductOpsByDecomposition
decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
- auto lhs = rewriter.create<vector::ScalableExtractOp>(
- loc, sliceType, outerProductOp.getLhs(), smeTile.row);
- auto rhs = rewriter.create<vector::ScalableExtractOp>(
- loc, sliceType, outerProductOp.getRhs(), smeTile.col);
- auto smeOuterProduct = rewriter.create<vector::OuterProductOp>(
- loc, smeTileType, lhs, rhs,
+ auto lhs = vector::ScalableExtractOp::create(
+ rewriter, loc, sliceType, outerProductOp.getLhs(), smeTile.row);
+ auto rhs = vector::ScalableExtractOp::create(
+ rewriter, loc, sliceType, outerProductOp.getRhs(), smeTile.col);
+ auto smeOuterProduct = vector::OuterProductOp::create(
+ rewriter, loc, smeTileType, lhs, rhs,
!accSMETiles.empty() ? accSMETiles[index] : Value{},
outerProductOp.getKind());
@@ -314,8 +315,8 @@ struct LegalizeTransferReadOpsByDecomposition
for (SMESubTile smeTile :
decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
- auto smeRead = rewriter.create<vector::TransferReadOp>(
- loc, smeTileType, readOp.getBase(),
+ auto smeRead = vector::TransferReadOp::create(
+ rewriter, loc, smeTileType, readOp.getBase(),
getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile),
readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask,
readOp.getInBoundsAttr());
@@ -363,8 +364,8 @@ struct LegalizeTransferWriteOpsByDecomposition
for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles(
rewriter, vectorType, smeTileType, transposed))) {
auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
- auto smeWrite = rewriter.create<vector::TransferWriteOp>(
- loc, inputSMETiles[index], destTensorOrMemref,
+ auto smeWrite = vector::TransferWriteOp::create(
+ rewriter, loc, inputSMETiles[index], destTensorOrMemref,
getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile),
writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr());
if (writeOp.hasPureTensorSemantics())
@@ -456,11 +457,11 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
VectorType::get(minTileSlices, rewriter.getI1Type(), true);
// Create loop over all tile slices.
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
auto upperBound = createVscaleMultiple(minTileSlices);
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
auto storeLoop =
- rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+ scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
rewriter.setInsertionPointToStart(storeLoop.getBody());
// For each sub-tile of the multi-tile `vectorType`.
@@ -474,30 +475,31 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
// The current slice of `vectorType` we are processing.
auto sliceIndex =
- rewriter.create<arith::AddIOp>(loc, tileRow, tileSliceIndex);
+ arith::AddIOp::create(rewriter, loc, tileRow, tileSliceIndex);
// Where in the destination memref the current slice will be stored.
- auto storeRow = rewriter.create<arith::AddIOp>(loc, sliceIndex,
- writeOp.getIndices()[0]);
- auto storeCol =
- rewriter.create<arith::AddIOp>(loc, tileCol, writeOp.getIndices()[1]);
+ auto storeRow = arith::AddIOp::create(rewriter, loc, sliceIndex,
+ writeOp.getIndices()[0]);
+ auto storeCol = arith::AddIOp::create(rewriter, loc, tileCol,
+ writeOp.getIndices()[1]);
// Extract the mask for the current slice.
Value sliceMask = nullptr;
if (mask) {
- sliceMask = rewriter.create<vector::ExtractOp>(
- loc, mask, OpFoldResult(sliceIndex));
+ sliceMask = vector::ExtractOp::create(rewriter, loc, mask,
+ OpFoldResult(sliceIndex));
if (sliceMaskType != sliceMask.getType())
- sliceMask = rewriter.create<vector::ScalableExtractOp>(
- loc, sliceMaskType, sliceMask, smeTile.col);
+ sliceMask = vector::ScalableExtractOp::create(
+ rewriter, loc, sliceMaskType, sliceMask, smeTile.col);
}
// Extract and store the current slice.
Value tile = inputSMETiles[index];
auto slice =
- rewriter.create<vector::ExtractOp>(loc, tile, tileSliceIndex);
- rewriter.create<vector::TransferWriteOp>(
- loc, slice, writeOp.getBase(), ValueRange{storeRow, storeCol},
+ vector::ExtractOp::create(rewriter, loc, tile, tileSliceIndex);
+ vector::TransferWriteOp::create(
+ rewriter, loc, slice, writeOp.getBase(),
+ ValueRange{storeRow, storeCol},
AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),
sliceMask,
rewriter.getBoolArrayAttr(
@@ -567,14 +569,15 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks
extractOp,
"constant vector.create_masks dims should be folded elsewhere");
- auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
auto extractionIndex = getValueOrCreateConstantIndexOp(
rewriter, loc, extractOp.getMixedPosition()[0]);
- auto extractionInTrueRegion = rewriter.create<arith::CmpIOp>(
- loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, extractionIndex,
- frontMaskDim);
- auto newMaskFrontDim = rewriter.create<arith::SelectOp>(
- loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero);
+ auto extractionInTrueRegion = arith::CmpIOp::create(
+ rewriter, loc, rewriter.getI1Type(), arith::CmpIPredicate::slt,
+ extractionIndex, frontMaskDim);
+ auto newMaskFrontDim =
+ arith::SelectOp::create(rewriter, loc, extractionInTrueRegion,
+ createMaskOp.getOperand(1), zero);
rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
extractOp, extractedMaskType,
@@ -660,8 +663,8 @@ struct LiftIllegalVectorTransposeToMemory
illegalRead, "expected read to have identity permutation map");
auto loc = transposeOp.getLoc();
- auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
// Create a subview that matches the size of the illegal read vector type.
auto readType = illegalRead.getVectorType();
@@ -669,16 +672,16 @@ struct LiftIllegalVectorTransposeToMemory
llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
[&](auto dim) -> Value {
auto [size, isScalable] = dim;
- auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size);
+ auto dimSize = arith::ConstantIndexOp::create(rewriter, loc, size);
if (!isScalable)
return dimSize;
- auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
- return rewriter.create<arith::MulIOp>(loc, vscale, dimSize);
+ auto vscale = vector::VectorScaleOp::create(rewriter, loc);
+ return arith::MulIOp::create(rewriter, loc, vscale, dimSize);
});
SmallVector<Value> strides(readType.getRank(), Value(one));
- auto readSubview = rewriter.create<memref::SubViewOp>(
- loc, illegalRead.getBase(), illegalRead.getIndices(), readSizes,
- strides);
+ auto readSubview =
+ memref::SubViewOp::create(rewriter, loc, illegalRead.getBase(),
+ illegalRead.getIndices(), readSizes, strides);
// Apply the transpose to all values/attributes of the transfer_read:
// - The mask
@@ -686,14 +689,14 @@ struct LiftIllegalVectorTransposeToMemory
if (mask) {
// Note: The transpose for the mask should fold into the
// vector.create_mask/constant_mask op, which will then become legal.
- mask = rewriter.create<vector::TransposeOp>(loc, mask,
- transposeOp.getPermutation());
+ mask = vector::TransposeOp::create(rewriter, loc, mask,
+ transposeOp.getPermutation());
}
// - The source memref
mlir::AffineMap transposeMap = AffineMap::getPermutationMap(
transposeOp.getPermutation(), getContext());
- auto transposedSubview = rewriter.create<memref::TransposeOp>(
- loc, readSubview, AffineMapAttr::get(transposeMap));
+ auto transposedSubview = memref::TransposeOp::create(
+ rewriter, loc, readSubview, AffineMapAttr::get(transposeMap));
ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
// - The `in_bounds` attribute
if (inBoundsAttr) {
@@ -706,8 +709,8 @@ struct LiftIllegalVectorTransposeToMemory
VectorType legalReadType = resultType.clone(readType.getElementType());
// Note: The indices are all zero as the subview is already offset.
SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero);
- auto legalRead = rewriter.create<vector::TransferReadOp>(
- loc, legalReadType, transposedSubview, readIndices,
+ auto legalRead = vector::TransferReadOp::create(
+ rewriter, loc, legalReadType, transposedSubview, readIndices,
illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
inBoundsAttr);
@@ -797,12 +800,12 @@ struct LowerIllegalTransposeStoreViaZA
AffineMap::getPermutationMap(ArrayRef<int64_t>{1, 0}, getContext()));
// Note: We need to use `get_tile` as there's no vector-level `undef`.
- Value undefTile = rewriter.create<arm_sme::GetTileOp>(loc, smeTileType);
+ Value undefTile = arm_sme::GetTileOp::create(rewriter, loc, smeTileType);
Value destTensorOrMemref = writeOp.getBase();
auto numSlicesPerTile =
std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0));
auto numSlices =
- rewriter.create<arith::ConstantIndexOp>(loc, numSlicesPerTile);
+ arith::ConstantIndexOp::create(rewriter, loc, numSlicesPerTile);
for (auto [index, smeTile] : llvm::enumerate(
decomposeToSMETiles(rewriter, sourceType, smeTileType))) {
// 1. _Deliberately_ drop a scalable dimension and insert a fixed number
@@ -811,47 +814,47 @@ struct LowerIllegalTransposeStoreViaZA
// rows of the tile after 1*vscale rows.
Value tile = undefTile;
for (int d = 0; d < numSlicesPerTile; ++d) {
- Value vector = rewriter.create<vector::ExtractOp>(
- loc, transposeOp.getVector(),
- rewriter.getIndexAttr(d + smeTile.row));
+ Value vector =
+ vector::ExtractOp::create(rewriter, loc, transposeOp.getVector(),
+ rewriter.getIndexAttr(d + smeTile.row));
if (vector.getType() != smeSliceType) {
- vector = rewriter.create<vector::ScalableExtractOp>(
- loc, smeSliceType, vector, smeTile.col);
+ vector = vector::ScalableExtractOp::create(
+ rewriter, loc, smeSliceType, vector, smeTile.col);
}
- tile = rewriter.create<vector::InsertOp>(loc, vector, tile, d);
+ tile = vector::InsertOp::create(rewriter, loc, vector, tile, d);
}
// 2. Transpose the tile position.
auto transposedRow = createVscaleMultiple(smeTile.col);
auto transposedCol =
- rewriter.create<arith::ConstantIndexOp>(loc, smeTile.row);
+ arith::ConstantIndexOp::create(rewriter, loc, smeTile.row);
// 3. Compute mask for tile store.
Value maskRows;
Value maskCols;
if (auto mask = writeOp.getMask()) {
auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
- maskRows = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(0),
- transposedRow);
- maskCols = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(1),
- transposedCol);
- maskCols = rewriter.create<index::MinSOp>(loc, maskCols, numSlices);
+ maskRows = arith::SubIOp::create(
+ rewriter, loc, createMask.getOperand(0), transposedRow);
+ maskCols = arith::SubIOp::create(
+ rewriter, loc, createMask.getOperand(1), transposedCol);
+ maskCols = index::MinSOp::create(rewriter, loc, maskCols, numSlices);
} else {
maskRows = createVscaleMultiple(smeTileType.getDimSize(0));
maskCols = numSlices;
}
- auto subMask = rewriter.create<vector::CreateMaskOp>(
- loc, smeTileType.clone(rewriter.getI1Type()),
+ auto subMask = vector::CreateMaskOp::create(
+ rewriter, loc, smeTileType.clone(rewriter.getI1Type()),
ValueRange{maskRows, maskCols});
// 4. Emit a transposed tile write.
auto writeIndices = writeOp.getIndices();
Value destRow =
- rewriter.create<arith::AddIOp>(loc, transposedRow, writeIndices[0]);
+ arith::AddIOp::create(rewriter, loc, transposedRow, writeIndices[0]);
Value destCol =
- rewriter.create<arith::AddIOp>(loc, transposedCol, writeIndices[1]);
- auto smeWrite = rewriter.create<vector::TransferWriteOp>(
- loc, tile, destTensorOrMemref, ValueRange{destRow, destCol},
+ arith::AddIOp::create(rewriter, loc, transposedCol, writeIndices[1]);
+ auto smeWrite = vector::TransferWriteOp::create(
+ rewriter, loc, tile, destTensorOrMemref, ValueRange{destRow, destCol},
transposeMap, subMask, writeOp.getInBounds());
if (writeOp.hasPureTensorSemantics())
@@ -934,42 +937,42 @@ struct LowerColumnTransferReadToLoops
// Create a loop over all rows and load one element at a time.
auto loc = readOp.getLoc();
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
auto createVscaleMultiple =
vector::makeVscaleConstantBuilder(rewriter, loc);
auto upperBound = createVscaleMultiple(numRows);
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- Value init = rewriter.create<arith::ConstantOp>(
- loc, newResType, DenseElementsAttr::get(newResType, 0.0f));
+ auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ Value init = arith::ConstantOp::create(
+ rewriter, loc, newResType, DenseElementsAttr::get(newResType, 0.0f));
scf::ForOp loadLoop;
{
OpBuilder::InsertionGuard g(rewriter);
- loadLoop = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
- ValueRange{init});
+ loadLoop = scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step,
+ ValueRange{init});
rewriter.setInsertionPointToStart(loadLoop.getBody());
auto tileSliceIndex = loadLoop.getInductionVar();
- auto idx0 = rewriter.create<arith::AddIOp>(loc, tileSliceIndex,
- readOp.getIndices()[0]);
+ auto idx0 = arith::AddIOp::create(rewriter, loc, tileSliceIndex,
+ readOp.getIndices()[0]);
auto idx1 = readOp.getIndices()[1];
- Value scalar = rewriter.create<memref::LoadOp>(
- loc, readOp.getBase(), SmallVector<Value>({idx0, idx1}));
+ Value scalar = memref::LoadOp::create(rewriter, loc, readOp.getBase(),
+ SmallVector<Value>({idx0, idx1}));
- Operation *updateInit = rewriter.create<vector::InsertOp>(
- loc, scalar, loadLoop.getRegionIterArg(0), tileSliceIndex);
+ Operation *updateInit = vector::InsertOp::create(
+ rewriter, loc, scalar, loadLoop.getRegionIterArg(0), tileSliceIndex);
- rewriter.create<scf::YieldOp>(loc, updateInit->getResult(0));
+ scf::YieldOp::create(rewriter, loc, updateInit->getResult(0));
}
// The read operation has been "legalized", but since the original result
// type was a 2D vector, we need to cast before returning the result. This
// ShapeCast should cancel-out with some other ShapeCast (i.e. it's a
// no-op).
- auto sc = rewriter.create<vector::ShapeCastOp>(
- loc, readOp.getResult().getType(), loadLoop.getResult(0));
+ auto sc = vector::ShapeCastOp::create(
+ rewriter, loc, readOp.getResult().getType(), loadLoop.getResult(0));
rewriter.replaceOp(readOp, sc);
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 7b64e57..a7c6981 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -87,8 +87,8 @@ struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
VectorType sourceType = source.getType();
VectorType resultType = convertOp.getResult().getType();
- Value result = rewriter.create<arith::ConstantOp>(
- loc, resultType, rewriter.getZeroAttr(resultType));
+ Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+ rewriter.getZeroAttr(resultType));
// We want to iterate over the input vector in steps of the trailing
// dimension. So this creates tile shape where all leading dimensions are 1,
@@ -100,15 +100,15 @@ struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
for (SmallVector<int64_t> index :
StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
auto extractOrInsertPosition = ArrayRef(index).drop_back();
- auto sourceVector = rewriter.create<vector::ExtractOp>(
- loc, source, extractOrInsertPosition);
+ auto sourceVector = vector::ExtractOp::create(rewriter, loc, source,
+ extractOrInsertPosition);
VectorType convertedType =
VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType()))
.setDim(0, resultType.getShape().back());
auto convertedVector =
- rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector);
- result = rewriter.create<vector::InsertOp>(loc, convertedVector, result,
- extractOrInsertPosition);
+ IntrOp::create(rewriter, loc, TypeRange{convertedType}, sourceVector);
+ result = vector::InsertOp::create(rewriter, loc, convertedVector, result,
+ extractOrInsertPosition);
}
rewriter.replaceOp(convertOp, result);
@@ -135,12 +135,12 @@ struct PselOpLowering : public ConvertOpToLLVMPattern<PselOp> {
ConversionPatternRewriter &rewriter) const override {
auto svboolType = VectorType::get(16, rewriter.getI1Type(), true);
auto loc = pselOp.getLoc();
- auto svboolP1 = rewriter.create<ConvertToSvboolIntrOp>(loc, svboolType,
- adaptor.getP1());
- auto indexI32 = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI32Type(), pselOp.getIndex());
- auto pselIntr = rewriter.create<PselIntrOp>(loc, svboolType, svboolP1,
- pselOp.getP2(), indexI32);
+ auto svboolP1 = ConvertToSvboolIntrOp::create(rewriter, loc, svboolType,
+ adaptor.getP1());
+ auto indexI32 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI32Type(), pselOp.getIndex());
+ auto pselIntr = PselIntrOp::create(rewriter, loc, svboolType, svboolP1,
+ pselOp.getP2(), indexI32);
rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>(
pselOp, adaptor.getP1().getType(), pselIntr);
return success();
@@ -174,7 +174,7 @@ struct CreateMaskOpLowering
"not SVE predicate-sized");
auto loc = createMaskOp.getLoc();
- auto zero = rewriter.create<LLVM::ZeroOp>(loc, rewriter.getI64Type());
+ auto zero = LLVM::ZeroOp::create(rewriter, loc, rewriter.getI64Type());
rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero,
adaptor.getOperands()[0]);
return success();
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
index 3dbb93b..3a409ad 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
@@ -71,8 +71,8 @@ void replaceOpWithUnrealizedConversion(PatternRewriter &rewriter, TOp op,
TLegalizerCallback callback) {
replaceOpWithLegalizedOp(rewriter, op, [&](TOp newOp) {
// Mark our `unrealized_conversion_casts` with a pass label.
- return rewriter.create<UnrealizedConversionCastOp>(
- op.getLoc(), TypeRange{op.getResult().getType()},
+ return UnrealizedConversionCastOp::create(
+ rewriter, op.getLoc(), TypeRange{op.getResult().getType()},
ValueRange{callback(newOp)},
NamedAttribute(rewriter.getStringAttr(kSVELegalizerTag),
rewriter.getUnitAttr()));
@@ -239,8 +239,8 @@ struct LegalizeSVEMaskStoreConversion
auto legalMaskType = widenScalableMaskTypeToSvbool(
llvm::cast<VectorType>(valueToStore.getType()));
- auto convertToSvbool = rewriter.create<arm_sve::ConvertToSvboolOp>(
- loc, legalMaskType, valueToStore);
+ auto convertToSvbool = arm_sve::ConvertToSvboolOp::create(
+ rewriter, loc, legalMaskType, valueToStore);
// Replace this store with a conversion to a storable svbool mask [1],
// followed by a wider store.
replaceOpWithLegalizedOp(rewriter, storeOp,
@@ -290,8 +290,8 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern<memref::LoadOp> {
replaceOpWithLegalizedOp(rewriter, loadOp, [&](memref::LoadOp newLoadOp) {
newLoadOp.setMemRef(*legalMemref);
newLoadOp.getResult().setType(legalMaskType);
- return rewriter.create<arm_sve::ConvertFromSvboolOp>(
- loc, loadedMask.getType(), newLoadOp);
+ return arm_sve::ConvertFromSvboolOp::create(
+ rewriter, loc, loadedMask.getType(), newLoadOp);
});
return success();
@@ -408,8 +408,8 @@ struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> {
reassoc.back().push_back(i);
if (!memref::CollapseShapeOp::isGuaranteedCollapsible(memTy, reassoc))
return failure();
- Value collapsedMem = rewriter.create<memref::CollapseShapeOp>(
- readOp.getLoc(), readOp.getBase(), reassoc);
+ Value collapsedMem = memref::CollapseShapeOp::create(
+ rewriter, readOp.getLoc(), readOp.getBase(), reassoc);
// Get a vector type with collapsed trailing dimensions.
SmallVector<int64_t> shape(origVT.getShape());
@@ -424,14 +424,14 @@ struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> {
auto indices = readOp.getIndices().drop_back(numCollapseDims - 1);
// Create the new `transfer_read`.
- auto newReadOp = rewriter.create<vector::TransferReadOp>(
- readOp.getLoc(), collapsedVT, collapsedMem, indices,
+ auto newReadOp = vector::TransferReadOp::create(
+ rewriter, readOp.getLoc(), collapsedVT, collapsedMem, indices,
readOp.getPadding(),
ArrayRef<bool>(origInBounds).drop_back(numCollapseDims - 1));
// Cast back to the original vector type.
- auto toOrigShape = rewriter.create<vector::ShapeCastOp>(readOp.getLoc(),
- origVT, newReadOp);
+ auto toOrigShape = vector::ShapeCastOp::create(rewriter, readOp.getLoc(),
+ origVT, newReadOp);
rewriter.replaceOp(readOp, toOrigShape);
return success();
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
index ac1df38..35b0bd1 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
@@ -50,7 +50,7 @@ std::optional<Value> getExtOperand(Value v) {
// If the operand is not defined by an explicit extend operation of the
// accepted operation type allow for an implicit sign-extension.
- auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
+ auto extOp = v.getDefiningOp<Op>();
if (!extOp) {
if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
auto vTy = cast<VectorType>(v.getType());
@@ -214,13 +214,13 @@ Value VectorContractRewriter::createMMLA(PatternRewriter &rewriter,
switch (mmlaOp) {
case MMLA::SignedInt:
- return rewriter.create<arm_sve::SmmlaOp>(loc, resTy, acc, lhs, rhs);
+ return arm_sve::SmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs);
case MMLA::UnsignedInt:
- return rewriter.create<arm_sve::UmmlaOp>(loc, resTy, acc, lhs, rhs);
+ return arm_sve::UmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs);
case MMLA::MixedInt:
- return rewriter.create<arm_sve::UsmmlaOp>(loc, resTy, acc, lhs, rhs);
+ return arm_sve::UsmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs);
case MMLA::Bfloat:
- return rewriter.create<arm_sve::BfmmlaOp>(loc, resTy, acc, lhs, rhs);
+ return arm_sve::BfmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs);
default:
llvm_unreachable("Uninitialized operation kind");
}
@@ -316,62 +316,63 @@ Value VectorContractRewriter::lower(vector::ContractionOp op,
for (int64_t i = 0; i < M; i += 2) {
// Extract two consecutive rows of the LHS tile.
auto r0 =
- rewriter.create<vector::ExtractOp>(loc, lhs, ArrayRef<int64_t>{i});
+ vector::ExtractOp::create(rewriter, loc, lhs, ArrayRef<int64_t>{i});
auto r1 =
- rewriter.create<vector::ExtractOp>(loc, lhs, ArrayRef<int64_t>{i + 1});
+ vector::ExtractOp::create(rewriter, loc, lhs, ArrayRef<int64_t>{i + 1});
// Concatenate to obtain a 2 x K x <input-type> flattened sub-tile.
SmallVector<int64_t> shuffleIdx(2 * K);
std::iota(shuffleIdx.begin(), shuffleIdx.end(), 0);
- auto t = rewriter.create<vector::ShuffleOp>(loc, r0, r1, shuffleIdx);
+ auto t = vector::ShuffleOp::create(rewriter, loc, r0, r1, shuffleIdx);
// Turn it into a scalable vector.
- auto s = rewriter.create<vector::ScalableInsertOp>(
- loc, t, rewriter.create<ub::PoisonOp>(loc, flatLhsType), 0);
+ auto s = vector::ScalableInsertOp::create(
+ rewriter, loc, t, ub::PoisonOp::create(rewriter, loc, flatLhsType), 0);
// Replicate the sub-tile VSCALE times to fill the entire vector.
- auto r = rewriter.create<arm_sve::DupQLaneOp>(loc, s, 0);
+ auto r = arm_sve::DupQLaneOp::create(rewriter, loc, s, 0);
lhsTile.push_back(r);
}
// "Flatten" the RHS tile from <[N]xK> to <[N*K]>.
- auto rhs = rewriter.create<vector::ShapeCastOp>(this->rhs.getLoc(),
- flatRhsTileType, this->rhs);
+ auto rhs = vector::ShapeCastOp::create(rewriter, this->rhs.getLoc(),
+ flatRhsTileType, this->rhs);
// Extract the RHS sub-tiles with logical shape <Kx[2]>.
SmallVector<Value> rhsTile;
for (int64_t j = 0; j < N; j += 2)
- rhsTile.push_back(rewriter.create<vector::ScalableExtractOp>(
- loc, flatRhsType, rhs, j * K));
+ rhsTile.push_back(vector::ScalableExtractOp::create(
+ rewriter, loc, flatRhsType, rhs, j * K));
// Extract and pack the ACC sub-tiles.
SmallVector<Value> accTile;
for (int64_t i = 0; i < M; i += 2) {
// Extract two consecutive rows of the accumulator tile.
- auto r0 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
- ArrayRef<int64_t>{i});
- auto r1 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
- ArrayRef<int64_t>{i + 1});
+ auto r0 = vector::ExtractOp::create(rewriter, loc, op.getAcc(),
+ ArrayRef<int64_t>{i});
+ auto r1 = vector::ExtractOp::create(rewriter, loc, op.getAcc(),
+ ArrayRef<int64_t>{i + 1});
Value accTileVec;
if (swapOperands) {
// We are performing the operation with swapped LHS and RHS we need to
// transpose each individual 2x2 tile of the accumulator and (later) the
// final result.
- accTileVec = rewriter.create<vector::InterleaveOp>(loc, r0, r1);
+ accTileVec = vector::InterleaveOp::create(rewriter, loc, r0, r1);
} else {
// Bitcast accumulator rows to double-width integer elements, so
// subsequent interleave/deinterleave work on pairs of elements.
- auto r0I64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r0);
- auto r1I64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r1);
+ auto r0I64 = vector::BitCastOp::create(rewriter, loc, accRow64Ty, r0);
+ auto r1I64 = vector::BitCastOp::create(rewriter, loc, accRow64Ty, r1);
// Interleave the rows, effectively flattening each 2x2 tile into 4
// consecutive elements.
- auto intrI64 = rewriter.create<vector::InterleaveOp>(loc, r0I64, r1I64);
+ auto intrI64 = vector::InterleaveOp::create(rewriter, loc, r0I64, r1I64);
// Bitcast back to original element type.
- accTileVec = rewriter.create<vector::BitCastOp>(loc, accRowX2Ty, intrI64);
+ accTileVec =
+ vector::BitCastOp::create(rewriter, loc, accRowX2Ty, intrI64);
}
// Extract ACC sub-tiles.
for (int64_t j = 0; j < N; j += 2)
- accTile.push_back(rewriter.create<vector::ScalableExtractOp>(
- loc, flatAccType, accTileVec, j * 2));
+ accTile.push_back(vector::ScalableExtractOp::create(
+ rewriter, loc, flatAccType, accTileVec, j * 2));
}
// Emit sub-tile matrix multiplications.
@@ -384,13 +385,13 @@ Value VectorContractRewriter::lower(vector::ContractionOp op,
}
// Unpack the OUT sub-tiles and insert into the result.
- Value result = rewriter.create<ub::PoisonOp>(loc, op.getResultType());
+ Value result = ub::PoisonOp::create(rewriter, loc, op.getResultType());
for (int64_t i = 0; i < M / 2; ++i) {
// Collect a number of sub-tiles in a row.
- Value row = rewriter.create<ub::PoisonOp>(loc, accRowX2Ty);
+ Value row = ub::PoisonOp::create(rewriter, loc, accRowX2Ty);
for (int64_t j = 0; j < N / 2; ++j)
- row = rewriter.create<vector::ScalableInsertOp>(
- loc, outTile[i * N / 2 + j], row, j * 4);
+ row = vector::ScalableInsertOp::create(
+ rewriter, loc, outTile[i * N / 2 + j], row, j * 4);
// Unpack the row to obtain two rows of the output. If we have the out
// sub-tiles transposed we obtain two consecutive output rows by
@@ -398,22 +399,22 @@ Value VectorContractRewriter::lower(vector::ContractionOp op,
// Otherwise, the interleave is by pairs.
Value out0, out1;
if (swapOperands) {
- auto tmp = rewriter.create<vector::DeinterleaveOp>(loc, row);
+ auto tmp = vector::DeinterleaveOp::create(rewriter, loc, row);
out0 = tmp.getRes1();
out1 = tmp.getRes2();
} else {
// Deinterleave by pairs.
- auto row64 = rewriter.create<vector::BitCastOp>(loc, accRowX264Ty, row);
- auto deintr64 = rewriter.create<vector::DeinterleaveOp>(loc, row64);
+ auto row64 = vector::BitCastOp::create(rewriter, loc, accRowX264Ty, row);
+ auto deintr64 = vector::DeinterleaveOp::create(rewriter, loc, row64);
// Bitcast back into original element type and insert into the result.
- out0 =
- rewriter.create<vector::BitCastOp>(loc, accRowTy, deintr64.getRes1());
- out1 =
- rewriter.create<vector::BitCastOp>(loc, accRowTy, deintr64.getRes2());
+ out0 = vector::BitCastOp::create(rewriter, loc, accRowTy,
+ deintr64.getRes1());
+ out1 = vector::BitCastOp::create(rewriter, loc, accRowTy,
+ deintr64.getRes2());
}
- result = rewriter.create<vector::InsertOp>(loc, out0, result, i * 2);
- result = rewriter.create<vector::InsertOp>(loc, out1, result, i * 2 + 1);
+ result = vector::InsertOp::create(rewriter, loc, out0, result, i * 2);
+ result = vector::InsertOp::create(rewriter, loc, out1, result, i * 2 + 1);
}
return result;
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 08a57db..dc7b07d 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -97,7 +97,7 @@ void ExecuteOp::build(OpBuilder &builder, OperationState &result,
// expected result is empty. Otherwise, leave this to the caller
// because we don't know which values to return from the execute op.
if (resultTypes.empty() && !bodyBuilder) {
- builder.create<async::YieldOp>(result.location, ValueRange());
+ async::YieldOp::create(builder, result.location, ValueRange());
} else if (bodyBuilder) {
bodyBuilder(builder, result.location, bodyBlock->getArguments());
}
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index bf6bfe2a..96283cd 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -190,8 +190,8 @@ static SmallVector<Value> delinearize(ImplicitLocOpBuilder &b, Value index,
assert(!tripCounts.empty() && "tripCounts must be not empty");
for (ssize_t i = tripCounts.size() - 1; i >= 0; --i) {
- coords[i] = b.create<arith::RemSIOp>(index, tripCounts[i]);
- index = b.create<arith::DivSIOp>(index, tripCounts[i]);
+ coords[i] = arith::RemSIOp::create(b, index, tripCounts[i]);
+ index = arith::DivSIOp::create(b, index, tripCounts[i]);
}
return coords;
@@ -275,15 +275,15 @@ static ParallelComputeFunction createParallelComputeFunction(
BlockArgument blockSize = args.blockSize();
// Constants used below.
- Value c0 = b.create<arith::ConstantIndexOp>(0);
- Value c1 = b.create<arith::ConstantIndexOp>(1);
+ Value c0 = arith::ConstantIndexOp::create(b, 0);
+ Value c1 = arith::ConstantIndexOp::create(b, 1);
// Materialize known constants as constant operation in the function body.
auto values = [&](ArrayRef<BlockArgument> args, ArrayRef<IntegerAttr> attrs) {
return llvm::to_vector(
llvm::map_range(llvm::zip(args, attrs), [&](auto tuple) -> Value {
if (IntegerAttr attr = std::get<1>(tuple))
- return b.create<arith::ConstantOp>(attr);
+ return arith::ConstantOp::create(b, attr);
return std::get<0>(tuple);
}));
};
@@ -302,17 +302,17 @@ static ParallelComputeFunction createParallelComputeFunction(
// one-dimensional iteration space.
Value tripCount = tripCounts[0];
for (unsigned i = 1; i < tripCounts.size(); ++i)
- tripCount = b.create<arith::MulIOp>(tripCount, tripCounts[i]);
+ tripCount = arith::MulIOp::create(b, tripCount, tripCounts[i]);
// Find one-dimensional iteration bounds: [blockFirstIndex, blockLastIndex]:
// blockFirstIndex = blockIndex * blockSize
- Value blockFirstIndex = b.create<arith::MulIOp>(blockIndex, blockSize);
+ Value blockFirstIndex = arith::MulIOp::create(b, blockIndex, blockSize);
// The last one-dimensional index in the block defined by the `blockIndex`:
// blockLastIndex = min(blockFirstIndex + blockSize, tripCount) - 1
- Value blockEnd0 = b.create<arith::AddIOp>(blockFirstIndex, blockSize);
- Value blockEnd1 = b.create<arith::MinSIOp>(blockEnd0, tripCount);
- Value blockLastIndex = b.create<arith::SubIOp>(blockEnd1, c1);
+ Value blockEnd0 = arith::AddIOp::create(b, blockFirstIndex, blockSize);
+ Value blockEnd1 = arith::MinSIOp::create(b, blockEnd0, tripCount);
+ Value blockLastIndex = arith::SubIOp::create(b, blockEnd1, c1);
// Convert one-dimensional indices to multi-dimensional coordinates.
auto blockFirstCoord = delinearize(b, blockFirstIndex, tripCounts);
@@ -325,7 +325,7 @@ static ParallelComputeFunction createParallelComputeFunction(
// dimension when inner compute dimension contains multiple blocks.
SmallVector<Value> blockEndCoord(op.getNumLoops());
for (size_t i = 0; i < blockLastCoord.size(); ++i)
- blockEndCoord[i] = b.create<arith::AddIOp>(blockLastCoord[i], c1);
+ blockEndCoord[i] = arith::AddIOp::create(b, blockLastCoord[i], c1);
// Construct a loop nest out of scf.for operations that will iterate over
// all coordinates in [blockFirstCoord, blockLastCoord] range.
@@ -368,21 +368,22 @@ static ParallelComputeFunction createParallelComputeFunction(
ImplicitLocOpBuilder b(loc, nestedBuilder);
// Compute induction variable for `loopIdx`.
- computeBlockInductionVars[loopIdx] = b.create<arith::AddIOp>(
- lowerBounds[loopIdx], b.create<arith::MulIOp>(iv, steps[loopIdx]));
+ computeBlockInductionVars[loopIdx] =
+ arith::AddIOp::create(b, lowerBounds[loopIdx],
+ arith::MulIOp::create(b, iv, steps[loopIdx]));
// Check if we are inside first or last iteration of the loop.
- isBlockFirstCoord[loopIdx] = b.create<arith::CmpIOp>(
- arith::CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]);
- isBlockLastCoord[loopIdx] = b.create<arith::CmpIOp>(
- arith::CmpIPredicate::eq, iv, blockLastCoord[loopIdx]);
+ isBlockFirstCoord[loopIdx] = arith::CmpIOp::create(
+ b, arith::CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]);
+ isBlockLastCoord[loopIdx] = arith::CmpIOp::create(
+ b, arith::CmpIPredicate::eq, iv, blockLastCoord[loopIdx]);
// Check if the previous loop is in its first or last iteration.
if (loopIdx > 0) {
- isBlockFirstCoord[loopIdx] = b.create<arith::AndIOp>(
- isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]);
- isBlockLastCoord[loopIdx] = b.create<arith::AndIOp>(
- isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]);
+ isBlockFirstCoord[loopIdx] = arith::AndIOp::create(
+ b, isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]);
+ isBlockLastCoord[loopIdx] = arith::AndIOp::create(
+ b, isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]);
}
// Keep building loop nest.
@@ -390,24 +391,24 @@ static ParallelComputeFunction createParallelComputeFunction(
if (loopIdx + 1 >= op.getNumLoops() - numBlockAlignedInnerLoops) {
// For block aligned loops we always iterate starting from 0 up to
// the loop trip counts.
- b.create<scf::ForOp>(c0, tripCounts[loopIdx + 1], c1, ValueRange(),
- workLoopBuilder(loopIdx + 1));
+ scf::ForOp::create(b, c0, tripCounts[loopIdx + 1], c1, ValueRange(),
+ workLoopBuilder(loopIdx + 1));
} else {
// Select nested loop lower/upper bounds depending on our position in
// the multi-dimensional iteration space.
- auto lb = b.create<arith::SelectOp>(isBlockFirstCoord[loopIdx],
- blockFirstCoord[loopIdx + 1], c0);
+ auto lb = arith::SelectOp::create(b, isBlockFirstCoord[loopIdx],
+ blockFirstCoord[loopIdx + 1], c0);
- auto ub = b.create<arith::SelectOp>(isBlockLastCoord[loopIdx],
- blockEndCoord[loopIdx + 1],
- tripCounts[loopIdx + 1]);
+ auto ub = arith::SelectOp::create(b, isBlockLastCoord[loopIdx],
+ blockEndCoord[loopIdx + 1],
+ tripCounts[loopIdx + 1]);
- b.create<scf::ForOp>(lb, ub, c1, ValueRange(),
- workLoopBuilder(loopIdx + 1));
+ scf::ForOp::create(b, lb, ub, c1, ValueRange(),
+ workLoopBuilder(loopIdx + 1));
}
- b.create<scf::YieldOp>(loc);
+ scf::YieldOp::create(b, loc);
return;
}
@@ -418,13 +419,13 @@ static ParallelComputeFunction createParallelComputeFunction(
for (auto &bodyOp : op.getRegion().front().without_terminator())
b.clone(bodyOp, mapping);
- b.create<scf::YieldOp>(loc);
+ scf::YieldOp::create(b, loc);
};
};
- b.create<scf::ForOp>(blockFirstCoord[0], blockEndCoord[0], c1, ValueRange(),
- workLoopBuilder(0));
- b.create<func::ReturnOp>(ValueRange());
+ scf::ForOp::create(b, blockFirstCoord[0], blockEndCoord[0], c1, ValueRange(),
+ workLoopBuilder(0));
+ func::ReturnOp::create(b, ValueRange());
return {op.getNumLoops(), func, std::move(computeFuncType.captures)};
}
@@ -484,8 +485,8 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
b.setInsertionPointToEnd(block);
Type indexTy = b.getIndexType();
- Value c1 = b.create<arith::ConstantIndexOp>(1);
- Value c2 = b.create<arith::ConstantIndexOp>(2);
+ Value c1 = arith::ConstantIndexOp::create(b, 1);
+ Value c2 = arith::ConstantIndexOp::create(b, 2);
// Get the async group that will track async dispatch completion.
Value group = block->getArgument(0);
@@ -500,7 +501,7 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
SmallVector<Location> locations = {loc, loc};
// Create a recursive dispatch loop.
- scf::WhileOp whileOp = b.create<scf::WhileOp>(types, operands);
+ scf::WhileOp whileOp = scf::WhileOp::create(b, types, operands);
Block *before = b.createBlock(&whileOp.getBefore(), {}, types, locations);
Block *after = b.createBlock(&whileOp.getAfter(), {}, types, locations);
@@ -510,10 +511,10 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
b.setInsertionPointToEnd(before);
Value start = before->getArgument(0);
Value end = before->getArgument(1);
- Value distance = b.create<arith::SubIOp>(end, start);
+ Value distance = arith::SubIOp::create(b, end, start);
Value dispatch =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, distance, c1);
- b.create<scf::ConditionOp>(dispatch, before->getArguments());
+ arith::CmpIOp::create(b, arith::CmpIPredicate::sgt, distance, c1);
+ scf::ConditionOp::create(b, dispatch, before->getArguments());
}
// Setup the async dispatch loop body: recursively call dispatch function
@@ -522,9 +523,9 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
b.setInsertionPointToEnd(after);
Value start = after->getArgument(0);
Value end = after->getArgument(1);
- Value distance = b.create<arith::SubIOp>(end, start);
- Value halfDistance = b.create<arith::DivSIOp>(distance, c2);
- Value midIndex = b.create<arith::AddIOp>(start, halfDistance);
+ Value distance = arith::SubIOp::create(b, end, start);
+ Value halfDistance = arith::DivSIOp::create(b, distance, c2);
+ Value midIndex = arith::AddIOp::create(b, start, halfDistance);
// Call parallel compute function inside the async.execute region.
auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
@@ -535,16 +536,16 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
operands[1] = midIndex;
operands[2] = end;
- executeBuilder.create<func::CallOp>(executeLoc, func.getSymName(),
- func.getResultTypes(), operands);
- executeBuilder.create<async::YieldOp>(executeLoc, ValueRange());
+ func::CallOp::create(executeBuilder, executeLoc, func.getSymName(),
+ func.getResultTypes(), operands);
+ async::YieldOp::create(executeBuilder, executeLoc, ValueRange());
};
// Create async.execute operation to dispatch half of the block range.
- auto execute = b.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(),
- executeBodyBuilder);
- b.create<AddToGroupOp>(indexTy, execute.getToken(), group);
- b.create<scf::YieldOp>(ValueRange({start, midIndex}));
+ auto execute = ExecuteOp::create(b, TypeRange(), ValueRange(), ValueRange(),
+ executeBodyBuilder);
+ AddToGroupOp::create(b, indexTy, execute.getToken(), group);
+ scf::YieldOp::create(b, ValueRange({start, midIndex}));
}
// After dispatching async operations to process the tail of the block range
@@ -556,10 +557,9 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
SmallVector<Value> computeFuncOperands = {blockStart};
computeFuncOperands.append(forwardedInputs.begin(), forwardedInputs.end());
- b.create<func::CallOp>(computeFunc.func.getSymName(),
- computeFunc.func.getResultTypes(),
- computeFuncOperands);
- b.create<func::ReturnOp>(ValueRange());
+ func::CallOp::create(b, computeFunc.func.getSymName(),
+ computeFunc.func.getResultTypes(), computeFuncOperands);
+ func::ReturnOp::create(b, ValueRange());
return func;
}
@@ -577,8 +577,8 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
func::FuncOp asyncDispatchFunction =
createAsyncDispatchFunction(parallelComputeFunction, rewriter);
- Value c0 = b.create<arith::ConstantIndexOp>(0);
- Value c1 = b.create<arith::ConstantIndexOp>(1);
+ Value c0 = arith::ConstantIndexOp::create(b, 0);
+ Value c1 = arith::ConstantIndexOp::create(b, 1);
// Appends operands shared by async dispatch and parallel compute functions to
// the given operands vector.
@@ -594,7 +594,7 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
// completely. If this will be known statically, then canonicalization will
// erase async group operations.
Value isSingleBlock =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, blockCount, c1);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::eq, blockCount, c1);
auto syncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
ImplicitLocOpBuilder b(loc, nestedBuilder);
@@ -603,10 +603,10 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
SmallVector<Value> operands = {c0, blockSize};
appendBlockComputeOperands(operands);
- b.create<func::CallOp>(parallelComputeFunction.func.getSymName(),
- parallelComputeFunction.func.getResultTypes(),
- operands);
- b.create<scf::YieldOp>();
+ func::CallOp::create(b, parallelComputeFunction.func.getSymName(),
+ parallelComputeFunction.func.getResultTypes(),
+ operands);
+ scf::YieldOp::create(b);
};
auto asyncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
@@ -615,24 +615,24 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
// Create an async.group to wait on all async tokens from the concurrent
// execution of multiple parallel compute function. First block will be
// executed synchronously in the caller thread.
- Value groupSize = b.create<arith::SubIOp>(blockCount, c1);
- Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize);
+ Value groupSize = arith::SubIOp::create(b, blockCount, c1);
+ Value group = CreateGroupOp::create(b, GroupType::get(ctx), groupSize);
// Launch async dispatch function for [0, blockCount) range.
SmallVector<Value> operands = {group, c0, blockCount, blockSize};
appendBlockComputeOperands(operands);
- b.create<func::CallOp>(asyncDispatchFunction.getSymName(),
- asyncDispatchFunction.getResultTypes(), operands);
+ func::CallOp::create(b, asyncDispatchFunction.getSymName(),
+ asyncDispatchFunction.getResultTypes(), operands);
// Wait for the completion of all parallel compute operations.
- b.create<AwaitAllOp>(group);
+ AwaitAllOp::create(b, group);
- b.create<scf::YieldOp>();
+ scf::YieldOp::create(b);
};
// Dispatch either single block compute function, or launch async dispatch.
- b.create<scf::IfOp>(isSingleBlock, syncDispatch, asyncDispatch);
+ scf::IfOp::create(b, isSingleBlock, syncDispatch, asyncDispatch);
}
// Dispatch parallel compute functions by submitting all async compute tasks
@@ -646,14 +646,14 @@ doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
func::FuncOp compute = parallelComputeFunction.func;
- Value c0 = b.create<arith::ConstantIndexOp>(0);
- Value c1 = b.create<arith::ConstantIndexOp>(1);
+ Value c0 = arith::ConstantIndexOp::create(b, 0);
+ Value c1 = arith::ConstantIndexOp::create(b, 1);
// Create an async.group to wait on all async tokens from the concurrent
// execution of multiple parallel compute function. First block will be
// executed synchronously in the caller thread.
- Value groupSize = b.create<arith::SubIOp>(blockCount, c1);
- Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize);
+ Value groupSize = arith::SubIOp::create(b, blockCount, c1);
+ Value group = CreateGroupOp::create(b, GroupType::get(ctx), groupSize);
// Call parallel compute function for all blocks.
using LoopBodyBuilder =
@@ -680,28 +680,27 @@ doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
// Call parallel compute function inside the async.execute region.
auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
Location executeLoc, ValueRange executeArgs) {
- executeBuilder.create<func::CallOp>(executeLoc, compute.getSymName(),
- compute.getResultTypes(),
- computeFuncOperands(iv));
- executeBuilder.create<async::YieldOp>(executeLoc, ValueRange());
+ func::CallOp::create(executeBuilder, executeLoc, compute.getSymName(),
+ compute.getResultTypes(), computeFuncOperands(iv));
+ async::YieldOp::create(executeBuilder, executeLoc, ValueRange());
};
// Create async.execute operation to launch parallel computate function.
- auto execute = b.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(),
- executeBodyBuilder);
- b.create<AddToGroupOp>(rewriter.getIndexType(), execute.getToken(), group);
- b.create<scf::YieldOp>();
+ auto execute = ExecuteOp::create(b, TypeRange(), ValueRange(), ValueRange(),
+ executeBodyBuilder);
+ AddToGroupOp::create(b, rewriter.getIndexType(), execute.getToken(), group);
+ scf::YieldOp::create(b);
};
// Iterate over all compute blocks and launch parallel compute operations.
- b.create<scf::ForOp>(c1, blockCount, c1, ValueRange(), loopBuilder);
+ scf::ForOp::create(b, c1, blockCount, c1, ValueRange(), loopBuilder);
// Call parallel compute function for the first block in the caller thread.
- b.create<func::CallOp>(compute.getSymName(), compute.getResultTypes(),
- computeFuncOperands(c0));
+ func::CallOp::create(b, compute.getSymName(), compute.getResultTypes(),
+ computeFuncOperands(c0));
// Wait for the completion of all async compute operations.
- b.create<AwaitAllOp>(group);
+ AwaitAllOp::create(b, group);
}
LogicalResult
@@ -737,17 +736,17 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
// for the scf.parallel operation.
Value tripCount = tripCounts[0];
for (size_t i = 1; i < tripCounts.size(); ++i)
- tripCount = b.create<arith::MulIOp>(tripCount, tripCounts[i]);
+ tripCount = arith::MulIOp::create(b, tripCount, tripCounts[i]);
// Short circuit no-op parallel loops (zero iterations) that can arise from
// the memrefs with dynamic dimension(s) equal to zero.
- Value c0 = b.create<arith::ConstantIndexOp>(0);
+ Value c0 = arith::ConstantIndexOp::create(b, 0);
Value isZeroIterations =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, tripCount, c0);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::eq, tripCount, c0);
// Do absolutely nothing if the trip count is zero.
auto noOp = [&](OpBuilder &nestedBuilder, Location loc) {
- nestedBuilder.create<scf::YieldOp>(loc);
+ scf::YieldOp::create(nestedBuilder, loc);
};
// Compute the parallel block size and dispatch concurrent tasks computing
@@ -797,9 +796,9 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
Value numWorkerThreadsVal;
if (numWorkerThreads >= 0)
- numWorkerThreadsVal = b.create<arith::ConstantIndexOp>(numWorkerThreads);
+ numWorkerThreadsVal = arith::ConstantIndexOp::create(b, numWorkerThreads);
else
- numWorkerThreadsVal = b.create<async::RuntimeNumWorkerThreadsOp>();
+ numWorkerThreadsVal = async::RuntimeNumWorkerThreadsOp::create(b);
// With large number of threads the value of creating many compute blocks
// is reduced because the problem typically becomes memory bound. For this
@@ -818,38 +817,38 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
{4, 4.0f}, {8, 2.0f}, {16, 1.0f}, {32, 0.8f}, {64, 0.6f}};
const float initialOvershardingFactor = 8.0f;
- Value scalingFactor = b.create<arith::ConstantFloatOp>(
- b.getF32Type(), llvm::APFloat(initialOvershardingFactor));
+ Value scalingFactor = arith::ConstantFloatOp::create(
+ b, b.getF32Type(), llvm::APFloat(initialOvershardingFactor));
for (const std::pair<int, float> &p : overshardingBrackets) {
- Value bracketBegin = b.create<arith::ConstantIndexOp>(p.first);
- Value inBracket = b.create<arith::CmpIOp>(
- arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin);
- Value bracketScalingFactor = b.create<arith::ConstantFloatOp>(
- b.getF32Type(), llvm::APFloat(p.second));
- scalingFactor = b.create<arith::SelectOp>(inBracket, bracketScalingFactor,
- scalingFactor);
+ Value bracketBegin = arith::ConstantIndexOp::create(b, p.first);
+ Value inBracket = arith::CmpIOp::create(
+ b, arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin);
+ Value bracketScalingFactor = arith::ConstantFloatOp::create(
+ b, b.getF32Type(), llvm::APFloat(p.second));
+ scalingFactor = arith::SelectOp::create(
+ b, inBracket, bracketScalingFactor, scalingFactor);
}
Value numWorkersIndex =
- b.create<arith::IndexCastOp>(b.getI32Type(), numWorkerThreadsVal);
+ arith::IndexCastOp::create(b, b.getI32Type(), numWorkerThreadsVal);
Value numWorkersFloat =
- b.create<arith::SIToFPOp>(b.getF32Type(), numWorkersIndex);
+ arith::SIToFPOp::create(b, b.getF32Type(), numWorkersIndex);
Value scaledNumWorkers =
- b.create<arith::MulFOp>(scalingFactor, numWorkersFloat);
+ arith::MulFOp::create(b, scalingFactor, numWorkersFloat);
Value scaledNumInt =
- b.create<arith::FPToSIOp>(b.getI32Type(), scaledNumWorkers);
+ arith::FPToSIOp::create(b, b.getI32Type(), scaledNumWorkers);
Value scaledWorkers =
- b.create<arith::IndexCastOp>(b.getIndexType(), scaledNumInt);
+ arith::IndexCastOp::create(b, b.getIndexType(), scaledNumInt);
- Value maxComputeBlocks = b.create<arith::MaxSIOp>(
- b.create<arith::ConstantIndexOp>(1), scaledWorkers);
+ Value maxComputeBlocks = arith::MaxSIOp::create(
+ b, arith::ConstantIndexOp::create(b, 1), scaledWorkers);
// Compute parallel block size from the parallel problem size:
// blockSize = min(tripCount,
// max(ceil_div(tripCount, maxComputeBlocks),
// minTaskSize))
- Value bs0 = b.create<arith::CeilDivSIOp>(tripCount, maxComputeBlocks);
- Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSize);
- Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1);
+ Value bs0 = arith::CeilDivSIOp::create(b, tripCount, maxComputeBlocks);
+ Value bs1 = arith::MaxSIOp::create(b, bs0, minTaskSize);
+ Value blockSize = arith::MinSIOp::create(b, tripCount, bs1);
// Dispatch parallel compute function using async recursive work splitting,
// or by submitting compute task sequentially from a caller thread.
@@ -859,7 +858,7 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
// the parallel operation body for a subset of iteration space.
// Compute the number of parallel compute blocks.
- Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize);
+ Value blockCount = arith::CeilDivSIOp::create(b, tripCount, blockSize);
// Dispatch parallel compute function without hints to unroll inner loops.
auto dispatchDefault = [&](OpBuilder &nestedBuilder, Location loc) {
@@ -868,7 +867,7 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
ImplicitLocOpBuilder b(loc, nestedBuilder);
doDispatch(b, rewriter, compute, op, blockSize, blockCount, tripCounts);
- b.create<scf::YieldOp>();
+ scf::YieldOp::create(b);
};
// Dispatch parallel compute function with hints for unrolling inner loops.
@@ -879,34 +878,34 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
ImplicitLocOpBuilder b(loc, nestedBuilder);
// Align the block size to be a multiple of the statically known
// number of iterations in the inner loops.
- Value numIters = b.create<arith::ConstantIndexOp>(
- numIterations[op.getNumLoops() - numUnrollableLoops]);
- Value alignedBlockSize = b.create<arith::MulIOp>(
- b.create<arith::CeilDivSIOp>(blockSize, numIters), numIters);
+ Value numIters = arith::ConstantIndexOp::create(
+ b, numIterations[op.getNumLoops() - numUnrollableLoops]);
+ Value alignedBlockSize = arith::MulIOp::create(
+ b, arith::CeilDivSIOp::create(b, blockSize, numIters), numIters);
doDispatch(b, rewriter, compute, op, alignedBlockSize, blockCount,
tripCounts);
- b.create<scf::YieldOp>();
+ scf::YieldOp::create(b);
};
// Dispatch to block aligned compute function only if the computed block
// size is larger than the number of iterations in the unrollable inner
// loops, because otherwise it can reduce the available parallelism.
if (numUnrollableLoops > 0) {
- Value numIters = b.create<arith::ConstantIndexOp>(
- numIterations[op.getNumLoops() - numUnrollableLoops]);
- Value useBlockAlignedComputeFn = b.create<arith::CmpIOp>(
- arith::CmpIPredicate::sge, blockSize, numIters);
-
- b.create<scf::IfOp>(useBlockAlignedComputeFn, dispatchBlockAligned,
- dispatchDefault);
- b.create<scf::YieldOp>();
+ Value numIters = arith::ConstantIndexOp::create(
+ b, numIterations[op.getNumLoops() - numUnrollableLoops]);
+ Value useBlockAlignedComputeFn = arith::CmpIOp::create(
+ b, arith::CmpIPredicate::sge, blockSize, numIters);
+
+ scf::IfOp::create(b, useBlockAlignedComputeFn, dispatchBlockAligned,
+ dispatchDefault);
+ scf::YieldOp::create(b);
} else {
dispatchDefault(b, loc);
}
};
// Replace the `scf.parallel` operation with the parallel compute function.
- b.create<scf::IfOp>(isZeroIterations, noOp, dispatch);
+ scf::IfOp::create(b, isZeroIterations, noOp, dispatch);
// Parallel operation was replaced with a block iteration loop.
rewriter.eraseOp(op);
@@ -921,7 +920,7 @@ void AsyncParallelForPass::runOnOperation() {
populateAsyncParallelForPatterns(
patterns, asyncDispatch, numWorkerThreads,
[&](ImplicitLocOpBuilder builder, scf::ParallelOp op) {
- return builder.create<arith::ConstantIndexOp>(minTaskSize);
+ return arith::ConstantIndexOp::create(builder, minTaskSize);
});
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
index 0da9b3a..ddc64ea 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
@@ -48,7 +48,7 @@ static LogicalResult dropRefIfNoUses(Value value, unsigned count = 1) {
else
b.setInsertionPointToStart(value.getParentBlock());
- b.create<RuntimeDropRefOp>(value.getLoc(), value, b.getI64IntegerAttr(1));
+ RuntimeDropRefOp::create(b, value.getLoc(), value, b.getI64IntegerAttr(1));
return success();
}
@@ -309,7 +309,7 @@ LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) {
// Add a drop_ref immediately after the last user.
builder.setInsertionPointAfter(lastUser);
- builder.create<RuntimeDropRefOp>(loc, value, builder.getI64IntegerAttr(1));
+ RuntimeDropRefOp::create(builder, loc, value, builder.getI64IntegerAttr(1));
}
return success();
@@ -327,7 +327,7 @@ AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) {
// Add a reference before the function call to pass the value at `+1`
// reference to the function entry block.
builder.setInsertionPoint(user);
- builder.create<RuntimeAddRefOp>(loc, value, builder.getI64IntegerAttr(1));
+ RuntimeAddRefOp::create(builder, loc, value, builder.getI64IntegerAttr(1));
}
return success();
@@ -411,12 +411,12 @@ AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor(
refCountingBlock = &successor->getParent()->emplaceBlock();
refCountingBlock->moveBefore(successor);
OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock);
- builder.create<cf::BranchOp>(value.getLoc(), successor);
+ cf::BranchOp::create(builder, value.getLoc(), successor);
}
OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock);
- builder.create<RuntimeDropRefOp>(value.getLoc(), value,
- builder.getI64IntegerAttr(1));
+ RuntimeDropRefOp::create(builder, value.getLoc(), value,
+ builder.getI64IntegerAttr(1));
// No need to update the terminator operation.
if (successor == refCountingBlock)
@@ -507,13 +507,13 @@ AsyncRuntimePolicyBasedRefCountingPass::addRefCounting(Value value) {
// Create `add_ref` operation before the operand owner.
if (cnt > 0) {
b.setInsertionPoint(operand.getOwner());
- b.create<RuntimeAddRefOp>(loc, value, b.getI64IntegerAttr(cnt));
+ RuntimeAddRefOp::create(b, loc, value, b.getI64IntegerAttr(cnt));
}
// Create `drop_ref` operation after the operand owner.
if (cnt < 0) {
b.setInsertionPointAfter(operand.getOwner());
- b.create<RuntimeDropRefOp>(loc, value, b.getI64IntegerAttr(-cnt));
+ RuntimeDropRefOp::create(b, loc, value, b.getI64IntegerAttr(-cnt));
}
}
}
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 44a3837..112d69c 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -186,22 +186,22 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) {
std::optional<Value> retToken;
if (isStateful)
- retToken.emplace(builder.create<RuntimeCreateOp>(TokenType::get(ctx)));
+ retToken.emplace(RuntimeCreateOp::create(builder, TokenType::get(ctx)));
llvm::SmallVector<Value, 4> retValues;
ArrayRef<Type> resValueTypes =
isStateful ? func.getResultTypes().drop_front() : func.getResultTypes();
for (auto resType : resValueTypes)
retValues.emplace_back(
- builder.create<RuntimeCreateOp>(resType).getResult());
+ RuntimeCreateOp::create(builder, resType).getResult());
// ------------------------------------------------------------------------ //
// Initialize coroutine: get coroutine id and coroutine handle.
// ------------------------------------------------------------------------ //
- auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
+ auto coroIdOp = CoroIdOp::create(builder, CoroIdType::get(ctx));
auto coroHdlOp =
- builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.getId());
- builder.create<cf::BranchOp>(originalEntryBlock);
+ CoroBeginOp::create(builder, CoroHandleType::get(ctx), coroIdOp.getId());
+ cf::BranchOp::create(builder, originalEntryBlock);
Block *cleanupBlock = func.addBlock();
Block *cleanupBlockForDestroy = func.addBlock();
@@ -212,10 +212,10 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) {
// ------------------------------------------------------------------------ //
auto buildCleanupBlock = [&](Block *cb) {
builder.setInsertionPointToStart(cb);
- builder.create<CoroFreeOp>(coroIdOp.getId(), coroHdlOp.getHandle());
+ CoroFreeOp::create(builder, coroIdOp.getId(), coroHdlOp.getHandle());
// Branch into the suspend block.
- builder.create<cf::BranchOp>(suspendBlock);
+ cf::BranchOp::create(builder, suspendBlock);
};
buildCleanupBlock(cleanupBlock);
buildCleanupBlock(cleanupBlockForDestroy);
@@ -227,7 +227,7 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) {
builder.setInsertionPointToStart(suspendBlock);
// Mark the end of a coroutine: async.coro.end
- builder.create<CoroEndOp>(coroHdlOp.getHandle());
+ CoroEndOp::create(builder, coroHdlOp.getHandle());
// Return created optional `async.token` and `async.values` from the suspend
// block. This will be the return value of a coroutine ramp function.
@@ -235,7 +235,7 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) {
if (retToken)
ret.push_back(*retToken);
llvm::append_range(ret, retValues);
- builder.create<func::ReturnOp>(ret);
+ func::ReturnOp::create(builder, ret);
// `async.await` op lowering will create resume blocks for async
// continuations, and will conditionally branch to cleanup or suspend blocks.
@@ -272,13 +272,13 @@ static Block *setupSetErrorBlock(CoroMachinery &coro) {
// Coroutine set_error block: set error on token and all returned values.
if (coro.asyncToken)
- builder.create<RuntimeSetErrorOp>(*coro.asyncToken);
+ RuntimeSetErrorOp::create(builder, *coro.asyncToken);
for (Value retValue : coro.returnValues)
- builder.create<RuntimeSetErrorOp>(retValue);
+ RuntimeSetErrorOp::create(builder, retValue);
// Branch into the cleanup block.
- builder.create<cf::BranchOp>(coro.cleanup);
+ cf::BranchOp::create(builder, coro.cleanup);
return *coro.setError;
}
@@ -333,13 +333,13 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
// Await on all dependencies before starting to execute the body region.
for (size_t i = 0; i < numDependencies; ++i)
- builder.create<AwaitOp>(func.getArgument(i));
+ AwaitOp::create(builder, func.getArgument(i));
// Await on all async value operands and unwrap the payload.
SmallVector<Value, 4> unwrappedOperands(numOperands);
for (size_t i = 0; i < numOperands; ++i) {
Value operand = func.getArgument(numDependencies + i);
- unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).getResult();
+ unwrappedOperands[i] = AwaitOp::create(builder, loc, operand).getResult();
}
// Map from function inputs defined above the execute op to the function
@@ -366,15 +366,15 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
// Save the coroutine state: async.coro.save
auto coroSaveOp =
- builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
+ CoroSaveOp::create(builder, CoroStateType::get(ctx), coro.coroHandle);
// Pass coroutine to the runtime to be resumed on a runtime managed
// thread.
- builder.create<RuntimeResumeOp>(coro.coroHandle);
+ RuntimeResumeOp::create(builder, coro.coroHandle);
// Add async.coro.suspend as a suspended block terminator.
- builder.create<CoroSuspendOp>(coroSaveOp.getState(), coro.suspend,
- branch.getDest(), coro.cleanupForDestroy);
+ CoroSuspendOp::create(builder, coroSaveOp.getState(), coro.suspend,
+ branch.getDest(), coro.cleanupForDestroy);
branch.erase();
}
@@ -382,8 +382,9 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
// Replace the original `async.execute` with a call to outlined function.
{
ImplicitLocOpBuilder callBuilder(loc, execute);
- auto callOutlinedFunc = callBuilder.create<func::CallOp>(
- func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
+ auto callOutlinedFunc = func::CallOp::create(callBuilder, func.getName(),
+ execute.getResultTypes(),
+ functionInputs.getArrayRef());
execute.replaceAllUsesWith(callOutlinedFunc.getResults());
execute.erase();
}
@@ -451,7 +452,7 @@ public:
Location loc = op->getLoc();
auto newFuncOp =
- rewriter.create<func::FuncOp>(loc, op.getName(), op.getFunctionType());
+ func::FuncOp::create(rewriter, loc, op.getName(), op.getFunctionType());
SymbolTable::setSymbolVisibility(newFuncOp,
SymbolTable::getSymbolVisibility(op));
@@ -521,16 +522,16 @@ public:
for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
Value returnValue = std::get<0>(tuple);
Value asyncValue = std::get<1>(tuple);
- rewriter.create<RuntimeStoreOp>(loc, returnValue, asyncValue);
- rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
+ RuntimeStoreOp::create(rewriter, loc, returnValue, asyncValue);
+ RuntimeSetAvailableOp::create(rewriter, loc, asyncValue);
}
if (coro.asyncToken)
// Switch the coroutine completion token to available state.
- rewriter.create<RuntimeSetAvailableOp>(loc, *coro.asyncToken);
+ RuntimeSetAvailableOp::create(rewriter, loc, *coro.asyncToken);
rewriter.eraseOp(op);
- rewriter.create<cf::BranchOp>(loc, coro.cleanup);
+ cf::BranchOp::create(rewriter, loc, coro.cleanup);
return success();
}
@@ -581,16 +582,17 @@ public:
// the async object (token, value or group) to become available.
if (!isInCoroutine) {
ImplicitLocOpBuilder builder(loc, rewriter);
- builder.create<RuntimeAwaitOp>(loc, operand);
+ RuntimeAwaitOp::create(builder, loc, operand);
// Assert that the awaited operands is not in the error state.
- Value isError = builder.create<RuntimeIsErrorOp>(i1, operand);
- Value notError = builder.create<arith::XOrIOp>(
- isError, builder.create<arith::ConstantOp>(
- loc, i1, builder.getIntegerAttr(i1, 1)));
-
- builder.create<cf::AssertOp>(notError,
- "Awaited async operand is in error state");
+ Value isError = RuntimeIsErrorOp::create(builder, i1, operand);
+ Value notError = arith::XOrIOp::create(
+ builder, isError,
+ arith::ConstantOp::create(builder, loc, i1,
+ builder.getIntegerAttr(i1, 1)));
+
+ cf::AssertOp::create(builder, notError,
+ "Awaited async operand is in error state");
}
// Inside the coroutine we convert await operation into coroutine suspension
@@ -605,28 +607,28 @@ public:
// Save the coroutine state and resume on a runtime managed thread when
// the operand becomes available.
auto coroSaveOp =
- builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
- builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
+ CoroSaveOp::create(builder, CoroStateType::get(ctx), coro.coroHandle);
+ RuntimeAwaitAndResumeOp::create(builder, operand, coro.coroHandle);
// Split the entry block before the await operation.
Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
// Add async.coro.suspend as a suspended block terminator.
builder.setInsertionPointToEnd(suspended);
- builder.create<CoroSuspendOp>(coroSaveOp.getState(), coro.suspend, resume,
- coro.cleanupForDestroy);
+ CoroSuspendOp::create(builder, coroSaveOp.getState(), coro.suspend,
+ resume, coro.cleanupForDestroy);
// Split the resume block into error checking and continuation.
Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
// Check if the awaited value is in the error state.
builder.setInsertionPointToStart(resume);
- auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
- builder.create<cf::CondBranchOp>(isError,
- /*trueDest=*/setupSetErrorBlock(coro),
- /*trueArgs=*/ArrayRef<Value>(),
- /*falseDest=*/continuation,
- /*falseArgs=*/ArrayRef<Value>());
+ auto isError = RuntimeIsErrorOp::create(builder, loc, i1, operand);
+ cf::CondBranchOp::create(builder, isError,
+ /*trueDest=*/setupSetErrorBlock(coro),
+ /*trueArgs=*/ArrayRef<Value>(),
+ /*falseDest=*/continuation,
+ /*falseArgs=*/ArrayRef<Value>());
// Make sure that replacement value will be constructed in the
// continuation block.
@@ -672,7 +674,7 @@ public:
ConversionPatternRewriter &rewriter) const override {
// Load from the async value storage.
auto valueType = cast<ValueType>(operand.getType()).getValueType();
- return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
+ return RuntimeLoadOp::create(rewriter, op->getLoc(), valueType, operand);
}
};
@@ -713,15 +715,15 @@ public:
for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
Value yieldValue = std::get<0>(tuple);
Value asyncValue = std::get<1>(tuple);
- rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
- rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
+ RuntimeStoreOp::create(rewriter, loc, yieldValue, asyncValue);
+ RuntimeSetAvailableOp::create(rewriter, loc, asyncValue);
}
if (coro.asyncToken)
// Switch the coroutine completion token to available state.
- rewriter.create<RuntimeSetAvailableOp>(loc, *coro.asyncToken);
+ RuntimeSetAvailableOp::create(rewriter, loc, *coro.asyncToken);
- rewriter.create<cf::BranchOp>(loc, coro.cleanup);
+ cf::BranchOp::create(rewriter, loc, coro.cleanup);
rewriter.eraseOp(op);
return success();
@@ -755,11 +757,11 @@ public:
Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
rewriter.setInsertionPointToEnd(cont->getPrevNode());
- rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(),
- /*trueDest=*/cont,
- /*trueArgs=*/ArrayRef<Value>(),
- /*falseDest=*/setupSetErrorBlock(coro),
- /*falseArgs=*/ArrayRef<Value>());
+ cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(),
+ /*trueDest=*/cont,
+ /*trueArgs=*/ArrayRef<Value>(),
+ /*falseDest=*/setupSetErrorBlock(coro),
+ /*falseArgs=*/ArrayRef<Value>());
rewriter.eraseOp(op);
return success();
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
index 2bf326a..4dfba74 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
@@ -35,7 +35,7 @@ using namespace bufferization;
//===----------------------------------------------------------------------===//
static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
- return builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(value));
+ return arith::ConstantOp::create(builder, loc, builder.getBoolAttr(value));
}
static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
@@ -150,7 +150,7 @@ DeallocationState::getMemrefWithUniqueOwnership(OpBuilder &builder,
// ownerships more intelligently to not end up with an 'Unknown' ownership in
// the first place.
auto cloneOp =
- builder.create<bufferization::CloneOp>(memref.getLoc(), memref);
+ bufferization::CloneOp::create(builder, memref.getLoc(), memref);
Value condition = buildBoolValue(builder, memref.getLoc(), true);
Value newMemref = cloneOp.getResult();
updateOwnership(newMemref, condition);
@@ -196,8 +196,8 @@ LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate(
// Simply cast unranked MemRefs to ranked memrefs with 0 dimensions such
// that we can call extract_strided_metadata on it.
if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType()))
- memref = builder.create<memref::ReinterpretCastOp>(
- loc, memref,
+ memref = memref::ReinterpretCastOp::create(
+ builder, loc, memref,
/*offset=*/builder.getIndexAttr(0),
/*sizes=*/ArrayRef<OpFoldResult>{},
/*strides=*/ArrayRef<OpFoldResult>{});
@@ -207,7 +207,7 @@ LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate(
// alloc operation has to be passed to the dealloc operation. Passing
// subviews, etc. to a dealloc operation is not allowed.
memrefs.push_back(
- builder.create<memref::ExtractStridedMetadataOp>(loc, memref)
+ memref::ExtractStridedMetadataOp::create(builder, loc, memref)
.getResult(0));
conditions.push_back(ownership.getIndicator());
}
@@ -296,8 +296,8 @@ FailureOr<Operation *> deallocation_impl::insertDeallocOpForReturnLike(
if (memrefs.empty() && toRetain.empty())
return op;
- auto deallocOp = builder.create<bufferization::DeallocOp>(
- op->getLoc(), memrefs, conditions, toRetain);
+ auto deallocOp = bufferization::DeallocOp::create(
+ builder, op->getLoc(), memrefs, conditions, toRetain);
// We want to replace the current ownership of the retained values with the
// result values of the dealloc operation as they are always unique.
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 8f17a82f..f7b0b87 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -18,7 +18,6 @@
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/ADT/ScopeExit.h"
-#include "llvm/Support/Debug.h"
//===----------------------------------------------------------------------===//
// BufferizableOpInterface
@@ -35,8 +34,6 @@ namespace bufferization {
MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::AnalysisState)
#define DEBUG_TYPE "bufferizable-op-interface"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << (X))
using namespace mlir;
using namespace bufferization;
@@ -170,8 +167,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
if (llvm::isa<RankedTensorType>(shapedValue.getType())) {
tensor = shapedValue;
} else if (llvm::isa<MemRefType>(shapedValue.getType())) {
- tensor = b.create<ToTensorOp>(
- loc, memref::getTensorTypeFromMemRefType(shapedValue.getType()),
+ tensor = ToTensorOp::create(
+ b, loc, memref::getTensorTypeFromMemRefType(shapedValue.getType()),
shapedValue);
} else if (llvm::isa<UnrankedTensorType>(shapedValue.getType()) ||
llvm::isa<UnrankedMemRefType>(shapedValue.getType())) {
@@ -209,8 +206,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
}
// Create AllocTensorOp.
- auto allocTensorOp = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes,
- copy ? tensor : Value());
+ auto allocTensorOp = AllocTensorOp::create(b, loc, tensorType, dynamicSizes,
+ copy ? tensor : Value());
// Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
if (copy)
@@ -691,8 +688,8 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
if (failed(bufferType))
return failure();
ensureToBufferOpIsValid(value, *bufferType);
- return rewriter
- .create<bufferization::ToBufferOp>(value.getLoc(), *bufferType, value)
+ return bufferization::ToBufferOp::create(rewriter, value.getLoc(),
+ *bufferType, value)
.getResult();
}
@@ -753,8 +750,8 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
// ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
// loose all of its users and eventually DCE away.
rewriter.setInsertionPointAfter(op);
- replacement = rewriter.create<bufferization::ToTensorOp>(
- replacement.getLoc(), opResult.getType(), replacement);
+ replacement = bufferization::ToTensorOp::create(
+ rewriter, replacement.getLoc(), opResult.getType(), replacement);
}
replacements.push_back(replacement);
}
@@ -775,11 +772,10 @@ FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
// Default bufferallocation via AllocOp.
if (bufferAlignment != 0)
- return b
- .create<memref::AllocOp>(loc, type, dynShape,
- b.getI64IntegerAttr(bufferAlignment))
+ return memref::AllocOp::create(b, loc, type, dynShape,
+ b.getI64IntegerAttr(bufferAlignment))
.getResult();
- return b.create<memref::AllocOp>(loc, type, dynShape).getResult();
+ return memref::AllocOp::create(b, loc, type, dynShape).getResult();
}
/// Create a memory copy between two memref buffers.
@@ -788,7 +784,7 @@ LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
if (memCpyFn)
return (*memCpyFn)(b, loc, from, to);
- b.create<memref::CopyOp>(loc, from, to);
+ memref::CopyOp::create(b, loc, from, to);
return success();
}
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 875a065..7eb729f 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -58,7 +58,7 @@ FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
// a fix extra conditions in `isGuaranteedCastCompatible`.
if (memref::CastOp::areCastCompatible(srcType, destType) &&
isGuaranteedCastCompatible(srcType, destType)) {
- Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
+ Value casted = memref::CastOp::create(b, value.getLoc(), destType, value);
return casted;
}
@@ -67,7 +67,7 @@ FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
for (int i = 0; i < destType.getRank(); ++i) {
if (destType.getShape()[i] != ShapedType::kDynamic)
continue;
- Value size = b.create<memref::DimOp>(loc, value, i);
+ Value size = memref::DimOp::create(b, loc, value, i);
dynamicOperands.push_back(size);
}
@@ -134,10 +134,10 @@ void mlir::bufferization::populateDynamicDimSizes(
for (int64_t i = 0; i < shapedType.getRank(); ++i) {
if (shapedType.isDynamicDim(i)) {
if (llvm::isa<MemRefType>(shapedType)) {
- dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i));
+ dynamicDims.push_back(memref::DimOp::create(b, loc, shapedValue, i));
} else {
assert(llvm::isa<RankedTensorType>(shapedType) && "expected tensor");
- dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
+ dynamicDims.push_back(tensor::DimOp::create(b, loc, shapedValue, i));
}
}
}
@@ -321,8 +321,8 @@ struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
newShape, op.getType().getElementType(), op.getType().getEncoding());
if (newType == op.getType())
return failure();
- auto newOp = rewriter.create<AllocTensorOp>(
- op.getLoc(), newType, newDynamicSizes, /*copy=*/Value());
+ auto newOp = AllocTensorOp::create(rewriter, op.getLoc(), newType,
+ newDynamicSizes, /*copy=*/Value());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
return success();
}
@@ -427,7 +427,7 @@ void AllocTensorOp::print(OpAsmPrinter &p) {
Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
assert(isDynamicDim(idx) && "expected dynamic dim");
if (getCopy())
- return b.create<tensor::DimOp>(getLoc(), getCopy(), idx);
+ return tensor::DimOp::create(b, getLoc(), getCopy(), idx);
return getOperand(getIndexOfDynamicSize(idx));
}
@@ -513,8 +513,8 @@ struct SimplifyClones : public OpRewritePattern<CloneOp> {
}
if (source.getType() != cloneOp.getType())
- source = rewriter.create<memref::CastOp>(cloneOp.getLoc(),
- cloneOp.getType(), source);
+ source = memref::CastOp::create(rewriter, cloneOp.getLoc(),
+ cloneOp.getType(), source);
rewriter.replaceOp(cloneOp, source);
rewriter.eraseOp(redundantDealloc);
return success();
@@ -538,7 +538,7 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options, state);
if (failed(buffer))
return failure();
- rewriter.create<memref::DeallocOp>(getLoc(), *buffer);
+ memref::DeallocOp::create(rewriter, getLoc(), *buffer);
rewriter.eraseOp(getOperation());
return success();
}
@@ -643,8 +643,9 @@ Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
assert(getRestrict() &&
"expected that ops with memrefs dest have 'restrict'");
setRestrict(false);
- return builder.create<ToTensorOp>(
- loc, memref::getTensorTypeFromMemRefType(getDest().getType()), getDest(),
+ return ToTensorOp::create(
+ builder, loc, memref::getTensorTypeFromMemRefType(getDest().getType()),
+ getDest(),
/*restrict=*/true, getWritable());
}
@@ -804,10 +805,18 @@ struct ToBufferOfCast : public OpRewritePattern<ToBufferOp> {
tensorCastOperand.getOperand().getType());
if (!srcTensorType)
return failure();
+ auto currentOutputMemRefType =
+ dyn_cast<MemRefType>(toBuffer.getResult().getType());
+ if (!currentOutputMemRefType)
+ return failure();
+
auto memrefType = MemRefType::get(srcTensorType.getShape(),
- srcTensorType.getElementType());
- Value memref = rewriter.create<ToBufferOp>(toBuffer.getLoc(), memrefType,
- tensorCastOperand.getOperand());
+ srcTensorType.getElementType(),
+ currentOutputMemRefType.getLayout(),
+ currentOutputMemRefType.getMemorySpace());
+ Value memref = ToBufferOp::create(rewriter, toBuffer.getLoc(), memrefType,
+ tensorCastOperand.getOperand(),
+ toBuffer.getReadOnly());
rewriter.replaceOpWithNewOp<memref::CastOp>(toBuffer, toBuffer.getType(),
memref);
return success();
@@ -880,12 +889,12 @@ LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter,
std::optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder,
Value alloc) {
- return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
+ return memref::DeallocOp::create(builder, alloc.getLoc(), alloc)
.getOperation();
}
std::optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
- return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
+ return CloneOp::create(builder, alloc.getLoc(), alloc).getResult();
}
//===----------------------------------------------------------------------===//
@@ -959,7 +968,7 @@ struct DeallocRemoveDuplicateDeallocMemrefs
Value &newCond = newConditions[memrefToCondition[memref]];
if (newCond != cond)
newCond =
- rewriter.create<arith::OrIOp>(deallocOp.getLoc(), newCond, cond);
+ arith::OrIOp::create(rewriter, deallocOp.getLoc(), newCond, cond);
} else {
memrefToCondition.insert({memref, newConditions.size()});
newMemrefs.push_back(memref);
@@ -1014,8 +1023,8 @@ struct DeallocRemoveDuplicateRetainedMemrefs
// We need to create a new op because the number of results is always the
// same as the number of condition operands.
auto newDeallocOp =
- rewriter.create<DeallocOp>(deallocOp.getLoc(), deallocOp.getMemrefs(),
- deallocOp.getConditions(), newRetained);
+ DeallocOp::create(rewriter, deallocOp.getLoc(), deallocOp.getMemrefs(),
+ deallocOp.getConditions(), newRetained);
SmallVector<Value> replacements(
llvm::map_range(resultReplacementIdx, [&](unsigned idx) {
return newDeallocOp.getUpdatedConditions()[idx];
@@ -1036,8 +1045,8 @@ struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
LogicalResult matchAndRewrite(DeallocOp deallocOp,
PatternRewriter &rewriter) const override {
if (deallocOp.getMemrefs().empty()) {
- Value constFalse = rewriter.create<arith::ConstantOp>(
- deallocOp.getLoc(), rewriter.getBoolAttr(false));
+ Value constFalse = arith::ConstantOp::create(rewriter, deallocOp.getLoc(),
+ rewriter.getBoolAttr(false));
rewriter.replaceOp(
deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(),
constFalse));
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index db1eb20..7f495b0 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -70,12 +70,12 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
*getFunctionBoundaryTypeConversion());
if (getMemcpyOp() == "memref.copy") {
options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
- b.create<memref::CopyOp>(loc, from, to);
+ memref::CopyOp::create(b, loc, from, to);
return success();
};
} else if (getMemcpyOp() == "linalg.copy") {
options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
- b.create<linalg::CopyOp>(loc, from, to);
+ linalg::CopyOp::create(b, loc, from, to);
return success();
};
} else {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index c5fab80..8916526 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -167,8 +167,8 @@ struct RemoveDeallocMemrefsContainedInRetained
std::optional<bool> analysisResult =
analysis.isSameAllocation(retained, memref);
if (analysisResult == true) {
- auto disjunction = rewriter.create<arith::OrIOp>(
- deallocOp.getLoc(), updatedCondition, cond);
+ auto disjunction = arith::OrIOp::create(rewriter, deallocOp.getLoc(),
+ updatedCondition, cond);
rewriter.replaceAllUsesExcept(updatedCondition, disjunction.getResult(),
disjunction);
}
@@ -247,16 +247,16 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
continue;
}
- replacements.push_back(rewriter.create<arith::ConstantOp>(
- deallocOp.getLoc(), rewriter.getBoolAttr(false)));
+ replacements.push_back(arith::ConstantOp::create(
+ rewriter, deallocOp.getLoc(), rewriter.getBoolAttr(false)));
}
if (newRetainedMemrefs.size() == deallocOp.getRetained().size())
return failure();
- auto newDeallocOp = rewriter.create<DeallocOp>(
- deallocOp.getLoc(), deallocOp.getMemrefs(), deallocOp.getConditions(),
- newRetainedMemrefs);
+ auto newDeallocOp =
+ DeallocOp::create(rewriter, deallocOp.getLoc(), deallocOp.getMemrefs(),
+ deallocOp.getConditions(), newRetainedMemrefs);
int i = 0;
for (auto &repl : replacements) {
if (!repl)
@@ -326,8 +326,8 @@ struct SplitDeallocWhenNotAliasingAnyOther
}
// Create new bufferization.dealloc op for `memref`.
- auto newDeallocOp = rewriter.create<DeallocOp>(loc, memref, cond,
- deallocOp.getRetained());
+ auto newDeallocOp = DeallocOp::create(rewriter, loc, memref, cond,
+ deallocOp.getRetained());
updatedConditions.push_back(
llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions())));
}
@@ -337,8 +337,9 @@ struct SplitDeallocWhenNotAliasingAnyOther
return failure();
// Create bufferization.dealloc op for all remaining memrefs.
- auto newDeallocOp = rewriter.create<DeallocOp>(
- loc, remainingMemrefs, remainingConditions, deallocOp.getRetained());
+ auto newDeallocOp =
+ DeallocOp::create(rewriter, loc, remainingMemrefs, remainingConditions,
+ deallocOp.getRetained());
// Bit-or all conditions.
SmallVector<Value> replacements =
@@ -347,8 +348,8 @@ struct SplitDeallocWhenNotAliasingAnyOther
assert(replacements.size() == additionalConditions.size() &&
"expected same number of updated conditions");
for (int64_t i = 0, e = replacements.size(); i < e; ++i) {
- replacements[i] = rewriter.create<arith::OrIOp>(
- loc, replacements[i], additionalConditions[i]);
+ replacements[i] = arith::OrIOp::create(rewriter, loc, replacements[i],
+ additionalConditions[i]);
}
}
rewriter.replaceOp(deallocOp, replacements);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 6924e88..e30e094 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -132,7 +132,7 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
return WalkResult::interrupt();
}
}
- builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
+ func::ReturnOp::create(builder, op.getLoc(), keepAsReturnOperands);
op.erase();
return WalkResult::advance();
});
@@ -190,7 +190,7 @@ updateCalls(ModuleOp module,
assert(hasFullyDynamicLayoutMap(memrefType) &&
"layout map not supported");
outParam =
- builder.create<memref::CastOp>(op.getLoc(), memrefType, outParam);
+ memref::CastOp::create(builder, op.getLoc(), memrefType, outParam);
}
memref.replaceAllUsesWith(outParam);
outParams.push_back(outParam);
@@ -200,8 +200,8 @@ updateCalls(ModuleOp module,
newOperands.append(outParams.begin(), outParams.end());
auto newResultTypes = llvm::to_vector<6>(llvm::map_range(
replaceWithNewCallResults, [](Value v) { return v.getType(); }));
- auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(),
- newResultTypes, newOperands);
+ auto newCall = func::CallOp::create(
+ builder, op.getLoc(), op.getCalleeAttr(), newResultTypes, newOperands);
for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
std::get<0>(t).replaceAllUsesWith(std::get<1>(t));
op.erase();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index a66be7d..c0e0809 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -141,8 +141,9 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp,
cast<MemRefType>(getMemRefTypeWithStaticIdentityLayout(type));
if (memorySpace)
memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace);
- auto global = globalBuilder.create<memref::GlobalOp>(
- constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
+ auto global = memref::GlobalOp::create(
+ globalBuilder, constantOp.getLoc(),
+ (Twine("__constant_") + os.str()).str(),
/*sym_visibility=*/globalBuilder.getStringAttr("private"),
/*type=*/memrefType,
/*initial_value=*/cast<ElementsAttr>(constantOp.getValue()),
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 246555d..91f6f25 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -434,8 +434,8 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
// Replace all uses of the original tensor bbArg.
rewriter.setInsertionPointToStart(block);
if (!bbArgUses.empty()) {
- Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
- bbArg.getLoc(), tensorType, bbArg);
+ Value toTensorOp = bufferization::ToTensorOp::create(
+ rewriter, bbArg.getLoc(), tensorType, bbArg);
for (OpOperand *use : bbArgUses)
use->set(toTensorOp);
}
@@ -466,13 +466,13 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
if (failed(operandBufferType))
return failure();
rewriter.setInsertionPointAfterValue(operand);
- Value bufferizedOperand = rewriter.create<bufferization::ToBufferOp>(
- operand.getLoc(), *operandBufferType, operand);
+ Value bufferizedOperand = bufferization::ToBufferOp::create(
+ rewriter, operand.getLoc(), *operandBufferType, operand);
// A cast is needed if the operand and the block argument have different
// bufferized types.
if (type != *operandBufferType)
- bufferizedOperand = rewriter.create<memref::CastOp>(
- operand.getLoc(), type, bufferizedOperand);
+ bufferizedOperand = memref::CastOp::create(rewriter, operand.getLoc(),
+ type, bufferizedOperand);
newOperands.push_back(bufferizedOperand);
}
operands.getMutableForwardedOperands().assign(newOperands);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
index c10d290..a50ddbe 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
@@ -118,8 +118,8 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
// Update function calls.
for (func::CallOp callOp : callerMap[funcOp]) {
rewriter.setInsertionPoint(callOp);
- auto newCallOp = rewriter.create<func::CallOp>(callOp.getLoc(), funcOp,
- callOp.getOperands());
+ auto newCallOp = func::CallOp::create(rewriter, callOp.getLoc(), funcOp,
+ callOp.getOperands());
SmallVector<Value> newResults;
int64_t nextResult = 0;
for (int64_t i = 0; i < callOp.getNumResults(); ++i) {
@@ -134,8 +134,8 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
Type expectedType = callOp.getResult(i).getType();
if (replacement.getType() != expectedType) {
// A cast must be inserted at the call site.
- replacement = rewriter.create<memref::CastOp>(
- callOp.getLoc(), expectedType, replacement);
+ replacement = memref::CastOp::create(rewriter, callOp.getLoc(),
+ expectedType, replacement);
}
newResults.push_back(replacement);
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index b7db2e8..1784964 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -168,8 +168,8 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
cast<ShapedType>(v.getType()).getElementType())
continue;
rewriter.setInsertionPointAfterValue(replacement);
- replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
- replacement);
+ replacement = tensor::CastOp::create(rewriter, v.getLoc(), v.getType(),
+ replacement);
}
// Replace the specific use of the tensor::EmptyOp.
rewriter.modifyOpInPlace(user, [&]() {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 2a98203..f69efd1 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -319,8 +319,9 @@ struct CallOpInterface
}
// 3. Create the new CallOp.
- Operation *newCallOp = rewriter.create<func::CallOp>(
- callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands);
+ Operation *newCallOp =
+ func::CallOp::create(rewriter, callOp.getLoc(), funcOp.getSymName(),
+ resultTypes, newOperands);
newCallOp->setAttrs(callOp->getAttrs());
// 4. Replace the old op with the new op.
@@ -483,8 +484,8 @@ struct FuncOpInterface
// Note: If `inferFunctionResultLayout = true`, casts are later folded
// away.
- Value toBufferOp = rewriter.create<bufferization::ToBufferOp>(
- returnOp.getLoc(), bufferizedType, returnVal);
+ Value toBufferOp = bufferization::ToBufferOp::create(
+ rewriter, returnOp.getLoc(), bufferizedType, returnVal);
returnValues.push_back(toBufferOp);
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
index a611126..e9ad13f 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
@@ -64,8 +64,8 @@ class DeallocOpConversion
rewriter.replaceOpWithNewOp<scf::IfOp>(
op, adaptor.getConditions()[0], [&](OpBuilder &builder, Location loc) {
- builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]);
- builder.create<scf::YieldOp>(loc);
+ memref::DeallocOp::create(builder, loc, adaptor.getMemrefs()[0]);
+ scf::YieldOp::create(builder, loc);
});
return success();
}
@@ -108,45 +108,46 @@ class DeallocOpConversion
// Compute the base pointer indices, compare all retained indices to the
// memref index to check if they alias.
SmallVector<Value> doesNotAliasList;
- Value memrefAsIdx = rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(
- op->getLoc(), adaptor.getMemrefs()[0]);
+ Value memrefAsIdx = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, op->getLoc(), adaptor.getMemrefs()[0]);
for (Value retained : adaptor.getRetained()) {
- Value retainedAsIdx =
- rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op->getLoc(),
- retained);
- Value doesNotAlias = rewriter.create<arith::CmpIOp>(
- op->getLoc(), arith::CmpIPredicate::ne, memrefAsIdx, retainedAsIdx);
+ Value retainedAsIdx = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, op->getLoc(), retained);
+ Value doesNotAlias = arith::CmpIOp::create(rewriter, op->getLoc(),
+ arith::CmpIPredicate::ne,
+ memrefAsIdx, retainedAsIdx);
doesNotAliasList.push_back(doesNotAlias);
}
// AND-reduce the list of booleans from above.
Value prev = doesNotAliasList.front();
for (Value doesNotAlias : ArrayRef(doesNotAliasList).drop_front())
- prev = rewriter.create<arith::AndIOp>(op->getLoc(), prev, doesNotAlias);
+ prev = arith::AndIOp::create(rewriter, op->getLoc(), prev, doesNotAlias);
// Also consider the condition given by the dealloc operation and perform a
// conditional deallocation guarded by that value.
- Value shouldDealloc = rewriter.create<arith::AndIOp>(
- op->getLoc(), prev, adaptor.getConditions()[0]);
+ Value shouldDealloc = arith::AndIOp::create(rewriter, op->getLoc(), prev,
+ adaptor.getConditions()[0]);
- rewriter.create<scf::IfOp>(
- op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) {
- builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]);
- builder.create<scf::YieldOp>(loc);
- });
+ scf::IfOp::create(rewriter, op.getLoc(), shouldDealloc,
+ [&](OpBuilder &builder, Location loc) {
+ memref::DeallocOp::create(builder, loc,
+ adaptor.getMemrefs()[0]);
+ scf::YieldOp::create(builder, loc);
+ });
// Compute the replacement values for the dealloc operation results. This
// inserts an already canonicalized form of
// `select(does_alias_with_memref(r), memref_cond, false)` for each retained
// value r.
SmallVector<Value> replacements;
- Value trueVal = rewriter.create<arith::ConstantOp>(
- op->getLoc(), rewriter.getBoolAttr(true));
+ Value trueVal = arith::ConstantOp::create(rewriter, op->getLoc(),
+ rewriter.getBoolAttr(true));
for (Value doesNotAlias : doesNotAliasList) {
Value aliases =
- rewriter.create<arith::XOrIOp>(op->getLoc(), doesNotAlias, trueVal);
- Value result = rewriter.create<arith::AndIOp>(op->getLoc(), aliases,
- adaptor.getConditions()[0]);
+ arith::XOrIOp::create(rewriter, op->getLoc(), doesNotAlias, trueVal);
+ Value result = arith::AndIOp::create(rewriter, op->getLoc(), aliases,
+ adaptor.getConditions()[0]);
replacements.push_back(result);
}
@@ -230,108 +231,112 @@ class DeallocOpConversion
// Without storing them to memrefs, we could not use for-loops but only a
// completely unrolled version of it, potentially leading to code-size
// blow-up.
- Value toDeallocMemref = rewriter.create<memref::AllocOp>(
- op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()},
- rewriter.getIndexType()));
- Value conditionMemref = rewriter.create<memref::AllocOp>(
- op.getLoc(), MemRefType::get({(int64_t)adaptor.getConditions().size()},
- rewriter.getI1Type()));
- Value toRetainMemref = rewriter.create<memref::AllocOp>(
- op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()},
- rewriter.getIndexType()));
+ Value toDeallocMemref = memref::AllocOp::create(
+ rewriter, op.getLoc(),
+ MemRefType::get({(int64_t)adaptor.getMemrefs().size()},
+ rewriter.getIndexType()));
+ Value conditionMemref = memref::AllocOp::create(
+ rewriter, op.getLoc(),
+ MemRefType::get({(int64_t)adaptor.getConditions().size()},
+ rewriter.getI1Type()));
+ Value toRetainMemref = memref::AllocOp::create(
+ rewriter, op.getLoc(),
+ MemRefType::get({(int64_t)adaptor.getRetained().size()},
+ rewriter.getIndexType()));
auto getConstValue = [&](uint64_t value) -> Value {
- return rewriter.create<arith::ConstantOp>(op.getLoc(),
- rewriter.getIndexAttr(value));
+ return arith::ConstantOp::create(rewriter, op.getLoc(),
+ rewriter.getIndexAttr(value));
};
// Extract the base pointers of the memrefs as indices to check for aliasing
// at runtime.
for (auto [i, toDealloc] : llvm::enumerate(adaptor.getMemrefs())) {
- Value memrefAsIdx =
- rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(),
- toDealloc);
- rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx,
- toDeallocMemref, getConstValue(i));
+ Value memrefAsIdx = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, op.getLoc(), toDealloc);
+ memref::StoreOp::create(rewriter, op.getLoc(), memrefAsIdx,
+ toDeallocMemref, getConstValue(i));
}
for (auto [i, cond] : llvm::enumerate(adaptor.getConditions()))
- rewriter.create<memref::StoreOp>(op.getLoc(), cond, conditionMemref,
- getConstValue(i));
+ memref::StoreOp::create(rewriter, op.getLoc(), cond, conditionMemref,
+ getConstValue(i));
for (auto [i, toRetain] : llvm::enumerate(adaptor.getRetained())) {
- Value memrefAsIdx =
- rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(),
- toRetain);
- rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx, toRetainMemref,
- getConstValue(i));
+ Value memrefAsIdx = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, op.getLoc(), toRetain);
+ memref::StoreOp::create(rewriter, op.getLoc(), memrefAsIdx,
+ toRetainMemref, getConstValue(i));
}
// Cast the allocated memrefs to dynamic shape because we want only one
// helper function no matter how many operands the bufferization.dealloc
// has.
- Value castedDeallocMemref = rewriter.create<memref::CastOp>(
- op->getLoc(),
+ Value castedDeallocMemref = memref::CastOp::create(
+ rewriter, op->getLoc(),
MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()),
toDeallocMemref);
- Value castedCondsMemref = rewriter.create<memref::CastOp>(
- op->getLoc(),
+ Value castedCondsMemref = memref::CastOp::create(
+ rewriter, op->getLoc(),
MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
conditionMemref);
- Value castedRetainMemref = rewriter.create<memref::CastOp>(
- op->getLoc(),
+ Value castedRetainMemref = memref::CastOp::create(
+ rewriter, op->getLoc(),
MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()),
toRetainMemref);
- Value deallocCondsMemref = rewriter.create<memref::AllocOp>(
- op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()},
- rewriter.getI1Type()));
- Value retainCondsMemref = rewriter.create<memref::AllocOp>(
- op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()},
- rewriter.getI1Type()));
-
- Value castedDeallocCondsMemref = rewriter.create<memref::CastOp>(
- op->getLoc(),
+ Value deallocCondsMemref = memref::AllocOp::create(
+ rewriter, op.getLoc(),
+ MemRefType::get({(int64_t)adaptor.getMemrefs().size()},
+ rewriter.getI1Type()));
+ Value retainCondsMemref = memref::AllocOp::create(
+ rewriter, op.getLoc(),
+ MemRefType::get({(int64_t)adaptor.getRetained().size()},
+ rewriter.getI1Type()));
+
+ Value castedDeallocCondsMemref = memref::CastOp::create(
+ rewriter, op->getLoc(),
MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
deallocCondsMemref);
- Value castedRetainCondsMemref = rewriter.create<memref::CastOp>(
- op->getLoc(),
+ Value castedRetainCondsMemref = memref::CastOp::create(
+ rewriter, op->getLoc(),
MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
retainCondsMemref);
Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>();
- rewriter.create<func::CallOp>(
- op.getLoc(), deallocHelperFuncMap.lookup(symtableOp),
+ func::CallOp::create(
+ rewriter, op.getLoc(), deallocHelperFuncMap.lookup(symtableOp),
SmallVector<Value>{castedDeallocMemref, castedRetainMemref,
castedCondsMemref, castedDeallocCondsMemref,
castedRetainCondsMemref});
for (unsigned i = 0, e = adaptor.getMemrefs().size(); i < e; ++i) {
Value idxValue = getConstValue(i);
- Value shouldDealloc = rewriter.create<memref::LoadOp>(
- op.getLoc(), deallocCondsMemref, idxValue);
- rewriter.create<scf::IfOp>(
- op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) {
- builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[i]);
- builder.create<scf::YieldOp>(loc);
- });
+ Value shouldDealloc = memref::LoadOp::create(
+ rewriter, op.getLoc(), deallocCondsMemref, idxValue);
+ scf::IfOp::create(rewriter, op.getLoc(), shouldDealloc,
+ [&](OpBuilder &builder, Location loc) {
+ memref::DeallocOp::create(builder, loc,
+ adaptor.getMemrefs()[i]);
+ scf::YieldOp::create(builder, loc);
+ });
}
SmallVector<Value> replacements;
for (unsigned i = 0, e = adaptor.getRetained().size(); i < e; ++i) {
Value idxValue = getConstValue(i);
- Value ownership = rewriter.create<memref::LoadOp>(
- op.getLoc(), retainCondsMemref, idxValue);
+ Value ownership = memref::LoadOp::create(rewriter, op.getLoc(),
+ retainCondsMemref, idxValue);
replacements.push_back(ownership);
}
// Deallocate above allocated memrefs again to avoid memory leaks.
// Deallocation will not be run on code after this stage.
- rewriter.create<memref::DeallocOp>(op.getLoc(), toDeallocMemref);
- rewriter.create<memref::DeallocOp>(op.getLoc(), toRetainMemref);
- rewriter.create<memref::DeallocOp>(op.getLoc(), conditionMemref);
- rewriter.create<memref::DeallocOp>(op.getLoc(), deallocCondsMemref);
- rewriter.create<memref::DeallocOp>(op.getLoc(), retainCondsMemref);
+ memref::DeallocOp::create(rewriter, op.getLoc(), toDeallocMemref);
+ memref::DeallocOp::create(rewriter, op.getLoc(), toRetainMemref);
+ memref::DeallocOp::create(rewriter, op.getLoc(), conditionMemref);
+ memref::DeallocOp::create(rewriter, op.getLoc(), deallocCondsMemref);
+ memref::DeallocOp::create(rewriter, op.getLoc(), retainCondsMemref);
rewriter.replaceOp(op, replacements);
return success();
@@ -349,8 +354,8 @@ public:
ConversionPatternRewriter &rewriter) const override {
// Lower the trivial case.
if (adaptor.getMemrefs().empty()) {
- Value falseVal = rewriter.create<arith::ConstantOp>(
- op.getLoc(), rewriter.getBoolAttr(false));
+ Value falseVal = arith::ConstantOp::create(rewriter, op.getLoc(),
+ rewriter.getBoolAttr(false));
rewriter.replaceOp(
op, SmallVector<Value>(adaptor.getRetained().size(), falseVal));
return success();
@@ -449,93 +454,92 @@ func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction(
Value retainCondsMemref = helperFuncOp.getArguments()[4];
// Insert some prerequisites.
- Value c0 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(0));
- Value c1 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(1));
+ Value c0 = arith::ConstantOp::create(builder, loc, builder.getIndexAttr(0));
+ Value c1 = arith::ConstantOp::create(builder, loc, builder.getIndexAttr(1));
Value trueValue =
- builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(true));
+ arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true));
Value falseValue =
- builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(false));
- Value toDeallocSize = builder.create<memref::DimOp>(loc, toDeallocMemref, c0);
- Value toRetainSize = builder.create<memref::DimOp>(loc, toRetainMemref, c0);
+ arith::ConstantOp::create(builder, loc, builder.getBoolAttr(false));
+ Value toDeallocSize =
+ memref::DimOp::create(builder, loc, toDeallocMemref, c0);
+ Value toRetainSize = memref::DimOp::create(builder, loc, toRetainMemref, c0);
- builder.create<scf::ForOp>(
- loc, c0, toRetainSize, c1, ValueRange(),
+ scf::ForOp::create(
+ builder, loc, c0, toRetainSize, c1, ValueRange(),
[&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) {
- builder.create<memref::StoreOp>(loc, falseValue, retainCondsMemref, i);
- builder.create<scf::YieldOp>(loc);
+ memref::StoreOp::create(builder, loc, falseValue, retainCondsMemref, i);
+ scf::YieldOp::create(builder, loc);
});
- builder.create<scf::ForOp>(
- loc, c0, toDeallocSize, c1, ValueRange(),
+ scf::ForOp::create(
+ builder, loc, c0, toDeallocSize, c1, ValueRange(),
[&](OpBuilder &builder, Location loc, Value outerIter,
ValueRange iterArgs) {
Value toDealloc =
- builder.create<memref::LoadOp>(loc, toDeallocMemref, outerIter);
+ memref::LoadOp::create(builder, loc, toDeallocMemref, outerIter);
Value cond =
- builder.create<memref::LoadOp>(loc, conditionMemref, outerIter);
+ memref::LoadOp::create(builder, loc, conditionMemref, outerIter);
// Build the first for loop that computes aliasing with retained
// memrefs.
- Value noRetainAlias =
- builder
- .create<scf::ForOp>(
- loc, c0, toRetainSize, c1, trueValue,
+ Value
+ noRetainAlias =
+ scf::ForOp::create(
+ builder, loc, c0, toRetainSize, c1, trueValue,
[&](OpBuilder &builder, Location loc, Value i,
ValueRange iterArgs) {
- Value retainValue = builder.create<memref::LoadOp>(
- loc, toRetainMemref, i);
- Value doesAlias = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, retainValue,
+ Value retainValue = memref::LoadOp::create(
+ builder, loc, toRetainMemref, i);
+ Value doesAlias = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::eq, retainValue,
toDealloc);
- builder.create<scf::IfOp>(
- loc, doesAlias,
+ scf::IfOp::create(
+ builder, loc, doesAlias,
[&](OpBuilder &builder, Location loc) {
- Value retainCondValue =
- builder.create<memref::LoadOp>(
- loc, retainCondsMemref, i);
- Value aggregatedRetainCond =
- builder.create<arith::OrIOp>(
- loc, retainCondValue, cond);
- builder.create<memref::StoreOp>(
- loc, aggregatedRetainCond, retainCondsMemref,
- i);
- builder.create<scf::YieldOp>(loc);
+ Value retainCondValue = memref::LoadOp::create(
+ builder, loc, retainCondsMemref, i);
+ Value aggregatedRetainCond = arith::OrIOp::create(
+ builder, loc, retainCondValue, cond);
+ memref::StoreOp::create(builder, loc,
+ aggregatedRetainCond,
+ retainCondsMemref, i);
+ scf::YieldOp::create(builder, loc);
});
- Value doesntAlias = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ne, retainValue,
+ Value doesntAlias = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::ne, retainValue,
toDealloc);
- Value yieldValue = builder.create<arith::AndIOp>(
- loc, iterArgs[0], doesntAlias);
- builder.create<scf::YieldOp>(loc, yieldValue);
+ Value yieldValue = arith::AndIOp::create(
+ builder, loc, iterArgs[0], doesntAlias);
+ scf::YieldOp::create(builder, loc, yieldValue);
})
- .getResult(0);
+ .getResult(0);
// Build the second for loop that adds aliasing with previously
// deallocated memrefs.
- Value noAlias =
- builder
- .create<scf::ForOp>(
- loc, c0, outerIter, c1, noRetainAlias,
+ Value
+ noAlias =
+ scf::ForOp::create(
+ builder, loc, c0, outerIter, c1, noRetainAlias,
[&](OpBuilder &builder, Location loc, Value i,
ValueRange iterArgs) {
- Value prevDeallocValue = builder.create<memref::LoadOp>(
- loc, toDeallocMemref, i);
- Value doesntAlias = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ne, prevDeallocValue,
- toDealloc);
- Value yieldValue = builder.create<arith::AndIOp>(
- loc, iterArgs[0], doesntAlias);
- builder.create<scf::YieldOp>(loc, yieldValue);
+ Value prevDeallocValue = memref::LoadOp::create(
+ builder, loc, toDeallocMemref, i);
+ Value doesntAlias = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::ne,
+ prevDeallocValue, toDealloc);
+ Value yieldValue = arith::AndIOp::create(
+ builder, loc, iterArgs[0], doesntAlias);
+ scf::YieldOp::create(builder, loc, yieldValue);
})
- .getResult(0);
+ .getResult(0);
- Value shouldDealoc = builder.create<arith::AndIOp>(loc, noAlias, cond);
- builder.create<memref::StoreOp>(loc, shouldDealoc, deallocCondsMemref,
- outerIter);
- builder.create<scf::YieldOp>(loc);
+ Value shouldDealoc = arith::AndIOp::create(builder, loc, noAlias, cond);
+ memref::StoreOp::create(builder, loc, shouldDealoc, deallocCondsMemref,
+ outerIter);
+ scf::YieldOp::create(builder, loc);
});
- builder.create<func::ReturnOp>(loc);
+ func::ReturnOp::create(builder, loc);
return helperFuncOp;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index d1d1062..aa53f94 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -1,4 +1,5 @@
-//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
+//===- OneShotModuleBufferize.cpp - Bufferization across Func. Boundaries
+//----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -8,12 +9,13 @@
//
// Module Bufferization is an extension of One-Shot Bufferize that
// bufferizes function boundaries. It provides `BufferizableOpInterface`
-// implementations for FuncOp, CallOp and ReturnOp.
+// implementations for FuncOp, CallOp and ReturnOp. Although it is named
+// Module Bufferization, it may operate on any SymbolTable.
//
-// Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`.
-// This function analyzes the given module and determines the order of analysis
-// and bufferization: Functions that are called are processed before their
-// respective callers.
+// Module Bufferization is run via `runOneShotModuleBufferize(SymbolTableOp,
+// ...)`. This function analyzes the given op and determines the order of
+// analysis and bufferization: Functions that are called are processed before
+// their respective callers.
//
// After analyzing a FuncOp, additional information about its bbArgs is
// gathered and stored in `FuncAnalysisState`.
@@ -309,7 +311,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
/// Return `failure()` if we are unable to retrieve the called FuncOp from
/// any func::CallOp.
static LogicalResult getFuncOpsOrderedByCalls(
- ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
+ Operation *moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
SymbolTableCollection &symbolTables) {
// For each FuncOp, the set of functions called by it (i.e. the union of
@@ -317,26 +319,29 @@ static LogicalResult getFuncOpsOrderedByCalls(
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
-
- for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
- // Collect function calls and populate the caller map.
- numberCallOpsContainedInFuncOp[funcOp] = 0;
- WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
- func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
- assert(calledFunction && "could not retrieved called func::FuncOp");
- // If the called function does not have any tensors in its signature, then
- // it is not necessary to bufferize the callee before the caller.
- if (!hasTensorSignature(calledFunction))
- return WalkResult::skip();
-
- callerMap[calledFunction].insert(callOp);
- if (calledBy[calledFunction].insert(funcOp).second) {
- numberCallOpsContainedInFuncOp[funcOp]++;
+ for (mlir::Region &region : moduleOp->getRegions()) {
+ for (mlir::Block &block : region.getBlocks()) {
+ for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
+ // Collect function calls and populate the caller map.
+ numberCallOpsContainedInFuncOp[funcOp] = 0;
+ WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
+ func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
+ assert(calledFunction && "could not retrieved called func::FuncOp");
+ // If the called function does not have any tensors in its signature,
+ // then it is not necessary to bufferize the callee before the caller.
+ if (!hasTensorSignature(calledFunction))
+ return WalkResult::skip();
+
+ callerMap[calledFunction].insert(callOp);
+ if (calledBy[calledFunction].insert(funcOp).second) {
+ numberCallOpsContainedInFuncOp[funcOp]++;
+ }
+ return WalkResult::advance();
+ });
+ if (res.wasInterrupted())
+ return failure();
}
- return WalkResult::advance();
- });
- if (res.wasInterrupted())
- return failure();
+ }
}
// Iteratively remove function operations that do not call any of the
@@ -447,7 +452,7 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
}
LogicalResult
-mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
+mlir::bufferization::analyzeModuleOp(Operation *moduleOp,
OneShotAnalysisState &state,
BufferizationStatistics *statistics) {
assert(state.getOptions().bufferizeFunctionBoundaries &&
@@ -512,19 +517,23 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
}
void mlir::bufferization::removeBufferizationAttributesInModule(
- ModuleOp moduleOp) {
- for (auto op : moduleOp.getOps<func::FuncOp>()) {
- for (BlockArgument bbArg : op.getArguments())
- removeBufferizationAttributes(bbArg);
+ Operation *moduleOp) {
+ for (mlir::Region &region : moduleOp->getRegions()) {
+ for (mlir::Block &block : region.getBlocks()) {
+ for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
+ for (BlockArgument bbArg : funcOp.getArguments())
+ removeBufferizationAttributes(bbArg);
+ }
+ }
}
}
LogicalResult mlir::bufferization::bufferizeModuleOp(
- ModuleOp moduleOp, const OneShotBufferizationOptions &options,
+ Operation *moduleOp, const OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics) {
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
- IRRewriter rewriter(moduleOp.getContext());
+ IRRewriter rewriter(moduleOp->getContext());
// A list of non-circular functions in the order in which they are analyzed
// and bufferized.
@@ -571,12 +580,17 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
}
// Bufferize all other ops.
- for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
- // Functions were already bufferized.
- if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
- continue;
- if (failed(bufferizeOp(&op, options, state, statistics)))
- return failure();
+ for (mlir::Region &region : moduleOp->getRegions()) {
+ for (mlir::Block &block : region.getBlocks()) {
+ for (mlir::Operation &op :
+ llvm::make_early_inc_range(block.getOperations())) {
+ // Functions were already bufferized.
+ if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
+ continue;
+ if (failed(bufferizeOp(&op, options, state, statistics)))
+ return failure();
+ }
+ }
}
// Post-pass cleanup of function argument attributes.
@@ -586,7 +600,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
}
LogicalResult mlir::bufferization::runOneShotModuleBufferize(
- ModuleOp moduleOp, const OneShotBufferizationOptions &options,
+ Operation *moduleOp, const OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics) {
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp
index 605a487..b8ddee6 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp
@@ -18,11 +18,9 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "optimize-allocation-liveness"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
namespace mlir {
namespace bufferization {
@@ -65,8 +63,8 @@ Operation *findUserWithFreeSideEffect(Value value) {
for (const auto &effect : effects) {
if (isa<MemoryEffects::Free>(effect.getEffect())) {
if (freeOpUser) {
- LDBG("Multiple users with free effect found: " << *freeOpUser
- << " and " << *user);
+ LDBG() << "Multiple users with free effect found: " << *freeOpUser
+ << " and " << *user;
return nullptr;
}
freeOpUser = user;
@@ -121,7 +119,7 @@ public:
return WalkResult::advance();
auto allocOp = memEffectOp;
- LDBG("Checking alloc op: " << allocOp);
+ LDBG() << "Checking alloc op: " << allocOp;
SmallVector<OpResult> allocationResults = collectAllocations(allocOp);
// Multiple allocations from a single op are not considered here yet.
@@ -129,7 +127,7 @@ public:
return WalkResult::advance();
OpResult allocResult = allocationResults[0];
- LDBG("On allocation result: " << allocResult);
+ LDBG() << "On allocation result: " << allocResult;
auto *deallocOp = findUserWithFreeSideEffect(allocResult);
if (!deallocOp || (deallocOp->getBlock() != allocOp->getBlock())) {
@@ -159,12 +157,12 @@ public:
if (lastUser == nullptr) {
return WalkResult::advance();
}
- LDBG("Last user found: " << *lastUser);
+ LDBG() << "Last user found: " << *lastUser;
assert(lastUser->getBlock() == allocOp->getBlock());
assert(lastUser->getBlock() == deallocOp->getBlock());
// Move the dealloc op after the last user.
deallocOp->moveAfter(lastUser);
- LDBG("Moved dealloc op after: " << *lastUser);
+ LDBG() << "Moved dealloc op after: " << *lastUser;
return WalkResult::advance();
});
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index 1eeafc4..725fa24 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -43,7 +43,7 @@ using namespace mlir::bufferization;
//===----------------------------------------------------------------------===//
static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
- return builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(value));
+ return arith::ConstantOp::create(builder, loc, builder.getBoolAttr(value));
}
static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
@@ -750,19 +750,17 @@ Value BufferDeallocation::materializeMemrefWithGuaranteedOwnership(
// Insert a runtime check and only clone if we still don't have ownership at
// runtime.
- Value maybeClone =
- builder
- .create<scf::IfOp>(
- memref.getLoc(), condition,
- [&](OpBuilder &builder, Location loc) {
- builder.create<scf::YieldOp>(loc, newMemref);
- },
- [&](OpBuilder &builder, Location loc) {
- Value clone =
- builder.create<bufferization::CloneOp>(loc, newMemref);
- builder.create<scf::YieldOp>(loc, clone);
- })
- .getResult(0);
+ Value maybeClone = scf::IfOp::create(
+ builder, memref.getLoc(), condition,
+ [&](OpBuilder &builder, Location loc) {
+ scf::YieldOp::create(builder, loc, newMemref);
+ },
+ [&](OpBuilder &builder, Location loc) {
+ Value clone = bufferization::CloneOp::create(
+ builder, loc, newMemref);
+ scf::YieldOp::create(builder, loc, clone);
+ })
+ .getResult(0);
Value trueVal = buildBoolValue(builder, memref.getLoc(), true);
state.updateOwnership(maybeClone, trueVal);
state.addMemrefToDeallocate(maybeClone, maybeClone.getParentBlock());
@@ -797,8 +795,8 @@ BufferDeallocation::handleInterface(BranchOpInterface op) {
state.getMemrefsToRetain(block, op->getSuccessor(0), forwardedOperands,
toRetain);
- auto deallocOp = builder.create<bufferization::DeallocOp>(
- op.getLoc(), memrefs, conditions, toRetain);
+ auto deallocOp = bufferization::DeallocOp::create(
+ builder, op.getLoc(), memrefs, conditions, toRetain);
// We want to replace the current ownership of the retained values with the
// result values of the dealloc operation as they are always unique.
@@ -885,12 +883,11 @@ BufferDeallocation::handleInterface(MemoryEffectOpInterface op) {
builder.setInsertionPoint(op);
Ownership ownership = state.getOwnership(operand, block);
if (ownership.isUnique()) {
- Value ownershipInverted = builder.create<arith::XOrIOp>(
- op.getLoc(), ownership.getIndicator(),
+ Value ownershipInverted = arith::XOrIOp::create(
+ builder, op.getLoc(), ownership.getIndicator(),
buildBoolValue(builder, op.getLoc(), true));
- builder.create<cf::AssertOp>(
- op.getLoc(), ownershipInverted,
- "expected that the block does not have ownership");
+ cf::AssertOp::create(builder, op.getLoc(), ownershipInverted,
+ "expected that the block does not have ownership");
}
}
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
index f999c93..a6159ee 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
@@ -33,7 +33,7 @@ LogicalResult mlir::bufferization::insertTensorCopies(
// analysis depending on whether function boundary bufferization is enabled or
// not.
if (options.bufferizeFunctionBoundaries) {
- if (failed(analyzeModuleOp(cast<ModuleOp>(op), analysisState, statistics)))
+ if (failed(analyzeModuleOp(op, analysisState, statistics)))
return failure();
} else {
if (failed(analyzeOp(op, analysisState, statistics)))
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index 3cc52eb..053ee95 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -19,7 +19,7 @@ add_subdirectory(Linalg)
add_subdirectory(LLVMIR)
add_subdirectory(Math)
add_subdirectory(MemRef)
-add_subdirectory(Mesh)
+add_subdirectory(Shard)
add_subdirectory(MLProgram)
add_subdirectory(MPI)
add_subdirectory(NVGPU)
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 568da89..4c09022 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -171,10 +171,9 @@ static LogicalResult verifyInitializationAttribute(Operation *op,
/// In the format string, all `{}` are replaced by Placeholders, except if the
/// `{` is escaped by `{{` - then it doesn't start a placeholder.
template <class ArgType>
-FailureOr<SmallVector<ReplacementItem>>
-parseFormatString(StringRef toParse, ArgType fmtArgs,
- std::optional<llvm::function_ref<mlir::InFlightDiagnostic()>>
- emitError = {}) {
+FailureOr<SmallVector<ReplacementItem>> parseFormatString(
+ StringRef toParse, ArgType fmtArgs,
+ llvm::function_ref<mlir::InFlightDiagnostic()> emitError = {}) {
SmallVector<ReplacementItem> items;
// If there are not operands, the format string is not interpreted.
@@ -197,8 +196,7 @@ parseFormatString(StringRef toParse, ArgType fmtArgs,
continue;
}
if (toParse.size() < 2) {
- return (*emitError)()
- << "expected '}' after unescaped '{' at end of string";
+ return emitError() << "expected '}' after unescaped '{' at end of string";
}
// toParse contains at least two characters and starts with `{`.
char nextChar = toParse[1];
@@ -214,8 +212,8 @@ parseFormatString(StringRef toParse, ArgType fmtArgs,
continue;
}
- if (emitError.has_value()) {
- return (*emitError)() << "expected '}' after unescaped '{'";
+ if (emitError) {
+ return emitError() << "expected '}' after unescaped '{'";
}
return failure();
}
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index d5fe3b4..3f0690c 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -62,9 +62,7 @@ struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
continue;
for (Value operand : op.getOperands()) {
- auto usedExpression =
- dyn_cast_if_present<ExpressionOp>(operand.getDefiningOp());
-
+ auto usedExpression = operand.getDefiningOp<ExpressionOp>();
if (!usedExpression)
continue;
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index 612e809..fa05ad8 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -31,7 +31,7 @@ struct WrapFuncInClassPass
Operation *rootOp = getOperation();
RewritePatternSet patterns(&getContext());
- populateFuncPatterns(patterns, namedAttribute);
+ populateFuncPatterns(patterns);
walkAndApplyPatterns(rootOp, std::move(patterns));
}
@@ -43,8 +43,8 @@ struct WrapFuncInClassPass
class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
public:
- WrapFuncInClass(MLIRContext *context, StringRef attrName)
- : OpRewritePattern<emitc::FuncOp>(context), attributeName(attrName) {}
+ WrapFuncInClass(MLIRContext *context)
+ : OpRewritePattern<emitc::FuncOp>(context) {}
LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
PatternRewriter &rewriter) const override {
@@ -101,12 +101,8 @@ public:
rewriter.replaceOp(funcOp, newClassOp);
return success();
}
-
-private:
- StringRef attributeName;
};
-void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns,
- StringRef namedAttribute) {
- patterns.add<WrapFuncInClass>(patterns.getContext(), namedAttribute);
+void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns) {
+ patterns.add<WrapFuncInClass>(patterns.getContext());
}
diff --git a/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp b/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp
index eb6b59b..1b18ef2 100644
--- a/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp
+++ b/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp
@@ -8,7 +8,7 @@
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
-#include "mlir/Dialect/Func/Extensions/MeshShardingExtensions.h"
+#include "mlir/Dialect/Func/Extensions/ShardingExtensions.h"
using namespace mlir;
diff --git a/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt b/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt
index 47363f4..87ef51e 100644
--- a/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt
@@ -1,7 +1,7 @@
set(LLVM_OPTIONAL_SOURCES
AllExtensions.cpp
InlinerExtension.cpp
- MeshShardingExtensions.cpp
+ ShardingExtensions.cpp
)
add_mlir_extension_library(MLIRFuncInlinerExtension
@@ -17,8 +17,8 @@ add_mlir_extension_library(MLIRFuncInlinerExtension
MLIRFuncDialect
)
-add_mlir_extension_library(MLIRFuncMeshShardingExtensions
- MeshShardingExtensions.cpp
+add_mlir_extension_library(MLIRFuncShardingExtensions
+ ShardingExtensions.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Extensions
@@ -38,5 +38,5 @@ add_mlir_extension_library(MLIRFuncAllExtensions
LINK_LIBS PUBLIC
MLIRFuncInlinerExtension
- MLIRFuncMeshShardingExtensions
+ MLIRFuncShardingExtensions
)
diff --git a/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Func/Extensions/ShardingExtensions.cpp
index da508cc..dfd1348 100644
--- a/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Func/Extensions/ShardingExtensions.cpp
@@ -1,4 +1,4 @@
-//===- MeshShardingExtensions.cpp - ---------------------------------------===//
+//===- ShardingExtensions.cpp - ---------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,9 +6,9 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Func/Extensions/MeshShardingExtensions.h"
+#include "mlir/Dialect/Func/Extensions/ShardingExtensions.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
#include "mlir/IR/MLIRContext.h"
namespace mlir::func {
@@ -16,7 +16,7 @@ namespace mlir::func {
void registerShardingInterfaceExternalModels(DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, FuncDialect *dialect) {
ReturnOp::attachInterface<
- mesh::IndependentParallelIteratorDomainShardingInterface<ReturnOp>>(
+ shard::IndependentParallelIteratorDomainShardingInterface<ReturnOp>>(
*ctx);
});
}
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index d186a48..5a72ef1 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1395,40 +1395,12 @@ void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
// RotateOp
//===----------------------------------------------------------------------===//
-void RotateOp::build(OpBuilder &builder, OperationState &result, Value value,
- int32_t offset, int32_t width) {
- build(builder, result, value,
- arith::ConstantOp::create(builder, result.location,
- builder.getI32IntegerAttr(offset)),
- arith::ConstantOp::create(builder, result.location,
- builder.getI32IntegerAttr(width)));
-}
-
LogicalResult RotateOp::verify() {
- auto offsetConstOp = getOffset().getDefiningOp<arith::ConstantOp>();
- if (!offsetConstOp)
- return emitOpError() << "offset is not a constant value";
-
- auto offsetIntAttr =
- llvm::dyn_cast<mlir::IntegerAttr>(offsetConstOp.getValue());
-
- auto widthConstOp = getWidth().getDefiningOp<arith::ConstantOp>();
- if (!widthConstOp)
- return emitOpError() << "width is not a constant value";
-
- auto widthIntAttr =
- llvm::dyn_cast<mlir::IntegerAttr>(widthConstOp.getValue());
-
- llvm::APInt offsetValue = offsetIntAttr.getValue();
- llvm::APInt widthValue = widthIntAttr.getValue();
-
- if (!widthValue.isPowerOf2())
- return emitOpError() << "width must be a power of two";
+ uint32_t offset = getOffset();
+ uint32_t width = getWidth();
- if (offsetValue.sge(widthValue) || offsetValue.slt(0)) {
- int64_t widthValueInt = widthValue.getSExtValue();
- return emitOpError() << "offset must be in the range [0, " << widthValueInt
- << ")";
+ if (offset >= width) {
+ return emitOpError() << "offset must be in the range [0, " << width << ")";
}
return success();
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 1d8279c..21cb2f6 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -39,7 +39,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/InterleavedRange.h"
#include "llvm/Support/LogicalResult.h"
@@ -51,11 +51,6 @@ using namespace mlir::transform;
using namespace mlir::transform::gpu;
#define DEBUG_TYPE "gpu-transforms"
-#define DEBUG_TYPE_ALIAS "gpu-transforms-alias"
-
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
-#define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
//===----------------------------------------------------------------------===//
// Apply...ConversionPatternsOp
@@ -471,7 +466,7 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes,
ForallRewriteResult &result, const GpuIdBuilder &gpuIdBuilder) {
- LDBG("--start rewriteOneForallCommonImpl");
+ LDBG() << "--start rewriteOneForallCommonImpl";
// Step 1. Complete the mapping to a full mapping (with 1s) if necessary.
auto numParallelIterations =
@@ -506,14 +501,14 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
// Otherwise, we have a new insertion without a size -> use size 1.
tmpMappingSizes.push_back(1);
}
- LDBG("----tmpMappingSizes extracted from scf.forall op: "
- << llvm::interleaved(tmpMappingSizes));
+ LDBG() << "----tmpMappingSizes extracted from scf.forall op: "
+ << llvm::interleaved(tmpMappingSizes);
// Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
SmallVector<int64_t> forallMappingSizes = getValuesSortedByKey(
forallMappingAttrs.getArrayRef(), tmpMappingSizes, comparator);
- LDBG("----forallMappingSizes: " << llvm::interleaved(forallMappingSizes));
- LDBG("----forallMappingAttrs: " << llvm::interleaved(forallMappingAttrs));
+ LDBG() << "----forallMappingSizes: " << llvm::interleaved(forallMappingSizes);
+ LDBG() << "----forallMappingAttrs: " << llvm::interleaved(forallMappingAttrs);
// Step 3. Generate the mappingIdOps using the provided generator.
Location loc = forallOp.getLoc();
@@ -522,24 +517,24 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
SmallVector<int64_t> originalBasis(availableMappingSizes);
bool originalBasisWasProvided = !originalBasis.empty();
if (!originalBasisWasProvided) {
- LDBG("----originalBasis was not provided, deriving it and there will be no "
- "predication");
+ LDBG() << "----originalBasis was not provided, deriving it and there will "
+ "be no "
+ "predication";
originalBasis = forallMappingSizes;
while (originalBasis.size() < 3)
originalBasis.push_back(1);
} else {
- LDBG("----originalBasis was provided, using it, there will be predication");
+ LDBG() << "----originalBasis was provided, using it, there will be "
+ "predication";
}
- LLVM_DEBUG(
- llvm::interleaveComma(originalBasis, DBGS() << "------originalBasis: ");
- llvm::dbgs() << "\n");
+ LDBG() << "------originalBasis: " << llvm::interleaved(originalBasis);
IdBuilderResult builderResult =
gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis);
if (!builderResult.errorMsg.empty())
return definiteFailureHelper(transformOp, forallOp, builderResult.errorMsg);
- LLVM_DEBUG(DBGS() << builderResult);
+ LDBG() << builderResult;
// Step 4. Map the induction variables to the mappingIdOps, this may involve
// a permutation.
@@ -550,7 +545,7 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) {
auto mappingAttr = cast<DeviceMappingAttrInterface>(dim);
Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()];
- LDBG("----map: " << iv << " to " << peIdOp);
+ LDBG() << "----map: " << iv << " to " << peIdOp;
bvm.map(iv, peIdOp);
}
@@ -596,9 +591,9 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
// Step 8. Erase old op.
rewriter.eraseOp(forallOp);
- LDBG("----result forallMappingSizes: "
- << llvm::interleaved(forallMappingSizes));
- LDBG("----result mappingIdOps: " << llvm::interleaved(mappingIdOps));
+ LDBG() << "----result forallMappingSizes: "
+ << llvm::interleaved(forallMappingSizes);
+ LDBG() << "----result mappingIdOps: " << llvm::interleaved(mappingIdOps);
result = ForallRewriteResult{forallMappingSizes, mappingIdOps};
return DiagnosedSilenceableFailure::success();
@@ -612,7 +607,7 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl(
RewriterBase &rewriter, TransformOpInterface transformOp,
scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims,
const GpuIdBuilder &gpuIdBuilder) {
- LDBG("Start mapForallToBlocksImpl");
+ LDBG() << "Start mapForallToBlocksImpl";
{
// GPU-specific verifications. There is no better place to anchor
@@ -893,7 +888,7 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl(
RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
Operation *target, ArrayRef<int64_t> blockDims, int64_t warpSize,
bool syncAfterDistribute) {
- LDBG("Start mapNestedForallToThreadsImpl");
+ LDBG() << "Start mapNestedForallToThreadsImpl";
if (blockDims.size() != 3) {
return definiteFailureHelper(transformOp, target,
"requires size-3 thread mapping");
diff --git a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
index 2fba09b..05bd917 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
@@ -27,7 +27,8 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/InterleavedRange.h"
using namespace mlir;
using namespace mlir::gpu;
@@ -36,10 +37,6 @@ using namespace mlir::transform::gpu;
#define DEBUG_TYPE "gpu-transforms"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
-#define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
-
/// Build predicates to filter execution by only the activeIds. Along each
/// dimension, 3 cases appear:
/// 1. activeMappingSize > availableMappingSize: this is an unsupported case
@@ -54,15 +51,9 @@ buildPredicates(RewriterBase &rewriter, Location loc, ArrayRef<Value> activeIds,
ArrayRef<int64_t> activeMappingSizes,
ArrayRef<int64_t> availableMappingSizes,
std::string &errorMsg) {
- // clang-format off
- LLVM_DEBUG(
- llvm::interleaveComma(
- activeMappingSizes, DBGS() << "----activeMappingSizes: ");
- DBGS() << "\n";
- llvm::interleaveComma(
- availableMappingSizes, DBGS() << "----availableMappingSizes: ");
- DBGS() << "\n";);
- // clang-format on
+ LDBG() << "----activeMappingSizes: " << llvm::interleaved(activeMappingSizes);
+ LDBG() << "----availableMappingSizes: "
+ << llvm::interleaved(availableMappingSizes);
SmallVector<Value> predicateOps;
for (auto [activeId, activeMappingSize, availableMappingSize] :
@@ -88,10 +79,8 @@ buildPredicates(RewriterBase &rewriter, Location loc, ArrayRef<Value> activeIds,
template <typename ThreadOrBlockIdOp>
static Value buildLinearId(RewriterBase &rewriter, Location loc,
ArrayRef<OpFoldResult> originalBasisOfr) {
- LLVM_DEBUG(llvm::interleaveComma(
- originalBasisOfr,
- DBGS() << "----buildLinearId with originalBasisOfr: ");
- llvm::dbgs() << "\n");
+ LDBG() << "----buildLinearId with originalBasisOfr: "
+ << llvm::interleaved(originalBasisOfr);
assert(originalBasisOfr.size() == 3 && "expected 3 sizes");
IndexType indexType = rewriter.getIndexType();
AffineExpr tx, ty, tz, bdx, bdy;
@@ -157,7 +146,7 @@ commonLinearIdBuilderFn(int64_t multiplicity = 1,
mask.createLogicalLinearMappingId(rewriter, scaledLinearIdI64);
scaledLinearId = arith::IndexCastUIOp::create(
rewriter, loc, rewriter.getIndexType(), logicalLinearIdI64);
- LDBG("------adjusting linearId with mask: " << scaledLinearId);
+ LDBG() << "------adjusting linearId with mask: " << scaledLinearId;
}
// 3. Compute remapped indices.
@@ -179,7 +168,7 @@ commonLinearIdBuilderFn(int64_t multiplicity = 1,
if (mask) {
Value isActiveIdPredicate =
mask.createIsActiveIdPredicate(rewriter, scaledLinearIdI64);
- LDBG("------adjusting predicate with mask: " << isActiveIdPredicate);
+ LDBG() << "------adjusting predicate with mask: " << isActiveIdPredicate;
predicateOps.push_back(isActiveIdPredicate);
} else {
// 4.b. Otherwise, handle predicates using physicalLinearId.
diff --git a/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp
index d88f4d5..8e05436 100644
--- a/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp
@@ -60,14 +60,12 @@ struct GpuShuffleRewriter : public OpRewritePattern<gpu::ShuffleOp> {
// Shuffle the values.
ValueRange loRes =
- rewriter
- .create<gpu::ShuffleOp>(op.getLoc(), lo, op.getOffset(),
- op.getWidth(), op.getMode())
+ gpu::ShuffleOp::create(rewriter, op.getLoc(), lo, op.getOffset(),
+ op.getWidth(), op.getMode())
.getResults();
ValueRange hiRes =
- rewriter
- .create<gpu::ShuffleOp>(op.getLoc(), hi, op.getOffset(),
- op.getWidth(), op.getMode())
+ gpu::ShuffleOp::create(rewriter, op.getLoc(), hi, op.getOffset(),
+ op.getWidth(), op.getMode())
.getResults();
// Convert lo back to i64.
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
index b9e2dd5..b45fdf3 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -197,10 +197,9 @@ Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
// Parallel reduction using butterfly shuffles.
for (unsigned i = ci.clusterStride; i < ci.clusterStride * ci.clusterSize;
i <<= 1) {
- Value shuffled = builder
- .create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
- /*width=*/ci.subgroupSize,
- /*mode=*/gpu::ShuffleMode::XOR)
+ Value shuffled = gpu::ShuffleOp::create(builder, loc, packFn(laneVal), i,
+ /*width=*/ci.subgroupSize,
+ /*mode=*/gpu::ShuffleMode::XOR)
.getShuffleResult();
laneVal = vector::makeArithReduction(builder, loc,
gpu::convertReductionKind(mode),
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index d987b72..ff55f17 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -21,10 +21,7 @@ add_mlir_dialect_library(MLIRLLVMDialect
intrinsics_gen
LINK_COMPONENTS
- AsmParser
BinaryFormat
- BitReader
- BitWriter
Core
LINK_LIBS PUBLIC
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 5b01596..422039f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -26,8 +26,7 @@
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/IR/Function.h"
-#include "llvm/IR/Type.h"
+#include "llvm/IR/DataLayout.h"
#include "llvm/Support/Error.h"
#include <numeric>
@@ -2707,7 +2706,7 @@ LogicalResult IFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
while (alias) {
Block &initBlock = alias.getInitializerBlock();
auto returnOp = cast<ReturnOp>(initBlock.getTerminator());
- auto addrOp = dyn_cast<AddressOfOp>(returnOp.getArg().getDefiningOp());
+ auto addrOp = returnOp.getArg().getDefiningOp<AddressOfOp>();
// FIXME: This is a best effort solution. The AliasOp body might be more
// complex and in that case we bail out with success. To completely match
// the LLVM IR logic it would be necessary to implement proper alias and
@@ -4064,28 +4063,9 @@ void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
}
void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
- Value cond,
- ArrayRef<llvm::OperandBundleDefT<Value>> opBundles) {
- SmallVector<ValueRange> opBundleOperands;
- SmallVector<Attribute> opBundleTags;
- opBundleOperands.reserve(opBundles.size());
- opBundleTags.reserve(opBundles.size());
-
- for (const llvm::OperandBundleDefT<Value> &bundle : opBundles) {
- opBundleOperands.emplace_back(bundle.inputs());
- opBundleTags.push_back(
- StringAttr::get(builder.getContext(), bundle.getTag()));
- }
-
- auto opBundleTagsAttr = ArrayAttr::get(builder.getContext(), opBundleTags);
- return build(builder, state, cond, opBundleOperands, opBundleTagsAttr);
-}
-
-void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
Value cond, llvm::StringRef tag, ValueRange args) {
- llvm::OperandBundleDefT<Value> opBundle(
- tag.str(), SmallVector<Value>(args.begin(), args.end()));
- return build(builder, state, cond, opBundle);
+ return build(builder, state, cond, ArrayRef<ValueRange>(args),
+ builder.getStrArrayAttr(tag));
}
void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 6e29b12..52cd0ce 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -30,15 +30,9 @@
#include "mlir/IR/Types.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/AsmParser/Parser.h"
-#include "llvm/IR/Attributes.h"
-#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/IntrinsicsNVPTX.h"
-#include "llvm/IR/Type.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <optional>
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 1a9ccf5..17371ec 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -24,7 +24,6 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/IR/Type.h"
using namespace mlir;
using namespace ROCDL;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/VCIXDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/VCIXDialect.cpp
index bd9d3528..1d4a0af 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/VCIXDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/VCIXDialect.cpp
@@ -20,11 +20,6 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/AsmParser/Parser.h"
-#include "llvm/IR/Attributes.h"
-#include "llvm/IR/Function.h"
-#include "llvm/IR/Type.h"
-#include "llvm/Support/SourceMgr.h"
using namespace mlir;
using namespace vcix;
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
index 935aa3c..b951df8 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
@@ -22,6 +22,8 @@
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
+
#define DEBUG_TYPE "llvm-inliner"
using namespace mlir;
@@ -670,44 +672,42 @@ struct LLVMInlinerInterface : public DialectInlinerInterface {
bool wouldBeCloned) const final {
auto callOp = dyn_cast<LLVM::CallOp>(call);
if (!callOp) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot inline: call is not an '"
- << LLVM::CallOp::getOperationName() << "' op\n");
+ LDBG() << "Cannot inline: call is not an '"
+ << LLVM::CallOp::getOperationName() << "' op";
return false;
}
if (callOp.getNoInline()) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot inline: call is marked no_inline\n");
+ LDBG() << "Cannot inline: call is marked no_inline";
return false;
}
auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(callable);
if (!funcOp) {
- LLVM_DEBUG(llvm::dbgs()
- << "Cannot inline: callable is not an '"
- << LLVM::LLVMFuncOp::getOperationName() << "' op\n");
+ LDBG() << "Cannot inline: callable is not an '"
+ << LLVM::LLVMFuncOp::getOperationName() << "' op";
return false;
}
if (funcOp.isNoInline()) {
- LLVM_DEBUG(llvm::dbgs()
- << "Cannot inline: function is marked no_inline\n");
+ LDBG() << "Cannot inline: function is marked no_inline";
return false;
}
if (funcOp.isVarArg()) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot inline: callable is variadic\n");
+ LDBG() << "Cannot inline: callable is variadic";
return false;
}
// TODO: Generate aliasing metadata from noalias result attributes.
if (auto attrs = funcOp.getArgAttrs()) {
for (DictionaryAttr attrDict : attrs->getAsRange<DictionaryAttr>()) {
if (attrDict.contains(LLVM::LLVMDialect::getInAllocaAttrName())) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName()
- << ": inalloca arguments not supported\n");
+ LDBG() << "Cannot inline " << funcOp.getSymName()
+ << ": inalloca arguments not supported";
return false;
}
}
}
// TODO: Handle exceptions.
if (funcOp.getPersonality()) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName()
- << ": unhandled function personality\n");
+ LDBG() << "Cannot inline " << funcOp.getSymName()
+ << ": unhandled function personality";
return false;
}
if (funcOp.getPassthrough()) {
@@ -717,10 +717,8 @@ struct LLVMInlinerInterface : public DialectInlinerInterface {
if (!stringAttr)
return false;
if (disallowedFunctionAttrs.contains(stringAttr)) {
- LLVM_DEBUG(llvm::dbgs()
- << "Cannot inline " << funcOp.getSymName()
- << ": found disallowed function attribute "
- << stringAttr << "\n");
+ LDBG() << "Cannot inline " << funcOp.getSymName()
+ << ": found disallowed function attribute " << stringAttr;
return true;
}
return false;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp
index b6e168e..7f6ecab 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp
@@ -15,7 +15,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
@@ -119,8 +119,8 @@ void mlir::linalg::LinalgDialect::initialize() {
addInterfaces<LinalgInlinerInterface>();
- declarePromisedInterface<mesh::ShardingInterface, GenericOp>();
- declarePromisedInterfaces<mesh::ShardingInterface,
+ declarePromisedInterface<shard::ShardingInterface, GenericOp>();
+ declarePromisedInterfaces<shard::ShardingInterface,
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>();
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index f49d9a1..73ae029 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -476,10 +476,10 @@ inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps,
SmallVector<unsigned, 2>(ac.begin(), ac.end()),
SmallVector<unsigned, 2>(bc.begin(), bc.end()),
SmallVector<unsigned, 2>(ra.begin(), ra.end())};
- llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
- llvm::sort(dimensions.m.begin(), dimensions.m.end());
- llvm::sort(dimensions.n.begin(), dimensions.n.end());
- llvm::sort(dimensions.k.begin(), dimensions.k.end());
+ llvm::sort(dimensions.batch);
+ llvm::sort(dimensions.m);
+ llvm::sort(dimensions.n);
+ llvm::sort(dimensions.k);
return dimensions;
}
@@ -797,12 +797,12 @@ inferConvolutionDimsImpl(LinalgOp linalgOp,
SmallVector<unsigned, 2>(depth.begin(), depth.end()),
/*strides=*/SmallVector<int64_t, 2>{},
/*dilations=*/SmallVector<int64_t, 2>{}};
- llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
- llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
- llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
- llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
- llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
- llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
+ llvm::sort(dimensions.batch);
+ llvm::sort(dimensions.outputImage);
+ llvm::sort(dimensions.outputChannel);
+ llvm::sort(dimensions.filterLoop);
+ llvm::sort(dimensions.inputChannel);
+ llvm::sort(dimensions.depth);
// Use the op carried strides/dilations attribute if present.
auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3aa6ac3..34c63d3 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -32,6 +32,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -62,10 +63,10 @@ static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
return getAsOpFoldResult(
TypeSwitch<Type, Value>(v.getType())
.Case<RankedTensorType>([&](RankedTensorType t) -> Value {
- return builder.create<tensor::DimOp>(loc, v, dim);
+ return tensor::DimOp::create(builder, loc, v, dim);
})
.Case<MemRefType>([&](MemRefType t) -> Value {
- return builder.create<memref::DimOp>(loc, v, dim);
+ return memref::DimOp::create(builder, loc, v, dim);
}));
}
@@ -77,12 +78,12 @@ static Operation *getSlice(OpBuilder &b, Location loc, Value source,
ArrayRef<OpFoldResult> strides) {
return TypeSwitch<Type, Operation *>(source.getType())
.Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
- return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
- strides);
+ return tensor::ExtractSliceOp::create(b, loc, source, offsets, sizes,
+ strides);
})
.Case<MemRefType>([&](MemRefType type) -> Operation * {
- return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
- strides);
+ return memref::SubViewOp::create(b, loc, source, offsets, sizes,
+ strides);
})
.Default([&](Type t) -> Operation * { return nullptr; });
}
@@ -453,35 +454,35 @@ public:
builder.setInsertionPointToEnd(&block);
switch (unaryFn) {
case UnaryFn::exp:
- return builder.create<math::ExpOp>(arg.getLoc(), arg);
+ return math::ExpOp::create(builder, arg.getLoc(), arg);
case UnaryFn::log:
- return builder.create<math::LogOp>(arg.getLoc(), arg);
+ return math::LogOp::create(builder, arg.getLoc(), arg);
case UnaryFn::abs:
- return builder.create<math::AbsFOp>(arg.getLoc(), arg);
+ return math::AbsFOp::create(builder, arg.getLoc(), arg);
case UnaryFn::ceil:
- return builder.create<math::CeilOp>(arg.getLoc(), arg);
+ return math::CeilOp::create(builder, arg.getLoc(), arg);
case UnaryFn::floor:
- return builder.create<math::FloorOp>(arg.getLoc(), arg);
+ return math::FloorOp::create(builder, arg.getLoc(), arg);
case UnaryFn::negf:
- return builder.create<arith::NegFOp>(arg.getLoc(), arg);
+ return arith::NegFOp::create(builder, arg.getLoc(), arg);
case UnaryFn::reciprocal: {
Attribute oneAttr = builder.getOneAttr(arg.getType());
- auto one = builder.create<arith::ConstantOp>(arg.getLoc(),
- ::cast<TypedAttr>(oneAttr));
- return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
+ auto one = arith::ConstantOp::create(builder, arg.getLoc(),
+ ::cast<TypedAttr>(oneAttr));
+ return arith::DivFOp::create(builder, arg.getLoc(), one, arg);
}
case UnaryFn::round:
- return builder.create<math::RoundOp>(arg.getLoc(), arg);
+ return math::RoundOp::create(builder, arg.getLoc(), arg);
case UnaryFn::sqrt:
- return builder.create<math::SqrtOp>(arg.getLoc(), arg);
+ return math::SqrtOp::create(builder, arg.getLoc(), arg);
case UnaryFn::rsqrt:
- return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
+ return math::RsqrtOp::create(builder, arg.getLoc(), arg);
case UnaryFn::square:
- return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg);
+ return arith::MulFOp::create(builder, arg.getLoc(), arg, arg);
case UnaryFn::tanh:
- return builder.create<math::TanhOp>(arg.getLoc(), arg);
+ return math::TanhOp::create(builder, arg.getLoc(), arg);
case UnaryFn::erf:
- return builder.create<math::ErfOp>(arg.getLoc(), arg);
+ return math::ErfOp::create(builder, arg.getLoc(), arg);
}
if (emitError) {
emitError() << "unsupported unary function";
@@ -516,17 +517,17 @@ public:
switch (binaryFn) {
case BinaryFn::add:
if (allComplex)
- return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
+ return complex::AddOp::create(builder, arg0.getLoc(), arg0, arg1);
if (allFloatingPoint)
- return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
+ return arith::AddFOp::create(builder, arg0.getLoc(), arg0, arg1);
if (allBool)
- return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
- return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
+ return arith::OrIOp::create(builder, arg0.getLoc(), arg0, arg1);
+ return arith::AddIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::sub:
if (allComplex)
- return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
+ return complex::SubOp::create(builder, arg0.getLoc(), arg0, arg1);
if (allFloatingPoint)
- return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
+ return arith::SubFOp::create(builder, arg0.getLoc(), arg0, arg1);
if (allBool) {
if (emitError) {
emitError() << "unsupported operation: sub with bools";
@@ -534,20 +535,20 @@ public:
}
llvm_unreachable("unsupported operation: sub with bools");
}
- return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
+ return arith::SubIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::mul:
if (allComplex)
- return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
+ return complex::MulOp::create(builder, arg0.getLoc(), arg0, arg1);
if (allFloatingPoint)
- return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
+ return arith::MulFOp::create(builder, arg0.getLoc(), arg0, arg1);
if (allBool)
- return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
- return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
+ return arith::AndIOp::create(builder, arg0.getLoc(), arg0, arg1);
+ return arith::MulIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::div:
if (allComplex)
- return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
+ return complex::DivOp::create(builder, arg0.getLoc(), arg0, arg1);
if (allFloatingPoint)
- return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
+ return arith::DivFOp::create(builder, arg0.getLoc(), arg0, arg1);
if (allBool) {
if (emitError) {
emitError() << "unsupported operation: div with bools";
@@ -555,7 +556,7 @@ public:
}
llvm_unreachable("unsupported operation: div with bools");
}
- return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
+ return arith::DivSIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::div_unsigned:
if (!allInteger || allBool) {
if (emitError) {
@@ -564,30 +565,30 @@ public:
}
llvm_unreachable("unsupported operation: unsigned div not on uint");
}
- return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
+ return arith::DivUIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::max_signed:
assert(!allComplex);
if (allFloatingPoint)
- return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
- return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
+ return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+ return arith::MaxSIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::min_signed:
assert(!allComplex);
if (allFloatingPoint)
- return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
- return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
+ return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+ return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::max_unsigned:
assert(!allComplex);
if (allFloatingPoint)
- return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
- return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
+ return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+ return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::min_unsigned:
assert(!allComplex);
if (allFloatingPoint)
- return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
- return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
+ return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+ return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::powf:
assert(allFloatingPoint);
- return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1);
+ return math::PowFOp::create(builder, arg0.getLoc(), arg0, arg1);
}
if (emitError) {
emitError() << "unsupported binary function";
@@ -610,7 +611,7 @@ public:
case TernaryFn::select:
if (!headBool && !(tailFloatingPoint || tailInteger))
llvm_unreachable("unsupported non numeric type");
- return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
+ return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2);
}
if (emitError) {
emitError() << "unsupported ternary function";
@@ -639,7 +640,7 @@ public:
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointToEnd(&block);
Location loc = builder.getUnknownLoc();
- builder.create<YieldOp>(loc, values);
+ YieldOp::create(builder, loc, values);
}
Value constant(const std::string &value) {
@@ -647,13 +648,14 @@ public:
builder.setInsertionPointToEnd(&block);
Location loc = builder.getUnknownLoc();
Attribute valueAttr = parseAttribute(value, builder.getContext());
- return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
+ return arith::ConstantOp::create(builder, loc,
+ ::cast<TypedAttr>(valueAttr));
}
Value index(int64_t dim) {
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointToEnd(&block);
- return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
+ return IndexOp::create(builder, builder.getUnknownLoc(), dim);
}
Type getIntegerType(unsigned width) {
@@ -749,14 +751,14 @@ struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
TensorReshapeOp newInit;
if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
- newInit = rewriter.create<TensorReshapeOp>(
- loc, reshapeOp.getResultType(), oldFill.output(),
+ newInit = TensorReshapeOp::create(
+ rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
reshapeOp.getStaticOutputShape());
} else {
- newInit = rewriter.create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
- oldFill.output(),
- reshapeOp.getReassociation());
+ newInit = TensorReshapeOp::create(
+ rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
+ reshapeOp.getReassociation());
}
rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
ValueRange{newInit});
@@ -786,17 +788,16 @@ struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
return rewriter.notifyMatchFailure(
padOp, "failed to reify tensor.pad op result shape");
- auto emptyTensor = rewriter.create<tensor::EmptyOp>(
- padOp.getLoc(), reifiedShape.front(),
- padOp.getResultType().getElementType());
+ auto emptyTensor =
+ tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
+ padOp.getResultType().getElementType());
Value replacement =
- rewriter
- .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
- ValueRange{emptyTensor})
+ FillOp::create(rewriter, fillOp.getLoc(), ValueRange{padValue},
+ ValueRange{emptyTensor})
.getResult(0);
if (replacement.getType() != padOp.getResultType()) {
- replacement = rewriter.create<tensor::CastOp>(
- fillOp.getLoc(), padOp.getResultType(), replacement);
+ replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
+ padOp.getResultType(), replacement);
}
rewriter.replaceOp(padOp, replacement);
return success();
@@ -889,7 +890,7 @@ struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
if (srcPadType.isDynamicDim(i)) {
newSizes.push_back(
- rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
+ tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
.getResult());
} else {
newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
@@ -942,8 +943,8 @@ static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
if (!packOpDest.hasOneUse())
return failure();
- return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
- packOp.getDest());
+ return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
+ packOp.getDest());
}
/// Wrapper pattern that applies foldFillPackIntoFillOp method.
@@ -1042,8 +1043,8 @@ struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
concatOp, "not all operands are defined by a compatible fill op");
}
- Value outsConcat = rewriter.create<tensor::ConcatOp>(
- concatOp.getLoc(), concatOp.getDim(), allOuts);
+ Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
+ concatOp.getDim(), allOuts);
rewriter.replaceOpWithNewOp<linalg::FillOp>(
concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
return success();
@@ -1407,14 +1408,14 @@ struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
// TODO: unify the two ops?
if (sparse_tensor::getSparseTensorEncoding(returnType) ||
sparse_tensor::getSparseTensorEncoding(resultType))
- returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
- linalgOp.getLoc(), resultType, returnedArg);
+ returnedArg = sparse_tensor::ConvertOp::create(
+ rewriter, linalgOp.getLoc(), resultType, returnedArg);
else {
if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
resultType))
return failure();
- returnedArg = rewriter.create<tensor::CastOp>(
- linalgOp.getLoc(), resultType, returnedArg);
+ returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
+ resultType, returnedArg);
}
}
returnedArgs.push_back(returnedArg);
@@ -1528,7 +1529,7 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
.getElementType()},
payloadOpAttrs);
- b.create<YieldOp>(result.location, payloadOp->getResults());
+ YieldOp::create(b, result.location, payloadOp->getResults());
}
ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -1945,7 +1946,7 @@ static void buildIdentityRegion(OpBuilder &builder, Location loc,
buildGenericRegion(builder, loc, region, inputs, outputs,
[](OpBuilder &b, Location loc, ValueRange args) {
if (!args.empty())
- b.create<linalg::YieldOp>(loc, args[0]);
+ linalg::YieldOp::create(b, loc, args[0]);
});
}
@@ -2138,7 +2139,7 @@ struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
unsigned inputRank = broadcastInputTy.getRank();
for (unsigned i = 0; i < inputRank; ++i) {
if (broadcastInputTy.isDynamicDim(i)) {
- dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i)
+ dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
->getResult(0));
} else {
dims.push_back(IntegerAttr::get(IndexType::get(ctx),
@@ -2147,15 +2148,14 @@ struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
}
SmallVector<OpFoldResult> transposeResultShapes =
applyPermutation(dims, resultPerms);
- Value transposeInit = rewriter.create<tensor::EmptyOp>(
- transposeOp.getLoc(), transposeResultShapes,
+ Value transposeInit = tensor::EmptyOp::create(
+ rewriter, transposeOp.getLoc(), transposeResultShapes,
broadcastInputTy.getElementType());
// Create broadcast(transpose(input)).
Value transposeResult =
- rewriter
- .create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
- resultPerms)
+ TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
+ transposeInit, resultPerms)
->getResult(0);
rewriter.replaceOpWithNewOp<BroadcastOp>(
transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
@@ -2293,9 +2293,39 @@ Speculation::Speculatability BroadcastOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+/// Fold back-to-back broadcasts together.
+struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> {
+ using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
+ if (!defBroadcastOp)
+ return failure();
+ ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions();
+ ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
+ SmallVector<int64_t> foldedDims(dimensions);
+ Value init = broadcastOp.getInit();
+ int64_t initRank = cast<ShapedType>(init.getType()).getRank();
+ // Mapping from input dims to init dims.
+ SmallVector<int64_t> dimMap;
+ for (auto dim : llvm::seq<int64_t>(0, initRank)) {
+ if (!llvm::is_contained(dimensions, dim))
+ dimMap.push_back(dim);
+ }
+ for (auto dim : defDimensions)
+ foldedDims.push_back(dimMap[dim]);
+
+ llvm::sort(foldedDims);
+ rewriter.replaceOpWithNewOp<BroadcastOp>(
+ broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
+ return success();
+ }
+};
+
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
+ results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
}
//===----------------------------------------------------------------------===//
@@ -2547,7 +2577,7 @@ struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
// continue to propagate as far up the stack as it can go.
OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
Value newOperand =
- rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
+ tensor::CastOp::create(rewriter, loc, resultType, outOperand->get());
SmallVector<Value> newOperands = linalgOp.getDpsInputs();
SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
linalgOp.getDpsInits().end());
@@ -2560,8 +2590,8 @@ struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
// Create a tensor.cast operation back to the original type.
- Value castBack = rewriter.create<tensor::CastOp>(
- loc, resultValue.getType(), newOp->getResult(resultNumber));
+ Value castBack = tensor::CastOp::create(
+ rewriter, loc, resultValue.getType(), newOp->getResult(resultNumber));
SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
results[resultNumber] = castBack;
@@ -2653,7 +2683,7 @@ static void createNewOperandWithStaticSizes(
changeNeeded = true;
// Get the new operand value given its size and element type by
// casting it.
- Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
+ Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
unsigned index = opOperand->getOperandNumber();
newOperands[index] = newOperand;
}
@@ -2718,7 +2748,7 @@ struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
Type oldType = oldResult.getType();
replacements.push_back(
(newType != oldType)
- ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
+ ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
: newResult);
}
rewriter.replaceOp(linalgOp, replacements);
@@ -2756,8 +2786,8 @@ SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
int64_t operandRank = getInputOperandRank();
SmallVector<Range> loopBounds(operandRank);
Location loc = getLoc();
- Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
- Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+ Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
+ Value one = arith::ConstantIndexOp::create(builder, loc, 1);
Value source = getInput();
for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
loopBounds[dim].offset = zero;
@@ -2924,11 +2954,11 @@ static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
"We should have two maps: 1 for the input, 1 for the output");
assert(indexingMaps[0].isIdentity() && "input map should be identity");
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, output.getType(), input, output, indexingMaps, iteratorTypes,
- [&](OpBuilder &b, Location loc, ValueRange args) {
- Value result = b.create<T>(loc, args[0], args[1]);
- b.create<linalg::YieldOp>(loc, result);
+ auto genericOp = linalg::GenericOp::create(
+ builder, loc, output.getType(), input, output, indexingMaps,
+ iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
+ Value result = T::create(b, loc, args[0], args[1]);
+ linalg::YieldOp::create(b, loc, result);
});
return genericOp.getResult(0);
}
@@ -2947,12 +2977,13 @@ static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
assert(indexingMaps[0].isIdentity() && "input map should be identity");
// Add the affine map for the output argument.
indexingMaps.push_back(indexingMaps[0]);
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
- iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
- Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
- Value result = b.create<math::ExpOp>(loc, diff);
- b.create<linalg::YieldOp>(loc, result);
+ auto genericOp = linalg::GenericOp::create(
+ builder, loc, input.getType(), ValueRange{input, max}, output,
+ indexingMaps, iteratorTypes,
+ [&](OpBuilder &b, Location loc, ValueRange args) {
+ Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
+ Value result = math::ExpOp::create(b, loc, diff);
+ linalg::YieldOp::create(b, loc, result);
});
return genericOp.getResult(0);
}
@@ -2974,12 +3005,12 @@ static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
// Add the affine map for the output tensor.
indexingMaps.push_back(indexingMaps[0]);
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, numerator.getType(), ValueRange{numerator, denominator}, output,
- indexingMaps, iteratorTypes,
+ auto genericOp = linalg::GenericOp::create(
+ builder, loc, numerator.getType(), ValueRange{numerator, denominator},
+ output, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
- Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
- b.create<linalg::YieldOp>(loc, result);
+ Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
+ linalg::YieldOp::create(b, loc, result);
});
return genericOp.getResult(0);
}
@@ -3015,12 +3046,12 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
Value output = getOutput();
dims.erase(dims.begin() + reductionDim);
// Step 1: Compute max along dim.
- Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
+ Value outputReduce = tensor::EmptyOp::create(b, loc, dims, elementType);
Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
elementType, b, loc,
/*useOnlyFiniteValue=*/true);
Value neutralForMaxFInit =
- b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
+ linalg::FillOp::create(b, loc, Value{neutralForMaxF}, outputReduce)
.result();
Value max =
reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
@@ -3032,7 +3063,7 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
b, loc, /*useOnlyFiniteValue=*/true);
Value zeroInit =
- b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
+ linalg::FillOp::create(b, loc, Value{zero}, outputReduce).result();
Value denominator =
reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
@@ -3153,8 +3184,8 @@ FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
int64_t filterRank = getFilterOperandRank();
SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
Location loc = getLoc();
- auto filterSlice = builder.create<tensor::ExtractSliceOp>(
- loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
+ auto filterSlice = tensor::ExtractSliceOp::create(
+ builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
tiledOperands.emplace_back(filterSlice);
SmallVector<OpFoldResult> resultOffsets, resultSizes;
@@ -3164,8 +3195,8 @@ FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
int64_t outputRank = getOutputOperandRank();
SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
- auto outputSlice = builder.create<tensor::ExtractSliceOp>(
- loc, getOutput(), resultOffsets, resultSizes, outputStrides);
+ auto outputSlice = tensor::ExtractSliceOp::create(
+ builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
tiledOperands.emplace_back(outputSlice);
SmallVector<Type> resultTypes;
@@ -3333,8 +3364,8 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
{sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
int64_t inputRank = getInputOperandRank();
SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
- auto inputSlice = builder.create<tensor::ExtractSliceOp>(
- loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
+ auto inputSlice = tensor::ExtractSliceOp::create(
+ builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
tiledOperands.emplace_back(inputSlice);
SmallVector<OpFoldResult> resultOffsets, resultSizes;
@@ -3344,8 +3375,8 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
int64_t outputRank = getOutputOperandRank();
SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
- auto outputSlice = builder.create<tensor::ExtractSliceOp>(
- loc, getOutput(), resultOffsets, resultSizes, outputStrides);
+ auto outputSlice = tensor::ExtractSliceOp::create(
+ builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
tiledOperands.emplace_back(outputSlice);
SmallVector<Type> resultTypes;
@@ -3504,8 +3535,8 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
sizes[getValueFDim()]});
int64_t valueRank = getValueOperandRank();
SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
- auto valueSlice = builder.create<tensor::ExtractSliceOp>(
- loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
+ auto valueSlice = tensor::ExtractSliceOp::create(
+ builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
tiledOperands.emplace_back(valueSlice);
SmallVector<OpFoldResult> resultOffsets, resultSizes;
@@ -3515,8 +3546,8 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
int64_t outputRank = getOutputOperandRank();
SmallVector<OpFoldResult> strides(outputRank, oneAttr);
- auto outputSlice = builder.create<tensor::ExtractSliceOp>(
- loc, getOutput(), resultOffsets, resultSizes, strides);
+ auto outputSlice = tensor::ExtractSliceOp::create(
+ builder, loc, getOutput(), resultOffsets, resultSizes, strides);
tiledOperands.emplace_back(outputSlice);
SmallVector<Type> resultTypes;
@@ -4490,6 +4521,29 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
//===----------------------------------------------------------------------===//
// PackOp/UnPackOp Common
//===----------------------------------------------------------------------===//
+
+template <typename OpTy, typename>
+SmallVector<int64_t>
+getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) {
+ RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
+ ? packOrUnPack.getDestType()
+ : packOrUnPack.getSourceType();
+ RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
+ ? packOrUnPack.getSourceType()
+ : packOrUnPack.getDestType();
+ SmallVector<int64_t> result(
+ packedType.getShape().take_front(unpackedType.getRank()));
+ if (!packOrUnPack.getOuterDimsPerm().empty()) {
+ applyPermutationToVector(
+ result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
+ }
+ return result;
+}
+template SmallVector<int64_t>
+ getPackedOuterShapeWithoutTransposition<PackOp>(PackOp);
+template SmallVector<int64_t>
+ getPackedOuterShapeWithoutTransposition<UnPackOp>(UnPackOp);
+
// Given the (potentially) updated packed type, `newPackedTy`, generates an
// updated mixed-tile-sizes attribute. A tile size is updated only
// when:
@@ -4599,22 +4653,6 @@ static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos,
});
}
-/// Returns true if the dimension of `sourceShape` is smaller than the dimension
-/// of the `limitShape`.
-static bool areAllInBound(ArrayRef<int64_t> sourceShape,
- ArrayRef<int64_t> limitShape) {
- assert(
- sourceShape.size() == limitShape.size() &&
- "expected source shape rank, and limit of the shape to have same rank");
- return llvm::all_of(
- llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
- int64_t sourceExtent = std::get<0>(it);
- int64_t limit = std::get<1>(it);
- return ShapedType::isDynamic(sourceExtent) ||
- ShapedType::isDynamic(limit) || sourceExtent <= limit;
- });
-}
-
template <typename OpTy>
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
@@ -4673,11 +4711,6 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
// represents full tiles.
RankedTensorType expectedPackedType = PackOp::inferPackedType(
unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
- if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
- return op->emitError("the shape of output is not large enough to hold the "
- "packed data. Expected at least ")
- << expectedPackedType << ", got " << packedType;
- }
if (!llvm::all_of(
llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
mixedTiles),
@@ -4694,6 +4727,12 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
return op->emitError("mismatch in inner tile sizes specified and shaped of "
"tiled dimension in the packed type");
}
+ if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
+ packedType.getShape()))) {
+ return op->emitError("expected ")
+ << expectedPackedType << " for the packed domain value, got "
+ << packedType;
+ }
return success();
}
@@ -4971,7 +5010,7 @@ Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
llvm::cast<RankedTensorType>(source.getType()).getShape())) {
if (ShapedType::isDynamic(value))
mixedSizes.push_back(
- b.create<tensor::DimOp>(loc, source, index).getResult());
+ tensor::DimOp::create(b, loc, source, index).getResult());
else
mixedSizes.push_back(b.getIndexAttr(value));
}
@@ -4985,7 +5024,7 @@ Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
- return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
+ return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
}
PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
@@ -4996,9 +5035,9 @@ PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
Value transposedDest =
createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
metadata.innerDimsPos, metadata.outerDimsPerm);
- return b.create<PackOp>(loc, getSource(), transposedDest,
- metadata.innerDimsPos, metadata.innerTiles,
- getPaddingValue(), metadata.outerDimsPerm);
+ return PackOp::create(b, loc, getSource(), transposedDest,
+ metadata.innerDimsPos, metadata.innerTiles,
+ getPaddingValue(), metadata.outerDimsPerm);
}
/// Returns true if the tiles and the tiled dims are constant.
@@ -5138,7 +5177,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
if (srcShape != packOp.getSourceType().getShape()) {
auto newSrcType = packOp.getSourceType().clone(srcShape);
source =
- rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
+ tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
}
Value dest = packOp.getDest();
RankedTensorType originalResultType = packOp.getDestType();
@@ -5146,7 +5185,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
if (needUpdateDestType) {
auto newDestType = packOp.getDestType().clone(destShape);
dest =
- rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
+ tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
}
rewriter.modifyOpInPlace(packOp, [&] {
packOp.getSourceMutable().assign(source);
@@ -5157,7 +5196,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
if (needUpdateDestType) {
rewriter.setInsertionPointAfter(packOp);
auto castOp =
- rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
+ tensor::CastOp::create(rewriter, loc, originalResultType, packOp);
rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
}
return success();
@@ -5250,18 +5289,20 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
// TODO: Strictly speaking, discardable attributes should be _discarded_ at
// this point. However, in practice, we use them for things that we'd like
// to preserve. Implement a better abstraction.
- PackOp newOp = rewriter.create<PackOp>(
- op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
- newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
+ PackOp newOp =
+ PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
+ op.getInnerDimsPos(), newMixedTileSizes,
+ op.getPaddingValue(), op.getOuterDimsPerm());
newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
// Replace op.
Value oldResult = op.getResult();
Value newResult = newOp.getResult();
- Value replacement = (newResult.getType() != oldResult.getType())
- ? rewriter.create<tensor::CastOp>(
- op->getLoc(), oldResult.getType(), newResult)
- : newResult;
+ Value replacement =
+ (newResult.getType() != oldResult.getType())
+ ? tensor::CastOp::create(rewriter, op->getLoc(),
+ oldResult.getType(), newResult)
+ : newResult;
rewriter.replaceOp(op, {replacement});
@@ -5358,7 +5399,8 @@ Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
for (auto i :
llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
if (srcType.isDynamicDim(i))
- mixedSizes.push_back(b.create<tensor::DimOp>(loc, source, i).getResult());
+ mixedSizes.push_back(
+ tensor::DimOp::create(b, loc, source, i).getResult());
else
mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
}
@@ -5371,7 +5413,7 @@ Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
auto elemType = srcType.getElementType();
- return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
+ return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
}
UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
@@ -5380,9 +5422,9 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
ArrayRef<int64_t> outerPermutation) {
PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
*this, innerPermutation, outerPermutation);
- return b.create<UnPackOp>(loc, transposedSource, getDest(),
- metadata.innerDimsPos, metadata.innerTiles,
- metadata.outerDimsPerm);
+ return UnPackOp::create(b, loc, transposedSource, getDest(),
+ metadata.innerDimsPos, metadata.innerTiles,
+ metadata.outerDimsPerm);
}
/// Returns true if the `srcShape` or `destShape` is different from the one in
@@ -5447,15 +5489,11 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
if (unPackOp->hasOneUse()) {
auto extractSliceUser =
dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
- if (extractSliceUser &&
- areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) &&
- areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) &&
- extractSliceUser.getSourceType().getRank() ==
- extractSliceUser.getResultType().getRank()) {
+ if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unPackOp);
- auto newDest = rewriter.create<tensor::ExtractSliceOp>(
- unPackOp->getLoc(), unPackOp.getDest(),
+ auto newDest = tensor::ExtractSliceOp::create(
+ rewriter, unPackOp->getLoc(), unPackOp.getDest(),
extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
extractSliceUser.getMixedStrides());
rewriter.modifyOpInPlace(unPackOp, [&]() {
@@ -5474,18 +5512,18 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
Value source = unPackOp.getSource();
if (srcShape != unPackOp.getSourceType().getShape()) {
auto newSrcType = unPackOp.getSourceType().clone(srcShape);
- source = rewriter.create<tensor::CastOp>(loc, newSrcType,
- unPackOp.getSource());
+ source = tensor::CastOp::create(rewriter, loc, newSrcType,
+ unPackOp.getSource());
}
Value dest = unPackOp.getDest();
if (destShape != unPackOp.getDestType().getShape()) {
auto newDestType = unPackOp.getDestType().clone(destShape);
- dest =
- rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
+ dest = tensor::CastOp::create(rewriter, loc, newDestType,
+ unPackOp.getDest());
}
- Value newOp = rewriter.create<UnPackOp>(
- loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
- unPackOp.getOuterDimsPerm());
+ Value newOp = UnPackOp::create(
+ rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
+ unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
rewriter.replaceOpWithNewOp<tensor::CastOp>(
unPackOp, unPackOp.getResult().getType(), newOp);
return success();
@@ -5494,6 +5532,32 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
return failure();
}
+bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
+ // Rank-reduced folding is not supported.
+ if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
+ return false;
+ if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
+ !areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
+ return false;
+ RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
+ SmallVector<int64_t> outerShapeWithoutTranspose =
+ getPackedOuterShapeWithoutTransposition(*this);
+ for (auto [pos, tileSize] :
+ llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
+ if (unpackedTypeAfterFold.isDynamicDim(pos))
+ return false;
+ if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
+ return false;
+ if (ShapedType::isDynamic(tileSize))
+ return false;
+ int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
+ unpackedTypeAfterFold.getDimSize(pos);
+ if (paddingSize >= tileSize)
+ return false;
+ }
+ return true;
+}
+
bool UnPackOp::isLikeUnPad() {
RankedTensorType packedTensorType = getSourceType();
return isLikePadUnPad(*this, packedTensorType);
@@ -5542,18 +5606,19 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
// TODO: Strictly speaking, discardable attributes should be _discarded_ at
// this point. However, in practice, we use them for things that we'd like
// to preserve. Implement a better abstraction.
- UnPackOp newOp = rewriter.create<UnPackOp>(
- op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
- newMixedTileSizes, op.getOuterDimsPerm());
+ UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
+ newOperands[1], op.getInnerDimsPos(),
+ newMixedTileSizes, op.getOuterDimsPerm());
newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
// Replace op.
Value oldResult = op.getResult();
Value newResult = newOp.getResult();
- Value replacement = (newResult.getType() != oldResult.getType())
- ? rewriter.create<tensor::CastOp>(
- op->getLoc(), oldResult.getType(), newResult)
- : newResult;
+ Value replacement =
+ (newResult.getType() != oldResult.getType())
+ ? tensor::CastOp::create(rewriter, op->getLoc(),
+ oldResult.getType(), newResult)
+ : newResult;
rewriter.replaceOp(op, {replacement});
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp b/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp
index ce1b1b9..5c8c2de 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp
@@ -12,6 +12,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/InterleavedRange.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
@@ -21,8 +22,6 @@
using namespace mlir;
#define DEBUG_TYPE "linalg-transforms"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
static Attribute linearId0(MLIRContext *ctx) {
return gpu::GPUThreadMappingAttr::get(ctx, gpu::MappingId::LinearDim0);
@@ -43,9 +42,8 @@ transform::gpu::CopyMappingInfo::CopyMappingInfo(MLIRContext *ctx,
assert(!copySizes.empty() && copySizes.size() <= 3 &&
"only 1,2,3-D copies are supported for now");
- LDBG("START CopyMappingInfo, favorPredication: " << favorPredication);
- LLVM_DEBUG(DBGS() << "--copy shape: " << llvm::interleaved(copySizes)
- << "\n");
+ LDBG() << "START CopyMappingInfo, favorPredication: " << favorPredication;
+ LDBG() << "--copy shape: " << llvm::interleaved(copySizes);
// Greedily find the largest vector size that can be used to copy the most
// minor dimension: we are in the business of filling kMaxVectorLoadBitWidth
@@ -53,20 +51,19 @@ transform::gpu::CopyMappingInfo::CopyMappingInfo(MLIRContext *ctx,
int64_t desiredVectorSize = CopyMappingInfo::maxContiguousElementsToTransfer(
desiredBitAlignment, copySizes.back(), elementalBitwidth);
- LDBG("--greedily determined vectorSize: "
- << desiredVectorSize << " elements of " << elementalBitwidth
- << "b each -> " << (desiredVectorSize * elementalBitwidth)
- << "b total out of a max of " << kMaxVectorLoadBitWidth << "b");
+ LDBG() << "--greedily determined vectorSize: " << desiredVectorSize
+ << " elements of " << elementalBitwidth << "b each -> "
+ << (desiredVectorSize * elementalBitwidth)
+ << "b total out of a max of " << kMaxVectorLoadBitWidth << "b";
status = inferNumThreads(totalNumThreads, copySizes, desiredVectorSize,
favorPredication);
if (status == Status::Invalid)
return;
- LLVM_DEBUG(DBGS() << "--copy: " << llvm::interleaved(copySizes) << "\n"
- << "--numThreads: " << llvm::interleaved(this->numThreads)
- << "\n"
- << "--vectorSize: " << this->vectorSize << "\n");
+ LDBG() << "--copy: " << llvm::interleaved(copySizes) << "\n"
+ << "--numThreads: " << llvm::interleaved(this->numThreads) << "\n"
+ << "--vectorSize: " << this->vectorSize;
assert(this->numThreads.size() == copySizes.size() &&
"compute copy mapping expected same number of threads and copy sizes");
@@ -84,7 +81,7 @@ transform::gpu::CopyMappingInfo::CopyMappingInfo(MLIRContext *ctx,
this->threadMapping =
llvm::to_vector(ArrayRef(allThreadMappings)
.take_back(this->smallestBoundingTileSizes.size()));
- LLVM_DEBUG(this->print(DBGS()); llvm::dbgs() << "\n");
+ LDBG() << *this;
}
int64_t transform::gpu::CopyMappingInfo::maxContiguousElementsToTransfer(
@@ -140,7 +137,7 @@ static SmallVector<int64_t> maximizeNumThreads(ArrayRef<int64_t> sizes,
"currentIndex out of bounds");
std::string indent(2 * currentIndex, '-');
if (static_cast<size_t>(currentIndex) == sizes.size() - 1) {
- LDBG(indent << "mandated globalBest: " << sizes[currentIndex]);
+ LDBG() << indent << "mandated globalBest: " << sizes[currentIndex];
return SmallVector<int64_t>{sizes[currentIndex]};
}
@@ -149,16 +146,16 @@ static SmallVector<int64_t> maximizeNumThreads(ArrayRef<int64_t> sizes,
SmallVector<int64_t> factors = getFactors(s);
SmallVector<int64_t> localThreadsPerDim;
localThreadsPerDim.reserve(sizes.size());
- LDBG(indent << "maximizeNumThreads in " << s
- << " with limit: " << maxNumThreads);
+ LDBG() << indent << "maximizeNumThreads in " << s
+ << " with limit: " << maxNumThreads;
for (auto factor : factors) {
auto nestedThreadsPerDim =
maximizeNumThreads(sizes, currentIndex + 1, maxNumThreads / factor);
int64_t localBest = factor * product(nestedThreadsPerDim);
if (localBest > best && localBest <= maxNumThreads) {
- LDBG(indent << "new localBest: " << localBest);
- LDBG(indent << "nestedThreadsPerDim: "
- << llvm::interleaved(nestedThreadsPerDim));
+ LDBG() << indent << "new localBest: " << localBest;
+ LDBG() << indent << "nestedThreadsPerDim: "
+ << llvm::interleaved(nestedThreadsPerDim);
localThreadsPerDim.clear();
localThreadsPerDim.push_back(factor);
llvm::append_range(localThreadsPerDim, nestedThreadsPerDim);
@@ -166,8 +163,8 @@ static SmallVector<int64_t> maximizeNumThreads(ArrayRef<int64_t> sizes,
}
}
- LDBG(indent << "found globalBest: " << best);
- LDBG(indent << "numThreads: " << llvm::interleaved(localThreadsPerDim));
+ LDBG() << indent << "found globalBest: " << best;
+ LDBG() << indent << "numThreads: " << llvm::interleaved(localThreadsPerDim);
return localThreadsPerDim;
}
@@ -192,8 +189,8 @@ transform::gpu::CopyMappingInfo::inferNumThreads(int64_t totalNumThreads,
if (status == Status::Success || status == Status::Invalid)
return status;
- LDBG("requires predication, try reducing vector size to "
- << (localVectorSize / 2));
+ LDBG() << "requires predication, try reducing vector size to "
+ << (localVectorSize / 2);
}
}
@@ -210,8 +207,8 @@ transform::gpu::CopyMappingInfo::inferNumThreadsImpl(
assert(sizes.back() % desiredVectorSize == 0 &&
"most-minor size not divisible by actualVectorSize");
- LDBG("inferNumThreadsImpl with totalNumThreads: "
- << totalNumThreads << " and vectorSize: " << desiredVectorSize);
+ LDBG() << "inferNumThreadsImpl with totalNumThreads: " << totalNumThreads
+ << " and vectorSize: " << desiredVectorSize;
// Scale the most minor size to account for the chosen vector size and
// maximize the number of threads without exceeding the total number of
@@ -219,22 +216,22 @@ transform::gpu::CopyMappingInfo::inferNumThreadsImpl(
SmallVector<int64_t> scaledSizes(sizes);
scaledSizes.back() /= desiredVectorSize;
if (scaledSizes.back() > totalNumThreads) {
- LDBG("--Too few threads given the required vector size -> FAIL");
+ LDBG() << "--Too few threads given the required vector size -> FAIL";
return Status::Invalid;
}
SmallVector<int64_t> inferredNumThreads =
maximizeNumThreads(scaledSizes, 0, totalNumThreads);
- LDBG("inferred numThreads: " << llvm::interleaved(inferredNumThreads));
- LDBG("computed actualVectorSize: " << desiredVectorSize);
+ LDBG() << "inferred numThreads: " << llvm::interleaved(inferredNumThreads);
+ LDBG() << "computed actualVectorSize: " << desiredVectorSize;
// Corner case: we cannot use more threads than available. If the dimension of
// the copy is so bad it is because higher-level tiling did not do its job, we
// do not try to recover from it here.
int64_t totalNumThreadsUsed = product(inferredNumThreads);
- LDBG("--totalNumThreadsUsed: " << totalNumThreadsUsed);
+ LDBG() << "--totalNumThreadsUsed: " << totalNumThreadsUsed;
if (totalNumThreadsUsed == 0 || totalNumThreadsUsed > totalNumThreads) {
- LDBG("--Too few threads given the required vector size -> FAIL");
+ LDBG() << "--Too few threads given the required vector size -> FAIL";
return Status::Invalid;
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index 2fe72a3..d4a3e5f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -15,14 +15,13 @@
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/InterleavedRange.h"
using namespace mlir;
#define DEBUG_TYPE "linalg-transforms"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
//===----------------------------------------------------------------------===//
// StructuredMatchOp
@@ -39,7 +38,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation(
return emitSilenceableError() << "expected a Linalg op";
}
// If errors are suppressed, succeed and set all results to empty lists.
- LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op");
+ LDBG() << "optional nested matcher expected a Linalg op";
results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
return DiagnosedSilenceableFailure::success();
}
@@ -75,8 +74,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation(
// When they are defined in this block, we additionally check if we have
// already applied the operation that defines them. If not, the
// corresponding results will be set to empty lists.
- LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage()
- << "\n");
+ LDBG() << "optional nested matcher failed: " << diag.getMessage();
(void)diag.silence();
SmallVector<OpOperand *> undefinedOperands;
for (OpOperand &terminatorOperand :
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 109e5b7..bdfc8d0 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -40,7 +40,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/LogicalResult.h"
#include <type_traits>
@@ -49,9 +49,6 @@ using namespace mlir::linalg;
using namespace mlir::transform;
#define DEBUG_TYPE "linalg-transforms"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
-#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
/// Attempts to apply the pattern specified as template argument to the given
/// operation. The pattern is expected to have a `returningMatchAndRewrite`
@@ -672,9 +669,10 @@ static Operation *replaceForAllWithNewSignature(
newOuts.push_back(outputs[resultNumber]);
// Create new scf.forall op
- auto newforallOp = rewriter.create<scf::ForallOp>(
- loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
- forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+ auto newforallOp = scf::ForallOp::create(
+ rewriter, loc, forallOp.getMixedLowerBound(),
+ forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
+ forallOp.getMapping());
rewriter.eraseBlock(newforallOp.getBody());
newforallOp.getRegion().takeBody(forallOp.getRegion());
@@ -699,8 +697,8 @@ static Operation *replaceForAllWithNewSignature(
Value src = tileAndFuseResult.tiledValues[0];
Value dst = newforallOp.getRegionIterArgs().back();
SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
- rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src,
- dst, offsets, sizes, strides);
+ tensor::ParallelInsertSliceOp::create(rewriter, firstYieldOp->getLoc(), src,
+ dst, offsets, sizes, strides);
for (auto result : llvm::enumerate(forallOp.getResults())) {
rewriter.replaceAllUsesWith(result.value(),
@@ -772,7 +770,7 @@ static bool sameOrEquivalentIterArg(Value src, Value dst) {
static std::tuple<SmallVector<Operation *>, Operation *>
tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
Operation *producerOp, Operation *containingOp) {
- LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
+ LDBG() << "Try to fuse a direct extract use";
auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
if (!tileableProducer) {
diag.attachNote(producerOp->getLoc())
@@ -837,7 +835,7 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
// Tile the producer.
int64_t resultNumber =
cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
- LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
+ LDBG() << "resultNumber: " << resultNumber;
SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
@@ -854,7 +852,7 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
#ifndef NDEBUG
for (auto *tiledOp : tileAndFuseResult->tiledOps) {
- LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n");
+ LDBG() << "tiledProducer: " << *tiledOp;
}
#endif
@@ -893,7 +891,7 @@ static SmallVector<Operation *>
tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
Operation *containingOp) {
- LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n");
+ LDBG() << "Try to fuse an extract use through block argument";
auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
if (!tileableProducer) {
@@ -946,7 +944,7 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
// Replace the use in the tileableProducer before tiling: clone, replace and
// then tile.
int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
- LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
+ LDBG() << "resultNumber: " << resultNumber;
// Gather destination tensors.
SmallVector<Value> destinationTensors;
@@ -995,7 +993,7 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
Operation *producerOp,
Operation *containingOp) {
- LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n");
+ LDBG() << "Try to fuse an use by cloning";
// Gather all uses inside the containing op.
SmallVector<OpOperand *> uses;
@@ -1029,7 +1027,7 @@ static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
"Parallel insert slice is not a valid clone destination");
unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber();
- LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
+ LDBG() << "resultNumber: " << resultNumber;
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(use->getOwner());
@@ -1112,7 +1110,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
auto [tiledOps, newContainingOp] =
tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
if (!tiledOps.empty()) {
- LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
+ LDBG() << "\nFused a direct extract use\n" << *containingOp;
fusedOps.append(tiledOps);
if (newContainingOp) {
// Update handles associated with the containing op so we don't need to
@@ -1138,8 +1136,8 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
rewriter, diag, producerOp, containingOp);
if (!tiledContainingOpOperand.empty()) {
- LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
- << *containingOp);
+ LDBG() << "\nFused an extract use through block argument\n"
+ << *containingOp;
fusedOps.append(tiledContainingOpOperand);
continue;
}
@@ -1147,7 +1145,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
Operation *cloned =
cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
if (cloned) {
- LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp);
+ LDBG() << "\nFused an use by cloning\n" << *containingOp;
fusedOps.push_back(cloned);
continue;
}
@@ -1851,7 +1849,7 @@ transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
assert(!packOp && "packOp must be null on entry when unPackOp is not null");
OpOperand *packUse = linalgOp.getDpsInitOperand(
cast<OpResult>(unPackOp.getSource()).getResultNumber());
- packOp = dyn_cast_or_null<linalg::PackOp>(packUse->get().getDefiningOp());
+ packOp = packUse->get().getDefiningOp<linalg::PackOp>();
if (!packOp || !packOp.getResult().hasOneUse())
return emitSilenceableError() << "could not find matching pack op";
}
@@ -3410,12 +3408,12 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
if (scalableSizes[ofrIdx]) {
- auto val = b.create<arith::ConstantIndexOp>(
- getLoc(), cast<IntegerAttr>(attr).getInt());
+ auto val = arith::ConstantIndexOp::create(
+ b, getLoc(), cast<IntegerAttr>(attr).getInt());
Value vscale =
- b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
+ vector::VectorScaleOp::create(b, getLoc(), b.getIndexType());
sizes.push_back(
- b.create<arith::MulIOp>(getLoc(), val, vscale).getResult());
+ arith::MulIOp::create(b, getLoc(), val, vscale).getResult());
} else {
sizes.push_back(attr);
}
@@ -3626,9 +3624,10 @@ static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
SmallVector<OpFoldResult> normalizedSteps(normalizedUbs.size(),
rewriter.getIndexAttr(1));
- auto normalizedForallOp = rewriter.create<scf::ForallOp>(
- loc, normalizedLbs, normalizedUbs, normalizedSteps, loop.getOutputs(),
- loop.getMapping(), [](OpBuilder &, Location, ValueRange) {});
+ auto normalizedForallOp = scf::ForallOp::create(
+ rewriter, loc, normalizedLbs, normalizedUbs, normalizedSteps,
+ loop.getOutputs(), loop.getMapping(),
+ [](OpBuilder &, Location, ValueRange) {});
auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
OpBuilder::InsertionGuard g(rewriter);
@@ -4131,12 +4130,11 @@ DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target,
target->template getParentOfType<scf::InParallelOp>());
}
- Value extracted = rewriter.create<tensor::ExtractSliceOp>(
- target.getLoc(), target.getDest(), target.getMixedOffsets(),
+ Value extracted = tensor::ExtractSliceOp::create(
+ rewriter, target.getLoc(), target.getDest(), target.getMixedOffsets(),
target.getMixedSizes(), target.getMixedStrides());
- Value copied = rewriter
- .create<linalg::CopyOp>(target.getLoc(),
- target.getSource(), extracted)
+ Value copied = linalg::CopyOp::create(rewriter, target.getLoc(),
+ target.getSource(), extracted)
.getResult(0);
// Reset the insertion point.
rewriter.setInsertionPoint(target);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp b/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp
index 281d9f2..ba94ad7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp
@@ -10,14 +10,14 @@
#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
void mlir::linalg::registerAllDialectInterfaceImplementations(
DialectRegistry &registry) {
registerBufferizableOpInterfaceExternalModels(registry);
- registerMeshShardingInterfaceExternalModels(registry);
+ registerShardingInterfaceExternalModels(registry);
registerSubsetOpInterfaceExternalModels(registry);
registerTilingInterfaceExternalModels(registry);
registerValueBoundsOpInterfaceExternalModels(registry);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 1f6d96c..3512ecd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -184,9 +184,9 @@ struct SoftmaxOpInterface
getBuffer(rewriter, softmaxOp.getOutput(), options, state);
if (failed(outputBuffer))
return failure();
- rewriter.create<linalg::SoftmaxOp>(softmaxOp.getLoc(),
- /*result=*/TypeRange(), *inputBuffer,
- *outputBuffer, softmaxOp.getDimension());
+ linalg::SoftmaxOp::create(rewriter, softmaxOp.getLoc(),
+ /*result=*/TypeRange(), *inputBuffer,
+ *outputBuffer, softmaxOp.getDimension());
replaceOpWithBufferizedValues(rewriter, op, *outputBuffer);
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 69e6fda..70f846e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -24,7 +24,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
Interchange.cpp
Loops.cpp
TransposeMatmul.cpp
- MeshShardingInterfaceImpl.cpp
+ ShardingInterfaceImpl.cpp
NamedOpConversions.cpp
BlockPackMatmul.cpp
PackAndUnpackPatterns.cpp
@@ -68,7 +68,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRIR
MLIRMemRefDialect
MLIRMemRefTransforms
- MLIRMeshTransforms
+ MLIRShardTransforms
MLIRLinalgDialect
MLIRLinalgUtils
MLIRSCFDialect
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index a7732b9..d1eb270 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -30,10 +30,10 @@ static bool hasAllOneValues(DenseIntElementsAttr attr) {
static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) {
if (isa<IntegerType>(x.getType()))
- return builder.create<arith::AddIOp>(loc, x, y);
+ return arith::AddIOp::create(builder, loc, x, y);
if (isa<ComplexType>(x.getType()))
- return builder.create<complex::AddOp>(loc, x, y);
- return builder.create<arith::AddFOp>(loc, x, y);
+ return complex::AddOp::create(builder, loc, x, y);
+ return arith::AddFOp::create(builder, loc, x, y);
}
static Value createMul(Location loc, Value x, Value y, Type accType,
@@ -44,10 +44,10 @@ static Value createMul(Location loc, Value x, Value y, Type accType,
Value yConvert =
convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false);
if (isa<ComplexType>(accType))
- return builder.create<complex::MulOp>(loc, xConvert, yConvert);
+ return complex::MulOp::create(builder, loc, xConvert, yConvert);
if (isa<IntegerType>(accType))
- return builder.create<arith::MulIOp>(loc, xConvert, yConvert);
- return builder.create<arith::MulFOp>(loc, xConvert, yConvert);
+ return arith::MulIOp::create(builder, loc, xConvert, yConvert);
+ return arith::MulFOp::create(builder, loc, xConvert, yConvert);
}
// Delinearizes the given composite `index` by the basis specified in `factors`.
@@ -56,7 +56,7 @@ static SmallVector<Value> unrollIndex(OpBuilder &b, Location loc, Value index,
assert(!factors.empty() && "empty factor list");
SmallVector<Value> basis;
for (int64_t f : factors)
- basis.push_back(b.create<arith::ConstantOp>(loc, b.getIndexAttr(f)));
+ basis.push_back(arith::ConstantOp::create(b, loc, b.getIndexAttr(f)));
FailureOr<SmallVector<Value>> multiIndex =
affine::delinearizeIndex(b, loc, index, basis);
assert(!failed(multiIndex) && "Failed to linearize img2col index");
@@ -115,18 +115,18 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
auto reshapedFilterType =
RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType());
- Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
- loc, reshapedFilterType, filter, filterReassocIndices);
+ Value reshapedFilter = tensor::CollapseShapeOp::create(
+ rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
RankedTensorType reshapedOutputType =
RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
- Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
- loc, reshapedOutputType, output, outputReassocIndices);
+ Value reshapedOutput = tensor::CollapseShapeOp::create(
+ rewriter, loc, reshapedOutputType, output, outputReassocIndices);
SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
- Value colTensor = rewriter.create<tensor::EmptyOp>(
- loc, colTensorShape, inputType.getElementType());
+ Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
+ inputType.getElementType());
// Convert the input to a (BMK) column tensor.
auto nloops = colTensorShape.size();
@@ -138,15 +138,15 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
SmallVector<AffineMap> img2colIndexingMaps = {
AffineMap::getMultiDimIdentityMap(nloops, context)};
- auto img2ColTensor = rewriter.create<linalg::GenericOp>(
- loc, colTensor.getType(),
+ auto img2ColTensor = linalg::GenericOp::create(
+ rewriter, loc, colTensor.getType(),
/*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
img2colIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
// Get the iterators named based on the matmul (batch, m, k).
- Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
- Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
- Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
+ Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0);
+ Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1);
+ Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2);
// Recover the original iteration indices from the problem/input sizes.
SmallVector<Value> mIndices = unrollIndex(
@@ -170,9 +170,9 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
// im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
- Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
- loc, input, extractionIndices);
- nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
+ Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input,
+ extractionIndices);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal);
});
// Because the filter does not share the same batch dimension,
@@ -187,8 +187,8 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
parallel, reduction};
- auto genericOp = rewriter.create<linalg::GenericOp>(
- loc, reshapedOutputType,
+ auto genericOp = linalg::GenericOp::create(
+ rewriter, loc, reshapedOutputType,
/*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
/*outputs=*/ValueRange{reshapedOutput},
ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
@@ -196,12 +196,12 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
Value mul =
createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
Value add = createAdd(loc, mul, args[2], nestedBuilder);
- nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc, add);
});
Value result = genericOp.getResults().front();
- auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
- loc, outputType, result, outputReassocIndices);
+ auto reshapedResult = tensor::ExpandShapeOp::create(
+ rewriter, loc, outputType, result, outputReassocIndices);
rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
@@ -244,8 +244,8 @@ rewriteInIm2Col(RewriterBase &rewriter,
SmallVector<int64_t> targetShape = llvm::to_vector<4>(llvm::map_range(
indices, [&](int64_t index) -> int64_t { return inputShape[index]; }));
- Value outputTensor = rewriter.create<tensor::EmptyOp>(
- loc, targetShape, operandTensorType.getElementType());
+ Value outputTensor = tensor::EmptyOp::create(
+ rewriter, loc, targetShape, operandTensorType.getElementType());
SmallVector<utils::IteratorType> loopAttributeTypes(
nloops, utils::IteratorType::parallel);
@@ -255,12 +255,12 @@ rewriteInIm2Col(RewriterBase &rewriter,
AffineMap::get(nloops, 0, exprs, rewriter.getContext())),
AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
- auto transposedOp = rewriter.create<linalg::GenericOp>(
- loc, outputTensor.getType(),
+ auto transposedOp = linalg::GenericOp::create(
+ rewriter, loc, outputTensor.getType(),
/*inputs=*/operand, /*outputs=*/outputTensor, indexingMaps,
loopAttributeTypes,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
});
return transposedOp.getResult(0);
@@ -307,15 +307,15 @@ rewriteInIm2Col(RewriterBase &rewriter,
AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()),
AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
- Value colTensor = rewriter.create<tensor::EmptyOp>(
- loc, colTensorShape, inputType.getElementType());
+ Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
+ inputType.getElementType());
- auto img2ColTensor = rewriter.create<linalg::GenericOp>(
- loc, colTensor.getType(),
+ auto img2ColTensor = linalg::GenericOp::create(
+ rewriter, loc, colTensor.getType(),
/*inputs=*/inputT, /*outputs=*/colTensor, indexingMaps,
loopAttributeTypes,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
});
SmallVector<ReassociationIndices> img2ColTensorReassocIndices = {
@@ -331,26 +331,27 @@ rewriteInIm2Col(RewriterBase &rewriter,
auto reshapedOutputTensorType =
RankedTensorType::get({n * c, oh * ow}, outputType.getElementType());
- Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
- loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
+ Value reshapedImg2ColTensor = tensor::CollapseShapeOp::create(
+ rewriter, loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
img2ColTensorReassocIndices);
- Value reshapedFilterTensor = rewriter.create<tensor::CollapseShapeOp>(
- loc, reshapedFilterTensorType, filterT, filterReassociationIndice);
- Value reshapedoutputTensor = rewriter.create<tensor::CollapseShapeOp>(
- loc, reshapedOutputTensorType, transposedOutputTensor,
+ Value reshapedFilterTensor =
+ tensor::CollapseShapeOp::create(rewriter, loc, reshapedFilterTensorType,
+ filterT, filterReassociationIndice);
+ Value reshapedoutputTensor = tensor::CollapseShapeOp::create(
+ rewriter, loc, reshapedOutputTensorType, transposedOutputTensor,
outputReassociationIndice);
- auto batchMatVecResult = rewriter.create<linalg::BatchMatvecOp>(
- loc, TypeRange{reshapedoutputTensor.getType()},
+ auto batchMatVecResult = linalg::BatchMatvecOp::create(
+ rewriter, loc, TypeRange{reshapedoutputTensor.getType()},
ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
ValueRange{reshapedoutputTensor});
SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1},
{2, 3}};
- auto batchMatVecResultReshaped = rewriter.create<tensor::ExpandShapeOp>(
- loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0),
- batchMatVecReassociationIndice);
+ auto batchMatVecResultReshaped = tensor::ExpandShapeOp::create(
+ rewriter, loc, transposedOutputTensor.getType(),
+ batchMatVecResult.getResult(0), batchMatVecReassociationIndice);
Value transposedResult =
transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
@@ -400,19 +401,19 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
auto reshapedFilterType =
RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType());
- Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
- loc, reshapedFilterType, filter, filterReassocIndices);
+ Value reshapedFilter = tensor::CollapseShapeOp::create(
+ rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1}, {2, 3}};
auto reshapedOutputType =
RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType());
- Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
- loc, reshapedOutputType, output, outputReassocIndices);
+ Value reshapedOutput = tensor::CollapseShapeOp::create(
+ rewriter, loc, reshapedOutputType, output, outputReassocIndices);
// Convert the input to a (BKN) tensor.
SmallVector<int64_t, 4> colTensorShape = {n, ic * fh * fw, oh * ow};
- Value colTensor = rewriter.create<tensor::EmptyOp>(
- loc, colTensorShape, inputType.getElementType());
+ Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
+ inputType.getElementType());
auto nloops = colTensorShape.size();
@@ -423,15 +424,15 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
SmallVector<AffineMap, 4> img2colIndexingMaps = {
AffineMap::getMultiDimIdentityMap(nloops, context)};
- auto img2ColTensor = rewriter.create<linalg::GenericOp>(
- loc, colTensor.getType(),
+ auto img2ColTensor = linalg::GenericOp::create(
+ rewriter, loc, colTensor.getType(),
/*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
img2colIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
// Get the iterators named based on the matmul (batch, m, k).
- Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
- Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
- Value nIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
+ Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0);
+ Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 1);
+ Value nIndex = linalg::IndexOp::create(nestedBuilder, loc, 2);
// Recover the original iteration indices from the problem/input sizes.
SmallVector<Value> kIndices = unrollIndex(
@@ -455,9 +456,9 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
// im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex};
- Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
- loc, input, extractionIndices);
- nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
+ Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input,
+ extractionIndices);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal);
});
// Because the filter does not share the same batch dimension,
@@ -471,8 +472,8 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
parallel, reduction};
- auto genericOp = rewriter.create<linalg::GenericOp>(
- loc, reshapedOutputType,
+ auto genericOp = linalg::GenericOp::create(
+ rewriter, loc, reshapedOutputType,
/*inputs=*/ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
/*outputs=*/ValueRange{reshapedOutput},
ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
@@ -480,12 +481,12 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
Value mul =
createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
Value add = createAdd(loc, mul, args[2], nestedBuilder);
- nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc, add);
});
Value result = genericOp.getResults().front();
- auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
- loc, outputType, result, outputReassocIndices);
+ auto reshapedResult = tensor::ExpandShapeOp::create(
+ rewriter, loc, outputType, result, outputReassocIndices);
rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
@@ -535,18 +536,18 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
auto reshapedFilterType =
RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
- Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
- loc, reshapedFilterType, filter, filterReassocIndices);
+ Value reshapedFilter = tensor::CollapseShapeOp::create(
+ rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
RankedTensorType reshapedOutputType =
RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
- Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
- loc, reshapedOutputType, output, outputReassocIndices);
+ Value reshapedOutput = tensor::CollapseShapeOp::create(
+ rewriter, loc, reshapedOutputType, output, outputReassocIndices);
SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
- Value colTensor = rewriter.create<tensor::EmptyOp>(
- loc, colTensorShape, inputType.getElementType());
+ Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
+ inputType.getElementType());
// Convert the input to a (BMK) column tensor.
auto nloops = colTensorShape.size();
@@ -558,15 +559,15 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
SmallVector<AffineMap> img2colIndexingMaps = {
AffineMap::getMultiDimIdentityMap(nloops, context)};
- auto img2ColTensor = rewriter.create<linalg::GenericOp>(
- loc, colTensor.getType(),
+ auto img2ColTensor = linalg::GenericOp::create(
+ rewriter, loc, colTensor.getType(),
/*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
img2colIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
// Get the iterators named based on the matmul (batch, m, k).
- Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
- Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
- Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
+ Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0);
+ Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1);
+ Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2);
// Recover the original iteration indices from the problem/input sizes.
SmallVector<Value> mIndices = unrollIndex(
@@ -590,9 +591,9 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
// im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
- Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
- loc, input, extractionIndices);
- nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
+ Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input,
+ extractionIndices);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal);
});
// Because we didn't transpose the filters we don't actually have a batched
@@ -606,8 +607,8 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
parallel, reduction};
- auto genericOp = rewriter.create<linalg::GenericOp>(
- loc, reshapedOutputType,
+ auto genericOp = linalg::GenericOp::create(
+ rewriter, loc, reshapedOutputType,
/*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
/*outputs=*/ValueRange{reshapedOutput},
ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
@@ -615,12 +616,12 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
Value mul =
createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
Value add = createAdd(loc, mul, args[2], nestedBuilder);
- nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc, add);
});
Value result = genericOp.getResults().front();
- auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
- loc, outputType, result, outputReassocIndices);
+ auto reshapedResult = tensor::ExpandShapeOp::create(
+ rewriter, loc, outputType, result, outputReassocIndices);
rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 39e2aac..76ddee4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -37,8 +37,8 @@ static Value createInserts(RewriterBase &rewriter, Location loc, int dim,
if (dim == static_cast<int>(shape.size()) - 1) {
for (int i = 0; i < shape.back(); ++i) {
indices.back() = constants[i];
- destination = rewriter.create<tensor::InsertOp>(loc, *elementIt,
- destination, indices);
+ destination = tensor::InsertOp::create(rewriter, loc, *elementIt,
+ destination, indices);
++elementIt;
}
return destination;
@@ -65,27 +65,27 @@ static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource,
MaterializeInDestination: {
// Note: This is the preferred way of memcpy'ing because no layout map
// and/or memory space must be specified for the source.
- auto materializeOp = b.create<bufferization::MaterializeInDestinationOp>(
- loc, tensorSource, memrefDest);
+ auto materializeOp = bufferization::MaterializeInDestinationOp::create(
+ b, loc, tensorSource, memrefDest);
materializeOp.setWritable(true);
} break;
case linalg::BufferizeToAllocationOptions::MemcpyOp::MemrefCopy: {
// TODO: Support custom memory space on source.
// We do not know the layout map of the source yet, so use a fully dynamic
// layout for best compatibility.
- Value toBuffer = b.create<bufferization::ToBufferOp>(
- loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType),
+ Value toBuffer = bufferization::ToBufferOp::create(
+ b, loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType),
tensorSource, /*readOnly=*/true);
- b.create<memref::CopyOp>(loc, toBuffer, memrefDest);
+ memref::CopyOp::create(b, loc, toBuffer, memrefDest);
} break;
case linalg::BufferizeToAllocationOptions::MemcpyOp::LinalgCopy: {
// TODO: Support custom memory space on source.
// We do not know the layout map of the source yet, so use a fully dynamic
// layout for best compatibility.
- Value toBuffer = b.create<bufferization::ToBufferOp>(
- loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType),
+ Value toBuffer = bufferization::ToBufferOp::create(
+ b, loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType),
tensorSource, /*readOnly=*/true);
- b.create<linalg::CopyOp>(loc, toBuffer, memrefDest);
+ linalg::CopyOp::create(b, loc, toBuffer, memrefDest);
} break;
};
}
@@ -120,15 +120,15 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter,
->materializeConstant(rewriter, constYieldedValue,
yieldedValue.getType(), yieldedValue.getLoc())
->getResult(0);
- auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange(fillValue),
- ValueRange(dest));
+ auto fillOp = linalg::FillOp::create(rewriter, loc, ValueRange(fillValue),
+ ValueRange(dest));
return fillOp;
}
if (invariantYieldedValue) {
// Padding with an invariant value.
- auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange(yieldedValue),
- ValueRange(dest));
+ auto fillOp = linalg::FillOp::create(
+ rewriter, loc, ValueRange(yieldedValue), ValueRange(dest));
return fillOp;
}
@@ -137,8 +137,8 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter,
utils::IteratorType::parallel);
SmallVector<AffineMap> indexingMaps(
1, rewriter.getMultiDimIdentityMap(resultType.getRank()));
- auto genericOp = rewriter.create<linalg::GenericOp>(
- loc, resultType, /*inputs=*/ValueRange(),
+ auto genericOp = linalg::GenericOp::create(
+ rewriter, loc, resultType, /*inputs=*/ValueRange(),
/*outputs=*/ValueRange{dest}, /*indexingMaps=*/
indexingMaps, iteratorTypes);
Block *body = rewriter.createBlock(&genericOp->getRegion(0), {},
@@ -146,7 +146,7 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter,
rewriter.setInsertionPointToStart(body);
SmallVector<Value> bbArgReplacements;
for (int64_t i = 0; i < resultType.getRank(); ++i)
- bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i));
+ bbArgReplacements.push_back(linalg::IndexOp::create(rewriter, loc, i));
rewriter.mergeBlocks(padOp.getBody(), body, bbArgReplacements);
// Update terminator.
@@ -179,8 +179,8 @@ static SmallVector<Value> reifyOrComputeDynamicSizes(OpBuilder &b,
for (int64_t i = 0; i < tensorType.getRank(); ++i) {
if (tensorType.isDynamicDim(i))
dynSizes.push_back(
- b.create<DimOp>(value.getLoc(), value,
- b.create<arith::ConstantIndexOp>(value.getLoc(), i)));
+ DimOp::create(b, value.getLoc(), value,
+ arith::ConstantIndexOp::create(b, value.getLoc(), i)));
}
return dynSizes;
}
@@ -201,15 +201,15 @@ createAllocationForTensor(RewriterBase &rewriter, Location loc, Value value,
Value alloc;
if (options.allocOp ==
linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloc) {
- alloc = rewriter.create<memref::AllocOp>(loc, memrefType, dynamicSizes);
+ alloc = memref::AllocOp::create(rewriter, loc, memrefType, dynamicSizes);
if (options.emitDealloc) {
// Place deallocation at the end of the block.
rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator());
- rewriter.create<memref::DeallocOp>(loc, alloc);
+ memref::DeallocOp::create(rewriter, loc, alloc);
}
} else if (options.allocOp ==
linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloca) {
- alloc = rewriter.create<memref::AllocaOp>(loc, memrefType, dynamicSizes);
+ alloc = memref::AllocaOp::create(rewriter, loc, memrefType, dynamicSizes);
// No dealloc is needed.
}
@@ -243,14 +243,14 @@ Value linalg::bufferizeToAllocation(
getMixedSizes(rewriter, loc, padOp.getSource());
SmallVector<OpFoldResult> strides(padOp.getResultType().getRank(),
rewriter.getIndexAttr(1));
- Value subview = rewriter.create<memref::SubViewOp>(
- loc, alloc, /*offsets=*/padOp.getMixedLowPad(), sizes, strides);
+ Value subview = memref::SubViewOp::create(
+ rewriter, loc, alloc, /*offsets=*/padOp.getMixedLowPad(), sizes, strides);
createMemcpy(rewriter, loc, padOp.getSource(), subview, options);
// Create bufferization.to_tensor with "restrict" and "writable". The returned
// tensor is a new buffer allocation, so it does not alias with any buffer.
- Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
- loc, padOp.getResult().getType(), alloc, /*restrict=*/true,
+ Value toTensorOp = bufferization::ToTensorOp::create(
+ rewriter, loc, padOp.getResult().getType(), alloc, /*restrict=*/true,
/*writable=*/true);
rewriter.replaceOp(padOp, toTensorOp);
return alloc;
@@ -338,8 +338,9 @@ Value linalg::bufferizeToAllocation(
// Create bufferization.to_tensor with "restrict" and "writable". The returned
// tensor is a new buffer allocation, so it does not alias with any buffer.
- Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
- loc, allocTensorOp.getResult().getType(), alloc, /*restrict=*/true,
+ Value toTensorOp = bufferization::ToTensorOp::create(
+ rewriter, loc, allocTensorOp.getResult().getType(), alloc,
+ /*restrict=*/true,
/*writable=*/true);
rewriter.replaceOp(allocTensorOp, toTensorOp);
return alloc;
@@ -354,7 +355,7 @@ FailureOr<Operation *> mlir::linalg::rewriteInDestinationPassingStyle(
auto shape = tensorType.getShape();
// Create tensor.empty.
- auto emptyOp = rewriter.create<EmptyOp>(loc, tensorType, ValueRange());
+ auto emptyOp = EmptyOp::create(rewriter, loc, tensorType, ValueRange());
// Case: tensor<elem_type>.
if (shape.empty()) {
@@ -369,7 +370,7 @@ FailureOr<Operation *> mlir::linalg::rewriteInDestinationPassingStyle(
SmallVector<Value, 2> constants;
constants.reserve(maxDim);
for (int i = 0; i < maxDim; ++i)
- constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
+ constants.push_back(arith::ConstantIndexOp::create(rewriter, loc, i));
// Traverse all elements and create tensor.insert ops.
auto elementIt = fromElementsOp.getElements().begin();
@@ -394,16 +395,16 @@ mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter,
RankedTensorType tensorType = cast<RankedTensorType>(generateOp.getType());
// Create tensor.empty.
- auto emptyOp =
- rewriter.create<EmptyOp>(loc, tensorType, generateOp.getDynamicExtents());
+ auto emptyOp = EmptyOp::create(rewriter, loc, tensorType,
+ generateOp.getDynamicExtents());
// Create linalg.generic.
SmallVector<utils::IteratorType> iteratorTypes(tensorType.getRank(),
utils::IteratorType::parallel);
SmallVector<AffineMap> indexingMaps(
1, rewriter.getMultiDimIdentityMap(tensorType.getRank()));
- auto genericOp = rewriter.create<linalg::GenericOp>(
- loc, tensorType, /*inputs=*/ValueRange(),
+ auto genericOp = linalg::GenericOp::create(
+ rewriter, loc, tensorType, /*inputs=*/ValueRange(),
/*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/
indexingMaps, iteratorTypes);
Block *body = rewriter.createBlock(&genericOp->getRegion(0), {},
@@ -411,7 +412,7 @@ mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter,
rewriter.setInsertionPointToStart(body);
SmallVector<Value> bbArgReplacements;
for (int64_t i = 0; i < tensorType.getRank(); ++i)
- bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i));
+ bbArgReplacements.push_back(linalg::IndexOp::create(rewriter, loc, i));
rewriter.mergeBlocks(&generateOp.getBody().front(), body, bbArgReplacements);
// Update terminator.
@@ -450,13 +451,13 @@ mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter,
llvm::all_of(padOp.getMixedHighPad(), isZeroInteger)) {
using bufferization::AllocTensorOp;
Value allocated =
- rewriter.create<AllocTensorOp>(loc, resultType, dynamicSizes);
+ AllocTensorOp::create(rewriter, loc, resultType, dynamicSizes);
auto copyOp = rewriter.replaceOpWithNewOp<linalg::CopyOp>(
padOp, padOp.getSource(), allocated);
return copyOp.getOperation();
}
- Value empty = rewriter.create<EmptyOp>(loc, resultType, dynamicSizes);
+ Value empty = EmptyOp::create(rewriter, loc, resultType, dynamicSizes);
// Create linalg.fill or linalg.generic.
Operation *fillOp = movePaddingToFillOrGenericOp(rewriter, loc, padOp, empty);
rewriter.setInsertionPointAfter(fillOp);
@@ -567,8 +568,8 @@ Value linalg::bufferizeToAllocation(
createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options);
}
rewriter.modifyOpInPlace(op, [&]() {
- auto toTensorOp = rewriter.create<ToTensorOp>(
- op->getLoc(), operand->get().getType(), alloc);
+ auto toTensorOp = ToTensorOp::create(rewriter, op->getLoc(),
+ operand->get().getType(), alloc);
operand->set(toTensorOp);
if (options.bufferizeDestinationOnly) {
rewriter.modifyOpInPlace(toTensorOp, [&]() {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 7057490..0a9c176 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -287,8 +287,8 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
auto empty = linalg::PackOp::createDestinationTensor(
b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
- auto packedOperand = b.create<linalg::PackOp>(
- loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
+ auto packedOperand = linalg::PackOp::create(
+ b, loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
/*padding=*/std::nullopt, outerDimsPerm);
return std::make_tuple(packedOperand, indexingMap);
}
@@ -345,8 +345,9 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
indexingMaps.push_back(packedOutIndexingMap);
- auto newGenericOp = rewriter.create<linalg::GenericOp>(
- loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes,
+ auto newGenericOp = linalg::GenericOp::create(
+ rewriter, loc, dest.getType(), inputOperands, dest, indexingMaps,
+ iterTypes,
/*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
newGenericOp.getRegion().begin());
@@ -457,9 +458,9 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
if (!packOpDest.hasOneUse())
return failure();
if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
- packOpDest = rewriter.create<tensor::EmptyOp>(
- genericOp->getLoc(), emptyOp.getMixedSizes(),
- emptyOp.getType().getElementType());
+ packOpDest = tensor::EmptyOp::create(rewriter, genericOp->getLoc(),
+ emptyOp.getMixedSizes(),
+ emptyOp.getType().getElementType());
} else {
DominanceInfo dom(genericOp);
if (!dom.properlyDominates(packOpDest, genericOp))
@@ -562,8 +563,8 @@ public:
auto empty = linalg::PackOp::createDestinationTensor(
rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos,
outerDimsPerm);
- auto sourcePack = rewriter.create<linalg::PackOp>(
- loc, padOp.getSource(), empty, innerDimsPos, mixedTiles,
+ auto sourcePack = linalg::PackOp::create(
+ rewriter, loc, padOp.getSource(), empty, innerDimsPos, mixedTiles,
/*padding=*/std::nullopt, outerDimsPerm);
// If we have `outer_dims_perms` we need to adjust the padded dimensions.
@@ -579,17 +580,18 @@ public:
lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
- auto newPadOp = rewriter.create<tensor::PadOp>(
- loc, /*result=*/Type(), sourcePack, lowPad, highPad, paddingVal,
- padOp.getNofold());
+ auto newPadOp =
+ tensor::PadOp::create(rewriter, loc, /*result=*/Type(), sourcePack,
+ lowPad, highPad, paddingVal, padOp.getNofold());
// If the pad has more than one user, create an unpack on the new pad to
// replace the other uses.
if (!padOp->hasOneUse()) {
auto unpackEmpty = linalg::UnPackOp::createDestinationTensor(
rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm);
- Value unpackedPad = rewriter.create<linalg::UnPackOp>(
- loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles, outerDimsPerm);
+ Value unpackedPad =
+ linalg::UnPackOp::create(rewriter, loc, newPadOp, unpackEmpty,
+ innerDimsPos, mixedTiles, outerDimsPerm);
rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack);
}
@@ -719,9 +721,10 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
auto emptyOp = linalg::PackOp::createDestinationTensor(
rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
projectedInnerDimsPos, newOuterDimsPerm);
- auto newPackOp = rewriter.create<linalg::PackOp>(
- packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos,
- packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
+ auto newPackOp = linalg::PackOp::create(
+ rewriter, packOp.getLoc(), collapseOp.getSrc(), emptyOp,
+ projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(),
+ newOuterDimsPerm);
SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
// First apply the permutation on the reassociations of the outer dims.
@@ -735,8 +738,9 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
nextPos += 1;
}
- auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
- collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices);
+ auto newCollapseOp = tensor::CollapseShapeOp::create(
+ rewriter, collapseOp.getLoc(), packOp.getType(), newPackOp,
+ newReassocIndices);
rewriter.replaceOp(packOp, newCollapseOp);
return success();
@@ -853,13 +857,14 @@ bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
Value destTensor = linalg::PackOp::createDestinationTensor(
rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
- Value packedVal = rewriter.create<linalg::PackOp>(
- packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos,
- packOp.getMixedTiles(), packOp.getPaddingValue(),
+ Value packedVal = linalg::PackOp::create(
+ rewriter, packOp.getLoc(), expandOp.getSrc(), destTensor,
+ projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(),
/*outerDimsPerm=*/SmallVector<int64_t>{});
- Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
- packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand);
+ Value newExpandOp = tensor::ExpandShapeOp::create(rewriter, packOp.getLoc(),
+ packOp.getDestType(),
+ packedVal, *reassocExpand);
rewriter.replaceOp(packOp, newExpandOp);
return success();
@@ -972,15 +977,15 @@ static LogicalResult pushDownUnPackOpThroughExpandShape(
RankedTensorType newExpandType = linalg::PackOp::inferPackedType(
expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
- auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
- expandOp.getLoc(), newExpandType, unPackOp.getSource(),
- newReassocIndices);
+ auto newExpandOp =
+ tensor::ExpandShapeOp::create(rewriter, expandOp.getLoc(), newExpandType,
+ unPackOp.getSource(), newReassocIndices);
auto emptyOp = linalg::UnPackOp::createDestinationTensor(
rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
projectedInnerDimsPos, newOuterDimsPerm);
- auto newUnPackOp = rewriter.create<linalg::UnPackOp>(
- unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
+ auto newUnPackOp = linalg::UnPackOp::create(
+ rewriter, unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
rewriter.replaceOp(expandOp, newUnPackOp);
@@ -1138,10 +1143,9 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
// Insert an unPackOp right after the packed generic.
Value unPackOpRes =
- rewriter
- .create<linalg::UnPackOp>(genericOp.getLoc(), newResult,
- destPack.getSource(), innerDimsPos,
- mixedTiles, outerDimsPerm)
+ linalg::UnPackOp::create(rewriter, genericOp.getLoc(), newResult,
+ destPack.getSource(), innerDimsPos, mixedTiles,
+ outerDimsPerm)
.getResult();
return std::make_tuple(newGenericOp, unPackOpRes);
@@ -1212,17 +1216,17 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
- auto newPadOp = rewriter.create<tensor::PadOp>(
- loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad,
- paddingVal, padOp.getNofold());
+ auto newPadOp = tensor::PadOp::create(rewriter, loc, /*result=*/Type(),
+ unpackOp.getSource(), lowPad, highPad,
+ paddingVal, padOp.getNofold());
// Inject the linalg.unpack right after the packed padOp.
- Value outputUnPack = rewriter.create<tensor::EmptyOp>(
- loc, padOp.getResultType().getShape(),
- padOp.getResultType().getElementType());
+ Value outputUnPack =
+ tensor::EmptyOp::create(rewriter, loc, padOp.getResultType().getShape(),
+ padOp.getResultType().getElementType());
- Value replacement = rewriter.create<linalg::UnPackOp>(
- loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
+ Value replacement = linalg::UnPackOp::create(
+ rewriter, loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
unpackOp.getMixedTiles(), outerDimsPerm);
rewriter.replaceOp(padOp, replacement);
return success();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
index 692bf595..b7da20c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
@@ -198,10 +198,10 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
transposedShape[i] = inputRTType.getShape()[permutation[i]];
Value emptyTensor =
- rewriter.create<tensor::EmptyOp>(loc, transposedShape, elType);
+ tensor::EmptyOp::create(rewriter, loc, transposedShape, elType);
- auto transposeOp = rewriter.create<TransposeOp>(loc, newInitValues[i],
- emptyTensor, permutation);
+ auto transposeOp = TransposeOp::create(rewriter, loc, newInitValues[i],
+ emptyTensor, permutation);
newInitValues[i] = transposeOp->getResult(0);
isChanged = true;
}
@@ -209,11 +209,11 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
// Does it require broadcast?
if (!broadcastedDims.empty()) {
assert(broadcastedDims.size() && "should have non size broadcast");
- Value emptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, outputShape, inputRTType.getElementType());
+ Value emptyTensor = tensor::EmptyOp::create(rewriter, loc, outputShape,
+ inputRTType.getElementType());
- auto broadcastOp = rewriter.create<linalg::BroadcastOp>(
- loc, newInitValues[i], emptyTensor, broadcastedDims);
+ auto broadcastOp = linalg::BroadcastOp::create(
+ rewriter, loc, newInitValues[i], emptyTensor, broadcastedDims);
newInitValues[i] = broadcastOp->getResult(0);
isChanged = true;
@@ -227,7 +227,8 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
SmallVector<Value> operands = op->getOperands();
ValueRange operandsRef(operands);
- auto newOp = rewriter.create<linalg::GenericOp>(
+ auto newOp = linalg::GenericOp::create(
+ rewriter,
/*location=*/op.getLoc(),
/*resultTensorTypes=*/op->getResultTypes(),
/*inputs=*/newInitValues,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
index 1419175..c92a27f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
@@ -133,13 +133,13 @@ static Value getZero(OpBuilder &b, Location loc, Type elementType) {
assert(elementType.isIntOrIndexOrFloat() &&
"expected scalar type while computing zero value");
if (isa<IntegerType>(elementType))
- return b.create<arith::ConstantIntOp>(loc, elementType, 0);
+ return arith::ConstantIntOp::create(b, loc, elementType, 0);
if (elementType.isIndex())
- return b.create<arith::ConstantIndexOp>(loc, 0);
+ return arith::ConstantIndexOp::create(b, loc, 0);
// Assume float.
auto floatType = cast<FloatType>(elementType);
- return b.create<arith::ConstantFloatOp>(
- loc, floatType, APFloat::getZero(floatType.getFloatSemantics()));
+ return arith::ConstantFloatOp::create(
+ b, loc, floatType, APFloat::getZero(floatType.getFloatSemantics()));
}
GenericOp
@@ -188,8 +188,8 @@ DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
// Fall back path, use an `init_tensor` and identity indexing map.
AffineMap indexingMap = rewriter.getMultiDimIdentityMap(domain.size());
- Value emptyTensor =
- rewriter.create<tensor::EmptyOp>(loc, domain, scalarOpResult.getType());
+ Value emptyTensor = tensor::EmptyOp::create(rewriter, loc, domain,
+ scalarOpResult.getType());
newInitValues.push_back(emptyTensor);
newResultTypes.push_back(emptyTensor.getType());
peeledGenericOpIndexingMaps.push_back(indexingMap);
@@ -202,10 +202,10 @@ DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
resultTypes.append(newResultTypes.begin(), newResultTypes.end());
auto indexingMapAttr =
rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps);
- return rewriter.create<GenericOp>(
- loc, resultTypes, genericOp.getInputs(), outsOperands, indexingMapAttr,
- genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
- [](OpBuilder, Location, ValueRange) {});
+ return GenericOp::create(
+ rewriter, loc, resultTypes, genericOp.getInputs(), outsOperands,
+ indexingMapAttr, genericOp.getIteratorTypes(), /*doc=*/nullptr,
+ /*libraryCall=*/nullptr, [](OpBuilder, Location, ValueRange) {});
}
GenericOp
@@ -239,8 +239,8 @@ DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
indexingMaps.push_back(genericOp.getMatchingIndexingMap(&outOperand));
auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps);
- return rewriter.create<GenericOp>(
- genericOp->getLoc(), genericOp->getResultTypes(),
+ return GenericOp::create(
+ rewriter, genericOp->getLoc(), genericOp->getResultTypes(),
residualGenericOpOperands, genericOp.getOutputs(), indexingMapAttr,
genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr,
[](OpBuilder, Location, ValueRange) {});
@@ -324,7 +324,7 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
yieldedVals.append(llvm::to_vector(
llvm::map_range(peeledScalarOperation->getResults(),
[](OpResult opr) -> Value { return opr; })));
- rewriter.create<YieldOp>(genericOp.getLoc(), yieldedVals);
+ YieldOp::create(rewriter, genericOp.getLoc(), yieldedVals);
}
/// In the split operations, replace block arguments uses that refer to
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index ef24eb8..8309054 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -34,8 +34,8 @@ static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
// A detensored value is converted back by creating a new tensor from its
// element(s).
- return builder.create<tensor::FromElementsOp>(
- loc, RankedTensorType::get({}, inputType), inputs[0]);
+ return tensor::FromElementsOp::create(
+ builder, loc, RankedTensorType::get({}, inputType), inputs[0]);
}
namespace {
@@ -147,7 +147,7 @@ public:
// A tensor value is detensoried by extracting its element(s).
addTargetMaterialization([](OpBuilder &builder, Type type,
ValueRange inputs, Location loc) -> Value {
- return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
+ return tensor::ExtractOp::create(builder, loc, inputs[0], ValueRange{});
});
addSourceMaterialization(sourceMaterializationCallback);
@@ -480,8 +480,8 @@ struct LinalgDetensorize
Block *postEntryBlock =
rewriter.splitBlock(entryBlock, entryBlock->begin());
rewriter.setInsertionPointToStart(entryBlock);
- auto branch =
- rewriter.create<cf::BranchOp>(rewriter.getUnknownLoc(), postEntryBlock);
+ auto branch = cf::BranchOp::create(rewriter, rewriter.getUnknownLoc(),
+ postEntryBlock);
if (aggressiveMode.getValue()) {
AggressiveDetensoringModel costModel;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index e0062d1..bf66ed0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -118,16 +118,17 @@ struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfterValue(op->get());
auto elemType = cast<ShapedType>(op->get().getType()).getElementType();
- auto empty = rewriter.create<tensor::EmptyOp>(
- loc, tensor::getMixedSizes(rewriter, loc, op->get()), elemType);
+ auto empty = tensor::EmptyOp::create(
+ rewriter, loc, tensor::getMixedSizes(rewriter, loc, op->get()),
+ elemType);
unsigned start = genericOp.getDpsInits().getBeginOperandIndex();
newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
}
- auto newOp = rewriter.create<GenericOp>(
- loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands,
- newIndexingMaps, genericOp.getIteratorTypesArray(),
+ auto newOp = GenericOp::create(
+ rewriter, loc, genericOp.getResultTypes(), newInputOperands,
+ newOutputOperands, newIndexingMaps, genericOp.getIteratorTypesArray(),
/*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
OpBuilder::InsertionGuard guard(rewriter);
@@ -266,8 +267,8 @@ expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
assert(rankReductionStrategy ==
ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
"unknown rank reduction strategy");
- return rewriter
- .create<tensor::ExpandShapeOp>(loc, origResultType, result, reassociation)
+ return tensor::ExpandShapeOp::create(rewriter, loc, origResultType, result,
+ reassociation)
.getResult();
}
@@ -295,8 +296,8 @@ static Value collapseValue(
MemRefLayoutAttrInterface layout;
auto targetType = MemRefType::get(targetShape, memrefType.getElementType(),
layout, memrefType.getMemorySpace());
- return rewriter.create<memref::CollapseShapeOp>(loc, targetType, operand,
- reassociation);
+ return memref::CollapseShapeOp::create(rewriter, loc, targetType, operand,
+ reassociation);
}
if (auto tensorType = dyn_cast<RankedTensorType>(operand.getType())) {
if (rankReductionStrategy ==
@@ -314,8 +315,8 @@ static Value collapseValue(
"unknown rank reduction strategy");
auto targetType =
RankedTensorType::get(targetShape, tensorType.getElementType());
- return rewriter.create<tensor::CollapseShapeOp>(loc, targetType, operand,
- reassociation);
+ return tensor::CollapseShapeOp::create(rewriter, loc, targetType, operand,
+ reassociation);
}
llvm_unreachable("unsupported operand type");
}
@@ -331,14 +332,14 @@ struct UnitExtentReplacementInfo {
SmallVector<int64_t> targetShape;
};
static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
- MLIRContext *context, GenericOp genericOp, OpOperand *opOperand,
+ MLIRContext *context, IndexingMapOpInterface op, OpOperand *opOperand,
llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
ArrayRef<AffineExpr> dimReplacements) {
UnitExtentReplacementInfo info;
ReassociationIndices reassociationGroup;
SmallVector<AffineExpr> newIndexExprs;
- AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
- ArrayRef<int64_t> operandShape = genericOp.getShape(opOperand);
+ AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
+ SmallVector<int64_t> operandShape = op.getStaticOperandShape(opOperand);
ArrayRef<AffineExpr> exprs = indexingMap.getResults();
auto isUnitDim = [&](unsigned dim) {
@@ -380,9 +381,16 @@ static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
}
FailureOr<DropUnitDimsResult>
-linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
+linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
+ const DroppedUnitDimsBuilder &droppedUnitDimsBuilder,
const ControlDropUnitDims &options) {
- SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+ auto dpsOp = dyn_cast<DestinationStyleOpInterface>(op.getOperation());
+ if (!dpsOp) {
+ return rewriter.notifyMatchFailure(
+ op, "op should implement DestinationStyleOpInterface");
+ }
+
+ SmallVector<AffineMap> indexingMaps = op.getIndexingMapsArray();
if (indexingMaps.empty())
return failure();
@@ -392,19 +400,19 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
AffineMap invertedMap =
inversePermutation(concatAffineMaps(indexingMaps, rewriter.getContext()));
if (!invertedMap) {
- return rewriter.notifyMatchFailure(genericOp,
+ return rewriter.notifyMatchFailure(op,
"invalid indexing maps for operation");
}
SmallVector<int64_t> allShapesSizes;
- for (OpOperand &opOperand : genericOp->getOpOperands())
- llvm::append_range(allShapesSizes, genericOp.getShape(&opOperand));
+ for (OpOperand &opOperand : op->getOpOperands())
+ llvm::append_range(allShapesSizes, op.getStaticOperandShape(&opOperand));
// 1a. Get the allowed list of dimensions to drop from the `options`.
- SmallVector<unsigned> allowedUnitDims = options.controlFn(genericOp);
+ SmallVector<unsigned> allowedUnitDims = options.controlFn(op);
if (allowedUnitDims.empty()) {
return rewriter.notifyMatchFailure(
- genericOp, "control function returns no allowed unit dims to prune");
+ op, "control function returns no allowed unit dims to prune");
}
llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
allowedUnitDims.end());
@@ -417,19 +425,16 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
}
}
- // 2. Compute the iterator types of the modified op by dropping the one-trip
+ // 2. Compute the new loops of the modified op by dropping the one-trip
// count loops.
- SmallVector<utils::IteratorType> newIteratorTypes;
llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
SmallVector<AffineExpr> dimReplacements;
unsigned newDims = 0;
- for (auto [index, attr] :
- llvm::enumerate(genericOp.getIteratorTypesArray())) {
+ for (auto index : llvm::seq<int64_t>(op.getStaticLoopRanges().size())) {
if (unitDims.count(index)) {
dimReplacements.push_back(
getAffineConstantExpr(0, rewriter.getContext()));
} else {
- newIteratorTypes.push_back(attr);
oldDimToNewDimMap[index] = newDims;
dimReplacements.push_back(
getAffineDimExpr(newDims, rewriter.getContext()));
@@ -462,9 +467,9 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
}
return false;
};
- for (OpOperand &opOperand : genericOp->getOpOperands()) {
- auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
- ArrayRef<int64_t> shape = genericOp.getShape(&opOperand);
+ for (OpOperand &opOperand : op->getOpOperands()) {
+ auto indexingMap = op.getMatchingIndexingMap(&opOperand);
+ SmallVector<int64_t> shape = op.getStaticOperandShape(&opOperand);
if (!hasCollapsibleType(opOperand)) {
AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols(
dimReplacements, ArrayRef<AffineExpr>{}, oldDimToNewDimMap.size(), 0);
@@ -474,9 +479,9 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
reassociations.push_back({});
continue;
}
- auto replacementInfo = dropUnitExtentFromOperandMetadata(
- rewriter.getContext(), genericOp, &opOperand, oldDimToNewDimMap,
- dimReplacements);
+ auto replacementInfo =
+ dropUnitExtentFromOperandMetadata(rewriter.getContext(), op, &opOperand,
+ oldDimToNewDimMap, dimReplacements);
reassociations.push_back(replacementInfo.reassociation);
newIndexingMaps.push_back(replacementInfo.indexMap);
targetShapes.push_back(replacementInfo.targetShape);
@@ -491,13 +496,13 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
concatAffineMaps(newIndexingMaps, rewriter.getContext())))
return failure();
- Location loc = genericOp.getLoc();
+ Location loc = op.getLoc();
// 4. For each of the operands, collapse the operand to convert
// from original shape to shape in the modified operation if needed,
// either through use of reshapes or rank-reducing slices as
// specified in `options`.
SmallVector<Value> newOperands;
- for (OpOperand &opOperand : genericOp->getOpOperands()) {
+ for (OpOperand &opOperand : op->getOpOperands()) {
int64_t idx = opOperand.getOperandNumber();
if (!collapsed[idx]) {
newOperands.push_back(opOperand.get());
@@ -508,31 +513,15 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
options.rankReductionStrategy));
}
- // 5. Create the `linalg.generic` operation with the new operands,
- // indexing maps, iterator types and result types.
- ArrayRef<Value> newInputs =
- ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
- ArrayRef<Value> newOutputs =
- ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
- SmallVector<Type> resultTypes;
- resultTypes.reserve(genericOp.getNumResults());
- for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
- resultTypes.push_back(newOutputs[i].getType());
- GenericOp replacementOp =
- rewriter.create<GenericOp>(loc, resultTypes, newInputs, newOutputs,
- newIndexingMaps, newIteratorTypes);
- rewriter.inlineRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
- replacementOp.getRegion().begin());
- // 5a. Replace `linalg.index` operations that refer to the dropped unit
- // dimensions.
- replaceUnitDimIndexOps(replacementOp, unitDims, rewriter);
+ IndexingMapOpInterface replacementOp = droppedUnitDimsBuilder(
+ loc, rewriter, op, newOperands, newIndexingMaps, unitDims);
// 6. If any result type changes, insert a reshape/slice to convert from the
// original type to the new type.
SmallVector<Value> resultReplacements;
- for (auto [index, result] : llvm::enumerate(replacementOp.getResults())) {
- unsigned opOperandIndex = index + replacementOp.getNumDpsInputs();
- Value origDest = genericOp.getDpsInitOperand(index)->get();
+ for (auto [index, result] : llvm::enumerate(replacementOp->getResults())) {
+ unsigned opOperandIndex = index + dpsOp.getNumDpsInputs();
+ Value origDest = dpsOp.getDpsInitOperand(index)->get();
if (!collapsed[opOperandIndex]) {
resultReplacements.push_back(result);
continue;
@@ -546,6 +535,51 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
return DropUnitDimsResult{replacementOp, resultReplacements};
}
+FailureOr<DropUnitDimsResult>
+linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
+ const ControlDropUnitDims &options) {
+
+ DroppedUnitDimsBuilder build =
+ [](Location loc, OpBuilder &b, IndexingMapOpInterface op,
+ ArrayRef<Value> newOperands, ArrayRef<AffineMap> newIndexingMaps,
+ const llvm::SmallDenseSet<unsigned> &droppedDims)
+ -> IndexingMapOpInterface {
+ auto genericOp = cast<GenericOp>(op);
+ // Compute the iterator types of the modified op by dropping the one-trip
+ // count loops.
+ SmallVector<utils::IteratorType> newIteratorTypes;
+ for (auto [index, attr] :
+ llvm::enumerate(genericOp.getIteratorTypesArray())) {
+ if (!droppedDims.count(index))
+ newIteratorTypes.push_back(attr);
+ }
+
+ // Create the `linalg.generic` operation with the new operands,
+ // indexing maps, iterator types and result types.
+ ArrayRef<Value> newInputs =
+ ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
+ ArrayRef<Value> newOutputs =
+ ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
+ SmallVector<Type> resultTypes;
+ resultTypes.reserve(genericOp.getNumResults());
+ for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
+ resultTypes.push_back(newOutputs[i].getType());
+ GenericOp replacementOp =
+ GenericOp::create(b, loc, resultTypes, newInputs, newOutputs,
+ newIndexingMaps, newIteratorTypes);
+ b.cloneRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
+ replacementOp.getRegion().begin());
+ // 5a. Replace `linalg.index` operations that refer to the dropped unit
+ // dimensions.
+ IRRewriter rewriter(b);
+ replaceUnitDimIndexOps(replacementOp, droppedDims, rewriter);
+
+ return replacementOp;
+ };
+
+ return dropUnitDims(rewriter, genericOp, build, options);
+}
+
namespace {
struct DropUnitDims : public OpRewritePattern<GenericOp> {
DropUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
@@ -603,6 +637,7 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
}
ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape();
+ ArrayRef<int64_t> resultShape = padOp.getResultType().getShape();
int64_t padRank = sourceShape.size();
auto isStaticZero = [](OpFoldResult f) {
@@ -613,16 +648,18 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
allowedUnitDims.end());
llvm::SmallDenseSet<unsigned> unitDims;
SmallVector<int64_t> newShape;
+ SmallVector<int64_t> newResultShape;
SmallVector<OpFoldResult> newLowPad;
SmallVector<OpFoldResult> newHighPad;
- for (const auto [dim, size, low, high] :
- zip_equal(llvm::seq(static_cast<int64_t>(0), padRank), sourceShape,
- padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
+ for (const auto [dim, size, outSize, low, high] : zip_equal(
+ llvm::seq(static_cast<int64_t>(0), padRank), sourceShape,
+ resultShape, padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) &&
isStaticZero(high)) {
unitDims.insert(dim);
} else {
newShape.push_back(size);
+ newResultShape.push_back(outSize);
newLowPad.push_back(low);
newHighPad.push_back(high);
}
@@ -652,8 +689,10 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape,
reassociationMap, options.rankReductionStrategy);
+ auto newResultType = RankedTensorType::get(
+ newResultShape, padOp.getResultType().getElementType());
auto newPadOp = rewriter.create<tensor::PadOp>(
- padOp.getLoc(), /*result=*/Type(), collapsedSource, newLowPad,
+ padOp.getLoc(), /*result=*/newResultType, collapsedSource, newLowPad,
newHighPad, paddingVal, padOp.getNofold());
Value dest = padOp.getResult();
@@ -670,9 +709,8 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
expandedSizes.push_back(tensor::getMixedSize(
rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims));
}
- dest = rewriter.create<tensor::EmptyOp>(
- padOp.getLoc(), expandedSizes,
- padOp.getResultType().getElementType());
+ dest = tensor::EmptyOp::create(rewriter, padOp.getLoc(), expandedSizes,
+ padOp.getResultType().getElementType());
}
Value expandedValue =
@@ -713,8 +751,9 @@ struct RankReducedExtractSliceOp
strides));
Location loc = sliceOp.getLoc();
- Value newSlice = rewriter.create<tensor::ExtractSliceOp>(
- loc, rankReducedType, sliceOp.getSource(), offsets, sizes, strides);
+ Value newSlice = tensor::ExtractSliceOp::create(
+ rewriter, loc, rankReducedType, sliceOp.getSource(), offsets, sizes,
+ strides);
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
sliceOp, resultType, newSlice, *reassociation);
return success();
@@ -747,8 +786,8 @@ struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
// parallel case.
if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
- reshapedSource = rewriter.create<tensor::CollapseShapeOp>(
- loc, insertSliceOp.getSource(), *reassociation);
+ reshapedSource = tensor::CollapseShapeOp::create(
+ rewriter, loc, insertSliceOp.getSource(), *reassociation);
}
rewriter.replaceOpWithNewOp<InsertOpTy>(
insertSliceOp, reshapedSource, insertSliceOp.getDest(),
@@ -898,8 +937,8 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
/// Expand result tensor.
Value expandResult(PatternRewriter &rewriter, Value result,
RankedTensorType expandedType, int64_t dim) const {
- return rewriter.create<tensor::ExpandShapeOp>(
- result.getLoc(), expandedType, result,
+ return tensor::ExpandShapeOp::create(
+ rewriter, result.getLoc(), expandedType, result,
getReassociationForReshapeAtDim(expandedType.getRank(), dim));
}
@@ -934,9 +973,9 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
SmallVector<Type, 1> collapsedResultTy;
if (isa<RankedTensorType>(collapsedInit.getType()))
collapsedResultTy.push_back(collapsedInit.getType());
- auto collapsedOp = rewriter.create<ToOpTy>(
- loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
- ValueRange{collapsedInit});
+ auto collapsedOp = ToOpTy::create(rewriter, loc, collapsedResultTy,
+ ValueRange{collapsedLhs, collapsedRhs},
+ ValueRange{collapsedInit});
for (auto attr : contractionOp->getAttrs()) {
if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName ||
attr.getName() == "indexing_maps")
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 8a5c138..3bd763e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -237,12 +237,12 @@ static void generateFusedElementwiseOpRegion(
fusedIndices.reserve(numFusedOpLoops);
llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
std::back_inserter(fusedIndices), [&](uint64_t dim) {
- return rewriter.create<IndexOp>(producer.getLoc(), dim);
+ return IndexOp::create(rewriter, producer.getLoc(), dim);
});
for (IndexOp indexOp :
llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
- Value newIndex = rewriter.create<affine::AffineApplyOp>(
- producer.getLoc(),
+ Value newIndex = affine::AffineApplyOp::create(
+ rewriter, producer.getLoc(),
consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices);
mapper.map(indexOp.getResult(), newIndex);
}
@@ -328,7 +328,7 @@ static void generateFusedElementwiseOpRegion(
}
for (auto consumerYieldVal : consumerYieldOp.getOperands())
fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
- rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
+ YieldOp::create(rewriter, fusedOp.getLoc(), fusedYieldValues);
// Sanity checks.
assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
@@ -417,8 +417,8 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
}
// Generate the fused op.
- auto fusedOp = rewriter.create<GenericOp>(
- consumer.getLoc(), fusedResultTypes, fusedInputOperands,
+ auto fusedOp = GenericOp::create(
+ rewriter, consumer.getLoc(), fusedResultTypes, fusedInputOperands,
fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
consumer.getIteratorTypes(),
/*doc=*/nullptr,
@@ -751,9 +751,9 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
expandedIndices.reserve(expandedDims.size() - 1);
llvm::transform(
expandedDims.drop_front(), std::back_inserter(expandedIndices),
- [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
+ [&](int64_t dim) { return IndexOp::create(rewriter, loc, dim); });
OpFoldResult newIndex =
- rewriter.create<IndexOp>(loc, expandedDims.front()).getResult();
+ IndexOp::create(rewriter, loc, expandedDims.front()).getResult();
for (auto [expandedShape, expandedIndex] :
llvm::zip(expandedDimsShape, expandedIndices)) {
AffineExpr idx, acc, shape;
@@ -797,8 +797,8 @@ static Operation *createExpandedTransposeOp(PatternRewriter &rewriter,
newPerm.push_back(dim);
}
}
- return rewriter.create<TransposeOp>(transposeOp.getLoc(), expandedInput,
- output, invertPermutationVector(newPerm));
+ return TransposeOp::create(rewriter, transposeOp.getLoc(), expandedInput,
+ output, invertPermutationVector(newPerm));
}
// Create an expanded generic op.
@@ -814,9 +814,9 @@ static Operation *createExpandedGenericOp(
for (auto j : expansionInfo.getExpandedDims(i))
iteratorTypes[j] = type;
- Operation *fused = rewriter.create<GenericOp>(
- linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs,
- expandedOpIndexingMaps, iteratorTypes);
+ Operation *fused = GenericOp::create(rewriter, linalgOp.getLoc(), resultTypes,
+ expandedOpOperands, outputs,
+ expandedOpIndexingMaps, iteratorTypes);
Region &fusedRegion = fused->getRegion(0);
Region &originalRegion = linalgOp->getRegion(0);
@@ -934,8 +934,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
reassociation,
/*isExpandingReshape=*/true)))
return std::nullopt;
- expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
- loc, expandedOperandType, opOperand->get(), reassociation,
+ expandedOpOperands.push_back(tensor::ExpandShapeOp::create(
+ rewriter, loc, expandedOperandType, opOperand->get(), reassociation,
expandedOperandShape));
continue;
}
@@ -962,8 +962,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
reassociation,
/*isExpandingReshape=*/true)))
return std::nullopt;
- outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
- loc, expandedOutputType, opOperand.get(), reassociation,
+ outputs.push_back(tensor::ExpandShapeOp::create(
+ rewriter, loc, expandedOutputType, opOperand.get(), reassociation,
expandedOutputShape));
} else {
outputs.push_back(opOperand.get());
@@ -985,8 +985,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
linalgOp.getMatchingIndexingMap(
linalgOp.getDpsInitOperand(resultNumber)),
expansionInfo);
- resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
- linalgOp.getLoc(), opResult.getType(),
+ resultVals.push_back(tensor::CollapseShapeOp::create(
+ rewriter, linalgOp.getLoc(), opResult.getType(),
fusedOp->getResult(resultNumber), reassociation));
} else {
resultVals.push_back(fusedOp->getResult(resultNumber));
@@ -1087,8 +1087,8 @@ public:
Location loc = padOp->getLoc();
RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
- auto newPadOp = rewriter.create<tensor::PadOp>(
- loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
+ auto newPadOp = tensor::PadOp::create(
+ rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
padOp.getConstantPaddingValue(), padOp.getNofold());
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
@@ -1572,12 +1572,12 @@ static Value getCollapsedOpOperand(Location loc, LinalgOp op,
// Insert a reshape to collapse the dimensions.
if (isa<MemRefType>(operand.getType())) {
- return builder
- .create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
+ return memref::CollapseShapeOp::create(builder, loc, operand,
+ operandReassociation)
.getResult();
}
- return builder
- .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
+ return tensor::CollapseShapeOp::create(builder, loc, operand,
+ operandReassociation)
.getResult();
}
@@ -1604,7 +1604,7 @@ static void generateCollapsedIndexingRegion(
enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
ReassociationIndicesRef foldedDimsRef(foldedDims.value());
Value newIndexVal =
- rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
+ linalg::IndexOp::create(rewriter, loc, foldedDims.index());
for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
Value loopDim =
getValueOrCreateConstantIndexOp(rewriter, loc, loopRange[dim]);
@@ -1688,9 +1688,10 @@ GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes(
origOp.getIteratorTypesArray(), collapsingInfo));
- GenericOp collapsedOp = rewriter.create<linalg::GenericOp>(
- origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps,
- iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
+ GenericOp collapsedOp = linalg::GenericOp::create(
+ rewriter, origOp.getLoc(), resultTypes, inputOperands, outputOperands,
+ indexingMaps, iteratorTypes,
+ [](OpBuilder &builder, Location loc, ValueRange args) {});
Block *origOpBlock = &origOp->getRegion(0).front();
Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
@@ -1795,12 +1796,12 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
if (isa<MemRefType>(collapsedOpResult.getType())) {
MemRefType expandShapeResultType = MemRefType::get(
originalResultType.getShape(), originalResultType.getElementType());
- result = rewriter.create<memref::ExpandShapeOp>(
- loc, expandShapeResultType, collapsedOpResult, reassociation,
- resultShape);
+ result = memref::ExpandShapeOp::create(
+ rewriter, loc, expandShapeResultType, collapsedOpResult,
+ reassociation, resultShape);
} else {
- result = rewriter.create<tensor::ExpandShapeOp>(
- loc, originalResultType, collapsedOpResult, reassociation,
+ result = tensor::ExpandShapeOp::create(
+ rewriter, loc, originalResultType, collapsedOpResult, reassociation,
resultShape);
}
results.push_back(result);
@@ -1983,8 +1984,8 @@ public:
RankedTensorType collapsedPaddedType =
paddedType.clone(collapsedPaddedShape);
- auto newPadOp = rewriter.create<tensor::PadOp>(
- loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
+ auto newPadOp = tensor::PadOp::create(
+ rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
padOp.getConstantPaddingValue(), padOp.getNofold());
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
@@ -2118,17 +2119,18 @@ public:
// Create a constant scalar value from the splat constant.
Value scalarConstant =
- rewriter.create<arith::ConstantOp>(def->getLoc(), constantAttr);
+ arith::ConstantOp::create(rewriter, def->getLoc(), constantAttr);
SmallVector<Value> outputOperands = genericOp.getOutputs();
- auto fusedOp = rewriter.create<GenericOp>(
- rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
- /*inputs=*/fusedOperands,
- /*outputs=*/outputOperands,
- rewriter.getAffineMapArrayAttr(fusedIndexMaps),
- genericOp.getIteratorTypes(),
- /*doc=*/nullptr,
- /*library_call=*/nullptr);
+ auto fusedOp =
+ GenericOp::create(rewriter, rewriter.getFusedLoc(fusedLocs),
+ genericOp->getResultTypes(),
+ /*inputs=*/fusedOperands,
+ /*outputs=*/outputOperands,
+ rewriter.getAffineMapArrayAttr(fusedIndexMaps),
+ genericOp.getIteratorTypes(),
+ /*doc=*/nullptr,
+ /*library_call=*/nullptr);
// Map the block argument corresponding to the replaced argument with the
// scalar constant.
@@ -2184,8 +2186,8 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
modifiedOutput = true;
SmallVector<OpFoldResult> mixedSizes =
tensor::getMixedSizes(rewriter, loc, operandVal);
- Value emptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, mixedSizes, operandType.getElementType());
+ Value emptyTensor = tensor::EmptyOp::create(
+ rewriter, loc, mixedSizes, operandType.getElementType());
op->setOperand(opOperand.getOperandNumber(), emptyTensor);
}
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index c4af09c..c523153 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -64,8 +64,8 @@ getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) {
continue;
// Extract static / dynamic shape mix from the first operand.
- res.push_back(b.create<tensor::EmptyOp>(
- loc, tensor::getMixedSizes(b, loc, operands.front()),
+ res.push_back(tensor::EmptyOp::create(
+ b, loc, tensor::getMixedSizes(b, loc, operands.front()),
cast<RankedTensorType>(t).getElementType()));
}
return res;
@@ -104,7 +104,7 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
builder.create(loc, op->getName().getIdentifier(),
regionArgs.take_front(op->getNumOperands()),
resultTypes, op->getAttrs());
- builder.create<linalg::YieldOp>(loc, scalarOp->getResults());
+ linalg::YieldOp::create(builder, loc, scalarOp->getResults());
});
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
index d375878..9974ccd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
@@ -259,8 +259,8 @@ mlir::linalg::deduplicateOperandsAndRemoveDeadResults(
for (Value v : newOutputOperands)
if (isa<TensorType>(v.getType()))
newResultTypes.push_back(v.getType());
- auto newOp = rewriter.create<GenericOp>(
- loc, newResultTypes, newInputOperands, newOutputOperands,
+ auto newOp = GenericOp::create(
+ rewriter, loc, newResultTypes, newInputOperands, newOutputOperands,
rewriter.getAffineMapArrayAttr(newIndexingMaps),
genericOp.getIteratorTypes(), genericOp.getDocAttr(),
genericOp.getLibraryCallAttr(),
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp
index 44469bc..0ca8904 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp
@@ -72,14 +72,14 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
// Create the tensor of same size as output of the pad op.
RankedTensorType padResultType = padOp.getResultType();
auto resultSizes = resultShape[0];
- auto emptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, resultSizes, padResultType.getElementType());
+ auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, resultSizes,
+ padResultType.getElementType());
// Fill the tensor with the pad value.
// TODO: There is an option to fill only the boundaries. For now just
// filling the whole tensor.
- auto fillTensor =
- rewriter.create<linalg::FillOp>(loc, padValue, emptyTensor.getResult());
+ auto fillTensor = linalg::FillOp::create(rewriter, loc, padValue,
+ emptyTensor.getResult());
// Construct a slice of the fill result that is to be replaced with the
// result of the generic op. The low pad values are the offsets, the size of
@@ -93,15 +93,15 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
llvm::enumerate(cast<RankedTensorType>(source.getType()).getShape())) {
if (ShapedType::isDynamic(shape.value())) {
sizes.push_back(
- rewriter.create<tensor::DimOp>(loc, source, shape.index())
+ tensor::DimOp::create(rewriter, loc, source, shape.index())
.getResult());
} else {
sizes.push_back(rewriter.getIndexAttr(shape.value()));
}
}
SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
- auto slice = rewriter.create<tensor::ExtractSliceOp>(
- loc, fillTensor.getResult(0), offsets, sizes, strides);
+ auto slice = tensor::ExtractSliceOp::create(
+ rewriter, loc, fillTensor.getResult(0), offsets, sizes, strides);
// Clone the generic op.
auto clonedOp =
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 9bc7be2..41252c6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -277,7 +277,7 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
// mismatches. Insert a `tensor.cast` op to propagate the transformation
// invariant that types are compatible.
if (consumerType != def.getType())
- def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def);
+ def = tensor::CastOp::create(b, fusedProducer.getLoc(), consumerType, def);
consumerOpOperand.set(def);
return FusionInfo{cast<LinalgOp>(producerOpResult.getOwner()), fusedProducer};
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index 05f2157..3e31393 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -61,8 +61,9 @@ FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
// All named ops have a region attached that can be inlined.
assert(linalgOp->getNumRegions() == 1 &&
"expect named op to have one region attached");
- GenericOp genericOp = rewriter.create<GenericOp>(
- linalgOp.getLoc(), resultTypes, inputs, outputs, indexingMaps, iterators);
+ GenericOp genericOp =
+ GenericOp::create(rewriter, linalgOp.getLoc(), resultTypes, inputs,
+ outputs, indexingMaps, iterators);
rewriter.inlineRegionBefore(linalgOp->getRegion(0), genericOp.getRegion(),
genericOp.getRegion().begin());
rewriter.replaceOp(linalgOp, genericOp->getResults());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 94ed464..fd530f2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -591,8 +591,8 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl(
// Create a packing loop that takes `hoistedPackedTensor` as iteration
// argument.
- auto clonedForOp = rewriter.create<scf::ForOp>(
- loc, bvm.lookupOrDefault(forOp.getLowerBound()),
+ auto clonedForOp = scf::ForOp::create(
+ rewriter, loc, bvm.lookupOrDefault(forOp.getLowerBound()),
bvm.lookupOrDefault(forOp.getUpperBound()),
bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor);
@@ -640,11 +640,11 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl(
TransposeOp maybeTransposeOp;
Value paddedTensor = bvm.lookup(opToHoist.getResult());
if (!transposeVector.empty()) {
- Value outputTensor = rewriter.create<tensor::ExtractSliceOp>(
- loc, transposedTensorType, hoistedPackedTensor, offsets, sizes,
- strides);
- maybeTransposeOp = rewriter.create<linalg::TransposeOp>(
- loc, paddedTensor, outputTensor, transposeVector);
+ Value outputTensor = tensor::ExtractSliceOp::create(
+ rewriter, loc, transposedTensorType, hoistedPackedTensor, offsets,
+ sizes, strides);
+ maybeTransposeOp = linalg::TransposeOp::create(
+ rewriter, loc, paddedTensor, outputTensor, transposeVector);
paddedTensor = maybeTransposeOp.getResult()[0];
}
@@ -652,15 +652,16 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl(
if (nPackedLoops > 0) {
// Step 4. Create InsertSliceOp at the innermost loop level, inserting an
// optionally transposed padded slice into the packed tensor.
- Value inserted = rewriter.create<tensor::InsertSliceOp>(
- loc, paddedTensor, hoistedPackedTensor, offsets, sizes, strides);
+ Value inserted = tensor::InsertSliceOp::create(rewriter, loc, paddedTensor,
+ hoistedPackedTensor, offsets,
+ sizes, strides);
// Step 5. Iteratively pop the stack and propagate the yield.
Value valueToYield = inserted;
for (Value iv : llvm::reverse(clonedLoopIvs)) {
auto forOp = scf::getForInductionVarOwner(iv);
rewriter.setInsertionPointToEnd(&forOp.getRegion().front());
- rewriter.create<scf::YieldOp>(loc, valueToYield);
+ scf::YieldOp::create(rewriter, loc, valueToYield);
valueToYield = forOp.getResult(0);
}
}
@@ -712,8 +713,8 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl(
rewriter.setInsertionPoint(outerLoop);
SmallVector<Value> dynamicTensorSizes =
analysis.getHoistedPackedTensorSizes(rewriter, loc);
- auto emptyOp = rewriter.create<tensor::EmptyOp>(
- loc, hoistedPackedTensorType.getShape(),
+ auto emptyOp = tensor::EmptyOp::create(
+ rewriter, loc, hoistedPackedTensorType.getShape(),
hoistedPackedTensorType.getElementType(), dynamicTensorSizes);
return buildPackingLoopNestImpl(rewriter, bvm, opToHoist, transposeVector,
@@ -756,8 +757,7 @@ static bool tracesBackToExpectedValue(tensor::ExtractSliceOp extractSliceOp,
Value source = extractSliceOp.getSource();
LLVM_DEBUG(DBGS() << "--with starting source: " << source << "\n");
while (source && source != expectedSource) {
- auto destOp =
- dyn_cast_or_null<DestinationStyleOpInterface>(source.getDefiningOp());
+ auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
if (!destOp)
break;
LLVM_DEBUG(DBGS() << "--step dest op: " << destOp << "\n");
@@ -840,8 +840,8 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
{
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(forOp);
- extracted = rewriter.create<tensor::ExtractSliceOp>(
- hoistedPackedTensor.getLoc(), hoistedPackedTensor,
+ extracted = tensor::ExtractSliceOp::create(
+ rewriter, hoistedPackedTensor.getLoc(), hoistedPackedTensor,
outerSliceOp.getMixedOffsets(), outerSliceOp.getMixedSizes(),
outerSliceOp.getMixedStrides());
rewriter.replaceAllUsesWith(forOp.getResult(iterArgNumber), extracted);
@@ -934,8 +934,8 @@ static Value replaceByPackingResult(RewriterBase &rewriter,
// offsets = [maybe_leading_ivs, 0 .. 0].
// sizes = [1 .. 1, transposedShape] (defined above).
// strides = [1 .. 1] (defined above)
- return rewriter.create<tensor::ExtractSliceOp>(
- loc, transposedTensorType, hoistedPackedTensor, offsets,
+ return tensor::ExtractSliceOp::create(
+ rewriter, loc, transposedTensorType, hoistedPackedTensor, offsets,
packingResult.sizes, packingResult.strides);
}
@@ -982,10 +982,11 @@ FailureOr<Value> mlir::linalg::hoistPaddingOnTensors(
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newResult.getDefiningOp());
// Transpose the packed tensor back to the original storage order.
- Value emptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, paddedTensorType.getShape(), paddedTensorType.getElementType());
- TransposeOp unTransposeOp = rewriter.create<linalg::TransposeOp>(
- loc, newResult, emptyTensor, transposeVector);
+ Value emptyTensor =
+ tensor::EmptyOp::create(rewriter, loc, paddedTensorType.getShape(),
+ paddedTensorType.getElementType());
+ TransposeOp unTransposeOp = linalg::TransposeOp::create(
+ rewriter, loc, newResult, emptyTensor, transposeVector);
newResult = unTransposeOp.getResult()[0];
transposeOps.push_back(unTransposeOp);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index f2e51c29..58986a6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -53,9 +53,9 @@ static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter,
assert(index < inits.size());
inits[index] = newInitOperand;
- scf::ForOp newLoop = rewriter.create<scf::ForOp>(
- loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
- inits, [](OpBuilder &, Location, Value, ValueRange) {});
+ scf::ForOp newLoop = scf::ForOp::create(
+ rewriter, loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(),
+ loop.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {});
// Generate the new yield with the replaced operand.
auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
@@ -165,8 +165,7 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
Value source = transferRead.getBase();
// Skip view-like Ops and retrive the actual soruce Operation
- while (auto srcOp =
- dyn_cast_or_null<ViewLikeOpInterface>(source.getDefiningOp()))
+ while (auto srcOp = source.getDefiningOp<ViewLikeOpInterface>())
source = srcOp.getViewSource();
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
index 1f3336d..39cc21d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
@@ -60,9 +60,9 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
Location loc = genericOp->getLoc();
SmallVector<Value> outputOperands = genericOp.getOutputs();
- auto newOp = rewriter.create<GenericOp>(
- loc, genericOp->getResultTypes(), newOperands, outputOperands,
- newIndexingMaps, genericOp.getIteratorTypesArray());
+ auto newOp = GenericOp::create(rewriter, loc, genericOp->getResultTypes(),
+ newOperands, outputOperands, newIndexingMaps,
+ genericOp.getIteratorTypesArray());
rewriter.cloneRegionBefore(genericOp.getRegion(), newOp.getRegion(),
newOp.getRegion().begin());
@@ -77,11 +77,11 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
SmallVector<Value> indicesValues;
for (auto idx : indices)
indicesValues.emplace_back(
- rewriter.create<arith::ConstantIndexOp>(loc, idx));
+ arith::ConstantIndexOp::create(rewriter, loc, idx));
Value scalarValue = opOperand->get();
if (isa<RankedTensorType>(scalarValue.getType())) {
- scalarValue =
- rewriter.create<tensor::ExtractOp>(loc, scalarValue, indicesValues);
+ scalarValue = tensor::ExtractOp::create(rewriter, loc, scalarValue,
+ indicesValues);
}
body->getArgument(idx).replaceAllUsesWith(scalarValue);
body->eraseArgument(idx);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
index a92a0c8..96e6eee 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
@@ -88,7 +88,8 @@ mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
allIndices.reserve(genericOp.getNumLoops());
llvm::transform(llvm::seq<uint64_t>(0, genericOp.getNumLoops()),
std::back_inserter(allIndices), [&](uint64_t dim) {
- return rewriter.create<IndexOp>(indexOp->getLoc(), dim);
+ return IndexOp::create(rewriter, indexOp->getLoc(),
+ dim);
});
rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
indexOp, permutationMap.getSubMap(indexOp.getDim()), allIndices);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 488041a..38f1a8b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -49,7 +49,7 @@ static SmallVector<Value> makeCanonicalAffineApplies(OpBuilder &b, Location loc,
auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e);
SmallVector<Value> operands(vals);
affine::canonicalizeMapAndOperands(&exprMap, &operands);
- res.push_back(b.create<affine::AffineApplyOp>(loc, exprMap, operands));
+ res.push_back(affine::AffineApplyOp::create(b, loc, exprMap, operands));
}
return res;
}
@@ -70,8 +70,9 @@ static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op,
Operation *terminator = block.getTerminator();
for (OpOperand &operand : terminator->getOpOperands()) {
Value toStore = map.lookupOrDefault(operand.get());
- b.create<StoreOpTy>(loc, toStore, outputBuffers[operand.getOperandNumber()],
- indexing[operand.getOperandNumber()]);
+ StoreOpTy::create(b, loc, toStore,
+ outputBuffers[operand.getOperandNumber()],
+ indexing[operand.getOperandNumber()]);
}
}
@@ -145,7 +146,7 @@ static void emitScalarImplementation(OpBuilder &b, Location loc,
auto indexing = makeCanonicalAffineApplies(
b, loc, linalgOp.getMatchingIndexingMap(inputOperand), allIvsPlusDims);
indexedValues.push_back(
- b.create<LoadOpTy>(loc, inputOperand->get(), indexing));
+ LoadOpTy::create(b, loc, inputOperand->get(), indexing));
}
// 1.b. Emit load from output views.
for (OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) {
@@ -153,7 +154,7 @@ static void emitScalarImplementation(OpBuilder &b, Location loc,
b, loc, linalgOp.getMatchingIndexingMap(&outputOperand),
allIvsPlusDims);
indexedValues.push_back(
- b.create<LoadOpTy>(loc, outputOperand.get(), indexing));
+ LoadOpTy::create(b, loc, outputOperand.get(), indexing));
}
// TODO: When a region inliner exists, use it.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
index bb1e974..a2bd9d9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
@@ -59,8 +59,8 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
auto newKernelTy = RankedTensorType::get(
{kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)},
kernelTy.getElementType());
- auto collapsedKernel = rewriter.create<tensor::CollapseShapeOp>(
- loc, newKernelTy, kernel, collapsedKernelDims);
+ auto collapsedKernel = tensor::CollapseShapeOp::create(
+ rewriter, loc, newKernelTy, kernel, collapsedKernelDims);
// Collapse init dims.
SmallVector<ReassociationIndices, 4> collapsedInitDims = {
@@ -70,22 +70,23 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1),
initTy.getDimSize(2), initTy.getDimSize(3)},
initTy.getElementType());
- auto collapsedInit = rewriter.create<tensor::CollapseShapeOp>(
- loc, newInitTy, init, collapsedInitDims);
+ auto collapsedInit = tensor::CollapseShapeOp::create(rewriter, loc, newInitTy,
+ init, collapsedInitDims);
SmallVector<NamedAttribute> preservedAttrs;
Operation *newConv =
TypeSwitch<Operation *, Operation *>(operation)
.Case<DepthwiseConv2DNhwcHwcmOp>([&](auto op) {
preservedAttrs = getPrunedAttributeList(op);
- return rewriter.create<DepthwiseConv2DNhwcHwcOp>(
- loc, newInitTy, ValueRange{input, collapsedKernel},
+ return DepthwiseConv2DNhwcHwcOp::create(
+ rewriter, loc, newInitTy, ValueRange{input, collapsedKernel},
ValueRange{collapsedInit}, stride, dilation);
})
.Case<DepthwiseConv2DNhwcHwcmQOp>([&](auto op) {
preservedAttrs = getPrunedAttributeList(op);
- return rewriter.create<DepthwiseConv2DNhwcHwcQOp>(
- loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp},
+ return DepthwiseConv2DNhwcHwcQOp::create(
+ rewriter, loc, newInitTy,
+ ValueRange{input, collapsedKernel, iZp, kZp},
ValueRange{collapsedInit}, stride, dilation);
})
.Default([](Operation *op) { return nullptr; });
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 2afa2f9..9d7f4e0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
@@ -81,9 +82,8 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
ArrayRef<ReassociationIndices> reassociation) const {
if (operand.getType() == newOperandType)
return operand;
- return rewriter
- .create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
- reassociation)
+ return tensor::ExpandShapeOp::create(rewriter, loc, newOperandType, operand,
+ reassociation)
.getResult();
}
@@ -143,8 +143,8 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
Type newOperandType, ArrayAttr reassociation) const {
if (operand.getType() == newOperandType)
return operand;
- return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType,
- operand, reassociation);
+ return tensor::CollapseShapeOp::create(rewriter, loc, newOperandType,
+ operand, reassociation);
}
/// Returns success() if it is unpacking on the innermost dimension.
@@ -220,6 +220,33 @@ public:
if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
return failure();
+ // Folding is not allowed if it were to introduce artificial padding.
+ // Folding is also disabled in the case of dynamic dimensions and/or tile
+ // sizes - that is because it would be impossible to compute the padding
+ // size and hence to establish whether "artificial" padding would be
+ // created.
+ RankedTensorType unpackedType = packOp.getSourceType();
+ SmallVector<int64_t> outerShapeWithoutTranspose =
+ getPackedOuterShapeWithoutTransposition(packOp);
+ for (auto [pos, tileSize, high] :
+ llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
+ padOp.getMixedHighPad())) {
+ if (unpackedType.isDynamicDim(pos))
+ return failure();
+ if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
+ return failure();
+ if (ShapedType::isDynamic(tileSize))
+ return failure();
+ std::optional<int64_t> cstHigh = getConstantIntValue(high);
+ if (!cstHigh)
+ return failure();
+ int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
+ unpackedType.getDimSize(pos);
+ // Do not fold the op if it requires artificial padding.
+ if (paddingSize + cstHigh.value() >= tileSize)
+ return failure();
+ }
+
rewriter.replaceOpWithNewOp<PackOp>(
packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
packOp.getMixedTiles(), constantPaddingValue,
@@ -251,22 +278,13 @@ public:
if (controlFn && !controlFn(&sliceOp.getSourceMutable()))
return failure();
- if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
- return rewriter.notifyMatchFailure(
- sliceOp, "rank-reduced folding is not supported");
- }
-
- // Check all offsets are zeros, and all strides are ones.
- if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
- !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
- return rewriter.notifyMatchFailure(
- sliceOp, "expects offsets to be 0s and strides to be 1s");
- }
+ if (!unpackOp.canFoldSliceOp(sliceOp))
+ return failure();
// Create a new empty output tensor.
Type elementType = unpackOp.getDestType().getElementType();
- Value output = rewriter.create<tensor::EmptyOp>(
- sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
+ Value output = tensor::EmptyOp::create(
+ rewriter, sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
rewriter.replaceOpWithNewOp<UnPackOp>(
sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
@@ -529,8 +547,8 @@ public:
auto elemType =
cast<ShapedType>(unPackOp->getResultTypes()[0]).getElementType();
- Value output = rewriter.create<tensor::EmptyOp>(
- unPackOp->getLoc(), unpackOpResultDims[0], elemType);
+ Value output = tensor::EmptyOp::create(rewriter, unPackOp->getLoc(),
+ unpackOpResultDims[0], elemType);
rewriter.replaceOpWithNewOp<UnPackOp>(
unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 5eb3761..2e62523 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -55,6 +55,28 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
return paddingSizes;
}
+/// Extracts the constant multiplier from an affine expression of the form
+/// `d * c` or `c * d`, where `d` is an AffineDimExpr and `c` is an
+/// AffineConstantExpr. Returns 1 if the expression is not a simple
+/// multiplication of a dimension and a constant.
+static int64_t extractConstantMultiplier(AffineExpr expr) {
+ if (auto binOp = dyn_cast<AffineBinaryOpExpr>(expr)) {
+ if (binOp.getKind() == AffineExprKind::Mul) {
+ auto lhsD = dyn_cast<AffineDimExpr>(binOp.getLHS());
+ auto rhsC = dyn_cast<AffineConstantExpr>(binOp.getRHS());
+ if (lhsD && rhsC) {
+ return rhsC.getValue();
+ }
+ auto lhsC = dyn_cast<AffineConstantExpr>(binOp.getLHS());
+ auto rhsD = dyn_cast<AffineDimExpr>(binOp.getRHS());
+ if (lhsC && rhsD) {
+ return lhsC.getValue();
+ }
+ }
+ }
+ return 1;
+}
+
/// Compute the padded shape of the given value `v` of `RankedTensorType` given
/// - `indexingSizes` a list of OpFoldResult.
/// - an `indexingMap` that encodes how the shape of varies with increases
@@ -63,6 +85,13 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
/// The `indexingMap` + `indexingSizes` encoding suits StructuredOps.
/// The implementaiton below iteratively combines increases from contributing
/// dimensions using affine.apply operations.
+/// The padded shape is computed by evaluating the maximum accessed index per
+/// dimension, which may involve multiplying by constant factors derived from
+/// the affine indexing expressions. Currently, only a limited set of projected
+/// permutation indexing maps are supported, such as
+/// - affine_map<(d0, d1, d2) -> (d0, d1)>
+/// - affine_map<(d0, d1, d2) -> (d0, d1 + d2)>
+/// - affine_map<(d0, d1) -> (d0 * 3 + d1)>
/// In the future, more general interfaces can be devised to encode similar
/// shape evolutions and map between an op and its operands.
SmallVector<OpFoldResult> linalg::computePaddedShape(
@@ -114,24 +143,33 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
/*compressDims=*/true);
// If we are padding to the next multiple of, compose with ceil(sz) * sz.
+ OpFoldResult paddingDimOfr;
if (options.padToMultipleOf) {
AffineExpr d0, s0;
bindDims(rewriter.getContext(), d0);
bindSymbols(rewriter.getContext(), s0);
AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0);
AffineMap composedMap = projectedMap.compose(ceilMap);
- OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
+ paddingDimOfr = affine::makeComposedFoldedAffineApply(
rewriter, loc, composedMap,
{indexingSizes[paddingDim], paddingSize},
/*composeAffineMin=*/true);
- terms.push_back(paddingDimOfr);
} else {
// Otherwise just set to paddingSize.
- OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
+ paddingDimOfr = affine::makeComposedFoldedAffineApply(
rewriter, loc, projectedMap, paddingSize);
- terms.push_back(paddingDimOfr);
}
+ // Adjust for the maximum accessed index, which is (paddingSize - 1) *
+ // multiplier.
+ AffineExpr d0;
+ bindDims(rewriter.getContext(), d0);
+ int64_t multiplier = extractConstantMultiplier(projectedMap.getResult(0));
+ AffineMap subtractMap = AffineMap::get(1, 0, d0 - multiplier);
+ OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, subtractMap, {paddingDimOfr});
+ terms.push_back(maxAccessIdx);
+
LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n");
}
@@ -148,8 +186,9 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
AffineExpr sumExpr = dims.front();
for (unsigned i = 1; i < dims.size(); ++i)
sumExpr = sumExpr + dims[i];
- OpFoldResult paddedDimOfr =
- affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, terms);
+ // Add 1 to the maximum accessed index and get the final padded size.
+ OpFoldResult paddedDimOfr = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, sumExpr + 1, terms);
paddedShape[resultIndex] = paddedDimOfr;
}
@@ -192,11 +231,11 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
if (auto complexTy =
dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) {
auto complexAttr = cast<ArrayAttr>(paddingValueAttr);
- paddingValue = rewriter.create<complex::ConstantOp>(opToPad.getLoc(),
- complexTy, complexAttr);
+ paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
+ complexTy, complexAttr);
} else {
- paddingValue = rewriter.create<arith::ConstantOp>(
- opToPad.getLoc(), cast<TypedAttr>(paddingValueAttr));
+ paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(),
+ cast<TypedAttr>(paddingValueAttr));
}
// Pad the operand to the bounding box defined by `paddedShape`.
@@ -323,8 +362,8 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank();
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
- paddedSubtensorResults.push_back(rewriter.create<tensor::ExtractSliceOp>(
- loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
+ paddedSubtensorResults.push_back(tensor::ExtractSliceOp::create(
+ rewriter, loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
strides));
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index dc9e11e..dd84379 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -219,11 +219,11 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
if (auto complexTy = dyn_cast<ComplexType>(
getElementTypeOrSelf(opOperand->get().getType()))) {
auto complexAttr = cast<ArrayAttr>(paddingAttr);
- paddingValue = rewriter.create<complex::ConstantOp>(opToPad.getLoc(),
- complexTy, complexAttr);
+ paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
+ complexTy, complexAttr);
} else {
- paddingValue = rewriter.create<arith::ConstantOp>(
- opToPad.getLoc(), cast<TypedAttr>(paddingAttr));
+ paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(),
+ cast<TypedAttr>(paddingAttr));
}
// Computes the padded shape.
@@ -313,8 +313,8 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank();
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
- paddedSubtensorResults.push_back(rewriter.create<tensor::ExtractSliceOp>(
- loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
+ paddedSubtensorResults.push_back(tensor::ExtractSliceOp::create(
+ rewriter, loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
strides));
}
@@ -333,17 +333,16 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
for (auto it :
llvm::zip(paddedSubtensorResults, opToPad.getDpsInitsMutable())) {
if (options.copyBackOp == LinalgPaddingOptions::CopyBackOp::LinalgCopy) {
- replacements.push_back(rewriter
- .create<linalg::CopyOp>(loc, std::get<0>(it),
- std::get<1>(it).get())
+ replacements.push_back(linalg::CopyOp::create(rewriter, loc,
+ std::get<0>(it),
+ std::get<1>(it).get())
.getResult(0));
} else if (options.copyBackOp ==
LinalgPaddingOptions::CopyBackOp::
BufferizationMaterializeInDestination) {
replacements.push_back(
- rewriter
- .create<bufferization::MaterializeInDestinationOp>(
- loc, std::get<0>(it), std::get<1>(it).get())
+ bufferization::MaterializeInDestinationOp::create(
+ rewriter, loc, std::get<0>(it), std::get<1>(it).get())
->getResult(0));
} else {
llvm_unreachable("unsupported copy back op");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index 0433016..f05ffa8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -62,11 +62,11 @@ static Value allocBuffer(ImplicitLocOpBuilder &b,
staticBufferType =
MemRefType::Builder(staticBufferType).setMemorySpace(memorySpaceAttr);
if (options.useAlloca) {
- return b.create<memref::AllocaOp>(staticBufferType, ValueRange{},
- alignmentAttr);
+ return memref::AllocaOp::create(b, staticBufferType, ValueRange{},
+ alignmentAttr);
}
- return b.create<memref::AllocOp>(staticBufferType, ValueRange{},
- alignmentAttr);
+ return memref::AllocOp::create(b, staticBufferType, ValueRange{},
+ alignmentAttr);
}
// Fallback dynamic buffer.
@@ -75,10 +75,10 @@ static Value allocBuffer(ImplicitLocOpBuilder &b,
dynamicBufferType =
MemRefType::Builder(dynamicBufferType).setMemorySpace(memorySpaceAttr);
Value mul = b.createOrFold<arith::MulIOp>(
- b.create<arith::ConstantIndexOp>(width), allocSize);
+ arith::ConstantIndexOp::create(b, width), allocSize);
if (options.useAlloca)
- return b.create<memref::AllocaOp>(dynamicBufferType, mul, alignmentAttr);
- return b.create<memref::AllocOp>(dynamicBufferType, mul, alignmentAttr);
+ return memref::AllocaOp::create(b, dynamicBufferType, mul, alignmentAttr);
+ return memref::AllocOp::create(b, dynamicBufferType, mul, alignmentAttr);
}
/// Default allocation callback function. This allocates a promoted buffer when
@@ -91,8 +91,8 @@ static std::optional<Value> defaultAllocBufferCallBack(
std::optional<unsigned> alignment, DataLayout &layout) {
ShapedType viewType = subView.getType();
ImplicitLocOpBuilder b(subView.getLoc(), builder);
- auto zero = b.create<arith::ConstantIndexOp>(0);
- auto one = b.create<arith::ConstantIndexOp>(1);
+ auto zero = arith::ConstantIndexOp::create(b, 0);
+ auto one = arith::ConstantIndexOp::create(b, 1);
Attribute memorySpaceAttr;
if (options.memorySpace.has_value())
@@ -122,8 +122,8 @@ defaultDeallocBufferCallBack(const LinalgPromotionOptions &options,
OpBuilder &b, Value fullLocalView) {
if (!options.useAlloca) {
auto viewOp = cast<memref::ViewOp>(fullLocalView.getDefiningOp());
- b.create<memref::DeallocOp>(viewOp.getSource().getLoc(),
- viewOp.getSource());
+ memref::DeallocOp::create(b, viewOp.getSource().getLoc(),
+ viewOp.getSource());
}
return success();
}
@@ -210,7 +210,7 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
Location loc = linalgOp.getLoc();
auto defaultCopyCallBack = [loc](OpBuilder &b, Value src,
Value dst) -> LogicalResult {
- b.create<linalg::CopyOp>(loc, src, dst);
+ linalg::CopyOp::create(b, loc, src, dst);
return success();
};
copyInFn = (options.copyInFn ? *(options.copyInFn) : defaultCopyCallBack);
@@ -264,7 +264,7 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
/*stopCondition=*/nullptr, /*closedUB=*/true);
size = failed(upperBound)
? getValueOrCreateConstantIndexOp(b, loc, rangeValue.size)
- : b.create<arith::ConstantIndexOp>(loc, *upperBound);
+ : arith::ConstantIndexOp::create(b, loc, *upperBound);
}
LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");
fullSizes.push_back(size);
@@ -309,23 +309,23 @@ promoteSubViews(ImplicitLocOpBuilder &b,
Value fillVal =
llvm::TypeSwitch<Type, Value>(subviewEltType)
.Case([&](FloatType t) {
- return b.create<arith::ConstantOp>(FloatAttr::get(t, 0.0));
+ return arith::ConstantOp::create(b, FloatAttr::get(t, 0.0));
})
.Case([&](IntegerType t) {
- return b.create<arith::ConstantOp>(IntegerAttr::get(t, 0));
+ return arith::ConstantOp::create(b, IntegerAttr::get(t, 0));
})
.Case([&](ComplexType t) {
Value tmp;
if (auto et = dyn_cast<FloatType>(t.getElementType()))
- tmp = b.create<arith::ConstantOp>(FloatAttr::get(et, 0.0));
+ tmp = arith::ConstantOp::create(b, FloatAttr::get(et, 0.0));
else if (auto et = cast<IntegerType>(t.getElementType()))
- tmp = b.create<arith::ConstantOp>(IntegerAttr::get(et, 0));
- return b.create<complex::CreateOp>(t, tmp, tmp);
+ tmp = arith::ConstantOp::create(b, IntegerAttr::get(et, 0));
+ return complex::CreateOp::create(b, t, tmp, tmp);
})
.Default([](auto) { return Value(); });
if (!fillVal)
return failure();
- b.create<linalg::FillOp>(fillVal, promotionInfo->fullLocalView);
+ linalg::FillOp::create(b, fillVal, promotionInfo->fullLocalView);
}
// Copy data into the promoted buffers. Use callback if provided.
@@ -458,9 +458,9 @@ static std::optional<Value> allocateSubviewGPUMemoryInAddressSpace(
gpu::AddressSpaceAttr::get(builder.getContext(), addressSpace));
Value buffer;
if (addressSpace == gpu::GPUDialect::getWorkgroupAddressSpace()) {
- buffer = builder.create<memref::AllocOp>(funcOp.getLoc(), type);
+ buffer = memref::AllocOp::create(builder, funcOp.getLoc(), type);
} else if (addressSpace == gpu::GPUDialect::getPrivateAddressSpace()) {
- buffer = builder.create<memref::AllocaOp>(funcOp.getLoc(), type);
+ buffer = memref::AllocaOp::create(builder, funcOp.getLoc(), type);
} else {
return std::nullopt;
}
@@ -486,9 +486,9 @@ LogicalResult mlir::linalg::deallocateWorkgroupMemory(OpBuilder &,
/// the copy operation to ensure data integrity.
LogicalResult mlir::linalg::copyToWorkgroupMemory(OpBuilder &b, Value src,
Value dst) {
- b.create<gpu::BarrierOp>(src.getLoc());
- Operation *copyOp = b.create<memref::CopyOp>(src.getLoc(), src, dst);
- b.create<gpu::BarrierOp>(copyOp->getLoc());
+ gpu::BarrierOp::create(b, src.getLoc());
+ Operation *copyOp = memref::CopyOp::create(b, src.getLoc(), src, dst);
+ gpu::BarrierOp::create(b, copyOp->getLoc());
return success();
}
@@ -503,7 +503,7 @@ std::optional<Value> mlir::linalg::allocateGPUPrivateMemory(
/// Normal copy to between src and dst.
LogicalResult mlir::linalg::copyToGPUPrivateMemory(OpBuilder &b, Value src,
Value dst) {
- b.create<memref::CopyOp>(src.getLoc(), src, dst);
+ memref::CopyOp::create(b, src.getLoc(), src, dst);
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
index b30182d..eac0e47 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
@@ -38,8 +38,8 @@ struct StructuredOpInterface
SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc);
auto [starts, ends, _] = getOffsetsSizesAndStrides(loopRanges);
- auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
- auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
+ auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
+ auto one = arith::ConstantIndexOp::create(builder, loc, 1);
// Subtract one from the loop ends before composing with the indexing map
transform(ends, ends.begin(), [&](OpFoldResult end) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
index 24b8765..f277c5f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
@@ -1,4 +1,4 @@
-//===- MeshShardingInterfaceImpl.cpp --------------------------------------===//
+//===- ShardingInterfaceImpl.cpp --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,18 +6,18 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/ShardingInterfaceImpl.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
-#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Shard/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
@@ -36,13 +36,13 @@
namespace mlir::linalg {
-using MeshAxis = mesh::MeshAxis;
-using ReductionKind = mesh::ReductionKind;
-using MeshSharding = mesh::MeshSharding;
-using ShardingArray = mesh::ShardingArray;
-using MeshOp = mesh::MeshOp;
+using GridAxis = shard::GridAxis;
+using ReductionKind = shard::ReductionKind;
+using Sharding = shard::Sharding;
+using ShardingArray = shard::ShardingArray;
+using GridOp = shard::GridOp;
-// Returns the corresponding mesh reduction kind for the given arith op.
+// Returns the corresponding grid reduction kind for the given arith op.
static ReductionKind getReductionKind(Operation *op) {
return llvm::TypeSwitch<Operation *, ReductionKind>(op)
// Floating-point operations.
@@ -97,18 +97,18 @@ static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
return getReductionKind(reductionOp.value());
}
-static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings,
+static GridOp getGrid(Operation *op, ArrayRef<Sharding> operandShardings,
+ ArrayRef<Sharding> resultShardings,
SymbolTableCollection &symbolTable) {
- for (const MeshSharding &sharding : operandShardings) {
+ for (const Sharding &sharding : operandShardings) {
if (sharding) {
- return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
+ return shard::getGrid(op, sharding.getGridAttr(), symbolTable);
}
}
- for (const MeshSharding &sharding : resultShardings) {
+ for (const Sharding &sharding : resultShardings) {
if (sharding) {
- return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
+ return shard::getGrid(op, sharding.getGridAttr(), symbolTable);
}
}
@@ -117,29 +117,29 @@ static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
}
// Choose the operand based on the current process index along the reduction
-// mesh axes.
+// grid axes.
// We need to use the initial value only once to avoid including it in the
// reduction multiple times.
// In each process group only the leading process with linear index 0 would use
// the original operand.
// The other processes would use the reduction operation neutral tensor.
static Value createDestinationPassingStyleInitOperand(
- LinalgOp op, int operandNumber, Value spmdizedOperand,
- ArrayRef<MeshAxis> reductionMeshAxes, MeshOp meshOp,
+ LinalgOp op, int operandNumber, Value partitionedOperand,
+ ArrayRef<GridAxis> reductionGridAxes, GridOp gridOp,
ImplicitLocOpBuilder &builder) {
- Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
- meshOp.getSymName(), reductionMeshAxes, builder);
- Value zero = builder.create<arith::ConstantIndexOp>(0);
- Value isLeadProcess = builder.create<arith::CmpIOp>(
- builder.getI1Type(), arith::CmpIPredicate::eq,
+ Value processLinearIndexInReductionGroup = shard::createProcessLinearIndex(
+ gridOp.getSymName(), reductionGridAxes, builder);
+ Value zero = arith::ConstantIndexOp::create(builder, 0);
+ Value isLeadProcess = arith::CmpIOp::create(
+ builder, builder.getI1Type(), arith::CmpIPredicate::eq,
processLinearIndexInReductionGroup, zero);
- scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(),
- isLeadProcess, true, true);
+ scf::IfOp ifOp = scf::IfOp::create(builder, partitionedOperand.getType(),
+ isLeadProcess, true, true);
// Then block.
{
OpBuilder::InsertionGuard insertionGuard(builder);
builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
- builder.create<scf::YieldOp>(spmdizedOperand);
+ scf::YieldOp::create(builder, partitionedOperand);
}
// Else block.
@@ -147,7 +147,7 @@ static Value createDestinationPassingStyleInitOperand(
OpBuilder::InsertionGuard insertionGuard(builder);
builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
SmallVector<OpFoldResult> shape =
- tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
+ tensor::getMixedSizes(builder, builder.getLoc(), partitionedOperand);
SmallVector<Operation *> combinerOps;
matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps);
@@ -155,85 +155,84 @@ static Value createDestinationPassingStyleInitOperand(
std::optional<TypedAttr> neutralEl =
arith::getNeutralElement(combinerOps[0]);
- Value init = builder.create<tensor::EmptyOp>(op.getLoc(), shape,
- neutralEl.value().getType());
+ Value init = tensor::EmptyOp::create(builder, op.getLoc(), shape,
+ neutralEl.value().getType());
Value constant =
- builder.create<arith::ConstantOp>(op.getLoc(), neutralEl.value());
- Value fill = builder.create<linalg::FillOp>(op.getLoc(), constant, init)
+ arith::ConstantOp::create(builder, op.getLoc(), neutralEl.value());
+ Value fill = linalg::FillOp::create(builder, op.getLoc(), constant, init)
.getResult(0);
- builder.create<scf::YieldOp>(fill);
+ scf::YieldOp::create(builder, fill);
}
return ifOp.getResult(0);
}
-// Create the DPS init operands for the spmdized Linalg op.
-// Return all the new spmdized operands.
+// Create the DPS init operands for the partitioned Linalg op.
+// Return all the new partitioned operands.
static SmallVector<Value> createDestinationPassingStyleInitOperands(
- LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap,
+ LinalgOp op, GridOp gridOp, ArrayRef<Value> partitionedOperands,
+ ArrayRef<GridAxis> reductionGridAxes, IRMapping &partitionMap,
ImplicitLocOpBuilder &builder) {
// TODO: add support for multiple destination passing style initial value
// operands.
assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported.");
- SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
+ SmallVector<Value> newOperands = llvm::to_vector(partitionedOperands);
auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
- Value spmdizedInitOperand =
- spmdizationMap.lookup(op->getOperands()[operandIdx]);
+ Value partitionedInitOperand =
+ partitionMap.lookup(op->getOperands()[operandIdx]);
newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
- op, 0, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
+ op, 0, partitionedInitOperand, reductionGridAxes, gridOp, builder);
return newOperands;
}
static void createAllReduceForResultsWithoutPartialShardings(
- LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
- ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
+ LinalgOp unshardedOp, ArrayRef<GridAxis> opReductionGridAxes,
+ ArrayRef<Sharding> resultShardings, IRMapping &partitionMap,
ImplicitLocOpBuilder &builder) {
ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
for (auto [unshardedLinalgOpResult, resultSharding] :
llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
- Value spmdizedLinalgOpResult =
- spmdizationMap.lookup(unshardedLinalgOpResult);
- Value reducedValue = builder.create<mesh::AllReduceOp>(
- spmdizedLinalgOpResult, resultSharding.getMesh(), opReductionMeshAxes,
- reductionKind);
- spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
+ Value partitionedLinalgOpResult =
+ partitionMap.lookup(unshardedLinalgOpResult);
+ Value reducedValue = shard::AllReduceOp::create(
+ builder, partitionedLinalgOpResult, resultSharding.getGrid(),
+ opReductionGridAxes, reductionKind);
+ partitionMap.map(unshardedLinalgOpResult, reducedValue);
}
}
-static void spmdizeLinalgOpWithShardedReduction(
- LinalgOp op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings,
+static void partitionLinalgOpWithShardedReduction(
+ LinalgOp op, ArrayRef<Value> partitionedOperands,
+ ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
ArrayRef<utils::IteratorType> loopIteratorTypes,
- ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators,
- IRMapping &spmdizationMap, SymbolTableCollection &symbolTable,
+ ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators,
+ IRMapping &partitionMap, SymbolTableCollection &symbolTable,
ImplicitLocOpBuilder &builder) {
- MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable);
- SmallVector<MeshAxis> reductionMeshAxes = mesh::getReductionMeshAxes(
- loopIteratorTypes, meshAxisAssignmentForLoopIterators);
- SmallVector<Value> spmdizedLinalgOpOperands =
- createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands,
- reductionMeshAxes,
- spmdizationMap, builder);
- // We must not change the operand mappings of the original spmdizationMap as
- // they are the mappings for the whole spmdization blob and may be used by
+ GridOp grid = getGrid(op, operandShardings, resultShardings, symbolTable);
+ SmallVector<GridAxis> reductionGridAxes = shard::getReductionGridAxes(
+ loopIteratorTypes, gridAxisAssignmentForLoopIterators);
+ SmallVector<Value> partitionedLinalgOpOperands =
+ createDestinationPassingStyleInitOperands(op, grid, partitionedOperands,
+ reductionGridAxes, partitionMap,
+ builder);
+ // We must not change the operand mappings of the original partitionMap as
+ // they are the mappings for the whole partition blob and may be used by
// others.
- IRMapping internalSpmdizationMap;
- for (auto [unshardedOperand, spmdizedOperand] :
- llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) {
- internalSpmdizationMap.map(unshardedOperand, spmdizedOperand);
+ IRMapping internalPartitionMap;
+ for (auto [unshardedOperand, partitionedOperand] :
+ llvm::zip_equal(op->getOperands(), partitionedLinalgOpOperands)) {
+ internalPartitionMap.map(unshardedOperand, partitionedOperand);
}
- spmdizeTriviallyShardableOperation(
- *op, spmdizedLinalgOpOperands, operandShardings, resultShardings,
- internalSpmdizationMap, symbolTable, builder);
+ partitionTriviallyShardableOperation(
+ *op, partitionedLinalgOpOperands, operandShardings, resultShardings,
+ internalPartitionMap, symbolTable, builder);
for (Value result : op->getResults()) {
- spmdizationMap.map(result, internalSpmdizationMap.lookup(result));
+ partitionMap.map(result, internalPartitionMap.lookup(result));
}
// Handle partial shardings.
createAllReduceForResultsWithoutPartialShardings(
- op, reductionMeshAxes, resultShardings, spmdizationMap, builder);
+ op, reductionGridAxes, resultShardings, partitionMap, builder);
}
namespace {
@@ -243,7 +242,7 @@ namespace {
// permutations.
template <typename Op>
struct StructuredOpShardingInterface
- : public mesh::ShardingInterface::ExternalModel<
+ : public shard::ShardingInterface::ExternalModel<
StructuredOpShardingInterface<Op>, Op> {
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
@@ -272,16 +271,16 @@ struct StructuredOpShardingInterface
[](unsigned count, utils::IteratorType iter) {
return count + (iter == utils::IteratorType::reduction);
});
- mesh::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp);
+ shard::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp);
return SmallVector<ReductionKind>(reductionItersCount, reductionKind);
}
- LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings,
- IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTable,
- OpBuilder &builder) const {
+ LogicalResult partition(Operation *op, ArrayRef<Value> partitionedOperands,
+ ArrayRef<Sharding> operandShardings,
+ ArrayRef<Sharding> resultShardings,
+ IRMapping &partitionMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder) const {
LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
@@ -297,20 +296,20 @@ struct StructuredOpShardingInterface
SmallVector<utils::IteratorType> loopIteratorTypes =
linalgOp.getIteratorTypesArray();
- ShardingArray meshAxisAssignmentForLoopIterators =
- getMeshAxisAssignmentForLoopIterators(operandShardings, resultShardings,
+ ShardingArray gridAxisAssignmentForLoopIterators =
+ getGridAxisAssignmentForLoopIterators(operandShardings, resultShardings,
loopIteratorTypes, indexingMaps);
- if (mesh::isAtLeastOneReductionIteratorSharded(
- loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
+ if (shard::isAtLeastOneReductionIteratorSharded(
+ loopIteratorTypes, gridAxisAssignmentForLoopIterators)) {
ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
- spmdizeLinalgOpWithShardedReduction(
- linalgOp, spmdizedOperands, operandShardings, resultShardings,
- loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap,
+ partitionLinalgOpWithShardedReduction(
+ linalgOp, partitionedOperands, operandShardings, resultShardings,
+ loopIteratorTypes, gridAxisAssignmentForLoopIterators, partitionMap,
symbolTable, implicitLocBuilder);
} else {
- spmdizeTriviallyShardableOperation(*op, spmdizedOperands,
- operandShardings, resultShardings,
- spmdizationMap, symbolTable, builder);
+ partitionTriviallyShardableOperation(*op, partitionedOperands,
+ operandShardings, resultShardings,
+ partitionMap, symbolTable, builder);
}
return success();
@@ -330,7 +329,7 @@ static void registerAll(MLIRContext *ctx) {
(registerOne<OpTypes>(ctx), ...);
}
-void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry) {
+void registerShardingInterfaceExternalModels(DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) {
DialectRegistry registry;
registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
index 671dea8..76d0ba9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
@@ -52,8 +52,8 @@ createSplitPart(RewriterBase &b, Location loc, TilingInterface op,
return nullptr;
SmallVector<OpFoldResult> resultStrides(resultOffsets.size(),
b.getIndexAttr(1));
- Value inserted = b.create<tensor::InsertSliceOp>(
- loc, result, resultOperands[index], resultOffsets, resultSizes,
+ Value inserted = tensor::InsertSliceOp::create(
+ b, loc, result, resultOperands[index], resultOffsets, resultSizes,
resultStrides);
results.push_back(inserted);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index 5bfdbc6..b8f8620 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -115,8 +115,8 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
newShape,
cast<RankedTensorType>(operand->get().getType()).getElementType());
- Value newInput = b.create<tensor::ExpandShapeOp>(
- loc, newType, operand->get(), reassociation);
+ Value newInput = tensor::ExpandShapeOp::create(
+ b, loc, newType, operand->get(), reassociation);
newInputs.push_back(newInput);
}
@@ -140,18 +140,18 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
}
Value emptyOrAllocTensor;
if (useAlloc) {
- emptyOrAllocTensor = b.create<bufferization::AllocTensorOp>(
- loc,
+ emptyOrAllocTensor = bufferization::AllocTensorOp::create(
+ b, loc,
RankedTensorType::get(newOutputShape,
op.getRegionOutputArgs()[0].getType()),
ValueRange{});
} else {
- emptyOrAllocTensor = b.create<tensor::EmptyOp>(
- loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
+ emptyOrAllocTensor = tensor::EmptyOp::create(
+ b, loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
}
- Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
+ Value constantOp = arith::ConstantOp::create(b, loc, *identity);
Value identityTensor =
- b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor)
+ linalg::FillOp::create(b, op->getLoc(), constantOp, emptyOrAllocTensor)
.getResult(0);
newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
@@ -168,8 +168,8 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
}
// Create the new op matching the original op with an extra parallel
// dimension.
- GenericOp genericOp = b.create<GenericOp>(
- loc, TypeRange({emptyOrAllocTensor.getType()}), newInputs,
+ GenericOp genericOp = GenericOp::create(
+ b, loc, TypeRange({emptyOrAllocTensor.getType()}), newInputs,
ValueRange({identityTensor}), newMaps, newIteratorTypes);
b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(),
genericOp.getRegion().begin());
@@ -191,14 +191,14 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
- auto reduction = b.create<GenericOp>(
- loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
+ auto reduction = GenericOp::create(
+ b, loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
op.getDpsInits(), reductionMaps, reductionIteratorTypes,
[reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
Operation *clonedReductionOp = b.clone(*reductionOp);
clonedReductionOp->setOperand(0, inputs[0]);
clonedReductionOp->setOperand(1, inputs[1]);
- b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
+ linalg::YieldOp::create(b, loc, clonedReductionOp->getResult(0));
});
b.replaceOp(op, reduction.getResults());
@@ -318,14 +318,14 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
Value emptyOrAllocTensor;
if (useAlloc) {
emptyOrAllocTensor =
- b.create<bufferization::AllocTensorOp>(loc, newT, dims);
+ bufferization::AllocTensorOp::create(b, loc, newT, dims);
} else {
- emptyOrAllocTensor = b.create<tensor::EmptyOp>(loc, newT.getShape(),
- t.getElementType(), dims);
+ emptyOrAllocTensor = tensor::EmptyOp::create(b, loc, newT.getShape(),
+ t.getElementType(), dims);
}
- Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it));
- fillOps.push_back(
- b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor));
+ Value constantOp = arith::ConstantOp::create(b, loc, std::get<1>(it));
+ fillOps.push_back(linalg::FillOp::create(b, op->getLoc(), constantOp,
+ emptyOrAllocTensor));
newOutputs.push_back(fillOps.back().getResult(0));
emptyOrAllocTensorOps.push_back(emptyOrAllocTensor.getDefiningOp());
}
@@ -354,8 +354,8 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
SmallVector<Value> newInputs = op.getDpsInputs();
// Add a single shape-only tensor to carry the dimensions without resorting to
// more complex inversions.
- newInputs.push_back(b.create<tensor::EmptyOp>(
- loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor},
+ newInputs.push_back(tensor::EmptyOp::create(
+ b, loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor},
b.getIntegerType(1)));
// Output tensors are already good to go.
@@ -365,8 +365,8 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
utils::IteratorType::parallel);
GenericOp genericOp =
- b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs,
- newOutputs, newMaps, iteratorTypes);
+ GenericOp::create(b, loc, ValueRange(newOutputs).getTypes(), newInputs,
+ newOutputs, newMaps, iteratorTypes);
b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(),
genericOp.getRegion().begin());
genericOp.getRegion().front().insertArgument(reductionDimPos,
@@ -396,7 +396,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
utils::IteratorType::reduction;
// clang-format off
- auto reductionOp = b.create<GenericOp>(
+ auto reductionOp = GenericOp::create(b,
loc,
originalOutputType,
reindexedOutput,
@@ -407,7 +407,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
Operation *clonedReductionOp = b.clone(*combinerOp);
clonedReductionOp->setOperand(0, bbArgs[0]);
clonedReductionOp->setOperand(1, bbArgs[1]);
- b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
+ linalg::YieldOp::create(b, loc, clonedReductionOp->getResult(0));
});
// clang-format on
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp
index d35aad5..792ca3e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp
@@ -29,10 +29,10 @@ struct SwapExtractSliceOfFill final
if (!fillOp || !fillOp->hasOneUse())
return failure();
- auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
- extractOp.getLoc(), extractOp.getType(), fillOp.getOutputs()[0],
- extractOp.getMixedOffsets(), extractOp.getMixedSizes(),
- extractOp.getMixedStrides());
+ auto newExtractOp = tensor::ExtractSliceOp::create(
+ rewriter, extractOp.getLoc(), extractOp.getType(),
+ fillOp.getOutputs()[0], extractOp.getMixedOffsets(),
+ extractOp.getMixedSizes(), extractOp.getMixedStrides());
rewriter.replaceOpWithNewOp<FillOp>(extractOp, fillOp.getInputs(),
ValueRange{newExtractOp.getResult()});
return success();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 4741afe..705d6f2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -94,11 +94,11 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
return;
}
- Value zero = b.create<arith::ConstantIndexOp>(0);
- Value condition = b.create<arith::CmpIOp>(arith::CmpIPredicate::sgt,
- cast<Value>(value), zero);
- b.create<cf::AssertOp>(
- condition,
+ Value zero = arith::ConstantIndexOp::create(b, 0);
+ Value condition = arith::CmpIOp::create(b, arith::CmpIPredicate::sgt,
+ cast<Value>(value), zero);
+ cf::AssertOp::create(
+ b, condition,
b.getStringAttr("expected strictly positive tile size and divisor"));
}
@@ -317,11 +317,12 @@ mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op,
Value coveredSize =
apply(s0 * s1 + s2 * s3, {spec.lowTileSize, spec.lowTripCount,
spec.highTileSize, spec.highTripCount});
- Value equals = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
- coveredSize, tripCount);
- b.create<cf::AssertOp>(
- equals, builder.getStringAttr(
- "could not compute dynamic multi-size tile shapes"));
+ Value equals = arith::CmpIOp::create(b, arith::CmpIPredicate::eq,
+ coveredSize, tripCount);
+ cf::AssertOp::create(
+ b, equals,
+ builder.getStringAttr(
+ "could not compute dynamic multi-size tile shapes"));
}
return spec;
@@ -656,8 +657,8 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads);
// 2. Create the ForallOp with an empty region.
- scf::ForallOp forallOp = b.create<scf::ForallOp>(
- loc, getAsOpFoldResult(materializedNonZeroNumThreads), initTensors,
+ scf::ForallOp forallOp = scf::ForallOp::create(
+ b, loc, getAsOpFoldResult(materializedNonZeroNumThreads), initTensors,
mapping);
// 3. Calculate the tile offsets and sizes for the subsequent loop that will
@@ -689,8 +690,8 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
sizes[reductionDim] = b.getIndexAttr(1);
outOffsets[reductionDim] = forallOp.getInductionVars()[0];
// TODO: use SubsetExtractOpInterface once it is available.
- tiledDpsInitOperands.push_back(b.create<tensor::ExtractSliceOp>(
- loc, cast<RankedTensorType>(initOperand.getType()),
+ tiledDpsInitOperands.push_back(tensor::ExtractSliceOp::create(
+ b, loc, cast<RankedTensorType>(initOperand.getType()),
destBbArgs[destNum], outOffsets, sizes, strides));
}
@@ -768,8 +769,8 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
// 6.b. Parallel insertions are inserted at the end of the combining
// terminator.
b.setInsertionPointToEnd(forallOp.getTerminator().getBody());
- b.create<tensor::ParallelInsertSliceOp>(
- loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides);
+ tensor::ParallelInsertSliceOp::create(
+ b, loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides);
}
// 7. Merge the partial reductions.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 28d99b1..57b610b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -45,7 +45,7 @@ static SmallVector<Value> getIndicesForAccess(OpBuilder &b, Location loc,
for (auto result : indexingMap.getResults()) {
AffineMap m = AffineMap::get(indexingMap.getNumDims(),
indexingMap.getNumSymbols(), result);
- Value v = b.create<affine::AffineApplyOp>(loc, m, ivs);
+ Value v = affine::AffineApplyOp::create(b, loc, m, ivs);
indices.push_back(v);
}
return indices;
@@ -73,9 +73,9 @@ static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp,
OpOperand *storeInto = linalgOp.getDpsInitOperand(operand.index());
auto indices = getIndicesForAccess(
b, loc, linalgOp.getMatchingIndexingMap(storeInto), ivs);
- b.create<memref::StoreOp>(
- loc, toStore, linalgOp.getDpsInitOperand(operand.index())->get(),
- indices);
+ memref::StoreOp::create(b, loc, toStore,
+ linalgOp.getDpsInitOperand(operand.index())->get(),
+ indices);
}
return success();
}
@@ -352,7 +352,7 @@ struct LinalgOpTilingInterface
SmallVector<Value> indices = getIndicesForAccess(
builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(&operand), ivs);
Value load =
- builder.create<memref::LoadOp>(linalgOpLoc, operand.get(), indices);
+ memref::LoadOp::create(builder, linalgOpLoc, operand.get(), indices);
indexedValues.push_back(load);
}
@@ -520,10 +520,10 @@ struct LinalgOpPartialReductionInterface
Type elType = getElementTypeOrSelf(result.getType());
Value emptyTensor =
- b.create<tensor::EmptyOp>(loc, partialResultShape, elType);
- Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
+ tensor::EmptyOp::create(b, loc, partialResultShape, elType);
+ Value constantOp = arith::ConstantOp::create(b, loc, *identity);
auto identityTensor =
- b.create<linalg::FillOp>(loc, constantOp, emptyTensor);
+ linalg::FillOp::create(b, loc, constantOp, emptyTensor);
inits.push_back(identityTensor.getResult(0));
}
@@ -575,9 +575,9 @@ struct LinalgOpPartialReductionInterface
RankedTensorType sliceResultType = RankedTensorType::get(
sliceInfo.resultShape, valueToTileType.getElementType(),
valueToTileType.getEncoding());
- auto sliceOp = b.create<tensor::ExtractSliceOp>(
- loc, sliceResultType, valueToTile, sliceInfo.offsets, sliceInfo.sizes,
- sliceInfo.strides);
+ auto sliceOp = tensor::ExtractSliceOp::create(
+ b, loc, sliceResultType, valueToTile, sliceInfo.offsets,
+ sliceInfo.sizes, sliceInfo.strides);
tiledInits.push_back(sliceOp.getResult());
generatedSlices.push_back(sliceOp);
}
@@ -604,8 +604,8 @@ struct LinalgOpPartialReductionInterface
auto resultTypes = ValueRange(tiledInits).getTypes();
if (tilingStrategy ==
ReductionTilingStrategy::PartialReductionOuterReduction) {
- auto genericOp = b.create<GenericOp>(
- loc, resultTypes, tiledInputs, tiledInits, newMaps, newIteratorTypes);
+ auto genericOp = GenericOp::create(b, loc, resultTypes, tiledInputs,
+ tiledInits, newMaps, newIteratorTypes);
IRMapping mapping;
op->getRegion(0).cloneInto(&genericOp.getRegion(),
genericOp.getRegion().begin(), mapping);
@@ -649,8 +649,8 @@ struct LinalgOpPartialReductionInterface
}
}
- auto reduction = b.create<linalg::ReduceOp>(
- loc, partialResult, init, partialReductionDims,
+ auto reduction = linalg::ReduceOp::create(
+ b, loc, partialResult, init, partialReductionDims,
[&linalgOp, &initIdx](OpBuilder &b, Location loc, ValueRange inputs) {
// Get the combiner op.
SmallVector<Operation *, 4> combinerOps;
@@ -660,7 +660,7 @@ struct LinalgOpPartialReductionInterface
// Combine the input at idx and output at numInits + idx.
clonedReductionOp->setOperand(0, inputs[0]);
clonedReductionOp->setOperand(1, inputs[1]);
- b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
+ linalg::YieldOp::create(b, loc, clonedReductionOp->getResult(0));
});
mergeOperations.push_back(reduction);
@@ -791,8 +791,8 @@ struct PackOpTiling
SmallVector<OpFoldResult> strides(inputRank, oneAttr);
SmallVector<Value> tiledOperands;
- auto sourceSlice = b.create<tensor::ExtractSliceOp>(
- loc, packOp.getSource(), inputIndices, inputSizes, strides);
+ auto sourceSlice = tensor::ExtractSliceOp::create(
+ b, loc, packOp.getSource(), inputIndices, inputSizes, strides);
tiledOperands.push_back(sourceSlice);
SmallVector<OpFoldResult> outputOffsets, outputSizes;
@@ -801,8 +801,8 @@ struct PackOpTiling
return {};
strides.append(packOp.getDestRank() - inputRank, oneAttr);
- auto outSlice = b.create<tensor::ExtractSliceOp>(
- loc, packOp.getDest(), outputOffsets, outputSizes, strides);
+ auto outSlice = tensor::ExtractSliceOp::create(
+ b, loc, packOp.getDest(), outputOffsets, outputSizes, strides);
tiledOperands.push_back(outSlice);
if (auto val = packOp.getPaddingValue())
@@ -810,8 +810,8 @@ struct PackOpTiling
for (auto tile : packOp.getInnerTiles())
tiledOperands.push_back(tile);
- Operation *tiledPackOp = b.create<PackOp>(
- loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
+ Operation *tiledPackOp = PackOp::create(
+ b, loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
return TilingResult{
{tiledPackOp},
@@ -932,20 +932,6 @@ struct PackOpTiling
continue;
}
- // If the dimension needs padding, it is not supported because there are
- // iterations that only write padding values to the whole tile. The
- // consumer fusion is driven by the source, so it is not possible to map
- // an empty slice to the tile.
- bool needExtraPadding =
- ShapedType::isDynamic(destDimSize) || !cstInnerSize ||
- destDimSize * cstInnerSize.value() != srcDimSize;
- // Prioritize the case that the op already says that it does not need
- // padding.
- if (!packOp.getPaddingValue())
- needExtraPadding = false;
- if (needExtraPadding)
- return failure();
-
// Currently fusing `packOp` as consumer only expects perfect tiling
// scenario because even if without padding semantic, the `packOp` may
// also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
@@ -1007,8 +993,8 @@ struct PackOpTiling
SmallVector<OpFoldResult> strides(inputRank, oneAttr);
SmallVector<Value> tiledOperands;
- auto sourceSlice = b.create<tensor::ExtractSliceOp>(
- loc, packOp.getSource(), offsets, sizes, strides);
+ auto sourceSlice = tensor::ExtractSliceOp::create(
+ b, loc, packOp.getSource(), offsets, sizes, strides);
tiledOperands.push_back(sourceSlice);
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
@@ -1023,8 +1009,8 @@ struct PackOpTiling
return failure();
strides.append(packOp.getDestRank() - inputRank, oneAttr);
- auto outSlice = b.create<tensor::ExtractSliceOp>(
- loc, packOp.getDest(), outputOffsets, outputSizes, strides);
+ auto outSlice = tensor::ExtractSliceOp::create(
+ b, loc, packOp.getDest(), outputOffsets, outputSizes, strides);
tiledOperands.push_back(outSlice);
if (auto val = packOp.getPaddingValue())
@@ -1032,8 +1018,8 @@ struct PackOpTiling
for (auto tile : packOp.getInnerTiles())
tiledOperands.push_back(tile);
- Operation *tiledPackOp = b.create<PackOp>(
- loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
+ Operation *tiledPackOp = PackOp::create(
+ b, loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
return TilingResult{
{tiledPackOp},
@@ -1212,37 +1198,37 @@ struct UnPackOpTiling
sliceSrcSizes.append(unpackOp.getMixedTiles());
sliceSrcStrides.append(numInnerTiles, oneAttr);
SmallVector<Operation *> generatedSlices;
- tensor::ExtractSliceOp sliceSource = b.create<tensor::ExtractSliceOp>(
- loc, unpackOp.getSource(), sliceSrcIndices, sliceSrcSizes,
+ tensor::ExtractSliceOp sliceSource = tensor::ExtractSliceOp::create(
+ b, loc, unpackOp.getSource(), sliceSrcIndices, sliceSrcSizes,
sliceSrcStrides);
generatedSlices.push_back(sliceSource);
SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
Value sliceDest;
if (isPerfectTilingCase) {
- auto destSliceOp = b.create<tensor::ExtractSliceOp>(
- loc, unpackOp.getDest(), offsets, sizes, destStrides);
+ auto destSliceOp = tensor::ExtractSliceOp::create(
+ b, loc, unpackOp.getDest(), offsets, sizes, destStrides);
sliceDest = destSliceOp;
generatedSlices.push_back(destSliceOp);
} else {
- sliceDest = b.create<tensor::EmptyOp>(
- loc, destExpandedSizes, unpackOp.getDestType().getElementType());
+ sliceDest = tensor::EmptyOp::create(
+ b, loc, destExpandedSizes, unpackOp.getDestType().getElementType());
}
SmallVector<Value> tiledOperands = {sliceSource.getResult(), sliceDest};
for (auto tile : unpackOp.getInnerTiles())
tiledOperands.push_back(tile);
- Operation *tiledUnpackOp = b.create<UnPackOp>(
- loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs());
+ Operation *tiledUnpackOp = UnPackOp::create(
+ b, loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs());
if (isPerfectTilingCase)
return TilingResult{{tiledUnpackOp},
SmallVector<Value>(tiledUnpackOp->getResults()),
generatedSlices};
- auto extractSlice = b.create<tensor::ExtractSliceOp>(
- loc, tiledUnpackOp->getResult(0), resultOffsetsFromDest, sizes,
+ auto extractSlice = tensor::ExtractSliceOp::create(
+ b, loc, tiledUnpackOp->getResult(0), resultOffsetsFromDest, sizes,
destStrides);
return TilingResult{
{tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices};
@@ -1377,22 +1363,22 @@ struct UnPackOpTiling
SmallVector<Value> tiledOperands;
// Create slice of the dest operand.
- auto extractDestSlice = b.create<tensor::ExtractSliceOp>(
- loc, unPackOp.getDest(), outputOffsets, outputSizes, strides);
+ auto extractDestSlice = tensor::ExtractSliceOp::create(
+ b, loc, unPackOp.getDest(), outputOffsets, outputSizes, strides);
tiledOperands.push_back(extractDestSlice);
strides.append(unPackOp.getSourceRank() - outputRank, oneAttr);
// Create slice of the source operand.
- auto extractSourceSlice = b.create<tensor::ExtractSliceOp>(
- loc, unPackOp.getSource(), offsets, sizes, strides);
+ auto extractSourceSlice = tensor::ExtractSliceOp::create(
+ b, loc, unPackOp.getSource(), offsets, sizes, strides);
tiledOperands.insert(tiledOperands.begin(), extractSourceSlice);
for (auto tile : unPackOp.getInnerTiles())
tiledOperands.push_back(tile);
// Create tiled unpack op.
Operation *tiledUnPackOp =
- b.create<UnPackOp>(loc, TypeRange{extractDestSlice.getType()},
- tiledOperands, op->getAttrs());
+ UnPackOp::create(b, loc, TypeRange{extractDestSlice.getType()},
+ tiledOperands, op->getAttrs());
return TilingResult{{tiledUnPackOp},
SmallVector<Value>(tiledUnPackOp->getResults()),
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index eab74da..bb725f2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -269,12 +269,12 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
packingMetadata.reassociations);
Value paddingValue = packOp.getPaddingValue();
if (!paddingValue) {
- paddingValue = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
+ paddingValue = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
}
auto padOp =
- rewriter.create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows,
- highs, paddingValue, /*nofold=*/false);
+ tensor::PadOp::create(rewriter, loc, collapsed, packOp.getSource(), lows,
+ highs, paddingValue, /*nofold=*/false);
LLVM_DEBUG(
DBGSNL(); DBGSNL();
@@ -313,8 +313,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
SmallVector<OpFoldResult> sizes =
tensor::getMixedSizes(rewriter, loc, packOp.getDest());
- auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
- loc, /*source=*/padOp, /*dest=*/packOp.getDest(),
+ auto insertSliceOp = tensor::InsertSliceOp::create(
+ rewriter, loc, /*source=*/padOp, /*dest=*/packOp.getDest(),
/*offsets=*/zeros, sizes, /*strides=*/ones);
LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
@@ -329,15 +329,15 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
// 5. Expand from the padded result to the stripMinedShape.
auto expandShapeResultType =
RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
- auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
- loc, expandShapeResultType, padOp.getResult(),
+ auto reshapeOp = tensor::ExpandShapeOp::create(
+ rewriter, loc, expandShapeResultType, padOp.getResult(),
packingMetadata.reassociations);
// 6. Transpose stripMinedShape to packedShape.
SmallVector<int64_t> transpPerm =
invertPermutationVector(packedToStripMinedShapePerm);
- auto transposeOp = rewriter.create<linalg::TransposeOp>(
- loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
+ auto transposeOp = linalg::TransposeOp::create(
+ rewriter, loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
DBGS() << "reshape op: " << reshapeOp; DBGSNL();
@@ -371,8 +371,8 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one);
sizes.append(tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()));
- auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
- loc, destTensorType, unPackOp.getSource(),
+ auto extractSliceOp = tensor::ExtractSliceOp::create(
+ rewriter, loc, destTensorType, unPackOp.getSource(),
SmallVector<OpFoldResult>(packedRank, zero), sizes,
SmallVector<OpFoldResult>(packedRank, one));
@@ -404,10 +404,11 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
SmallVector<OpFoldResult, 4> dims =
tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
applyPermutationToVector(dims, packedToStripMinedShapePerm);
- auto emptyOp = rewriter.create<tensor::EmptyOp>(
- loc, dims, stripMinedTensorType.getElementType());
- auto transposeOp = rewriter.create<linalg::TransposeOp>(
- loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
+ auto emptyOp = tensor::EmptyOp::create(rewriter, loc, dims,
+ stripMinedTensorType.getElementType());
+ auto transposeOp =
+ linalg::TransposeOp::create(rewriter, loc, unPackOp.getSource(), emptyOp,
+ packedToStripMinedShapePerm);
LLVM_DEBUG(
DBGSNL(); DBGSNL();
@@ -426,21 +427,21 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););
// 4. Collapse from the stripMinedShape to the padded result.
- auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>(
- loc, collapsedType, transposeOp->getResult(0),
+ auto reshapeOp = tensor::CollapseShapeOp::create(
+ rewriter, loc, collapsedType, transposeOp->getResult(0),
packingMetadata.reassociations);
// 5. ExtractSlice.
int64_t destRank = destTensorType.getRank();
- auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
- loc, destTensorType, reshapeOp->getResult(0),
+ auto extractSliceOp = tensor::ExtractSliceOp::create(
+ rewriter, loc, destTensorType, reshapeOp->getResult(0),
SmallVector<OpFoldResult>(destRank, zero),
tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()),
SmallVector<OpFoldResult>(destRank, one));
// 6. Inject a copy to preserve DPS.
- auto copyOp = rewriter.create<linalg::CopyOp>(
- loc, extractSliceOp->getResult(0), unPackOp.getDest());
+ auto copyOp = linalg::CopyOp::create(
+ rewriter, loc, extractSliceOp->getResult(0), unPackOp.getDest());
// 7. Replace unPackOp by copyOp.
rewriter.replaceOp(unPackOp, copyOp->getResults());
@@ -554,16 +555,16 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
operandType.getShape(), innerPos,
cast<ShapedType>(dest.getType()).getShape(), {},
innerPackSizes)) {
- packOps.push_back(rewriter.create<linalg::PackOp>(
- loc, operand, dest, innerPos, innerPackSizes));
+ packOps.push_back(linalg::PackOp::create(rewriter, loc, operand, dest,
+ innerPos, innerPackSizes));
} else {
// TODO: value of the padding attribute should be determined by
// consumers.
auto zeroAttr =
rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
- Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
- packOps.push_back(rewriter.create<linalg::PackOp>(
- loc, operand, dest, innerPos, innerPackSizes, zero));
+ Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
+ packOps.push_back(linalg::PackOp::create(
+ rewriter, loc, operand, dest, innerPos, innerPackSizes, zero));
}
inputsAndInits.push_back(packOps.back());
}
@@ -574,9 +575,9 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
ValueRange inits =
ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
- auto packedLinalgOp = rewriter.create<linalg::GenericOp>(
- linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps,
- iteratorTypes);
+ auto packedLinalgOp =
+ linalg::GenericOp::create(rewriter, linalgOp.getLoc(), inits.getTypes(),
+ inputs, inits, indexingMaps, iteratorTypes);
packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
// Step 4. Propagate packing to all the op results.
@@ -589,8 +590,8 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
continue;
}
// Build the symmetrical UnPackOp to the existing PackOp.
- unPackOps.push_back(rewriter.create<linalg::UnPackOp>(
- packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
+ unPackOps.push_back(linalg::UnPackOp::create(
+ rewriter, packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
results.push_back(unPackOps.back());
}
@@ -655,7 +656,8 @@ static LinalgOp transposeOneLinalgOperandAndReplace(
operands[opOperand.getOperandNumber()] = transposedValue;
ValueRange operandsRef(operands);
- auto transposedGenericOp = rewriter.create<linalg::GenericOp>(
+ auto transposedGenericOp = linalg::GenericOp::create(
+ rewriter,
/*location=*/linalgOp->getLoc(),
/*resultTensorTypes=*/
operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(),
@@ -904,7 +906,7 @@ mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
b.setInsertionPointToStart(
&op->getParentOfType<func::FuncOp>().getBody().front());
return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
- Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
+ Value v = arith::ConstantIndexOp::create(b, op->getLoc(), s);
return v;
}));
};
@@ -926,12 +928,12 @@ Value DecomposePadOpPattern::createFillOrGenerateOp(
// Move the padding value defined inside the PadOp block to outside.
if (padValue.getParentBlock() == &padOp.getRegion().front())
rewriter.moveOpBefore(padValue.getDefiningOp(), padOp);
- return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
+ return FillOp::create(rewriter, padOp.getLoc(), padValue, dest).result();
}
// Fill could not be optimized: Lower to tensor::GenerateOp with region.
- auto generateOp = rewriter.create<tensor::GenerateOp>(
- padOp.getLoc(), padOp.getResultType(), dynSizes);
+ auto generateOp = tensor::GenerateOp::create(rewriter, padOp.getLoc(),
+ padOp.getResultType(), dynSizes);
// Copy region to new op.
IRMapping bvm;
padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
@@ -945,9 +947,9 @@ DecomposePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
auto getIdxValue = [&](OpFoldResult ofr) {
if (auto val = llvm::dyn_cast_if_present<Value>(ofr))
return val;
- return rewriter
- .create<arith::ConstantIndexOp>(
- padOp.getLoc(), cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
+ return arith::ConstantIndexOp::create(
+ rewriter, padOp.getLoc(),
+ cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
.getResult();
};
@@ -970,8 +972,9 @@ DecomposePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
}
// Init tensor and fill it with padding.
- Value emptyTensor = rewriter.create<tensor::EmptyOp>(
- padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes);
+ Value emptyTensor =
+ tensor::EmptyOp::create(rewriter, padOp.getLoc(), staticSizes,
+ resultType.getElementType(), dynSizes);
Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes);
// Generate a InsertSliceOp for copying the PadOp source.
@@ -1222,12 +1225,13 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
srcPermForTranspose);
- Value empty = rewriter.create<tensor::EmptyOp>(
- loc, transShapeForEmptyOp, packOp.getSourceType().getElementType());
+ Value empty =
+ tensor::EmptyOp::create(rewriter, loc, transShapeForEmptyOp,
+ packOp.getSourceType().getElementType());
// 2.2 Create linalg.transpose
- auto transposedOp = rewriter.create<linalg::TransposeOp>(loc, input, empty,
- srcPermForTranspose);
+ auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
+ srcPermForTranspose);
// 3. Insert the inner tile to the destination:
// %inserted_tile = tensor.insert_slice(%transposed_tile)
@@ -1246,9 +1250,9 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
}
// 4. Replace tensor.packOp with tensor.insert_slice created above
- auto insert = rewriter.create<tensor::InsertSliceOp>(
- loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
- writeSizes, writeStrides);
+ auto insert = tensor::InsertSliceOp::create(
+ rewriter, loc, transposedOp.getResult()[0], packOp.getDest(),
+ writeOffsets, writeSizes, writeStrides);
rewriter.replaceOp(packOp, insert.getResult());
return success();
@@ -1313,7 +1317,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
// outer-untiled-dims
if (ShapedType::isDynamic(srcShape[i])) {
OpFoldResult dynamicDim =
- rewriter.create<tensor::DimOp>(loc, source, i).getResult();
+ tensor::DimOp::create(rewriter, loc, source, i).getResult();
extractSliceSizes.push_back(dynamicDim);
shapeForEmptyOp.push_back(dynamicDim);
} else {
@@ -1340,8 +1344,8 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
Type elemType = unpackOp.getSourceType().getElementType();
auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
- Value innerTile = rewriter.create<tensor::ExtractSliceOp>(
- loc, readType, unpackOp.getSource(), extractSliceOffsets,
+ Value innerTile = tensor::ExtractSliceOp::create(
+ rewriter, loc, readType, unpackOp.getSource(), extractSliceOffsets,
extractSliceSizes, extractSliceStrides);
// 2. Transpose the tile to match the outer corresponding tile order.
@@ -1352,9 +1356,9 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
applyPermutationToVector<OpFoldResult>(shapeForEmptyOp, perm);
Value empty =
- rewriter.create<tensor::EmptyOp>(loc, shapeForEmptyOp, elemType);
+ tensor::EmptyOp::create(rewriter, loc, shapeForEmptyOp, elemType);
auto transposedOp =
- rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm);
+ linalg::TransposeOp::create(rewriter, loc, innerTile, empty, perm);
// 3. Handle in-complete tiles if needed. It truncates trailing data from the
// transposed tile.
@@ -1369,8 +1373,9 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
tensor::getMixedSize(rewriter, loc, unpackOp.getDest(), i));
}
- auto partialTile = rewriter.create<tensor::ExtractSliceOp>(
- loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides);
+ auto partialTile =
+ tensor::ExtractSliceOp::create(rewriter, loc, transposedOp.getResult()[0],
+ tileOffsets, tileSizes, tileStrides);
// 4. Insert the result to the destination tensor.
SmallVector<OpFoldResult> writeSizes;
@@ -1382,9 +1387,9 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
else
writeSizes.push_back(oneIdxAttr);
}
- auto insert = rewriter.create<tensor::InsertSliceOp>(
- loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes,
- writeStrides);
+ auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile,
+ unpackOp.getDest(), writeOffsets,
+ writeSizes, writeStrides);
rewriter.replaceOp(unpackOp, insert.getResult());
return success();
@@ -1491,8 +1496,8 @@ FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
dilations.erase(dilations.begin() + (removeH ? 0 : 1));
auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
- auto conv1DOp = rewriter.create<Conv1DOp>(
- loc, newOutputType, ValueRange{newInput, newKernel},
+ auto conv1DOp = Conv1DOp::create(
+ rewriter, loc, newOutputType, ValueRange{newInput, newKernel},
ValueRange{newOutput}, stridesAttr, dilationsAttr);
// Insert back.
@@ -1578,8 +1583,8 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
dilations.erase(dilations.begin() + (removeH ? 0 : 1));
auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
- auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
- loc, newOutputType, ValueRange{newInput, newKernel},
+ auto conv1DOp = DepthwiseConv1DNwcWcOp::create(
+ rewriter, loc, newOutputType, ValueRange{newInput, newKernel},
ValueRange{newOutput}, stridesAttr, dilationsAttr);
// Insert back.
@@ -1635,9 +1640,9 @@ DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp,
Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, output, newOutputType);
- auto conv1DOp = rewriter.create<Conv1DOp>(loc, newOutputType,
- ValueRange{newInput, newKernel},
- ValueRange{newOutput});
+ auto conv1DOp =
+ Conv1DOp::create(rewriter, loc, newOutputType,
+ ValueRange{newInput, newKernel}, ValueRange{newOutput});
// Insert back.
Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp
index 092aecc..35453e2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp
@@ -67,18 +67,17 @@ FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter,
Value input;
if (isTensorOp) {
- input = rewriter.create<tensor::EmptyOp>(loc, newFilterShape, elementTy)
+ input = tensor::EmptyOp::create(rewriter, loc, newFilterShape, elementTy)
.getResult();
} else {
- input = rewriter
- .create<memref::AllocOp>(
- loc, MemRefType::get(newFilterShape, elementTy))
+ input = memref::AllocOp::create(rewriter, loc,
+ MemRefType::get(newFilterShape, elementTy))
.getResult();
}
// We can then construct the transposition on our filter.
auto transpose =
- rewriter.create<linalg::TransposeOp>(loc, filter, input, filterPerm);
+ linalg::TransposeOp::create(rewriter, loc, filter, input, filterPerm);
Value newFilter;
if (isTensorOp) {
@@ -98,8 +97,8 @@ FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter,
resultTy.push_back(op->getResult(0).getType());
}
auto newConv =
- rewriter.create<HWCFConvOp>(loc, resultTy, newInputs, op.getOutputs(),
- op.getStrides(), op.getDilations());
+ HWCFConvOp::create(rewriter, loc, resultTy, newInputs, op.getOutputs(),
+ op.getStrides(), op.getDilations());
rewriter.replaceOp(op, newConv);
return newConv.getOperation();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index 934781d..a2a4335 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -47,25 +47,25 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
SmallVector<Value> dynamicDims;
if (type.isDynamicDim(1))
- dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
+ dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 1));
if (type.isDynamicDim(0))
- dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
+ dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0));
ArrayRef<int64_t> shape = type.getShape();
- Value empty = rewriter.create<tensor::EmptyOp>(
- loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
- dynamicDims);
- auto transposeOp = rewriter.create<linalg::TransposeOp>(
- loc, input, empty, ArrayRef<int64_t>{1, 0});
+ Value empty = tensor::EmptyOp::create(rewriter, loc,
+ ArrayRef<int64_t>{shape[1], shape[0]},
+ type.getElementType(), dynamicDims);
+ auto transposeOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
+ ArrayRef<int64_t>{1, 0});
Operation *newMatmulOp;
if (transposeLHS) {
- newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
- loc, matmulOp.getResultTypes(),
+ newMatmulOp = linalg::MatmulTransposeAOp::create(
+ rewriter, loc, matmulOp.getResultTypes(),
ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
matmulOp.getOutputs());
} else {
- newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
- loc, matmulOp.getResultTypes(),
+ newMatmulOp = linalg::MatmulTransposeBOp::create(
+ rewriter, loc, matmulOp.getResultTypes(),
ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
matmulOp.getOutputs());
}
@@ -102,27 +102,27 @@ mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
SmallVector<Value> dynamicDims;
if (type.isDynamicDim(0))
- dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
+ dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0));
if (type.isDynamicDim(2))
- dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
+ dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 2));
if (type.isDynamicDim(1))
- dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
+ dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 1));
ArrayRef<int64_t> shape = type.getShape();
- Value empty = rewriter.create<tensor::EmptyOp>(
- loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
+ Value empty = tensor::EmptyOp::create(
+ rewriter, loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
type.getElementType(), dynamicDims);
- auto transposeOp = rewriter.create<linalg::TransposeOp>(
- loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
+ auto transposeOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
+ ArrayRef<int64_t>{0, 2, 1});
Operation *newMatmulOp;
if (transposeLHS) {
- newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
- loc, batchMatmulOp.getResultTypes(),
+ newMatmulOp = linalg::BatchMatmulTransposeAOp::create(
+ rewriter, loc, batchMatmulOp.getResultTypes(),
ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
batchMatmulOp.getOutputs());
} else {
- newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
- loc, batchMatmulOp.getResultTypes(),
+ newMatmulOp = linalg::BatchMatmulTransposeBOp::create(
+ rewriter, loc, batchMatmulOp.getResultTypes(),
ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
batchMatmulOp.getOutputs());
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 77c85ab..ea68b1a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -38,7 +38,8 @@
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/InterleavedRange.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
@@ -48,9 +49,6 @@ using namespace mlir::linalg;
#define DEBUG_TYPE "linalg-vectorization"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
-
/// Try to vectorize `convOp` as a convolution.
static FailureOr<Operation *>
vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
@@ -120,8 +118,9 @@ extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input,
SmallVector<int64_t> strides = {1};
for (int64_t kw = 0; kw < kwSize; ++kw) {
for (int64_t w = 0; w < wSize; w += wSizeStep) {
- result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
- loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
+ result.push_back(vector::ExtractStridedSliceOp::create(
+ rewriter, loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes,
+ strides));
}
}
} else {
@@ -131,8 +130,8 @@ extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input,
SmallVector<int64_t> strides = {1, 1, 1};
for (int64_t kw = 0; kw < kwSize; ++kw) {
for (int64_t w = 0; w < wSize; w += wSizeStep) {
- result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
- loc, input,
+ result.push_back(vector::ExtractStridedSliceOp::create(
+ rewriter, loc, input,
/*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
sizes, strides));
}
@@ -150,8 +149,8 @@ static SmallVector<Value> extractConvFilterSlices(RewriterBase &rewriter,
// Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
// non-chanelled convolution] @ [kw].
for (int64_t kw = 0; kw < kwSize; ++kw) {
- result.push_back(rewriter.create<vector::ExtractOp>(
- loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
+ result.push_back(vector::ExtractOp::create(
+ rewriter, loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
}
return result;
}
@@ -168,8 +167,9 @@ extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res,
SmallVector<int64_t> sizes = {wSizeStep};
SmallVector<int64_t> strides = {1};
for (int64_t w = 0; w < wSize; w += wSizeStep) {
- result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
- loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
+ result.push_back(vector::ExtractStridedSliceOp::create(
+ rewriter, loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes,
+ strides));
}
} else {
// Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
@@ -177,8 +177,9 @@ extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res,
SmallVector<int64_t> sizes = {nSize, wSizeStep, fSize};
SmallVector<int64_t> strides = {1, 1, 1};
for (int64_t w = 0; w < wSize; w += wSizeStep) {
- result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
- loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
+ result.push_back(vector::ExtractStridedSliceOp::create(
+ rewriter, loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes,
+ strides));
}
}
return result;
@@ -195,17 +196,18 @@ static Value insertConvResultSlices(RewriterBase &rewriter, Location loc,
// This does not depend on kw.
SmallVector<int64_t> strides = {1};
for (int64_t w = 0; w < wSize; w += wSizeStep) {
- res = rewriter.create<vector::InsertStridedSliceOp>(
- loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
+ res = vector::InsertStridedSliceOp::create(
+ rewriter, loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w},
+ strides);
}
} else {
// Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
// convolution. This does not depend on kw.
SmallVector<int64_t> strides = {1, 1, 1};
for (int64_t w = 0; w < wSize; w += wSizeStep) {
- res = rewriter.create<vector::InsertStridedSliceOp>(
- loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
- strides);
+ res = vector::InsertStridedSliceOp::create(
+ rewriter, loc, resVals[w], res,
+ /*offsets=*/ArrayRef<int64_t>{0, w, 0}, strides);
}
}
return res;
@@ -347,8 +349,8 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
// Create constant index op for static dimensions.
- iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
- linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
+ iterSpaceValueSizes.push_back(arith::ConstantIndexOp::create(
+ rewriter, linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
continue;
}
@@ -360,11 +362,12 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
operandDimPos)))
return failure();
- Value dynamicDim = linalgOp.hasPureTensorSemantics()
- ? (Value)rewriter.create<tensor::DimOp>(
- linalgOp.getLoc(), operand, operandDimPos)
- : (Value)rewriter.create<memref::DimOp>(
- linalgOp.getLoc(), operand, operandDimPos);
+ Value dynamicDim =
+ linalgOp.hasPureTensorSemantics()
+ ? (Value)tensor::DimOp::create(rewriter, linalgOp.getLoc(), operand,
+ operandDimPos)
+ : (Value)memref::DimOp::create(rewriter, linalgOp.getLoc(), operand,
+ operandDimPos);
iterSpaceValueSizes.push_back(dynamicDim);
}
@@ -398,12 +401,8 @@ LogicalResult VectorizationState::initState(RewriterBase &rewriter,
scalableVecDims.append(linalgOp.getNumLoops(), false);
}
- LDBG("Canonical vector shape: ");
- LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
- LLVM_DEBUG(llvm::dbgs() << "\n");
- LDBG("Scalable vector dims: ");
- LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
- LLVM_DEBUG(llvm::dbgs() << "\n");
+ LDBG() << "Canonical vector shape: " << llvm::interleaved(canonicalVecShape);
+ LDBG() << "Scalable vector dims: " << llvm::interleaved(scalableVecDims);
if (ShapedType::isDynamicShape(canonicalVecShape))
return failure();
@@ -447,14 +446,14 @@ Value VectorizationState::getOrCreateMaskFor(
: AffineMap::getMultiDimIdentityMap(
linalgOp.getNumLoops(), rewriter.getContext());
- LDBG("Masking map: " << maskingMap << "\n");
+ LDBG() << "Masking map: " << maskingMap;
// Return the active mask for the masking map of this operation if it was
// already created.
auto activeMaskIt = activeMaskCache.find(maskingMap);
if (activeMaskIt != activeMaskCache.end()) {
Value mask = activeMaskIt->second;
- LDBG("Reusing mask: " << mask << "\n");
+ LDBG() << "Reusing mask: " << mask;
return mask;
}
@@ -469,12 +468,10 @@ Value VectorizationState::getOrCreateMaskFor(
auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
auto maskShape = maskType.getShape();
- LDBG("Mask shape: ");
- LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
- LLVM_DEBUG(llvm::dbgs() << "\n");
+ LDBG() << "Mask shape: " << llvm::interleaved(maskShape);
if (permutedStaticSizes == maskShape) {
- LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
+ LDBG() << "Masking is not needed for masking map: " << maskingMap;
activeMaskCache[maskingMap] = Value();
return Value();
}
@@ -489,8 +486,9 @@ Value VectorizationState::getOrCreateMaskFor(
? true
: std::get<0>(it) == std::get<1>(it);
})) {
- LDBG("Dynamic + static dimensions match vector sizes, masking is not "
- "required.\n");
+ LDBG()
+ << "Dynamic + static dimensions match vector sizes, masking is not "
+ "required.";
activeMaskCache[maskingMap] = Value();
return Value();
}
@@ -503,9 +501,9 @@ Value VectorizationState::getOrCreateMaskFor(
"Masked 0-d vectors are not supported yet");
// Create the mask based on the dimension values.
- Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
- maskType, upperBounds);
- LDBG("Creating new mask: " << mask << "\n");
+ Value mask = vector::CreateMaskOp::create(rewriter, linalgOp.getLoc(),
+ maskType, upperBounds);
+ LDBG() << "Creating new mask: " << mask;
activeMaskCache[maskingMap] = mask;
return mask;
}
@@ -514,7 +512,7 @@ Operation *
VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
LinalgOp linalgOp,
std::optional<AffineMap> maybeIndexingMap) {
- LDBG("Trying to mask: " << *opToMask << "\n");
+ LDBG() << "Trying to mask: " << *opToMask;
std::optional<AffineMap> maybeMaskingMap = std::nullopt;
if (maybeIndexingMap)
@@ -525,7 +523,7 @@ VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
if (!mask) {
- LDBG("No mask required\n");
+ LDBG() << "No mask required";
return opToMask;
}
@@ -539,7 +537,7 @@ VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
maskOpTerminator);
- LDBG("Masked operation: " << *maskOp << "\n");
+ LDBG() << "Masked operation: " << *maskOp;
return maskOp;
}
@@ -672,8 +670,8 @@ static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
ArrayRef<bool> dimsToMask) {
auto maybeKind = getCombinerOpKind(reduceOp);
assert(maybeKind && "Failed precondition: could not get reduction kind");
- return b.create<vector::MultiDimReductionOp>(
- reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
+ return vector::MultiDimReductionOp::create(
+ b, reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
}
static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
@@ -717,19 +715,20 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
Operation *write;
if (vectorType.getRank() > 0) {
AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
- SmallVector<Value> indices(linalgOp.getRank(outputOperand),
- rewriter.create<arith::ConstantIndexOp>(loc, 0));
+ SmallVector<Value> indices(
+ linalgOp.getRank(outputOperand),
+ arith::ConstantIndexOp::create(rewriter, loc, 0));
value = broadcastIfNeeded(rewriter, value, vectorType);
assert(value.getType() == vectorType && "Incorrect type");
- write = rewriter.create<vector::TransferWriteOp>(
- loc, value, outputOperand->get(), indices, writeMap);
+ write = vector::TransferWriteOp::create(
+ rewriter, loc, value, outputOperand->get(), indices, writeMap);
} else {
// 0-d case is still special: do not invert the reindexing writeMap.
if (!isa<VectorType>(value.getType()))
- value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
+ value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
assert(value.getType() == vectorType && "Incorrect type");
- write = rewriter.create<vector::TransferWriteOp>(
- loc, value, outputOperand->get(), ValueRange{});
+ write = vector::TransferWriteOp::create(rewriter, loc, value,
+ outputOperand->get(), ValueRange{});
}
write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
@@ -742,7 +741,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
}
- LDBG("vectorized op: " << *write << "\n");
+ LDBG() << "vectorized op: " << *write;
if (!write->getResults().empty())
return write->getResult(0);
return Value();
@@ -807,7 +806,7 @@ static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
auto indexVectorType =
VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
state.getScalableVecDims()[dim]);
- auto indexSteps = rewriter.create<vector::StepOp>(loc, indexVectorType);
+ auto indexSteps = vector::StepOp::create(rewriter, loc, indexVectorType);
// Return the one-dimensional index vector if it lives in the trailing
// dimension of the iteration space since the vectorization algorithm in this
// case can handle the broadcast.
@@ -822,14 +821,14 @@ static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
auto permMap =
AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
- auto broadCastOp = rewriter.create<vector::BroadcastOp>(
- loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
- indexSteps);
+ auto broadCastOp = vector::BroadcastOp::create(
+ rewriter, loc,
+ state.getCanonicalVecType(rewriter.getIndexType(), permMap), indexSteps);
SmallVector<int64_t> transposition =
llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
std::swap(transposition.back(), transposition[dim]);
auto transposeOp =
- rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
+ vector::TransposeOp::create(rewriter, loc, broadCastOp, transposition);
return VectorizationHookResult{VectorizationHookStatus::NewOp, transposeOp};
}
@@ -882,19 +881,19 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
const size_t numIndices = extractOp.getIndices().size();
for (size_t i = 1; i < numIndices; i++) {
- Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ Value dimIdx = arith::ConstantIndexOp::create(rewriter, loc, i);
auto dimSize = broadcastIfNeeded(
rewriter,
- rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
+ tensor::DimOp::create(rewriter, loc, extractOp.getTensor(), dimIdx),
indexVecType);
- offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
+ offset = arith::MulIOp::create(rewriter, loc, offset, dimSize);
auto extractOpIndex = broadcastIfNeeded(
rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
- offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
+ offset = arith::AddIOp::create(rewriter, loc, extractOpIndex, offset);
}
return offset;
@@ -1084,7 +1083,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
}
if (!leadingIdxsLoopInvariant) {
- LDBG("Found gather load: " << extractOp);
+ LDBG() << "Found gather load: " << extractOp;
return VectorMemoryAccessKind::Gather;
}
@@ -1098,7 +1097,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
// If the trailing index is loop invariant then this is a scalar load.
if (leadingIdxsLoopInvariant &&
isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
- LDBG("Found scalar broadcast load: " << extractOp);
+ LDBG() << "Found scalar broadcast load: " << extractOp;
return VectorMemoryAccessKind::ScalarBroadcast;
}
@@ -1116,12 +1115,12 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
isContiguousLoad &= (foundIndexOp && isRowVector);
if (isContiguousLoad) {
- LDBG("Found contigous load: " << extractOp);
+ LDBG() << "Found contigous load: " << extractOp;
return VectorMemoryAccessKind::Contiguous;
}
// 4. Fallback case - gather load.
- LDBG("Found gather load: " << extractOp);
+ LDBG() << "Found gather load: " << extractOp;
return VectorMemoryAccessKind::Gather;
}
@@ -1139,18 +1138,18 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
// Compute the static loop sizes of the extract op.
auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
- auto maskConstantOp = rewriter.create<arith::ConstantOp>(
- loc,
+ auto maskConstantOp = arith::ConstantOp::create(
+ rewriter, loc,
DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
/*value=*/true));
- auto passThruConstantOp =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
+ auto passThruConstantOp = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getZeroAttr(resultType));
// Base indices are currently set to 0. We will need to re-visit if more
// generic scenarios are to be supported.
SmallVector<Value> baseIndices(
extractOp.getIndices().size(),
- rewriter.create<arith::ConstantIndexOp>(loc, 0));
+ arith::ConstantIndexOp::create(rewriter, loc, 0));
VectorMemoryAccessKind memAccessKind =
getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
@@ -1160,12 +1159,12 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
// Generate the gather load
- Operation *gatherOp = rewriter.create<vector::GatherOp>(
- loc, resultType, extractOp.getTensor(), baseIndices, offset,
+ Operation *gatherOp = vector::GatherOp::create(
+ rewriter, loc, resultType, extractOp.getTensor(), baseIndices, offset,
maskConstantOp, passThruConstantOp);
gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
- LDBG("Vectorised as gather load: " << extractOp << "\n");
+ LDBG() << "Vectorised as gather load: " << extractOp;
return VectorizationHookResult{VectorizationHookStatus::NewOp, gatherOp};
}
@@ -1195,13 +1194,13 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
continue;
}
- auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
- loc,
+ auto indexAs1dVector = vector::ShapeCastOp::create(
+ rewriter, loc,
VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
resultType.getScalableDims().back()),
idx);
transferReadIdxs.push_back(
- rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0));
+ vector::ExtractOp::create(rewriter, loc, indexAs1dVector, 0));
}
// `tensor.extract_element` is always in-bounds, hence the following holds.
@@ -1215,8 +1214,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
- auto transferReadOp = rewriter.create<vector::TransferReadOp>(
- loc, resultType, extractOp.getTensor(), transferReadIdxs,
+ auto transferReadOp = vector::TransferReadOp::create(
+ rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
/*padding=*/std::nullopt, permutationMap, inBounds);
// Mask this broadcasting xfer_read here rather than relying on the generic
@@ -1224,12 +1223,12 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
// valid here).
SmallVector<int64_t> readMaskShape = {1};
auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
- auto allTrue = rewriter.create<vector::ConstantMaskOp>(
- loc, readMaskType, vector::ConstantMaskKind::AllTrue);
+ auto allTrue = vector::ConstantMaskOp::create(
+ rewriter, loc, readMaskType, vector::ConstantMaskKind::AllTrue);
auto *maskedReadOp =
mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
- LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
+ LDBG() << "Vectorised as scalar broadcast load: " << extractOp;
return VectorizationHookResult{VectorizationHookStatus::NewOp,
maskedReadOp};
}
@@ -1252,11 +1251,11 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
rankDiff--;
}
- auto transferReadOp = rewriter.create<vector::TransferReadOp>(
- loc, resultType, extractOp.getTensor(), transferReadIdxs,
+ auto transferReadOp = vector::TransferReadOp::create(
+ rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
/*padding=*/std::nullopt, permutationMap, inBounds);
- LDBG("Vectorised as contiguous load: " << extractOp);
+ LDBG() << "Vectorised as contiguous load: " << extractOp;
return VectorizationHookResult{VectorizationHookStatus::NewOp,
transferReadOp};
}
@@ -1304,7 +1303,7 @@ static VectorizationHookResult
vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
- LDBG("vectorize op " << *op << "\n");
+ LDBG() << "vectorize op " << *op;
// 1. Try to apply any CustomVectorizationHook.
if (!customVectorizationHooks.empty()) {
@@ -1419,7 +1418,7 @@ static LogicalResult
vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
LinalgOp linalgOp,
SmallVectorImpl<Value> &newResults) {
- LDBG("Vectorizing operation as linalg generic\n");
+ LDBG() << "Vectorizing operation as linalg generic/n";
Block *block = linalgOp.getBlock();
// 2. Values defined above the region can only be broadcast for now. Make them
@@ -1434,7 +1433,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
// 3. Turn all BBArgs into vector.transfer_read / load.
Location loc = linalgOp.getLoc();
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
if (linalgOp.isScalar(opOperand)) {
@@ -1464,8 +1463,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
- Operation *read = rewriter.create<vector::TransferReadOp>(
- loc, readType, opOperand->get(), indices,
+ Operation *read = vector::TransferReadOp::create(
+ rewriter, loc, readType, opOperand->get(), indices,
/*padding=*/std::nullopt, readMap);
read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
Value readValue = read->getResult(0);
@@ -1481,11 +1480,11 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
// 3.c. Not all ops support 0-d vectors, extract the scalar for now.
// TODO: remove this.
if (readType.getRank() == 0)
- readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
- ArrayRef<int64_t>());
+ readValue = vector::ExtractOp::create(rewriter, loc, readValue,
+ ArrayRef<int64_t>());
- LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
- << "\n");
+ LDBG() << "New vectorized bbarg(" << bbarg.getArgNumber()
+ << "): " << readValue;
bvm.map(bbarg, readValue);
bvm.map(opOperand->get(), readValue);
}
@@ -1517,13 +1516,13 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
VectorizationHookResult result =
vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
if (result.status == VectorizationHookStatus::Failure) {
- LDBG("failed to vectorize: " << op << "\n");
+ LDBG() << "failed to vectorize: " << op;
return failure();
}
if (result.status == VectorizationHookStatus::NewOp) {
Operation *maybeMaskedOp =
state.maskOperation(rewriter, result.newOp, linalgOp);
- LDBG("New vector op: " << *maybeMaskedOp << "\n");
+ LDBG() << "New vector op: " << *maybeMaskedOp;
bvm.map(op.getResults(), maybeMaskedOp->getResults());
}
}
@@ -1689,17 +1688,16 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
writeIndices.size() == static_cast<size_t>(destRank)) &&
"Invalid number of write indices!");
if (writeIndices.empty()) {
- auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
writeIndices.assign(destRank, zero);
}
// Generate the xfer_write Op
- Operation *write =
- builder.create<vector::TransferWriteOp>(loc,
- /*vector=*/vecToStore,
- /*source=*/dest,
- /*indices=*/writeIndices,
- /*inBounds=*/inBoundsVal);
+ Operation *write = vector::TransferWriteOp::create(builder, loc,
+ /*vector=*/vecToStore,
+ /*source=*/dest,
+ /*indices=*/writeIndices,
+ /*inBounds=*/inBoundsVal);
// If masking is disabled, exit.
if (useInBoundsInsteadOfMasking)
@@ -1774,8 +1772,9 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
Location loc = packOp.getLoc();
auto padValue = packOp.getPaddingValue();
if (!padValue) {
- padValue = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
+ padValue = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
}
ReifiedRankedShapedTypeDims reifiedReturnShapes;
LogicalResult status =
@@ -1814,17 +1813,17 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
packOp.getDestType().getElementType());
auto shapeCastOp =
- rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
+ vector::ShapeCastOp::create(rewriter, loc, tiledPackType, maskedRead);
// Create TransposeOp.
auto destPermutation =
invertPermutationVector(getPackInverseDestPerm(packOp));
- auto transposeOp = rewriter.create<vector::TransposeOp>(
- loc, shapeCastOp.getResult(), destPermutation);
+ auto transposeOp = vector::TransposeOp::create(
+ rewriter, loc, shapeCastOp.getResult(), destPermutation);
// Create TransferWriteOp.
- Value dest = rewriter.create<tensor::EmptyOp>(
- loc, reifiedReturnShapes[0],
+ Value dest = tensor::EmptyOp::create(
+ rewriter, loc, reifiedReturnShapes[0],
transposeOp.getResult().getType().getElementType());
Operation *write =
createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest);
@@ -1914,18 +1913,11 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
sourceShape.end());
- ReifiedRankedShapedTypeDims reifiedRetShapes;
- LogicalResult status =
- cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
- .reifyResultShapes(rewriter, reifiedRetShapes);
- if (status.failed()) {
- LDBG("Unable to reify result shapes of " << unpackOp);
- return failure();
- }
Location loc = unpackOp->getLoc();
- auto padValue = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
+ auto padValue = arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
// Read result, mask if necessary. If transferReadOp shape is not equal
// to shape of source, then a mask is necessary.
@@ -1943,23 +1935,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
RankedTensorType stripMineTensorType =
RankedTensorType::get(stripMineShape, stripMineElemType);
// Transpose the appropriate rows to match output.
- vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
- loc, readResult, lastDimToInsertPosPerm);
+ vector::TransposeOp transposeOp = vector::TransposeOp::create(
+ rewriter, loc, readResult, lastDimToInsertPosPerm);
// Collapse the vector to the size required by result.
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
stripMineTensorType, packMetadata.reassociations);
mlir::VectorType vecCollapsedType =
VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
- vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
- loc, vecCollapsedType, transposeOp->getResult(0));
-
- // writeVectorSizes had to match the shapecast shape for dynamic sizes,
- // otherwise the validator complains that the mask size is invalid.
- SmallVector<int64_t> writeVectorSizes(
- unpackOp.getDestType().hasStaticShape()
- ? vectorSizes
- : shapeCastOp.getResultVectorType().getShape());
+ vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
+ rewriter, loc, vecCollapsedType, transposeOp->getResult(0));
+
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
@@ -1992,8 +1978,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
/*useInBoundsInsteadOfMasking=*/false);
// Create Xfer write Op
- Value dest = rewriter.create<tensor::EmptyOp>(
- loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
+ Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
+ padOp.getResultType().getElementType());
Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
newResults.push_back(write->getResult(0));
return success();
@@ -2003,7 +1989,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
// ops that may not commute (e.g. linear reduction + non-linear instructions).
static LogicalResult reductionPreconditions(LinalgOp op) {
if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
- LDBG("reduction precondition failed: no reduction iterator\n");
+ LDBG() << "reduction precondition failed: no reduction iterator";
return failure();
}
for (OpOperand &opOperand : op.getDpsInitsMutable()) {
@@ -2013,7 +1999,7 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
Operation *reduceOp = matchLinalgReduction(&opOperand);
if (!reduceOp || !getCombinerOpKind(reduceOp)) {
- LDBG("reduction precondition failed: reduction detection failed\n");
+ LDBG() << "reduction precondition failed: reduction detection failed";
return failure();
}
}
@@ -2024,13 +2010,13 @@ static LogicalResult
vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
bool flatten1DDepthwiseConv) {
if (flatten1DDepthwiseConv) {
- LDBG("Vectorization of flattened convs with dynamic shapes is not "
- "supported\n");
+ LDBG() << "Vectorization of flattened convs with dynamic shapes is not "
+ "supported";
return failure();
}
if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
- LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
+ LDBG() << "Not a 1D depth-wise WC conv, dynamic shapes are not supported";
return failure();
}
@@ -2040,8 +2026,8 @@ vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
auto shapeWithoutCh = lhsShape.drop_back(1);
if (ShapedType::isDynamicShape(shapeWithoutCh)) {
- LDBG("Dynamically-shaped op vectorization precondition failed: only "
- "channel dim can be dynamic\n");
+ LDBG() << "Dynamically-shaped op vectorization precondition failed: only "
+ "channel dim can be dynamic";
return failure();
}
@@ -2064,7 +2050,7 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
op.getOperation()))
return failure();
- LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
+ LDBG() << "Dynamically-shaped op meets vectorization pre-conditions";
return success();
}
@@ -2076,7 +2062,7 @@ vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
return !getConstantIntValue(res).has_value();
})) {
- LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
+ LDBG() << "Inner-tiles must be constant: " << unpackOp;
return failure();
}
ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
@@ -2116,7 +2102,7 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
!sourceType.hasStaticShape() && inputVectorSizes.empty();
if (!padValue && isOutOfBoundsRead) {
- LDBG("Failed to get a pad value for out-of-bounds read access\n");
+ LDBG() << "Failed to get a pad value for out-of-bounds read access";
return failure();
}
return success();
@@ -2146,7 +2132,7 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
Operation *reduceOp = matchLinalgReduction(outOperand);
auto maybeKind = getCombinerOpKind(reduceOp);
if (!maybeKind) {
- LDBG("Failed to determine contraction combining kind.\n");
+ LDBG() << "Failed to determine contraction combining kind.";
return failure();
}
@@ -2156,7 +2142,7 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) {
- LDBG("Contractions with broadcasts are not supported.\n");
+ LDBG() << "Contractions with broadcasts are not supported.";
return failure();
}
@@ -2191,8 +2177,8 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
}
// Create contraction.
- Operation *contractOp = rewriter.create<vector::ContractionOp>(
- loc, /*lhs=*/vecOperands[0],
+ Operation *contractOp = vector::ContractionOp::create(
+ rewriter, loc, /*lhs=*/vecOperands[0],
/*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
@@ -2348,7 +2334,7 @@ static LogicalResult vectorizeLinalgOpPrecondition(
if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
linalgOp, flatten1DDepthwiseConv))) {
- LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
+ LDBG() << "Dynamically-shaped op failed vectorization pre-conditions";
return failure();
}
@@ -2390,11 +2376,11 @@ static LogicalResult vectorizeLinalgOpPrecondition(
// all indexing maps are projected permutations. For convs and stencils the
// logic will need to evolve.
if (!allIndexingsAreProjectedPermutation(linalgOp)) {
- LDBG("precondition failed: not projected permutations\n");
+ LDBG() << "precondition failed: not projected permutations";
return failure();
}
if (failed(reductionPreconditions(linalgOp))) {
- LDBG("precondition failed: reduction preconditions\n");
+ LDBG() << "precondition failed: reduction preconditions";
return failure();
}
return success();
@@ -2406,7 +2392,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
auto padValue = packOp.getPaddingValue();
Attribute cstAttr;
if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
- LDBG("pad value is not constant: " << packOp << "\n");
+ LDBG() << "pad value is not constant: " << packOp;
return failure();
}
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
@@ -2426,7 +2412,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
return !getConstantIntValue(v).has_value();
})) {
- LDBG("inner_tiles must be constant: " << packOp << "\n");
+ LDBG() << "inner_tiles must be constant: " << packOp;
return failure();
}
@@ -2438,7 +2424,7 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
ArrayRef<int64_t> inputVectorSizes) {
auto padValue = padOp.getConstantPaddingValue();
if (!padValue) {
- LDBG("pad value is not constant: " << padOp << "\n");
+ LDBG() << "pad value is not constant: " << padOp;
return failure();
}
@@ -2465,7 +2451,7 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
return (!pad.has_value() || pad.value() != 0) &&
resultTensorShape[pos] != 1;
})) {
- LDBG("low pad must all be zero for all non unit dims: " << padOp << "\n");
+ LDBG() << "low pad must all be zero for all non unit dims: " << padOp;
return failure();
}
@@ -2534,13 +2520,14 @@ vectorizeScalableVectorPrecondition(Operation *op,
case utils::IteratorType::reduction: {
// Check 3. above is met.
if (iterators.size() != inputVectorSizes.size()) {
- LDBG("Non-trailing reduction dim requested for scalable "
- "vectorization\n");
+ LDBG() << "Non-trailing reduction dim requested for scalable "
+ "vectorization";
return failure();
}
if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
- LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
- "is not supported\n");
+ LDBG()
+ << "Scalable vectorization of the reduction dim in Matmul-like ops "
+ "is not supported";
return failure();
}
break;
@@ -2548,8 +2535,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
case utils::IteratorType::parallel: {
// Check 1. and 2. above are met.
if (seenNonUnitParallel) {
- LDBG("Inner parallel dim not requested for scalable "
- "vectorization\n");
+ LDBG() << "Inner parallel dim not requested for scalable "
+ "vectorization";
return failure();
}
break;
@@ -2565,8 +2552,9 @@ vectorizeScalableVectorPrecondition(Operation *op,
// * iterators = [..., parallel, reduction]
// * scalable flags = [..., true, true]
if (iterators.back() == utils::IteratorType::reduction) {
- LDBG("Higher dim than the trailing reduction dim requested for scalable "
- "vectorization\n");
+ LDBG() << "Higher dim than the trailing reduction dim requested for "
+ "scalable "
+ "vectorizatio";
return failure();
}
scalableFlags.pop_back();
@@ -2649,18 +2637,15 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes,
bool createNamedContraction) {
- LDBG("Attempting to vectorize:\n" << *op << "\n");
- LDBG("Input vector sizes: ");
- LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
- LLVM_DEBUG(llvm::dbgs() << "\n");
- LDBG("Input scalable vector dims: ");
- LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
- LLVM_DEBUG(llvm::dbgs() << "\n");
+ LDBG() << "Attempting to vectorize: " << *op;
+ LDBG() << "Input vector sizes: " << llvm::interleaved(inputVectorSizes);
+ LDBG() << "Input scalable vector dims: "
+ << llvm::interleaved(inputScalableVecDims);
if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
vectorizeNDExtract,
flatten1DDepthwiseConv))) {
- LDBG("Vectorization pre-conditions failed\n");
+ LDBG() << "Vectorization pre-conditions failed";
return failure();
}
@@ -2670,7 +2655,7 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
inputScalableVecDims,
assumeDynamicDimsMatchVecSizes))) {
- LDBG("Vectorization state couldn't be initialized\n");
+ LDBG() << "Vectorization state couldn't be initialized";
return failure();
}
}
@@ -2691,7 +2676,7 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
return success();
}
- LDBG("Unsupported convolution can't be vectorized.\n");
+ LDBG() << "Unsupported convolution can't be vectorized.";
return failure();
}
@@ -2700,8 +2685,9 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
return vectorizeAsLinalgContraction(rewriter, state, linalgOp,
results);
- LDBG("Vectorize generic by broadcasting to the canonical vector "
- "shape\n");
+ LDBG()
+ << "Vectorize generic by broadcasting to the canonical vector "
+ "shape";
// Pre-process before proceeding.
convertAffineApply(rewriter, linalgOp);
@@ -2732,7 +2718,7 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
.Default([](auto) { return failure(); });
if (failed(vectorizeResult)) {
- LDBG("Vectorization failed\n");
+ LDBG() << "Vectorization failed";
return failure();
}
@@ -2756,20 +2742,21 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
auto writeType = VectorType::get(dstType.getShape(), dstElementType);
Location loc = copyOp->getLoc();
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
SmallVector<Value> indices(srcType.getRank(), zero);
- Value readValue = rewriter.create<vector::TransferReadOp>(
- loc, readType, copyOp.getSource(), indices,
+ Value readValue = vector::TransferReadOp::create(
+ rewriter, loc, readType, copyOp.getSource(), indices,
/*padding=*/std::nullopt,
rewriter.getMultiDimIdentityMap(srcType.getRank()));
if (cast<VectorType>(readValue.getType()).getRank() == 0) {
+ readValue = vector::ExtractOp::create(rewriter, loc, readValue,
+ ArrayRef<int64_t>());
readValue =
- rewriter.create<vector::ExtractOp>(loc, readValue, ArrayRef<int64_t>());
- readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
+ vector::BroadcastOp::create(rewriter, loc, writeType, readValue);
}
- Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
- loc, readValue, copyOp.getTarget(), indices,
+ Operation *writeValue = vector::TransferWriteOp::create(
+ rewriter, loc, readValue, copyOp.getTarget(), indices,
rewriter.getMultiDimIdentityMap(srcType.getRank()));
rewriter.replaceOp(copyOp, writeValue->getResults());
return success();
@@ -3079,8 +3066,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
if (!padValue) {
auto elemType = sourceType.getElementType();
- padValue = rewriter.create<arith::ConstantOp>(
- sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
+ padValue = arith::ConstantOp::create(rewriter, sliceOp.getLoc(), elemType,
+ rewriter.getZeroAttr(elemType));
}
// 2. Get the vector shape
@@ -3111,7 +3098,7 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
// Create read
SmallVector<Value> readIndices(
- vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
+ vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
Value read = mlir::vector::createReadOrMaskedRead(
rewriter, loc, source, vecType.getShape(), padValue,
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
@@ -3198,9 +3185,10 @@ struct PadOpVectorizationWithInsertSlicePattern
// Generate TransferReadOp: Read entire source tensor and add high
// padding.
SmallVector<Value> readIndices(
- vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
- auto read = rewriter.create<vector::TransferReadOp>(
- padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
+ vecRank, arith::ConstantIndexOp::create(rewriter, padOp.getLoc(), 0));
+ auto read = vector::TransferReadOp::create(rewriter, padOp.getLoc(),
+ vecType, padOp.getSource(),
+ readIndices, padValue);
// Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
// specified offsets. Write is fully in-bounds because a InsertSliceOp's
@@ -3235,8 +3223,8 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
ValueRange values) {
if (firstOp->getBlock() != secondOp->getBlock() ||
!firstOp->isBeforeInBlock(secondOp)) {
- LDBG("interleavedUses precondition failed, firstOp: "
- << *firstOp << ", second op: " << *secondOp << "\n");
+ LDBG() << "interleavedUses precondition failed, firstOp: " << *firstOp
+ << ", second op: " << *secondOp;
return true;
}
for (auto v : values) {
@@ -3248,8 +3236,8 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
if (owner->getBlock() == firstOp->getBlock() &&
(owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
continue;
- LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
- << ", second op: " << *secondOp << "\n");
+ LDBG() << " found interleaved op " << *owner << ", firstOp: " << *firstOp
+ << ", second op: " << *secondOp;
return true;
}
}
@@ -3334,8 +3322,8 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
// When forwarding to vector.transfer_read, the attribute must be reset
// conservatively.
auto vectorType = xferOp.getVectorType();
- Value res = rewriter.create<vector::TransferReadOp>(
- xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
+ Value res = vector::TransferReadOp::create(
+ rewriter, xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
rewriter.getBoolArrayAttr(
SmallVector<bool>(vectorType.getRank(), false)));
@@ -3393,8 +3381,8 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
// When forwarding to vector.transfer_write, the attribute must be reset
// conservatively.
auto vector = xferOp.getVector();
- rewriter.create<vector::TransferWriteOp>(
- xferOp.getLoc(), vector, out, xferOp.getIndices(),
+ vector::TransferWriteOp::create(
+ rewriter, xferOp.getLoc(), vector, out, xferOp.getIndices(),
xferOp.getPermutationMapAttr(), xferOp.getMask(),
rewriter.getBoolArrayAttr(SmallVector<bool>(
dyn_cast<VectorType>(vector.getType()).getRank(), false)));
@@ -3589,7 +3577,7 @@ struct Conv1DGenerator
}
vector::TransferWriteOp write;
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
// w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
// When strideW == 1, we can batch the contiguous loads and avoid
@@ -3608,17 +3596,17 @@ struct Conv1DGenerator
SmallVector<Value> resPadding(resShape.size(), zero);
// Read the whole lhs, rhs and res in one shot (with zero padding).
- Value lhs = rewriter.create<vector::TransferReadOp>(
- loc, lhsType, lhsShaped, lhsPadding,
+ Value lhs = vector::TransferReadOp::create(
+ rewriter, loc, lhsType, lhsShaped, lhsPadding,
/*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
// This is needed only for Conv.
Value rhs = nullptr;
if (oper == ConvOperationKind::Conv)
- rhs = rewriter.create<vector::TransferReadOp>(
- loc, rhsType, rhsShaped, rhsPadding,
+ rhs = vector::TransferReadOp::create(
+ rewriter, loc, rhsType, rhsShaped, rhsPadding,
/*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
- Value res = rewriter.create<vector::TransferReadOp>(
- loc, resType, resShaped, resPadding,
+ Value res = vector::TransferReadOp::create(
+ rewriter, loc, resType, resShaped, resPadding,
/*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
// The base vectorization case for channeled convolution is input:
@@ -3633,16 +3621,16 @@ struct Conv1DGenerator
// To match base vectorization case, we pre-transpose current case.
// ncw -> nwc
static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
+ lhs = vector::TransposeOp::create(rewriter, loc, lhs, permLhs);
// fcw -> wcf
static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
// This is needed only for Conv.
if (oper == ConvOperationKind::Conv)
- rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
+ rhs = vector::TransposeOp::create(rewriter, loc, rhs, permRhs);
// nfw -> nwf
static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
- res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
+ res = vector::TransposeOp::create(rewriter, loc, res, permRes);
break;
}
}
@@ -3707,13 +3695,13 @@ struct Conv1DGenerator
case Conv1DOpOrder::Ncw: {
// nwf -> nfw
static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
- res = rewriter.create<vector::TransposeOp>(loc, res, perm);
+ res = vector::TransposeOp::create(rewriter, loc, res, perm);
break;
}
}
- return rewriter
- .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
+ return vector::TransferWriteOp::create(rewriter, loc, res, resShaped,
+ resPadding)
.getOperation();
}
@@ -3731,16 +3719,16 @@ struct Conv1DGenerator
cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
- return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
+ return arith::SIToFPOp::create(rewriter, loc, dstType, val);
}
if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
srcWidth < dstWidth)
- return rewriter.create<arith::ExtFOp>(loc, dstType, val);
+ return arith::ExtFOp::create(rewriter, loc, dstType, val);
if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
srcWidth < dstWidth)
- return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
+ return arith::ExtSIOp::create(rewriter, loc, dstType, val);
assert(false && "unhandled promotion case");
return nullptr;
@@ -3755,8 +3743,8 @@ struct Conv1DGenerator
bindDims(ctx, n, w, f, c);
lhs = promote(rewriter, loc, lhs, res.getType());
rhs = promote(rewriter, loc, rhs, res.getType());
- auto contrationOp = rewriter.create<vector::ContractionOp>(
- loc, lhs, rhs, res,
+ auto contrationOp = vector::ContractionOp::create(
+ rewriter, loc, lhs, rhs, res,
/*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
/*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
contrationOp.setKind(reductionKind);
@@ -3767,8 +3755,8 @@ struct Conv1DGenerator
// convolution.
Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
Value lhs, Value rhs, Value res) {
- return rewriter.create<vector::OuterProductOp>(
- loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
+ return vector::OuterProductOp::create(rewriter, loc, res.getType(), lhs,
+ rhs, res, vector::CombiningKind::ADD);
}
// Create a reduction: lhs{n, w, c} -> res{n, w, c}
@@ -3815,7 +3803,7 @@ struct Conv1DGenerator
bindShapeDims(resShapedType, nSize, wSize);
vector::TransferWriteOp write;
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
// w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
// When strideW == 1, we can batch the contiguous loads and avoid
@@ -3858,29 +3846,29 @@ struct Conv1DGenerator
cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
Value maskOp =
- rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims);
+ vector::CreateMaskOp::create(rewriter, loc, maskType, mixedDims);
return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
};
// Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
// 0].
- Value lhs = rewriter.create<vector::TransferReadOp>(
- loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
+ Value lhs = vector::TransferReadOp::create(
+ rewriter, loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
/*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
auto maybeMaskedLhs = maybeMaskXferOp(
lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
// Read rhs slice of size {kw, c} @ [0, 0].
- Value rhs = rewriter.create<vector::TransferReadOp>(
- loc, rhsType, rhsShaped, ValueRange{zero, zero},
+ Value rhs = vector::TransferReadOp::create(
+ rewriter, loc, rhsType, rhsShaped, ValueRange{zero, zero},
/*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
auto maybeMaskedRhs = maybeMaskXferOp(
rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
// Read res slice of size {n, w, c} @ [0, 0, 0].
- Value res = rewriter.create<vector::TransferReadOp>(
- loc, resType, resShaped, ValueRange{zero, zero, zero},
+ Value res = vector::TransferReadOp::create(
+ rewriter, loc, resType, resShaped, ValueRange{zero, zero, zero},
/*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
auto maybeMaskedRes = maybeMaskXferOp(
resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
@@ -3897,22 +3885,22 @@ struct Conv1DGenerator
// @ [0, sw * w + dw * kw, 0].
for (int64_t kw = 0; kw < kwSize; ++kw) {
for (int64_t w = 0; w < wSize; w += wSizeStep) {
- lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
- loc, maybeMaskedLhs->getResult(0),
+ lhsVals.push_back(vector::ExtractStridedSliceOp::create(
+ rewriter, loc, maybeMaskedLhs->getResult(0),
/*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
inOutSliceSizes, inOutStrides));
}
}
// Extract rhs slice of size {c} @ [kw].
for (int64_t kw = 0; kw < kwSize; ++kw) {
- rhsVals.push_back(rewriter.create<vector::ExtractOp>(
- loc, maybeMaskedRhs->getResult(0),
- /*offsets=*/ArrayRef<int64_t>{kw}));
+ rhsVals.push_back(
+ vector::ExtractOp::create(rewriter, loc, maybeMaskedRhs->getResult(0),
+ /*offsets=*/ArrayRef<int64_t>{kw}));
}
// Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
for (int64_t w = 0; w < wSize; w += wSizeStep) {
- resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
- loc, maybeMaskedRes->getResult(0),
+ resVals.push_back(vector::ExtractStridedSliceOp::create(
+ rewriter, loc, maybeMaskedRes->getResult(0),
/*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
inOutStrides));
}
@@ -3937,17 +3925,19 @@ struct Conv1DGenerator
if (flatten) {
// Flatten the input and output vectors (collapse the channel
// dimension)
- lhsVal = rewriter.create<vector::ShapeCastOp>(
- loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
- resVal = rewriter.create<vector::ShapeCastOp>(
- loc, resTypeAfterFlattening, resVals[w]);
+ lhsVal =
+ vector::ShapeCastOp::create(rewriter, loc, lhsTypeAfterFlattening,
+ lhsVals[linearIndex(kw, w)]);
+ resVal = vector::ShapeCastOp::create(
+ rewriter, loc, resTypeAfterFlattening, resVals[w]);
}
resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
rhsVals[kw], resVal, flatten);
if (flatten) {
// Un-flatten the output vector (restore the channel dimension)
- resVals[w] = rewriter.create<vector::ShapeCastOp>(
- loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
+ resVals[w] = vector::ShapeCastOp::create(
+ rewriter, loc, VectorType::get(inOutSliceSizes, resEltType),
+ resVals[w]);
}
}
}
@@ -3965,8 +3955,8 @@ struct Conv1DGenerator
// Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
// This does not depend on kw.
for (int64_t w = 0; w < wSize; w += wSizeStep) {
- maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
- loc, resVals[w], maybeMaskedRes->getResult(0),
+ maybeMaskedRes = vector::InsertStridedSliceOp::create(
+ rewriter, loc, resVals[w], maybeMaskedRes->getResult(0),
/*offsets=*/ArrayRef<int64_t>{0, w, 0},
/*strides=*/ArrayRef<int64_t>{1, 1, 1});
}
@@ -3975,8 +3965,8 @@ struct Conv1DGenerator
//===------------------------------------------------------------------===//
// Write back res slice of size {n, w, c} @ [0, 0, 0].
- Operation *resOut = rewriter.create<vector::TransferWriteOp>(
- loc, maybeMaskedRes->getResult(0), resShaped,
+ Operation *resOut = vector::TransferWriteOp::create(
+ rewriter, loc, maybeMaskedRes->getResult(0), resShaped,
ValueRange{zero, zero, zero});
return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
resOut);
@@ -4013,11 +4003,11 @@ struct Conv1DGenerator
indices.push_back(j);
}
- rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indices);
+ rhs = vector::ShuffleOp::create(rewriter, loc, rhs, rhs, indices);
}
// Broadcast the filter to match the output vector
- rhs = rewriter.create<vector::BroadcastOp>(
- loc, resTy.clone(rhsTy.getElementType()), rhs);
+ rhs = vector::BroadcastOp::create(rewriter, loc,
+ resTy.clone(rhsTy.getElementType()), rhs);
rhs = promote(rewriter, loc, rhs, resTy);
@@ -4025,10 +4015,10 @@ struct Conv1DGenerator
return nullptr;
if (isa<FloatType>(resTy.getElementType()))
- return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
+ return vector::FMAOp::create(rewriter, loc, lhs, rhs, res);
- auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
- return rewriter.create<arith::AddIOp>(loc, mul, res);
+ auto mul = arith::MulIOp::create(rewriter, loc, lhs, rhs);
+ return arith::AddIOp::create(rewriter, loc, mul, res);
}
/// Entry point for non-channeled convolution:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 9fd0844..b80b27f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -201,11 +201,12 @@ Value create2DTransformMatrix(OpBuilder &builder, Location loc,
TransformMatrix transform, Type type) {
ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
- return builder.create<arith::ConstantOp>(
- loc, DenseFPElementsAttr::get(
- RankedTensorType::get(
- SmallVector<int64_t>{transform.rows, transform.cols}, type),
- constVec));
+ return arith::ConstantOp::create(
+ builder, loc,
+ DenseFPElementsAttr::get(
+ RankedTensorType::get(
+ SmallVector<int64_t>{transform.rows, transform.cols}, type),
+ constVec));
}
/// Extract height x width data from 4D tensors.
@@ -233,8 +234,8 @@ Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source,
auto extractFilterType =
RankedTensorType::get({extractHeight, extractWidth}, elementType);
- auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
- loc, extractFilterType, source, offsets, sizes, strides);
+ auto extractFilterOp = tensor::ExtractSliceOp::create(
+ builder, loc, extractFilterType, source, offsets, sizes, strides);
return extractFilterOp;
}
@@ -267,8 +268,8 @@ Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source,
SmallVector<OpFoldResult> strides(srcSize, oneIndex);
auto extractFilterType = RankedTensorType::get({height, width}, elementType);
- auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
- loc, extractFilterType, source, offsets, sizes, strides);
+ auto extractFilterOp = tensor::ExtractSliceOp::create(
+ builder, loc, extractFilterType, source, offsets, sizes, strides);
return extractFilterOp;
}
@@ -293,8 +294,8 @@ Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source,
retSizes[widthIdx] = builder.getIndexAttr(width);
SmallVector<OpFoldResult> strides(destSize, oneIndex);
- auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
- loc, source, dest, retOffsets, retSizes, strides);
+ auto insertSliceOp = tensor::InsertSliceOp::create(
+ builder, loc, source, dest, retOffsets, retSizes, strides);
return insertSliceOp;
}
@@ -321,8 +322,8 @@ Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
retSizes[widthIdx] = builder.getIndexAttr(width);
SmallVector<OpFoldResult> strides(destSize, oneIndex);
- auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
- loc, source, dest, retOffsets, retSizes, strides);
+ auto insertSliceOp = tensor::InsertSliceOp::create(
+ builder, loc, source, dest, retOffsets, retSizes, strides);
return insertSliceOp;
}
@@ -372,7 +373,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
if (filterW != r && filterW != 1)
return Value();
- Value zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value zeroIdx = arith::ConstantIndexOp::create(rewriter, loc, 0);
auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
ValueRange args) -> scf::ValueVector {
Value FIter = ivs[0];
@@ -386,8 +387,8 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
int64_t retRows = 1;
Value matmulRetValue = extractFilter;
- Value zero = builder.create<arith::ConstantOp>(
- loc, rewriter.getZeroAttr(elementType));
+ Value zero = arith::ConstantOp::create(builder, loc,
+ rewriter.getZeroAttr(elementType));
if (leftTransform) {
// Get constant transform matrix G.
auto it = GMatrices.find(fmr);
@@ -397,16 +398,17 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
retRows = GMatrix.rows;
auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
- auto empty =
- builder
- .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
- .getResult();
- auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
+ auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
+ elementType)
+ .getResult();
+ auto init =
+ linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
// Multiply G x g.
- auto matmulOp = builder.create<linalg::MatmulOp>(
- loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
+ auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
+ ValueRange{G, extractFilter},
+ ValueRange{init});
matmulRetValue = matmulOp.getResult(0);
}
@@ -419,16 +421,17 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
auto matmulType =
RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
- auto empty =
- builder
- .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
- .getResult();
- auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
+ auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
+ elementType)
+ .getResult();
+ auto init =
+ linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
// Multiply u = (G x g) x GT.
- auto matmulOp = builder.create<linalg::MatmulOp>(
- loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init});
+ auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
+ ValueRange{matmulRetValue, GT},
+ ValueRange{init});
matmulRetValue = matmulOp.getResult(0);
}
@@ -445,9 +448,9 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
return {insertSliceOp};
};
- auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
- auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
- auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto fUpperBound = arith::ConstantIndexOp::create(rewriter, loc, filterF);
+ auto cUpperBound = arith::ConstantIndexOp::create(rewriter, loc, filterC);
+ auto oneStep = arith::ConstantIndexOp::create(rewriter, loc, 1);
scf::LoopNest loops = scf::buildLoopNest(
rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound},
{oneStep, oneStep}, {retValue}, buildBody);
@@ -516,10 +519,11 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
auto affineMap =
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
- Value heightOffset = builder.create<affine::AffineApplyOp>(
- loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
- Value widthOffset = builder.create<affine::AffineApplyOp>(
- loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
+ Value heightOffset = affine::AffineApplyOp::create(
+ builder, loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
+ Value widthOffset = affine::AffineApplyOp::create(
+ builder, loc, rightTransform ? affineMap : identityAffineMap,
+ tileWIter);
// Extract (H, W) from (N, H, W, C).
auto extractInput =
@@ -530,8 +534,8 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
int64_t retRows = 1;
int64_t retCols = 1;
Value matmulRetValue = extractInput;
- Value zero = builder.create<arith::ConstantOp>(
- loc, rewriter.getZeroAttr(elementType));
+ Value zero = arith::ConstantOp::create(builder, loc,
+ rewriter.getZeroAttr(elementType));
if (leftTransform) {
// Get constant transform matrix BT.
auto it = BTMatrices.find(fmr);
@@ -541,17 +545,18 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
retRows = BTMatrix.rows;
auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType);
- auto empty =
- builder
- .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
- .getResult();
- auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
+ auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
+ elementType)
+ .getResult();
+ auto init =
+ linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
Value BT =
create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
// Multiply BT x d.
- auto matmulOp = builder.create<linalg::MatmulOp>(
- loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
+ auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
+ ValueRange{BT, matmulRetValue},
+ ValueRange{init});
matmulRetValue = matmulOp.getResult(0);
}
@@ -564,16 +569,17 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
retCols = BMatrix.cols;
auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
- auto empty =
- builder
- .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
- .getResult();
- auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
+ auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
+ elementType)
+ .getResult();
+ auto init =
+ linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
Value B =
create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
// Multiply v = (BT x d) x B.
- auto matmulOp = builder.create<linalg::MatmulOp>(
- loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init});
+ auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
+ ValueRange{matmulRetValue, B},
+ ValueRange{init});
matmulRetValue = matmulOp.getResult(0);
}
@@ -586,12 +592,12 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
return {combinedVal};
};
- auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tileH);
- auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW);
- auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputN);
- auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputC);
- auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto zeroIdx = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ auto tileHBound = arith::ConstantIndexOp::create(rewriter, loc, tileH);
+ auto tileWBound = arith::ConstantIndexOp::create(rewriter, loc, tileW);
+ auto nUpperBound = arith::ConstantIndexOp::create(rewriter, loc, inputN);
+ auto cUpperBound = arith::ConstantIndexOp::create(rewriter, loc, inputC);
+ auto oneStep = arith::ConstantIndexOp::create(rewriter, loc, 1);
scf::LoopNest loops = scf::buildLoopNest(
rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
{tileHBound, tileWBound, nUpperBound, cUpperBound},
@@ -629,8 +635,8 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc,
{filterShape[0] * filterShape[1], filterShape[2], filterShape[3]},
filterElementType);
SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}};
- Value collapseFilter = rewriter.create<tensor::CollapseShapeOp>(
- loc, filterReassocType, transformedFilter, filterReassoc);
+ Value collapseFilter = tensor::CollapseShapeOp::create(
+ rewriter, loc, filterReassocType, transformedFilter, filterReassoc);
// Convert (alphaH, alphaW, tileH, tileW, N, C) to
// (alphaH x alphaW, tileH x tileW x N, C) for input.
@@ -643,24 +649,23 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc,
inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]},
inputElementType);
SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}};
- Value collapseInput = rewriter.create<tensor::CollapseShapeOp>(
- loc, inputReassocType, transformedInput, inputReassoc);
+ Value collapseInput = tensor::CollapseShapeOp::create(
+ rewriter, loc, inputReassocType, transformedInput, inputReassoc);
// Batched matrix multiply.
auto matmulType = RankedTensorType::get(
{inputShape[0] * inputShape[1],
inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
outputElementType);
- Value empty = rewriter
- .create<tensor::EmptyOp>(loc, matmulType.getShape(),
- outputElementType)
+ Value empty = tensor::EmptyOp::create(rewriter, loc, matmulType.getShape(),
+ outputElementType)
.getResult();
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getZeroAttr(outputElementType));
- Value init = rewriter.create<linalg::FillOp>(loc, zero, empty).getResult(0);
+ Value zero = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getZeroAttr(outputElementType));
+ Value init = linalg::FillOp::create(rewriter, loc, zero, empty).getResult(0);
- auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
- loc, matmulType, ValueRange({collapseInput, collapseFilter}),
+ auto matmulOp = linalg::BatchMatmulOp::create(
+ rewriter, loc, matmulType, ValueRange({collapseInput, collapseFilter}),
ValueRange{init});
// The result shape of batch matmul is (alphaH x alphaW, tileH x tileW x N, F)
@@ -670,8 +675,8 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc,
RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2],
inputShape[3], inputShape[4], filterShape[3]},
outputElementType);
- auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
- loc, outputReassocType, matmulOp.getResult(0), outputReassoc);
+ auto expandOutput = tensor::ExpandShapeOp::create(
+ rewriter, loc, outputReassocType, matmulOp.getResult(0), outputReassoc);
return expandOutput;
}
@@ -750,16 +755,17 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
int64_t retRows = leftTransform ? ATMatrix.rows : 1;
Value matmulRetValue = extractValue;
- Value zero = builder.create<arith::ConstantOp>(
- loc, rewriter.getZeroAttr(elementType));
+ Value zero = arith::ConstantOp::create(builder, loc,
+ rewriter.getZeroAttr(elementType));
auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
auto affineMap =
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
- Value heightOffset = builder.create<affine::AffineApplyOp>(
- loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
- Value widthOffset = builder.create<affine::AffineApplyOp>(
- loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
+ Value heightOffset = affine::AffineApplyOp::create(
+ builder, loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
+ Value widthOffset = affine::AffineApplyOp::create(
+ builder, loc, rightTransform ? affineMap : identityAffineMap,
+ tileWIter);
Value outInitVal =
extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset,
@@ -771,17 +777,17 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
Value init = outInitVal;
if (rightTransform || scalarFactor != 1) {
- auto empty = builder
- .create<tensor::EmptyOp>(loc, matmulType.getShape(),
- elementType)
+ auto empty = tensor::EmptyOp::create(builder, loc,
+ matmulType.getShape(), elementType)
.getResult();
- init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
+ init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
}
Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
// Multiply AT x m.
- auto matmulOp = builder.create<linalg::MatmulOp>(
- loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init});
+ auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
+ ValueRange{AT, matmulRetValue},
+ ValueRange{init});
matmulRetValue = matmulOp.getResult(0);
}
@@ -790,47 +796,45 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
RankedTensorType::get({retRows, AMatrix.cols}, elementType);
Value init = outInitVal;
if (scalarFactor != 1) {
- auto empty = builder
- .create<tensor::EmptyOp>(loc, matmulType.getShape(),
- elementType)
+ auto empty = tensor::EmptyOp::create(builder, loc,
+ matmulType.getShape(), elementType)
.getResult();
- init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
+ init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
}
Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
// Multiply y = (AT x m) x A.
- auto matmulOp = builder.create<linalg::MatmulOp>(
- loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init});
+ auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
+ ValueRange{matmulRetValue, A},
+ ValueRange{init});
matmulRetValue = matmulOp.getResult(0);
}
if (scalarFactor != 1) {
// Multiply by scalar factor and add outInitVal.
- Value scalarFactorValue = builder.create<arith::ConstantOp>(
- loc, FloatAttr::get(elementType, scalarFactor));
+ Value scalarFactorValue = arith::ConstantOp::create(
+ builder, loc, FloatAttr::get(elementType, scalarFactor));
auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
SmallVector<AffineMap> affineMaps = {
AffineMap::get(2, 0, context), identityAffineMap, identityAffineMap};
matmulRetValue =
- rewriter
- .create<linalg::GenericOp>(
- loc, matmulType,
- ValueRange{scalarFactorValue, matmulRetValue},
- ValueRange{outInitVal}, affineMaps,
- llvm::ArrayRef<utils::IteratorType>{
- utils::IteratorType::parallel,
- utils::IteratorType::parallel},
- [&](OpBuilder &nestedBuilder, Location nestedLoc,
- ValueRange args) {
- auto mulf = nestedBuilder.create<arith::MulFOp>(
- nestedLoc, args[0], args[1]);
- auto addf = nestedBuilder.create<arith::AddFOp>(
- nestedLoc, mulf.getResult(), args[2]);
- nestedBuilder.create<linalg::YieldOp>(nestedLoc,
- addf.getResult());
- })
+ linalg::GenericOp::create(
+ rewriter, loc, matmulType,
+ ValueRange{scalarFactorValue, matmulRetValue},
+ ValueRange{outInitVal}, affineMaps,
+ llvm::ArrayRef<utils::IteratorType>{
+ utils::IteratorType::parallel, utils::IteratorType::parallel},
+ [&](OpBuilder &nestedBuilder, Location nestedLoc,
+ ValueRange args) {
+ auto mulf = arith::MulFOp::create(nestedBuilder, nestedLoc,
+ args[0], args[1]);
+ auto addf = arith::AddFOp::create(nestedBuilder, nestedLoc,
+ mulf.getResult(), args[2]);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc,
+ addf.getResult());
+ })
.getResult(0);
}
@@ -847,12 +851,12 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
int64_t tilwH = valueShape[2];
int64_t tileW = valueShape[3];
- auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tilwH);
- auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW);
- auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN);
- auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF);
- auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto zeroIdx = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ auto tileHBound = arith::ConstantIndexOp::create(rewriter, loc, tilwH);
+ auto tileWBound = arith::ConstantIndexOp::create(rewriter, loc, tileW);
+ auto nUpperBound = arith::ConstantIndexOp::create(rewriter, loc, valueN);
+ auto fUpperBound = arith::ConstantIndexOp::create(rewriter, loc, valueF);
+ auto oneStep = arith::ConstantIndexOp::create(rewriter, loc, 1);
scf::LoopNest loops = scf::buildLoopNest(
rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
{tileHBound, tileWBound, nUpperBound, fUpperBound},
@@ -867,8 +871,8 @@ static Value padToAlignedTensor(RewriterBase &rewriter, Location loc,
auto valueType = cast<ShapedType>(value.getType());
Type elementType = valueType.getElementType();
auto alignedType = RankedTensorType::get(alignedShape, elementType);
- Value padValue = rewriter.create<arith::ConstantOp>(
- loc, elementType, rewriter.getZeroAttr(elementType));
+ Value padValue = arith::ConstantOp::create(rewriter, loc, elementType,
+ rewriter.getZeroAttr(elementType));
return linalg::makeComposedPadHighOp(rewriter, loc, alignedType, value,
padValue, false);
@@ -887,8 +891,8 @@ static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
SmallVector<OpFoldResult> sizes =
getAsOpFoldResult(rewriter.getI64ArrayAttr(extractedShape));
- return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value,
- offsets, sizes, strides);
+ return tensor::ExtractSliceOp::create(rewriter, loc, extractedType, value,
+ offsets, sizes, strides);
}
/// Utility function to check all values in the attribute are 1.
@@ -979,10 +983,10 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
auto retType = RankedTensorType::get({alphaH, alphaW, filterC, filterF},
filterElementType);
- Value retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
- filterElementType);
- auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
- loc, retType, filter, retValue, fmr);
+ Value retValue = tensor::EmptyOp::create(rewriter, loc, retType.getShape(),
+ filterElementType);
+ auto transformedFilter = linalg::WinogradFilterTransformOp::create(
+ rewriter, loc, retType, filter, retValue, fmr);
// --- Create operation for input transform ---
@@ -998,10 +1002,10 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
retType = RankedTensorType::get(
{alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType);
- retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
- inputElementType);
- auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
- loc, retType, input, retValue, fmr);
+ retValue = tensor::EmptyOp::create(rewriter, loc, retType.getShape(),
+ inputElementType);
+ auto transformedInput = linalg::WinogradInputTransformOp::create(
+ rewriter, loc, retType, input, retValue, fmr);
Type outputElementType = outputType.getElementType();
Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter,
@@ -1023,8 +1027,8 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
outputType = alignedOutputType;
}
- Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>(
- loc, outputType, matmulRet, output, fmr);
+ Value transformedOutput = linalg::WinogradOutputTransformOp::create(
+ rewriter, loc, outputType, matmulRet, output, fmr);
// When output size is not aligned with output tile size, extract the
// value from the padded buffer.
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 14d6200..3593b53 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -320,14 +320,14 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) {
AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(),
utils::IteratorType::parallel);
- return b.create<linalg::GenericOp>(
- loc,
+ return linalg::GenericOp::create(
+ b, loc,
/*inputs=*/from,
/*outputs=*/to,
/*indexingMaps=*/llvm::ArrayRef({id, id}),
/*iteratorTypes=*/iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
- b.create<linalg::YieldOp>(loc, args.front());
+ linalg::YieldOp::create(b, loc, args.front());
});
}
@@ -483,8 +483,8 @@ static void generateParallelLoopNest(
case DistributionMethod::None: {
// Generate a single parallel loop-nest operation for all outermost
// parallel loops and recurse.
- b.create<scf::ParallelOp>(
- loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
+ scf::ParallelOp::create(
+ b, loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
steps.take_front(numProcessed),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
ivStorage.append(localIvs.begin(), localIvs.end());
@@ -499,8 +499,8 @@ static void generateParallelLoopNest(
case DistributionMethod::Cyclic: {
// Generate a single parallel loop-nest operation for all outermost
// parallel loops and recurse.
- b.create<scf::ParallelOp>(
- loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
+ scf::ParallelOp::create(
+ b, loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
steps.take_front(numProcessed),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
ivStorage.append(localIvs.begin(), localIvs.end());
@@ -519,13 +519,13 @@ static void generateParallelLoopNest(
for (unsigned i = 1; i < numProcessed; ++i)
cond = ab._and(cond, ab.slt(lbs[i], ubs[i]));
ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
- b.create<scf::IfOp>(loc, cond, [&](OpBuilder &b, Location loc) {
+ scf::IfOp::create(b, loc, cond, [&](OpBuilder &b, Location loc) {
generateParallelLoopNest(b, loc, lbs.drop_front(numProcessed),
ubs.drop_front(numProcessed),
steps.drop_front(numProcessed),
iteratorTypes.drop_front(numProcessed),
remainderProcInfo, bodyBuilderFn, ivStorage);
- b.create<scf::YieldOp>(loc, ValueRange{});
+ scf::YieldOp::create(b, loc, ValueRange{});
});
return;
}
@@ -595,13 +595,13 @@ static Operation *materializeTiledShape(OpBuilder &builder, Location loc,
auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
.Case([&](MemRefType) {
- return builder.create<memref::SubViewOp>(
- loc, valueToTile, sliceParams.offsets,
+ return memref::SubViewOp::create(
+ builder, loc, valueToTile, sliceParams.offsets,
sliceParams.sizes, sliceParams.strides);
})
.Case([&](RankedTensorType) {
- return builder.create<tensor::ExtractSliceOp>(
- loc, valueToTile, sliceParams.offsets,
+ return tensor::ExtractSliceOp::create(
+ builder, loc, valueToTile, sliceParams.offsets,
sliceParams.sizes, sliceParams.strides);
})
.Default([](ShapedType) -> Operation * {
@@ -793,8 +793,8 @@ SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
// `tiledOperands`.
Value outputTensor = operands[opOperand.getOperandNumber()];
if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
- Value inserted = builder.create<tensor::InsertSliceOp>(
- loc, sliceOp.getSource().getType(), results[resultIdx],
+ Value inserted = tensor::InsertSliceOp::create(
+ builder, loc, sliceOp.getSource().getType(), results[resultIdx],
sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(),
sliceOp.getStrides(), sliceOp.getStaticOffsets(),
sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index c5643f6..dfa2e4e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -85,11 +85,11 @@ Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
// TODO: support more types.
return TypeSwitch<Type, Value>(slot.elemType)
.Case([&](MemRefType t) {
- return builder.create<memref::AllocaOp>(getLoc(), t);
+ return memref::AllocaOp::create(builder, getLoc(), t);
})
.Default([&](Type t) {
- return builder.create<arith::ConstantOp>(getLoc(), t,
- builder.getZeroAttr(t));
+ return arith::ConstantOp::create(builder, getLoc(), t,
+ builder.getZeroAttr(t));
});
}
@@ -135,7 +135,7 @@ DenseMap<Attribute, MemorySlot> memref::AllocaOp::destructure(
for (Attribute usedIndex : usedIndices) {
Type elemType = memrefType.getTypeAtIndex(usedIndex);
MemRefType elemPtr = MemRefType::get({}, elemType);
- auto subAlloca = builder.create<memref::AllocaOp>(getLoc(), elemPtr);
+ auto subAlloca = memref::AllocaOp::create(builder, getLoc(), elemPtr);
newAllocators.push_back(subAlloca);
slotMap.try_emplace<MemorySlot>(usedIndex,
{subAlloca.getResult(), elemType});
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 51c8136..74b968c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -213,9 +213,9 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
// Create and insert the alloc op for the new memref.
- auto newAlloc = rewriter.create<AllocLikeOp>(
- alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
- alloc.getAlignmentAttr());
+ auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType,
+ dynamicSizes, alloc.getSymbolOperands(),
+ alloc.getAlignmentAttr());
// Insert a cast so we have the same type as the old alloc.
rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
return success();
@@ -797,7 +797,7 @@ void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
int64_t index) {
auto loc = result.location;
- Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
+ Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
build(builder, result, source, indexValue);
}
@@ -1044,9 +1044,9 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
rewriter.setInsertionPointAfter(reshape);
Location loc = dim.getLoc();
Value load =
- rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
+ LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
if (load.getType() != dim.getType())
- load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
+ load = arith::IndexCastOp::create(rewriter, loc, dim.getType(), load);
rewriter.replaceOp(dim, load);
return success();
}
@@ -1319,8 +1319,9 @@ static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
assert(isa<Attribute>(maybeConstant) &&
"The constified value should be either unchanged (i.e., == result) "
"or a constant");
- Value constantVal = rewriter.create<arith::ConstantIndexOp>(
- loc, llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
+ Value constantVal = arith::ConstantIndexOp::create(
+ rewriter, loc,
+ llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
// modifyOpInPlace: lambda cannot capture structured bindings in C++17
// yet.
@@ -2548,8 +2549,9 @@ public:
rewriter.modifyOpInPlace(
op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
} else {
- Value newOp = rewriter.create<CollapseShapeOp>(
- op->getLoc(), cast.getSource(), op.getReassociationIndices());
+ Value newOp =
+ CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(),
+ op.getReassociationIndices());
rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
}
return success();
@@ -3006,15 +3008,15 @@ SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
Value offset =
op.isDynamicOffset(idx)
? op.getDynamicOffset(idx)
- : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
+ : arith::ConstantIndexOp::create(b, loc, op.getStaticOffset(idx));
Value size =
op.isDynamicSize(idx)
? op.getDynamicSize(idx)
- : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
+ : arith::ConstantIndexOp::create(b, loc, op.getStaticSize(idx));
Value stride =
op.isDynamicStride(idx)
? op.getDynamicStride(idx)
- : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
+ : arith::ConstantIndexOp::create(b, loc, op.getStaticStride(idx));
res.emplace_back(Range{offset, size, stride});
}
return res;
@@ -3173,8 +3175,8 @@ public:
if (!resultType)
return failure();
- Value newSubView = rewriter.create<SubViewOp>(
- subViewOp.getLoc(), resultType, castOp.getSource(),
+ Value newSubView = SubViewOp::create(
+ rewriter, subViewOp.getLoc(), resultType, castOp.getSource(),
subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
subViewOp.getStaticStrides());
@@ -3495,9 +3497,9 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
return failure();
// Create new ViewOp.
- auto newViewOp = rewriter.create<ViewOp>(
- viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
- viewOp.getByteShift(), newOperands);
+ auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), newMemRefType,
+ viewOp.getOperand(0), viewOp.getByteShift(),
+ newOperands);
// Insert a cast so we have the same type as the old memref type.
rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
return success();
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index 0c03670..95eb2a9 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -155,9 +155,10 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
Type resultType = alloca.getResult().getType();
OpBuilder builder(rewriter.getContext());
// TODO: Add a better builder for this.
- globalOp = builder.create<memref::GlobalOp>(
- loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"),
- TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
+ globalOp = memref::GlobalOp::create(
+ builder, loc, StringAttr::get(ctx, "alloca"),
+ StringAttr::get(ctx, "private"), TypeAttr::get(resultType),
+ Attribute{}, UnitAttr{}, IntegerAttr{});
symbolTable.insert(globalOp);
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp
index c433415..75cc39e 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp
@@ -22,11 +22,11 @@ struct DefaultAllocationInterface
DefaultAllocationInterface, memref::AllocOp> {
static std::optional<Operation *> buildDealloc(OpBuilder &builder,
Value alloc) {
- return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
+ return memref::DeallocOp::create(builder, alloc.getLoc(), alloc)
.getOperation();
}
static std::optional<Value> buildClone(OpBuilder &builder, Value alloc) {
- return builder.create<bufferization::CloneOp>(alloc.getLoc(), alloc)
+ return bufferization::CloneOp::create(builder, alloc.getLoc(), alloc)
.getResult();
}
static ::mlir::HoistingKind getHoistingKind() {
@@ -35,8 +35,9 @@ struct DefaultAllocationInterface
static ::std::optional<::mlir::Operation *>
buildPromotedAlloc(OpBuilder &builder, Value alloc) {
Operation *definingOp = alloc.getDefiningOp();
- return builder.create<memref::AllocaOp>(
- definingOp->getLoc(), cast<MemRefType>(definingOp->getResultTypes()[0]),
+ return memref::AllocaOp::create(
+ builder, definingOp->getLoc(),
+ cast<MemRefType>(definingOp->getResultTypes()[0]),
definingOp->getOperands(), definingOp->getAttrs());
}
};
@@ -52,7 +53,7 @@ struct DefaultReallocationInterface
DefaultAllocationInterface, memref::ReallocOp> {
static std::optional<Operation *> buildDealloc(OpBuilder &builder,
Value realloc) {
- return builder.create<memref::DeallocOp>(realloc.getLoc(), realloc)
+ return memref::DeallocOp::create(builder, realloc.getLoc(), realloc)
.getOperation();
}
};
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
index 7c777e8..cce80db 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
@@ -80,10 +80,6 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
for (auto &&[opOffset, sourceOffset, sourceStride, opSize] :
llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
sourceOp.getMixedStrides(), op.getMixedSizes())) {
- // We only support static sizes.
- if (isa<Value>(opSize)) {
- return failure();
- }
sizes.push_back(opSize);
Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset),
sourceOffsetAttr =
@@ -124,8 +120,8 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
}
AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr);
- Value result = rewriter.create<affine::AffineApplyOp>(
- op.getLoc(), map, affineApplyOperands);
+ Value result = affine::AffineApplyOp::create(rewriter, op.getLoc(), map,
+ affineApplyOperands);
offsets.push_back(result);
}
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index ec2bc95..556ea1a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -99,7 +99,7 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
IntegerType dstType = builder.getIntegerType(targetBits);
- return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
+ return arith::IndexCastOp::create(builder, loc, dstType, bitOffset);
}
/// When writing a subbyte size, masked bitwise operations are used to only
@@ -112,14 +112,14 @@ static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices,
auto dstIntegerType = builder.getIntegerType(dstBits);
auto maskRightAlignedAttr =
builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
- Value maskRightAligned = builder.create<arith::ConstantOp>(
- loc, dstIntegerType, maskRightAlignedAttr);
+ Value maskRightAligned = arith::ConstantOp::create(
+ builder, loc, dstIntegerType, maskRightAlignedAttr);
Value writeMaskInverse =
- builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
+ arith::ShLIOp::create(builder, loc, maskRightAligned, bitwidthOffset);
auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
Value flipVal =
- builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr);
- return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
+ arith::ConstantOp::create(builder, loc, dstIntegerType, flipValAttr);
+ return arith::XOrIOp::create(builder, loc, writeMaskInverse, flipVal);
}
/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
@@ -141,7 +141,7 @@ getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
const SmallVector<OpFoldResult> &indices,
Value memref) {
auto stridedMetadata =
- builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
+ memref::ExtractStridedMetadataOp::create(builder, loc, memref);
OpFoldResult linearizedIndices;
std::tie(std::ignore, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
@@ -298,16 +298,16 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
// Special case 0-rank memref loads.
Value bitsLoad;
if (convertedType.getRank() == 0) {
- bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
- ValueRange{});
+ bitsLoad = memref::LoadOp::create(rewriter, loc, adaptor.getMemref(),
+ ValueRange{});
} else {
// Linearize the indices of the original load instruction. Do not account
// for the scaling yet. This will be accounted for later.
OpFoldResult linearizedIndices = getLinearizedSrcIndices(
rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
- Value newLoad = rewriter.create<memref::LoadOp>(
- loc, adaptor.getMemref(),
+ Value newLoad = memref::LoadOp::create(
+ rewriter, loc, adaptor.getMemref(),
getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
dstBits));
@@ -315,7 +315,7 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
// Note, currently only the big-endian is supported.
Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices,
srcBits, dstBits, rewriter);
- bitsLoad = rewriter.create<arith::ShRSIOp>(loc, newLoad, bitwidthOffset);
+ bitsLoad = arith::ShRSIOp::create(rewriter, loc, newLoad, bitwidthOffset);
}
// Get the corresponding bits. If the arith computation bitwidth equals
@@ -331,17 +331,17 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
: IntegerType::get(rewriter.getContext(),
resultTy.getIntOrFloatBitWidth());
if (conversionTy == convertedElementType) {
- auto mask = rewriter.create<arith::ConstantOp>(
- loc, convertedElementType,
+ auto mask = arith::ConstantOp::create(
+ rewriter, loc, convertedElementType,
rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
- result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
+ result = arith::AndIOp::create(rewriter, loc, bitsLoad, mask);
} else {
- result = rewriter.create<arith::TruncIOp>(loc, conversionTy, bitsLoad);
+ result = arith::TruncIOp::create(rewriter, loc, conversionTy, bitsLoad);
}
if (conversionTy != resultTy) {
- result = rewriter.create<arith::BitcastOp>(loc, resultTy, result);
+ result = arith::BitcastOp::create(rewriter, loc, resultTy, result);
}
rewriter.replaceOp(op, result);
@@ -428,20 +428,20 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
// Pad the input value with 0s on the left.
Value input = adaptor.getValue();
if (!input.getType().isInteger()) {
- input = rewriter.create<arith::BitcastOp>(
- loc,
+ input = arith::BitcastOp::create(
+ rewriter, loc,
IntegerType::get(rewriter.getContext(),
input.getType().getIntOrFloatBitWidth()),
input);
}
Value extendedInput =
- rewriter.create<arith::ExtUIOp>(loc, dstIntegerType, input);
+ arith::ExtUIOp::create(rewriter, loc, dstIntegerType, input);
// Special case 0-rank memref stores. No need for masking.
if (convertedType.getRank() == 0) {
- rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::assign,
- extendedInput, adaptor.getMemref(),
- ValueRange{});
+ memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::assign,
+ extendedInput, adaptor.getMemref(),
+ ValueRange{});
rewriter.eraseOp(op);
return success();
}
@@ -456,16 +456,14 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
dstBits, bitwidthOffset, rewriter);
// Align the value to write with the destination bits
Value alignedVal =
- rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
+ arith::ShLIOp::create(rewriter, loc, extendedInput, bitwidthOffset);
// Clear destination bits
- rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
- writeMask, adaptor.getMemref(),
- storeIndices);
+ memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::andi,
+ writeMask, adaptor.getMemref(), storeIndices);
// Write srcs bits to destination
- rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
- alignedVal, adaptor.getMemref(),
- storeIndices);
+ memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::ori,
+ alignedVal, adaptor.getMemref(), storeIndices);
rewriter.eraseOp(op);
return success();
}
@@ -525,8 +523,8 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
}
// Transform the offsets, sizes and strides according to the emulation.
- auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
- loc, subViewOp.getViewSource());
+ auto stridedMetadata = memref::ExtractStridedMetadataOp::create(
+ rewriter, loc, subViewOp.getViewSource());
OpFoldResult linearizedIndices;
auto strides = stridedMetadata.getConstifiedMixedStrides();
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index e6e4c3b0..17a148c 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -48,15 +48,15 @@ public:
Value size;
// Load dynamic sizes from the shape input, use constants for static dims.
if (op.getType().isDynamicDim(i)) {
- Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
- size = rewriter.create<memref::LoadOp>(loc, op.getShape(), index);
+ Value index = arith::ConstantIndexOp::create(rewriter, loc, i);
+ size = memref::LoadOp::create(rewriter, loc, op.getShape(), index);
if (!isa<IndexType>(size.getType()))
- size = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getIndexType(), size);
+ size = arith::IndexCastOp::create(rewriter, loc,
+ rewriter.getIndexType(), size);
sizes[i] = size;
} else {
auto sizeAttr = rewriter.getIndexAttr(op.getType().getDimSize(i));
- size = rewriter.create<arith::ConstantOp>(loc, sizeAttr);
+ size = arith::ConstantOp::create(rewriter, loc, sizeAttr);
sizes[i] = sizeAttr;
}
if (stride)
@@ -66,10 +66,11 @@ public:
if (i > 0) {
if (stride) {
- stride = rewriter.create<arith::MulIOp>(loc, stride, size);
+ stride = arith::MulIOp::create(rewriter, loc, stride, size);
} else if (op.getType().isDynamicDim(i)) {
- stride = rewriter.create<arith::MulIOp>(
- loc, rewriter.create<arith::ConstantIndexOp>(loc, staticStride),
+ stride = arith::MulIOp::create(
+ rewriter, loc,
+ arith::ConstantIndexOp::create(rewriter, loc, staticStride),
size);
} else {
staticStride *= op.getType().getDimSize(i);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp
index 7475d44..01d3262 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp
@@ -73,7 +73,7 @@ struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
if (ShapedType::isDynamic(inputSize)) {
Value dimZero = getValueOrCreateConstantIndexOp(rewriter, loc,
rewriter.getIndexAttr(0));
- currSize = rewriter.create<memref::DimOp>(loc, op.getSource(), dimZero)
+ currSize = memref::DimOp::create(rewriter, loc, op.getSource(), dimZero)
.getResult();
}
@@ -88,10 +88,10 @@ struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
// the old buffer is smaller than the requested size.
Value lhs = getValueOrCreateConstantIndexOp(rewriter, loc, currSize);
Value rhs = getValueOrCreateConstantIndexOp(rewriter, loc, targetSize);
- Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
- lhs, rhs);
- auto ifOp = rewriter.create<scf::IfOp>(
- loc, cond,
+ Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
+ lhs, rhs);
+ auto ifOp = scf::IfOp::create(
+ rewriter, loc, cond,
[&](OpBuilder &builder, Location loc) {
// Allocate the new buffer. If it is a dynamic memref we need to pass
// an additional operand for the size at runtime, otherwise the static
@@ -100,25 +100,26 @@ struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
if (op.getDynamicResultSize())
dynamicSizeOperands.push_back(op.getDynamicResultSize());
- Value newAlloc = builder.create<memref::AllocOp>(
- loc, op.getResult().getType(), dynamicSizeOperands,
+ Value newAlloc = memref::AllocOp::create(
+ builder, loc, op.getResult().getType(), dynamicSizeOperands,
op.getAlignmentAttr());
// Take a subview of the new (bigger) buffer such that we can copy the
// old values over (the copy operation requires both operands to have
// the same shape).
- Value subview = builder.create<memref::SubViewOp>(
- loc, newAlloc, ArrayRef<OpFoldResult>{rewriter.getIndexAttr(0)},
+ Value subview = memref::SubViewOp::create(
+ builder, loc, newAlloc,
+ ArrayRef<OpFoldResult>{rewriter.getIndexAttr(0)},
ArrayRef<OpFoldResult>{currSize},
ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
- builder.create<memref::CopyOp>(loc, op.getSource(), subview);
+ memref::CopyOp::create(builder, loc, op.getSource(), subview);
// Insert the deallocation of the old buffer only if requested
// (enabled by default).
if (emitDeallocs)
- builder.create<memref::DeallocOp>(loc, op.getSource());
+ memref::DeallocOp::create(builder, loc, op.getSource());
- builder.create<scf::YieldOp>(loc, newAlloc);
+ scf::YieldOp::create(builder, loc, newAlloc);
},
[&](OpBuilder &builder, Location loc) {
// We need to reinterpret-cast here because either the input or output
@@ -126,11 +127,12 @@ struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
// dynamic or vice-versa. If both are static and the original buffer
// is already bigger than the requested size, the cast represents a
// subview operation.
- Value casted = builder.create<memref::ReinterpretCastOp>(
- loc, cast<MemRefType>(op.getResult().getType()), op.getSource(),
- rewriter.getIndexAttr(0), ArrayRef<OpFoldResult>{targetSize},
+ Value casted = memref::ReinterpretCastOp::create(
+ builder, loc, cast<MemRefType>(op.getResult().getType()),
+ op.getSource(), rewriter.getIndexAttr(0),
+ ArrayRef<OpFoldResult>{targetSize},
ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
- builder.create<scf::YieldOp>(loc, casted);
+ scf::YieldOp::create(builder, loc, casted);
});
rewriter.replaceOp(op, ifOp.getResult(0));
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 2ba798f..9771bd2 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -66,7 +66,7 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter,
unsigned sourceRank = sourceType.getRank();
auto newExtractStridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
+ memref::ExtractStridedMetadataOp::create(rewriter, origLoc, source);
auto [sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
#ifndef NDEBUG
@@ -577,7 +577,7 @@ static FailureOr<StridedMetadata> resolveReshapeStridedMetadata(
unsigned sourceRank = sourceType.getRank();
auto newExtractStridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
+ memref::ExtractStridedMetadataOp::create(rewriter, origLoc, source);
// Collect statically known information.
auto [strides, offset] = sourceType.getStridesAndOffset();
@@ -828,14 +828,14 @@ public:
if (allocLikeOp.getType() == baseBufferType)
results.push_back(allocLikeOp);
else
- results.push_back(rewriter.create<memref::ReinterpretCastOp>(
- loc, baseBufferType, allocLikeOp, offset,
+ results.push_back(memref::ReinterpretCastOp::create(
+ rewriter, loc, baseBufferType, allocLikeOp, offset,
/*sizes=*/ArrayRef<int64_t>(),
/*strides=*/ArrayRef<int64_t>()));
}
// Offset.
- results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset));
+ results.push_back(arith::ConstantIndexOp::create(rewriter, loc, offset));
for (OpFoldResult size : sizes)
results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, size));
@@ -900,19 +900,19 @@ public:
if (getGlobalOp.getType() == baseBufferType)
results.push_back(getGlobalOp);
else
- results.push_back(rewriter.create<memref::ReinterpretCastOp>(
- loc, baseBufferType, getGlobalOp, offset,
+ results.push_back(memref::ReinterpretCastOp::create(
+ rewriter, loc, baseBufferType, getGlobalOp, offset,
/*sizes=*/ArrayRef<int64_t>(),
/*strides=*/ArrayRef<int64_t>()));
// Offset.
- results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset));
+ results.push_back(arith::ConstantIndexOp::create(rewriter, loc, offset));
for (auto size : sizes)
- results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, size));
+ results.push_back(arith::ConstantIndexOp::create(rewriter, loc, size));
for (auto stride : strides)
- results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, stride));
+ results.push_back(arith::ConstantIndexOp::create(rewriter, loc, stride));
rewriter.replaceOp(op, results);
return success();
@@ -1008,9 +1008,8 @@ class ExtractStridedMetadataOpReinterpretCastFolder
SmallVector<OpFoldResult> results;
results.resize_for_overwrite(rank * 2 + 2);
- auto newExtractStridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(
- loc, reinterpretCastOp.getSource());
+ auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
+ rewriter, loc, reinterpretCastOp.getSource());
// Register the base_buffer.
results[0] = newExtractStridedMetadata.getBaseBuffer();
@@ -1082,9 +1081,8 @@ class ExtractStridedMetadataOpCastFolder
SmallVector<OpFoldResult> results;
results.resize_for_overwrite(rank * 2 + 2);
- auto newExtractStridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(loc,
- castOp.getSource());
+ auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
+ rewriter, loc, castOp.getSource());
// Register the base_buffer.
results[0] = newExtractStridedMetadata.getBaseBuffer();
@@ -1142,9 +1140,8 @@ class ExtractStridedMetadataOpMemorySpaceCastFolder
auto memSpaceCastOp = source.getDefiningOp<memref::MemorySpaceCastOp>();
if (!memSpaceCastOp)
return failure();
- auto newExtractStridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(
- loc, memSpaceCastOp.getSource());
+ auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
+ rewriter, loc, memSpaceCastOp.getSource());
SmallVector<Value> results(newExtractStridedMetadata.getResults());
// As with most other strided metadata rewrite patterns, don't introduce
// a use of the base pointer where non existed. This needs to happen here,
@@ -1158,8 +1155,8 @@ class ExtractStridedMetadataOpMemorySpaceCastFolder
MemRefType::Builder newTypeBuilder(baseBufferType);
newTypeBuilder.setMemorySpace(
memSpaceCastOp.getResult().getType().getMemorySpace());
- results[0] = rewriter.create<memref::MemorySpaceCastOp>(
- loc, Type{newTypeBuilder}, baseBuffer);
+ results[0] = memref::MemorySpaceCastOp::create(
+ rewriter, loc, Type{newTypeBuilder}, baseBuffer);
} else {
results[0] = nullptr;
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
index 2f5c943..0946da8e 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
@@ -42,8 +42,8 @@ static memref::LoadOp rebuildLoadOp(RewriterBase &rewriter,
memref::LoadOp loadOp, Value srcMemRef,
ArrayRef<Value> indices) {
Location loc = loadOp.getLoc();
- return rewriter.create<memref::LoadOp>(loc, srcMemRef, indices,
- loadOp.getNontemporal());
+ return memref::LoadOp::create(rewriter, loc, srcMemRef, indices,
+ loadOp.getNontemporal());
}
// Matches getViewSizeForEachDim specs for LoadOp.
@@ -72,9 +72,8 @@ static memref::StoreOp rebuildStoreOp(RewriterBase &rewriter,
memref::StoreOp storeOp, Value srcMemRef,
ArrayRef<Value> indices) {
Location loc = storeOp.getLoc();
- return rewriter.create<memref::StoreOp>(loc, storeOp.getValueToStore(),
- srcMemRef, indices,
- storeOp.getNontemporal());
+ return memref::StoreOp::create(rewriter, loc, storeOp.getValueToStore(),
+ srcMemRef, indices, storeOp.getNontemporal());
}
// Matches getViewSizeForEachDim specs for StoreOp.
@@ -104,8 +103,8 @@ static nvgpu::LdMatrixOp rebuildLdMatrixOp(RewriterBase &rewriter,
Value srcMemRef,
ArrayRef<Value> indices) {
Location loc = ldMatrixOp.getLoc();
- return rewriter.create<nvgpu::LdMatrixOp>(
- loc, ldMatrixOp.getResult().getType(), srcMemRef, indices,
+ return nvgpu::LdMatrixOp::create(
+ rewriter, loc, ldMatrixOp.getResult().getType(), srcMemRef, indices,
ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles());
}
@@ -132,8 +131,8 @@ rebuildTransferReadOp(RewriterBase &rewriter,
vector::TransferReadOp transferReadOp, Value srcMemRef,
ArrayRef<Value> indices) {
Location loc = transferReadOp.getLoc();
- return rewriter.create<vector::TransferReadOp>(
- loc, transferReadOp.getResult().getType(), srcMemRef, indices,
+ return vector::TransferReadOp::create(
+ rewriter, loc, transferReadOp.getResult().getType(), srcMemRef, indices,
transferReadOp.getPermutationMap(), transferReadOp.getPadding(),
transferReadOp.getMask(), transferReadOp.getInBoundsAttr());
}
@@ -150,8 +149,8 @@ rebuildTransferWriteOp(RewriterBase &rewriter,
vector::TransferWriteOp transferWriteOp, Value srcMemRef,
ArrayRef<Value> indices) {
Location loc = transferWriteOp.getLoc();
- return rewriter.create<vector::TransferWriteOp>(
- loc, transferWriteOp.getValue(), srcMemRef, indices,
+ return vector::TransferWriteOp::create(
+ rewriter, loc, transferWriteOp.getValue(), srcMemRef, indices,
transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(),
transferWriteOp.getInBoundsAttr());
}
@@ -182,9 +181,8 @@ static SmallVector<OpFoldResult>
getGenericOpViewSizeForEachDim(RewriterBase &rewriter,
LoadStoreLikeOp loadStoreLikeOp) {
Location loc = loadStoreLikeOp.getLoc();
- auto extractStridedMetadataOp =
- rewriter.create<memref::ExtractStridedMetadataOp>(
- loc, getSrcMemRef(loadStoreLikeOp));
+ auto extractStridedMetadataOp = memref::ExtractStridedMetadataOp::create(
+ rewriter, loc, getSrcMemRef(loadStoreLikeOp));
SmallVector<OpFoldResult> srcSizes =
extractStridedMetadataOp.getConstifiedMixedSizes();
SmallVector<OpFoldResult> indices =
@@ -267,12 +265,12 @@ struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
// apply them properly to the input indices.
// Therefore the strides multipliers are simply ones.
auto subview =
- rewriter.create<memref::SubViewOp>(loc, /*source=*/srcMemRef,
- /*offsets=*/indices,
- /*sizes=*/sizes, /*strides=*/ones);
+ memref::SubViewOp::create(rewriter, loc, /*source=*/srcMemRef,
+ /*offsets=*/indices,
+ /*sizes=*/sizes, /*strides=*/ones);
// Rewrite the load/store with the subview as the base pointer.
SmallVector<Value> zeros(loadStoreRank,
- rewriter.create<arith::ConstantIndexOp>(loc, 0));
+ arith::ConstantIndexOp::create(rewriter, loc, 0));
LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices(
rewriter, loadStoreLikeOp, subview.getResult(), zeros);
rewriter.replaceOp(loadStoreLikeOp, newLoadStore->getResults());
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 76f7788..42be847 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -40,8 +40,8 @@ using namespace mlir;
static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
OpFoldResult in) {
if (Attribute offsetAttr = dyn_cast<Attribute>(in)) {
- return rewriter.create<arith::ConstantIndexOp>(
- loc, cast<IntegerAttr>(offsetAttr).getInt());
+ return arith::ConstantIndexOp::create(
+ rewriter, loc, cast<IntegerAttr>(offsetAttr).getInt());
}
return cast<Value>(in);
}
@@ -60,7 +60,7 @@ static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
}
memref::ExtractStridedMetadataOp stridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
+ memref::ExtractStridedMetadataOp::create(rewriter, loc, source);
auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth();
OpFoldResult linearizedIndices;
@@ -74,8 +74,8 @@ static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
getAsOpFoldResult(indices));
return std::make_pair(
- rewriter.create<memref::ReinterpretCastOp>(
- loc, source,
+ memref::ReinterpretCastOp::create(
+ rewriter, loc, source,
/* offset = */ linearizedInfo.linearizedOffset,
/* shapes = */
ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize},
@@ -111,7 +111,7 @@ template <typename T>
static void castAllocResult(T oper, T newOper, Location loc,
PatternRewriter &rewriter) {
memref::ExtractStridedMetadataOp stridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(loc, oper);
+ memref::ExtractStridedMetadataOp::create(rewriter, loc, oper);
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
oper, cast<MemRefType>(oper.getType()), newOper,
/*offset=*/rewriter.getIndexAttr(0),
@@ -125,63 +125,68 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
Location loc = op->getLoc();
llvm::TypeSwitch<Operation *>(op.getOperation())
.template Case<memref::AllocOp>([&](auto oper) {
- auto newAlloc = rewriter.create<memref::AllocOp>(
- loc, cast<MemRefType>(flatMemref.getType()),
+ auto newAlloc = memref::AllocOp::create(
+ rewriter, loc, cast<MemRefType>(flatMemref.getType()),
oper.getAlignmentAttr());
castAllocResult(oper, newAlloc, loc, rewriter);
})
.template Case<memref::AllocaOp>([&](auto oper) {
- auto newAlloca = rewriter.create<memref::AllocaOp>(
- loc, cast<MemRefType>(flatMemref.getType()),
+ auto newAlloca = memref::AllocaOp::create(
+ rewriter, loc, cast<MemRefType>(flatMemref.getType()),
oper.getAlignmentAttr());
castAllocResult(oper, newAlloca, loc, rewriter);
})
.template Case<memref::LoadOp>([&](auto op) {
- auto newLoad = rewriter.create<memref::LoadOp>(
- loc, op->getResultTypes(), flatMemref, ValueRange{offset});
+ auto newLoad =
+ memref::LoadOp::create(rewriter, loc, op->getResultTypes(),
+ flatMemref, ValueRange{offset});
newLoad->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newLoad.getResult());
})
.template Case<memref::StoreOp>([&](auto op) {
- auto newStore = rewriter.create<memref::StoreOp>(
- loc, op->getOperands().front(), flatMemref, ValueRange{offset});
+ auto newStore =
+ memref::StoreOp::create(rewriter, loc, op->getOperands().front(),
+ flatMemref, ValueRange{offset});
newStore->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newStore);
})
.template Case<vector::LoadOp>([&](auto op) {
- auto newLoad = rewriter.create<vector::LoadOp>(
- loc, op->getResultTypes(), flatMemref, ValueRange{offset});
+ auto newLoad =
+ vector::LoadOp::create(rewriter, loc, op->getResultTypes(),
+ flatMemref, ValueRange{offset});
newLoad->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newLoad.getResult());
})
.template Case<vector::StoreOp>([&](auto op) {
- auto newStore = rewriter.create<vector::StoreOp>(
- loc, op->getOperands().front(), flatMemref, ValueRange{offset});
+ auto newStore =
+ vector::StoreOp::create(rewriter, loc, op->getOperands().front(),
+ flatMemref, ValueRange{offset});
newStore->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newStore);
})
.template Case<vector::MaskedLoadOp>([&](auto op) {
- auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>(
- loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(),
- op.getPassThru());
+ auto newMaskedLoad = vector::MaskedLoadOp::create(
+ rewriter, loc, op.getType(), flatMemref, ValueRange{offset},
+ op.getMask(), op.getPassThru());
newMaskedLoad->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newMaskedLoad.getResult());
})
.template Case<vector::MaskedStoreOp>([&](auto op) {
- auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>(
- loc, flatMemref, ValueRange{offset}, op.getMask(),
+ auto newMaskedStore = vector::MaskedStoreOp::create(
+ rewriter, loc, flatMemref, ValueRange{offset}, op.getMask(),
op.getValueToStore());
newMaskedStore->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newMaskedStore);
})
.template Case<vector::TransferReadOp>([&](auto op) {
- auto newTransferRead = rewriter.create<vector::TransferReadOp>(
- loc, op.getType(), flatMemref, ValueRange{offset}, op.getPadding());
+ auto newTransferRead = vector::TransferReadOp::create(
+ rewriter, loc, op.getType(), flatMemref, ValueRange{offset},
+ op.getPadding());
rewriter.replaceOp(op, newTransferRead.getResult());
})
.template Case<vector::TransferWriteOp>([&](auto op) {
- auto newTransferWrite = rewriter.create<vector::TransferWriteOp>(
- loc, op.getVector(), flatMemref, ValueRange{offset});
+ auto newTransferWrite = vector::TransferWriteOp::create(
+ rewriter, loc, op.getVector(), flatMemref, ValueRange{offset});
rewriter.replaceOp(op, newTransferWrite);
})
.Default([&](auto op) {
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 89be188..24da447 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -44,97 +44,6 @@ using namespace mlir;
// Utility functions
//===----------------------------------------------------------------------===//
-/// Given the 'indices' of a load/store operation where the memref is a result
-/// of a expand_shape op, returns the indices w.r.t to the source memref of the
-/// expand_shape op. For example
-///
-/// %0 = ... : memref<12x42xf32>
-/// %1 = memref.expand_shape %0 [[0, 1], [2]]
-/// : memref<12x42xf32> into memref<2x6x42xf32>
-/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
-///
-/// could be folded into
-///
-/// %2 = load %0[6 * i1 + i2, %i3] :
-/// memref<12x42xf32>
-static LogicalResult resolveSourceIndicesExpandShape(
- Location loc, PatternRewriter &rewriter,
- memref::ExpandShapeOp expandShapeOp, ValueRange indices,
- SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
- SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
-
- // Traverse all reassociation groups to determine the appropriate indices
- // corresponding to each one of them post op folding.
- for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
- assert(!group.empty() && "association indices groups cannot be empty");
- int64_t groupSize = group.size();
- if (groupSize == 1) {
- sourceIndices.push_back(indices[group[0]]);
- continue;
- }
- SmallVector<OpFoldResult> groupBasis =
- llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
- SmallVector<Value> groupIndices =
- llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
- Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
- loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
- sourceIndices.push_back(collapsedIndex);
- }
- return success();
-}
-
-/// Given the 'indices' of a load/store operation where the memref is a result
-/// of a collapse_shape op, returns the indices w.r.t to the source memref of
-/// the collapse_shape op. For example
-///
-/// %0 = ... : memref<2x6x42xf32>
-/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
-/// : memref<2x6x42xf32> into memref<12x42xf32>
-/// %2 = load %1[%i1, %i2] : memref<12x42xf32>
-///
-/// could be folded into
-///
-/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
-/// memref<2x6x42xf32>
-static LogicalResult
-resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
- memref::CollapseShapeOp collapseShapeOp,
- ValueRange indices,
- SmallVectorImpl<Value> &sourceIndices) {
- // Note: collapse_shape requires a strided memref, we can do this.
- auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
- loc, collapseShapeOp.getSrc());
- SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
- for (auto [index, group] :
- llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
- assert(!group.empty() && "association indices groups cannot be empty");
- int64_t groupSize = group.size();
-
- if (groupSize == 1) {
- sourceIndices.push_back(index);
- continue;
- }
-
- SmallVector<OpFoldResult> basis =
- llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
- auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
- loc, index, basis, /*hasOuterBound=*/true);
- llvm::append_range(sourceIndices, delinearize.getResults());
- }
- if (collapseShapeOp.getReassociationIndices().empty()) {
- auto zeroAffineMap = rewriter.getConstantAffineMap(0);
- int64_t srcRank =
- cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
- rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
- for (int64_t i = 0; i < srcRank; i++) {
- sourceIndices.push_back(
- getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
- }
- }
- return success();
-}
-
/// Helpers to access the memref operand for each op.
template <typename LoadOrStoreOpTy>
static Value getMemRefOperand(LoadOrStoreOpTy op) {
diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index 35c661e..d5e2b97 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -51,14 +51,13 @@ FailureOr<Value> memref::buildIndependentOp(OpBuilder &b,
// Create a new memref::AllocaOp.
Value newAllocaOp =
- b.create<AllocaOp>(loc, newSizes, allocaOp.getType().getElementType());
+ AllocaOp::create(b, loc, newSizes, allocaOp.getType().getElementType());
// Create a memref::SubViewOp.
SmallVector<OpFoldResult> offsets(newSizes.size(), b.getIndexAttr(0));
SmallVector<OpFoldResult> strides(newSizes.size(), b.getIndexAttr(1));
- return b
- .create<SubViewOp>(loc, newAllocaOp, offsets, allocaOp.getMixedSizes(),
- strides)
+ return SubViewOp::create(b, loc, newAllocaOp, offsets,
+ allocaOp.getMixedSizes(), strides)
.getResult();
}
@@ -71,11 +70,11 @@ propagateSubViewOp(RewriterBase &rewriter,
MemRefType newResultType = SubViewOp::inferRankReducedResultType(
op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(),
op.getMixedSizes(), op.getMixedStrides());
- Value newSubview = rewriter.create<SubViewOp>(
- op.getLoc(), newResultType, conversionOp.getOperand(0),
+ Value newSubview = SubViewOp::create(
+ rewriter, op.getLoc(), newResultType, conversionOp.getOperand(0),
op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides());
- auto newConversionOp = rewriter.create<UnrealizedConversionCastOp>(
- op.getLoc(), op.getType(), newSubview);
+ auto newConversionOp = UnrealizedConversionCastOp::create(
+ rewriter, op.getLoc(), op.getType(), newSubview);
rewriter.replaceAllUsesWith(op.getResult(), newConversionOp->getResult(0));
return newConversionOp;
}
@@ -106,8 +105,8 @@ static void replaceAndPropagateMemRefType(RewriterBase &rewriter,
SmallVector<UnrealizedConversionCastOp> unrealizedConversions;
for (const auto &it :
llvm::enumerate(llvm::zip(from->getResults(), to->getResults()))) {
- unrealizedConversions.push_back(rewriter.create<UnrealizedConversionCastOp>(
- to->getLoc(), std::get<0>(it.value()).getType(),
+ unrealizedConversions.push_back(UnrealizedConversionCastOp::create(
+ rewriter, to->getLoc(), std::get<0>(it.value()).getType(),
std::get<1>(it.value())));
rewriter.replaceAllUsesWith(from->getResult(it.index()),
unrealizedConversions.back()->getResult(0));
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index 0a84962..5d3cec4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -63,9 +63,10 @@ static void replaceUsesAndPropagateType(RewriterBase &rewriter,
subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
subviewUse.getStaticStrides());
- Value newSubview = rewriter.create<memref::SubViewOp>(
- subviewUse->getLoc(), newType, val, subviewUse.getMixedOffsets(),
- subviewUse.getMixedSizes(), subviewUse.getMixedStrides());
+ Value newSubview = memref::SubViewOp::create(
+ rewriter, subviewUse->getLoc(), newType, val,
+ subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
+ subviewUse.getMixedStrides());
// Ouch recursion ... is this really necessary?
replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
@@ -177,8 +178,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
Location loc = allocOp->getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(allocOp);
- auto mbAlloc = rewriter.create<memref::AllocOp>(
- loc, mbMemRefType, ValueRange{}, allocOp->getAttrs());
+ auto mbAlloc = memref::AllocOp::create(rewriter, loc, mbMemRefType,
+ ValueRange{}, allocOp->getAttrs());
LLVM_DEBUG(DBGS() << "--multi-buffered alloc: " << mbAlloc << "\n");
// 3. Within the loop, build the modular leading index (i.e. each loop
@@ -211,8 +212,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
// Strides is [1, 1 ... 1 ].
MemRefType dstMemref = memref::SubViewOp::inferRankReducedResultType(
originalShape, mbMemRefType, offsets, sizes, strides);
- Value subview = rewriter.create<memref::SubViewOp>(loc, dstMemref, mbAlloc,
- offsets, sizes, strides);
+ Value subview = memref::SubViewOp::create(rewriter, loc, dstMemref, mbAlloc,
+ offsets, sizes, strides);
LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");
// 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to
@@ -224,7 +225,7 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(deallocOp);
auto newDeallocOp =
- rewriter.create<memref::DeallocOp>(deallocOp->getLoc(), mbAlloc);
+ memref::DeallocOp::create(rewriter, deallocOp->getLoc(), mbAlloc);
(void)newDeallocOp;
LLVM_DEBUG(DBGS() << "----Created dealloc: " << newDeallocOp << "\n");
rewriter.eraseOp(deallocOp);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
index 4ec0432..fa7991e 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
@@ -276,8 +276,8 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
if (!callOp)
continue;
Operation *newCallOp =
- builder.create<func::CallOp>(userOp->getLoc(), callOp.getCalleeAttr(),
- resultTypes, userOp->getOperands());
+ func::CallOp::create(builder, userOp->getLoc(), callOp.getCalleeAttr(),
+ resultTypes, userOp->getOperands());
bool replacingMemRefUsesFailed = false;
bool returnTypeChanged = false;
for (unsigned resIndex : llvm::seq<unsigned>(0, userOp->getNumResults())) {
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
index 46f9d64e..d65825b 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
@@ -115,10 +115,12 @@ static LogicalResult reifyOpResultShapes(RewriterBase &rewriter,
// Update the type.
newRes.setType(reifiedTy);
if (isa<RankedTensorType>(reifiedTy)) {
- newResults.push_back(rewriter.create<tensor::CastOp>(loc, oldTy, newRes));
+ newResults.push_back(
+ tensor::CastOp::create(rewriter, loc, oldTy, newRes));
} else {
assert(isa<MemRefType>(reifiedTy) && "expected a memref type");
- newResults.push_back(rewriter.create<memref::CastOp>(loc, oldTy, newRes));
+ newResults.push_back(
+ memref::CastOp::create(rewriter, loc, oldTy, newRes));
}
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 89a3895..6a81a15 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -69,7 +69,7 @@ struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> {
Location loc = dimOp->getLoc();
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
dimOp, resultShape,
- rewriter.create<arith::ConstantIndexOp>(loc, *dimIndex).getResult());
+ arith::ConstantIndexOp::create(rewriter, loc, *dimIndex).getResult());
return success();
}
};
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index d231516..d3a77c0 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -40,19 +40,18 @@ struct AssumeAlignmentOpInterface
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
auto assumeOp = cast<AssumeAlignmentOp>(op);
- Value ptr = builder.create<ExtractAlignedPointerAsIndexOp>(
- loc, assumeOp.getMemref());
- Value rest = builder.create<arith::RemUIOp>(
- loc, ptr,
- builder.create<arith::ConstantIndexOp>(loc, assumeOp.getAlignment()));
- Value isAligned = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, rest,
- builder.create<arith::ConstantIndexOp>(loc, 0));
- builder.create<cf::AssertOp>(
- loc, isAligned,
- RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "memref is not aligned to " +
- std::to_string(assumeOp.getAlignment())));
+ Value ptr = ExtractAlignedPointerAsIndexOp::create(builder, loc,
+ assumeOp.getMemref());
+ Value rest = arith::RemUIOp::create(
+ builder, loc, ptr,
+ arith::ConstantIndexOp::create(builder, loc, assumeOp.getAlignment()));
+ Value isAligned =
+ arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, rest,
+ arith::ConstantIndexOp::create(builder, loc, 0));
+ cf::AssertOp::create(builder, loc, isAligned,
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "memref is not aligned to " +
+ std::to_string(assumeOp.getAlignment())));
}
};
@@ -71,15 +70,14 @@ struct CastOpInterface
if (isa<UnrankedMemRefType>(srcType)) {
// Check rank.
- Value srcRank = builder.create<RankOp>(loc, castOp.getSource());
+ Value srcRank = RankOp::create(builder, loc, castOp.getSource());
Value resultRank =
- builder.create<arith::ConstantIndexOp>(loc, resultType.getRank());
- Value isSameRank = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, srcRank, resultRank);
- builder.create<cf::AssertOp>(
- loc, isSameRank,
- RuntimeVerifiableOpInterface::generateErrorMessage(op,
- "rank mismatch"));
+ arith::ConstantIndexOp::create(builder, loc, resultType.getRank());
+ Value isSameRank = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
+ cf::AssertOp::create(builder, loc, isSameRank,
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "rank mismatch"));
}
// Get source offset and strides. We do not have an op to get offsets and
@@ -95,8 +93,9 @@ struct CastOpInterface
MemRefType::get(dynamicShape, resultType.getElementType(),
stridedLayout, resultType.getMemorySpace());
Value helperCast =
- builder.create<CastOp>(loc, dynStridesType, castOp.getSource());
- auto metadataOp = builder.create<ExtractStridedMetadataOp>(loc, helperCast);
+ CastOp::create(builder, loc, dynStridesType, castOp.getSource());
+ auto metadataOp =
+ ExtractStridedMetadataOp::create(builder, loc, helperCast);
// Check dimension sizes.
for (const auto &it : llvm::enumerate(resultType.getShape())) {
@@ -110,13 +109,13 @@ struct CastOpInterface
continue;
Value srcDimSz =
- builder.create<DimOp>(loc, castOp.getSource(), it.index());
+ DimOp::create(builder, loc, castOp.getSource(), it.index());
Value resultDimSz =
- builder.create<arith::ConstantIndexOp>(loc, it.value());
- Value isSameSz = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
- builder.create<cf::AssertOp>(
- loc, isSameSz,
+ arith::ConstantIndexOp::create(builder, loc, it.value());
+ Value isSameSz = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
+ cf::AssertOp::create(
+ builder, loc, isSameSz,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "size mismatch of dim " + std::to_string(it.index())));
}
@@ -132,13 +131,12 @@ struct CastOpInterface
// Static/dynamic offset -> dynamic offset does not need verification.
Value srcOffset = metadataOp.getResult(1);
Value resultOffsetVal =
- builder.create<arith::ConstantIndexOp>(loc, resultOffset);
- Value isSameOffset = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
- builder.create<cf::AssertOp>(
- loc, isSameOffset,
- RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "offset mismatch"));
+ arith::ConstantIndexOp::create(builder, loc, resultOffset);
+ Value isSameOffset = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
+ cf::AssertOp::create(builder, loc, isSameOffset,
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "offset mismatch"));
}
// Check strides.
@@ -150,11 +148,11 @@ struct CastOpInterface
Value srcStride =
metadataOp.getResult(2 + resultType.getRank() + it.index());
Value resultStrideVal =
- builder.create<arith::ConstantIndexOp>(loc, it.value());
- Value isSameStride = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
- builder.create<cf::AssertOp>(
- loc, isSameStride,
+ arith::ConstantIndexOp::create(builder, loc, it.value());
+ Value isSameStride = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
+ cf::AssertOp::create(
+ builder, loc, isSameStride,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "stride mismatch of dim " + std::to_string(it.index())));
}
@@ -186,21 +184,19 @@ struct CopyOpInterface
auto getDimSize = [&](Value memRef, MemRefType type,
int64_t dim) -> Value {
return type.isDynamicDim(dim)
- ? builder.create<DimOp>(loc, memRef, dim).getResult()
- : builder
- .create<arith::ConstantIndexOp>(loc,
- type.getDimSize(dim))
+ ? DimOp::create(builder, loc, memRef, dim).getResult()
+ : arith::ConstantIndexOp::create(builder, loc,
+ type.getDimSize(dim))
.getResult();
};
Value sourceDim = getDimSize(copyOp.getSource(), rankedSourceType, i);
Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i);
- Value sameDimSize = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
- builder.create<cf::AssertOp>(
- loc, sameDimSize,
- RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "size of " + std::to_string(i) +
- "-th source/target dim does not match"));
+ Value sameDimSize = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
+ cf::AssertOp::create(builder, loc, sameDimSize,
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "size of " + std::to_string(i) +
+ "-th source/target dim does not match"));
}
}
};
@@ -211,10 +207,11 @@ struct DimOpInterface
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
auto dimOp = cast<DimOp>(op);
- Value rank = builder.create<RankOp>(loc, dimOp.getSource());
- Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
- builder.create<cf::AssertOp>(
- loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
+ Value rank = RankOp::create(builder, loc, dimOp.getSource());
+ Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
+ cf::AssertOp::create(
+ builder, loc,
+ generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "index is out of bounds"));
}
@@ -237,7 +234,7 @@ struct LoadStoreOpInterface
}
auto indices = loadStoreOp.getIndices();
- auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
Value assertCond;
for (auto i : llvm::seq<int64_t>(0, rank)) {
Value dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
@@ -247,10 +244,9 @@ struct LoadStoreOpInterface
i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
: inBounds;
}
- builder.create<cf::AssertOp>(
- loc, assertCond,
- RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "out-of-bounds access"));
+ cf::AssertOp::create(builder, loc, assertCond,
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "out-of-bounds access"));
}
};
@@ -265,10 +261,10 @@ struct SubViewOpInterface
// For each dimension, assert that:
// 0 <= offset < dim_size
// 0 <= offset + (size - 1) * stride < dim_size
- Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
- Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+ Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
+ Value one = arith::ConstantIndexOp::create(builder, loc, 1);
auto metadataOp =
- builder.create<ExtractStridedMetadataOp>(loc, subView.getSource());
+ ExtractStridedMetadataOp::create(builder, loc, subView.getSource());
for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
Value offset = getValueOrCreateConstantIndexOp(
builder, loc, subView.getMixedOffsets()[i]);
@@ -281,21 +277,21 @@ struct SubViewOpInterface
Value dimSize = metadataOp.getSizes()[i];
Value offsetInBounds =
generateInBoundsCheck(builder, loc, offset, zero, dimSize);
- builder.create<cf::AssertOp>(
- loc, offsetInBounds,
+ cf::AssertOp::create(
+ builder, loc, offsetInBounds,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "offset " + std::to_string(i) + " is out-of-bounds"));
// Verify that slice does not run out-of-bounds.
- Value sizeMinusOne = builder.create<arith::SubIOp>(loc, size, one);
+ Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
Value sizeMinusOneTimesStride =
- builder.create<arith::MulIOp>(loc, sizeMinusOne, stride);
+ arith::MulIOp::create(builder, loc, sizeMinusOne, stride);
Value lastPos =
- builder.create<arith::AddIOp>(loc, offset, sizeMinusOneTimesStride);
+ arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
Value lastPosInBounds =
generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
- builder.create<cf::AssertOp>(
- loc, lastPosInBounds,
+ cf::AssertOp::create(
+ builder, loc, lastPosInBounds,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "subview runs out-of-bounds along dimension " +
std::to_string(i)));
@@ -315,7 +311,7 @@ struct ExpandShapeOpInterface
for (const auto &it :
llvm::enumerate(expandShapeOp.getReassociationIndices())) {
Value srcDimSz =
- builder.create<DimOp>(loc, expandShapeOp.getSrc(), it.index());
+ DimOp::create(builder, loc, expandShapeOp.getSrc(), it.index());
int64_t groupSz = 1;
bool foundDynamicDim = false;
for (int64_t resultDim : it.value()) {
@@ -330,18 +326,17 @@ struct ExpandShapeOpInterface
groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
}
Value staticResultDimSz =
- builder.create<arith::ConstantIndexOp>(loc, groupSz);
+ arith::ConstantIndexOp::create(builder, loc, groupSz);
// staticResultDimSz must divide srcDimSz evenly.
Value mod =
- builder.create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz);
- Value isModZero = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, mod,
- builder.create<arith::ConstantIndexOp>(loc, 0));
- builder.create<cf::AssertOp>(
- loc, isModZero,
- RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "static result dims in reassoc group do not "
- "divide src dim evenly"));
+ arith::RemSIOp::create(builder, loc, srcDimSz, staticResultDimSz);
+ Value isModZero = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::eq, mod,
+ arith::ConstantIndexOp::create(builder, loc, 0));
+ cf::AssertOp::create(builder, loc, isModZero,
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "static result dims in reassoc group do not "
+ "divide src dim evenly"));
}
}
};
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index a50b4cf..5af46a4 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/STLExtras.h"
@@ -217,5 +218,70 @@ MemrefValue skipViewLikeOps(MemrefValue source) {
return source;
}
+LogicalResult resolveSourceIndicesExpandShape(
+ Location loc, PatternRewriter &rewriter,
+ memref::ExpandShapeOp expandShapeOp, ValueRange indices,
+ SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
+ SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
+
+ // Traverse all reassociation groups to determine the appropriate indices
+ // corresponding to each one of them post op folding.
+ for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
+ assert(!group.empty() && "association indices groups cannot be empty");
+ int64_t groupSize = group.size();
+ if (groupSize == 1) {
+ sourceIndices.push_back(indices[group[0]]);
+ continue;
+ }
+ SmallVector<OpFoldResult> groupBasis =
+ llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
+ SmallVector<Value> groupIndices =
+ llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
+ Value collapsedIndex = affine::AffineLinearizeIndexOp::create(
+ rewriter, loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
+ sourceIndices.push_back(collapsedIndex);
+ }
+ return success();
+}
+
+LogicalResult
+resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
+ memref::CollapseShapeOp collapseShapeOp,
+ ValueRange indices,
+ SmallVectorImpl<Value> &sourceIndices) {
+ // Note: collapse_shape requires a strided memref, we can do this.
+ auto metadata = memref::ExtractStridedMetadataOp::create(
+ rewriter, loc, collapseShapeOp.getSrc());
+ SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
+ for (auto [index, group] :
+ llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
+ assert(!group.empty() && "association indices groups cannot be empty");
+ int64_t groupSize = group.size();
+
+ if (groupSize == 1) {
+ sourceIndices.push_back(index);
+ continue;
+ }
+
+ SmallVector<OpFoldResult> basis =
+ llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
+ auto delinearize = affine::AffineDelinearizeIndexOp::create(
+ rewriter, loc, index, basis, /*hasOuterBound=*/true);
+ llvm::append_range(sourceIndices, delinearize.getResults());
+ }
+ if (collapseShapeOp.getReassociationIndices().empty()) {
+ auto zeroAffineMap = rewriter.getConstantAffineMap(0);
+ int64_t srcRank =
+ cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
+ OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
+ for (int64_t i = 0; i < srcRank; i++) {
+ sourceIndices.push_back(
+ getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+ }
+ }
+ return success();
+}
+
} // namespace memref
} // namespace mlir
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index f5f0bfa..bc3e8b2 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -38,9 +38,6 @@ using namespace mlir::NVVM;
using namespace mlir::transform;
#define DEBUG_TYPE "nvgpu-transforms"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
-#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
//===----------------------------------------------------------------------===//
// Apply...ConversionPatternsOp
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index e73bdd3..9d5dfc1 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -2957,6 +2957,23 @@ bool acc::LoopOp::hasDefaultGangWorkerVector() {
getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
}
+acc::LoopParMode
+acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
+ if (hasSeq(deviceType))
+ return LoopParMode::loop_seq;
+ if (hasAuto(deviceType))
+ return LoopParMode::loop_auto;
+ if (hasIndependent(deviceType))
+ return LoopParMode::loop_independent;
+ if (hasSeq())
+ return LoopParMode::loop_seq;
+ if (hasAuto())
+ return LoopParMode::loop_auto;
+ assert(hasIndependent() &&
+ "loop must have default auto, seq, or independent");
+ return LoopParMode::loop_independent;
+}
+
void acc::LoopOp::addGangOperands(
MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
llvm::ArrayRef<GangArgType> argTypes, mlir::ValueRange values) {
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 5d6c5499..c1c1767 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1730,8 +1730,7 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
if (!mapOp.getDefiningOp())
return emitError(op->getLoc(), "missing map operation");
- if (auto mapInfoOp =
- mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
+ if (auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
uint64_t mapTypeBits = mapInfoOp.getMapType();
bool to = mapTypeToBitFlag(
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index b44dbfd..c5ec0ca 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -53,7 +53,7 @@ OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
Value ptrLike;
FromPtrOp fromPtr = *this;
while (fromPtr != nullptr) {
- auto toPtr = dyn_cast_or_null<ToPtrOp>(fromPtr.getPtr().getDefiningOp());
+ auto toPtr = fromPtr.getPtr().getDefiningOp<ToPtrOp>();
// Cannot fold if it's not a `to_ptr` op or the initial and final types are
// different.
if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType())
@@ -64,13 +64,12 @@ OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
ptrLike = toPtr.getPtr();
} else if (md) {
// Fold if the metadata can be verified to be equal.
- if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp());
+ if (auto mdOp = md.getDefiningOp<GetMetadataOp>();
mdOp && mdOp.getPtr() == toPtr.getPtr())
ptrLike = toPtr.getPtr();
}
// Check for a sequence of casts.
- fromPtr = dyn_cast_or_null<FromPtrOp>(ptrLike ? ptrLike.getDefiningOp()
- : nullptr);
+ fromPtr = ptrLike ? ptrLike.getDefiningOp<FromPtrOp>() : nullptr;
}
return ptrLike;
}
@@ -112,13 +111,13 @@ OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) {
Value ptr;
ToPtrOp toPtr = *this;
while (toPtr != nullptr) {
- auto fromPtr = dyn_cast_or_null<FromPtrOp>(toPtr.getPtr().getDefiningOp());
+ auto fromPtr = toPtr.getPtr().getDefiningOp<FromPtrOp>();
// Cannot fold if it's not a `from_ptr` op.
if (!fromPtr)
return ptr;
ptr = fromPtr.getPtr();
// Check for chains of casts.
- toPtr = dyn_cast_or_null<ToPtrOp>(ptr.getDefiningOp());
+ toPtr = ptr.getDefiningOp<ToPtrOp>();
}
return ptr;
}
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index 58cd160..9e37bc5 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -148,16 +148,14 @@ flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input,
auto axisValue = arith::ConstantIndexOp::create(builder, loc, axis);
auto axisNextValue = arith::ConstantIndexOp::create(builder, loc, axis + 1);
auto shapeLeft =
- builder
- .create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
- inputShape, axisValue)
+ shape::SplitAtOp::create(builder, loc, TypeRange{shapeType, shapeType},
+ inputShape, axisValue)
.getResult(0);
auto sizeLeft =
shape::NumElementsOp::create(builder, loc, indexType, shapeLeft);
auto shapeRight =
- builder
- .create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
- inputShape, axisNextValue)
+ shape::SplitAtOp::create(builder, loc, TypeRange{shapeType, shapeType},
+ inputShape, axisNextValue)
.getResult(1);
auto sizeRight =
shape::NumElementsOp::create(builder, loc, indexType, shapeRight);
@@ -557,25 +555,24 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
SmallVector<AffineMap> indexingMaps{
builder.getMultiDimIdentityMap(inputRank), channelAxisAffineMap,
channelAxisAffineMap, builder.getMultiDimIdentityMap(inputRank)};
- auto result = builder
- .create<linalg::GenericOp>(
- loc,
- init.getType(), // resultType
- ValueRange{input, scales, zeroPoints}, // inputs
- ValueRange{init}, // outputs
- indexingMaps, iteratorTypes,
- [&](OpBuilder &builder, Location loc, ValueRange args) {
- assert(args.size() == 4);
- auto input = args[0];
- auto scale = args[1];
- auto zeroPoint = args[2];
-
- auto result =
- convertRanked(builder, loc, op, input, {}, scale,
- zeroPoint, quantizedType);
-
- linalg::YieldOp::create(builder, loc, result);
- })
+ auto result = linalg::GenericOp::create(
+ builder, loc,
+ init.getType(), // resultType
+ ValueRange{input, scales, zeroPoints}, // inputs
+ ValueRange{init}, // outputs
+ indexingMaps, iteratorTypes,
+ [&](OpBuilder &builder, Location loc, ValueRange args) {
+ assert(args.size() == 4);
+ auto input = args[0];
+ auto scale = args[1];
+ auto zeroPoint = args[2];
+
+ auto result =
+ convertRanked(builder, loc, op, input, {}, scale,
+ zeroPoint, quantizedType);
+
+ linalg::YieldOp::create(builder, loc, result);
+ })
.getResult(0);
return result;
@@ -660,25 +657,24 @@ Value convertSubChannel(OpBuilder &builder, Location loc, Operation *op,
SmallVector<AffineMap> indexingMaps{
builder.getMultiDimIdentityMap(inputRank), affineMap, affineMap,
builder.getMultiDimIdentityMap(inputRank)};
- auto result = builder
- .create<linalg::GenericOp>(
- loc,
- init.getType(), // resultType
- ValueRange{input, scales, zeroPoints}, // inputs
- ValueRange{init}, // outputs
- indexingMaps, iteratorTypes,
- [&](OpBuilder &builder, Location loc, ValueRange args) {
- assert(args.size() == 4);
- auto input = args[0];
- auto scale = args[1];
- auto zeroPoint = args[2];
-
- auto result =
- convertRanked(builder, loc, op, input, {}, scale,
- zeroPoint, quantizedType);
-
- linalg::YieldOp::create(builder, loc, result);
- })
+ auto result = linalg::GenericOp::create(
+ builder, loc,
+ init.getType(), // resultType
+ ValueRange{input, scales, zeroPoints}, // inputs
+ ValueRange{init}, // outputs
+ indexingMaps, iteratorTypes,
+ [&](OpBuilder &builder, Location loc, ValueRange args) {
+ assert(args.size() == 4);
+ auto input = args[0];
+ auto scale = args[1];
+ auto zeroPoint = args[2];
+
+ auto result =
+ convertRanked(builder, loc, op, input, {}, scale,
+ zeroPoint, quantizedType);
+
+ linalg::YieldOp::create(builder, loc, result);
+ })
.getResult(0);
return result;
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index e282ca4..0262a1b 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -84,7 +84,7 @@ void SCFDialect::initialize() {
/// Default callback for IfOp builders. Inserts a yield without arguments.
void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) {
- builder.create<scf::YieldOp>(loc);
+ scf::YieldOp::create(builder, loc);
}
/// Verifies that the first block of the given `region` is terminated by a
@@ -137,6 +137,9 @@ ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
if (parser.parseOptionalArrowTypeList(result.types))
return failure();
+ if (succeeded(parser.parseOptionalKeyword("no_inline")))
+ result.addAttribute("no_inline", parser.getBuilder().getUnitAttr());
+
// Introduce the body region and parse it.
Region *body = result.addRegion();
if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
@@ -148,8 +151,9 @@ ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
void ExecuteRegionOp::print(OpAsmPrinter &p) {
p.printOptionalArrowTypeList(getResultTypes());
-
p << ' ';
+ if (getNoInline())
+ p << "no_inline ";
p.printRegion(getRegion(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
@@ -184,7 +188,7 @@ struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
LogicalResult matchAndRewrite(ExecuteRegionOp op,
PatternRewriter &rewriter) const override {
- if (!op.getRegion().hasOneBlock())
+ if (!op.getRegion().hasOneBlock() || op.getNoInline())
return failure();
replaceOpWithRegion(rewriter, op, op.getRegion());
return success();
@@ -240,13 +244,13 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
rewriter.setInsertionPointToEnd(prevBlock);
- rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
+ cf::BranchOp::create(rewriter, op.getLoc(), &op.getRegion().front());
for (Block &blk : op.getRegion()) {
if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
rewriter.setInsertionPoint(yieldOp);
- rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
- yieldOp.getResults());
+ cf::BranchOp::create(rewriter, yieldOp.getLoc(), postBlock,
+ yieldOp.getResults());
rewriter.eraseOp(yieldOp);
}
}
@@ -556,8 +560,8 @@ ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
rewriter.setInsertionPoint(getOperation());
auto inits = llvm::to_vector(getInitArgs());
inits.append(newInitOperands.begin(), newInitOperands.end());
- scf::ForOp newLoop = rewriter.create<scf::ForOp>(
- getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
+ scf::ForOp newLoop = scf::ForOp::create(
+ rewriter, getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
[](OpBuilder &, Location, Value, ValueRange) {});
newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
@@ -672,8 +676,8 @@ void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
Value dst = parallelInsertSliceOp.getDest();
Value src = parallelInsertSliceOp.getSource();
if (llvm::isa<TensorType>(src.getType())) {
- results.push_back(rewriter.create<tensor::InsertSliceOp>(
- forallOp.getLoc(), dst.getType(), src, dst,
+ results.push_back(tensor::InsertSliceOp::create(
+ rewriter, forallOp.getLoc(), dst.getType(), src, dst,
parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
parallelInsertSliceOp.getStrides(),
parallelInsertSliceOp.getStaticOffsets(),
@@ -721,8 +725,8 @@ LoopNest mlir::scf::buildLoopNest(
ValueRange currentIterArgs = iterArgs;
Location currentLoc = loc;
for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
- auto loop = builder.create<scf::ForOp>(
- currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
+ auto loop = scf::ForOp::create(
+ builder, currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
[&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
ValueRange args) {
ivs.push_back(iv);
@@ -741,7 +745,7 @@ LoopNest mlir::scf::buildLoopNest(
// For all loops but the innermost, yield the results of the nested loop.
for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
builder.setInsertionPointToEnd(loops[i].getBody());
- builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
+ scf::YieldOp::create(builder, loc, loops[i + 1].getResults());
}
// In the body of the innermost loop, call the body building function if any
@@ -755,7 +759,7 @@ LoopNest mlir::scf::buildLoopNest(
"loop nest body must return as many values as loop has iteration "
"arguments");
builder.setInsertionPointToEnd(loops.back().getBody());
- builder.create<scf::YieldOp>(loc, results);
+ scf::YieldOp::create(builder, loc, results);
// Return the loops.
ValueVector nestResults;
@@ -800,8 +804,8 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
}
// 2. Create the new forOp shell.
- scf::ForOp newForOp = rewriter.create<scf::ForOp>(
- forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
+ scf::ForOp newForOp = scf::ForOp::create(
+ rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newIterOperands);
newForOp->setAttrs(forOp->getAttrs());
Block &newBlock = newForOp.getRegion().front();
@@ -830,7 +834,7 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
clonedYieldOp.getOperand(yieldIdx));
SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
newYieldOperands[yieldIdx] = castOut;
- rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
+ scf::YieldOp::create(rewriter, newForOp.getLoc(), newYieldOperands);
rewriter.eraseOp(clonedYieldOp);
// 6. Inject an outgoing cast op after the forOp.
@@ -925,9 +929,9 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
if (!canonicalize)
return failure();
- scf::ForOp newForOp = rewriter.create<scf::ForOp>(
- forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), newIterArgs);
+ scf::ForOp newForOp =
+ scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
+ forOp.getUpperBound(), forOp.getStep(), newIterArgs);
newForOp->setAttrs(forOp->getAttrs());
Block &newBlock = newForOp.getRegion().front();
@@ -969,8 +973,8 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
if (keepMask[idx])
filteredOperands.push_back(mergedTerminator.getOperand(idx));
- rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
- filteredOperands);
+ scf::YieldOp::create(rewriter, mergedTerminator.getLoc(),
+ filteredOperands);
};
rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
@@ -1110,7 +1114,7 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
op, replaceAndCastForOpIterArg(
rewriter, op, iterOpOperand, incomingCast.getSource(),
[](OpBuilder &b, Location loc, Type type, Value source) {
- return b.create<tensor::CastOp>(loc, type, source);
+ return tensor::CastOp::create(b, loc, type, source);
}));
return success();
}
@@ -1684,8 +1688,8 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
// Step 3. Create a new scf.forall op with the new shared_outs' operands
// fetched earlier
- auto newForallOp = rewriter.create<scf::ForallOp>(
- forallOp.getLoc(), forallOp.getMixedLowerBound(),
+ auto newForallOp = scf::ForallOp::create(
+ rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
forallOp.getMapping(),
/*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
@@ -1781,9 +1785,9 @@ struct ForallOpSingleOrZeroIterationDimsFolder
// Replace the loop by a lower-dimensional loop.
ForallOp newOp;
- newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,
- newMixedUpperBounds, newMixedSteps,
- op.getOutputs(), std::nullopt, nullptr);
+ newOp = ForallOp::create(rewriter, loc, newMixedLowerBounds,
+ newMixedUpperBounds, newMixedSteps,
+ op.getOutputs(), std::nullopt, nullptr);
newOp.getBodyRegion().getBlocks().clear();
// The new loop needs to keep all attributes from the old one, except for
// "operandSegmentSizes" and static loop bound attributes which capture
@@ -1866,16 +1870,17 @@ struct FoldTensorCastOfOutputIntoForallOp
// Create new loop.
Location loc = forallOp.getLoc();
- auto newForallOp = rewriter.create<ForallOp>(
- loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
- forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
+ auto newForallOp = ForallOp::create(
+ rewriter, loc, forallOp.getMixedLowerBound(),
+ forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
+ newOutputTensors, forallOp.getMapping(),
[&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
auto castBlockArgs =
llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
for (auto [index, cast] : tensorCastProducers) {
Value &oldTypeBBArg = castBlockArgs[index];
- oldTypeBBArg = nestedBuilder.create<tensor::CastOp>(
- nestedLoc, cast.dstType, oldTypeBBArg);
+ oldTypeBBArg = tensor::CastOp::create(nestedBuilder, nestedLoc,
+ cast.dstType, oldTypeBBArg);
}
// Move old body into new parallel loop.
@@ -1901,8 +1906,8 @@ struct FoldTensorCastOfOutputIntoForallOp
SmallVector<Value> castResults = newForallOp.getResults();
for (auto &item : tensorCastProducers) {
Value &oldTypeResult = castResults[item.first];
- oldTypeResult = rewriter.create<tensor::CastOp>(loc, item.second.dstType,
- oldTypeResult);
+ oldTypeResult = tensor::CastOp::create(rewriter, loc, item.second.dstType,
+ oldTypeResult);
}
rewriter.replaceOp(forallOp, castResults);
return success();
@@ -2310,7 +2315,7 @@ struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
// Create a replacement operation with empty then and else regions.
auto newOp =
- rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition());
+ IfOp::create(rewriter, op.getLoc(), newTypes, op.getCondition());
rewriter.createBlock(&newOp.getThenRegion());
rewriter.createBlock(&newOp.getElseRegion());
@@ -2373,8 +2378,8 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
if (nonHoistable.size() == op->getNumResults())
return failure();
- IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond,
- /*withElseRegion=*/false);
+ IfOp replacement = IfOp::create(rewriter, op.getLoc(), nonHoistable, cond,
+ /*withElseRegion=*/false);
if (replacement.thenBlock())
rewriter.eraseBlock(replacement.thenBlock());
replacement.getThenRegion().takeBody(op.getThenRegion());
@@ -2399,8 +2404,8 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
} else if (trueVal == falseVal)
results[it.index()] = trueVal;
else
- results[it.index()] = rewriter.create<arith::SelectOp>(
- op.getLoc(), cond, trueVal, falseVal);
+ results[it.index()] = arith::SelectOp::create(rewriter, op.getLoc(),
+ cond, trueVal, falseVal);
}
rewriter.setInsertionPointToEnd(replacement.thenBlock());
@@ -2489,8 +2494,8 @@ struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
if (!trueVal && falseVal) {
if (!opResult.use_empty()) {
Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
- Value notCond = rewriter.create<arith::XOrIOp>(
- op.getLoc(), op.getCondition(),
+ Value notCond = arith::XOrIOp::create(
+ rewriter, op.getLoc(), op.getCondition(),
constDialect
->materializeConstant(rewriter,
rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
@@ -2603,8 +2608,8 @@ struct CombineIfs : public OpRewritePattern<IfOp> {
SmallVector<Type> mergedTypes(prevIf.getResultTypes());
llvm::append_range(mergedTypes, nextIf.getResultTypes());
- IfOp combinedIf = rewriter.create<IfOp>(
- nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false);
+ IfOp combinedIf = IfOp::create(rewriter, nextIf.getLoc(), mergedTypes,
+ prevIf.getCondition(), /*hasElse=*/false);
rewriter.eraseBlock(&combinedIf.getThenRegion().back());
rewriter.inlineRegionBefore(prevIf.getThenRegion(),
@@ -2619,7 +2624,7 @@ struct CombineIfs : public OpRewritePattern<IfOp> {
SmallVector<Value> mergedYields(thenYield.getOperands());
llvm::append_range(mergedYields, thenYield2.getOperands());
- rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
+ YieldOp::create(rewriter, thenYield2.getLoc(), mergedYields);
rewriter.eraseOp(thenYield);
rewriter.eraseOp(thenYield2);
}
@@ -2643,7 +2648,7 @@ struct CombineIfs : public OpRewritePattern<IfOp> {
SmallVector<Value> mergedElseYields(elseYield.getOperands());
llvm::append_range(mergedElseYields, elseYield2.getOperands());
- rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
+ YieldOp::create(rewriter, elseYield2.getLoc(), mergedElseYields);
rewriter.eraseOp(elseYield);
rewriter.eraseOp(elseYield2);
}
@@ -2765,9 +2770,9 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
}
Location loc = op.getLoc();
- Value newCondition = rewriter.create<arith::AndIOp>(
- loc, op.getCondition(), nestedIf.getCondition());
- auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
+ Value newCondition = arith::AndIOp::create(rewriter, loc, op.getCondition(),
+ nestedIf.getCondition());
+ auto newIf = IfOp::create(rewriter, loc, op.getResultTypes(), newCondition);
Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
SmallVector<Value> results;
@@ -2775,8 +2780,9 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
rewriter.setInsertionPoint(newIf);
for (auto idx : elseYieldsToUpgradeToSelect)
- results[idx] = rewriter.create<arith::SelectOp>(
- op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
+ results[idx] =
+ arith::SelectOp::create(rewriter, op.getLoc(), op.getCondition(),
+ thenYield[idx], elseYield[idx]);
rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
rewriter.setInsertionPointToEnd(newIf.thenBlock());
@@ -2784,7 +2790,7 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
if (!elseYield.empty()) {
rewriter.createBlock(&newIf.getElseRegion());
rewriter.setInsertionPointToEnd(newIf.elseBlock());
- rewriter.create<YieldOp>(loc, elseYield);
+ YieldOp::create(rewriter, loc, elseYield);
}
rewriter.replaceOp(op, results);
return success();
@@ -3101,8 +3107,8 @@ struct ParallelOpSingleOrZeroIterationDimsFolder
}
// Replace the parallel loop by lower-dimensional parallel loop.
auto newOp =
- rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
- newSteps, op.getInitVals(), nullptr);
+ ParallelOp::create(rewriter, op.getLoc(), newLowerBounds,
+ newUpperBounds, newSteps, op.getInitVals(), nullptr);
// Erase the empty block that was inserted by the builder.
rewriter.eraseBlock(newOp.getBody());
// Clone the loop body and remap the block arguments of the collapsed loops
@@ -3482,8 +3488,8 @@ struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
if (!constantTrue)
- constantTrue = rewriter.create<arith::ConstantOp>(
- op.getLoc(), term.getCondition().getType(),
+ constantTrue = arith::ConstantOp::create(
+ rewriter, op.getLoc(), term.getCondition().getType(),
rewriter.getBoolAttr(true));
rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
@@ -3625,8 +3631,8 @@ struct RemoveLoopInvariantArgsFromBeforeBlock
rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
}
- auto newWhile =
- rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
+ auto newWhile = WhileOp::create(rewriter, op.getLoc(), op.getResultTypes(),
+ newInitArgs);
Block &newBeforeBlock = *rewriter.createBlock(
&newWhile.getBefore(), /*insertPt*/ {},
@@ -3748,8 +3754,8 @@ struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
newCondOpArgs);
}
- auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
- op.getOperands());
+ auto newWhile = WhileOp::create(rewriter, op.getLoc(), newAfterBlockType,
+ op.getOperands());
Block &newAfterBlock =
*rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
@@ -3855,7 +3861,7 @@ struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
}
auto newWhile =
- rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
+ WhileOp::create(rewriter, op.getLoc(), newResultTypes, op.getInits());
Block &newAfterBlock = *rewriter.createBlock(
&newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
@@ -3984,8 +3990,8 @@ struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
Location loc = op.getLoc();
auto newWhileOp =
- rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits,
- /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
+ WhileOp::create(rewriter, loc, op.getResultTypes(), newInits,
+ /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
Block &newBeforeBlock = *newWhileOp.getBeforeBody();
Block &newAfterBlock = *newWhileOp.getAfterBody();
@@ -4032,9 +4038,10 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
ValueRange argsRange(newArgs);
Location loc = op.getLoc();
- auto newWhileOp = rewriter.create<scf::WhileOp>(
- loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr,
- /*afterBody*/ nullptr);
+ auto newWhileOp =
+ scf::WhileOp::create(rewriter, loc, argsRange.getTypes(), op.getInits(),
+ /*beforeBody*/ nullptr,
+ /*afterBody*/ nullptr);
Block &newBeforeBlock = *newWhileOp.getBeforeBody();
Block &newAfterBlock = *newWhileOp.getAfterBody();
@@ -4128,8 +4135,8 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
for (auto &&[i, j] : llvm::enumerate(*mapping))
newResultTypes[j] = loop.getResult(i).getType();
- auto newLoop = rewriter.create<WhileOp>(
- loop.getLoc(), newResultTypes, loop.getInits(),
+ auto newLoop = WhileOp::create(
+ rewriter, loop.getLoc(), newResultTypes, loop.getInits(),
/*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
auto newBefore = newLoop.getBeforeBody();
auto newAfter = newLoop.getAfterBody();
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 9a68565..aea842d 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -160,7 +160,7 @@ static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
scf::ExecuteRegionOp executeRegionOp =
- b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes());
+ scf::ExecuteRegionOp::create(b, op->getLoc(), op->getResultTypes());
{
OpBuilder::InsertionGuard g(b);
b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock());
@@ -169,7 +169,7 @@ static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
assert(clonedRegion.empty() && "expected empty region");
b.inlineRegionBefore(op->getRegions().front(), clonedRegion,
clonedRegion.end());
- b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults());
+ scf::YieldOp::create(b, op->getLoc(), clonedOp->getResults());
}
b.replaceOp(op, executeRegionOp.getResults());
return executeRegionOp;
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 8509382..f8799c5 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -41,7 +41,7 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
// iter_arg's layout map must be changed (see uses of `castBuffer`).
assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
"scf.while op bufferization: cast incompatible");
- return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
+ return memref::CastOp::create(b, buffer.getLoc(), type, buffer).getResult();
}
/// Helper function for loop bufferization. Return "true" if the given value
@@ -189,7 +189,7 @@ struct ExecuteRegionOpInterface
// Create new op and move over region.
auto newOp =
- rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
+ scf::ExecuteRegionOp::create(rewriter, op->getLoc(), newResultTypes);
newOp.getRegion().takeBody(executeRegionOp.getRegion());
// Bufferize every block.
@@ -203,8 +203,8 @@ struct ExecuteRegionOpInterface
SmallVector<Value> newResults;
for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
if (isa<TensorType>(it.value())) {
- newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
- executeRegionOp.getLoc(), it.value(),
+ newResults.push_back(bufferization::ToTensorOp::create(
+ rewriter, executeRegionOp.getLoc(), it.value(),
newOp->getResult(it.index())));
} else {
newResults.push_back(newOp->getResult(it.index()));
@@ -258,9 +258,9 @@ struct IfOpInterface
// Create new op.
rewriter.setInsertionPoint(ifOp);
- auto newIfOp =
- rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
- /*withElseRegion=*/true);
+ auto newIfOp = scf::IfOp::create(rewriter, ifOp.getLoc(), newTypes,
+ ifOp.getCondition(),
+ /*withElseRegion=*/true);
// Move over then/else blocks.
rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
@@ -372,9 +372,9 @@ struct IndexSwitchOpInterface
// Create new op.
rewriter.setInsertionPoint(switchOp);
- auto newSwitchOp = rewriter.create<scf::IndexSwitchOp>(
- switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(),
- switchOp.getCases().size());
+ auto newSwitchOp = scf::IndexSwitchOp::create(
+ rewriter, switchOp.getLoc(), newTypes, switchOp.getArg(),
+ switchOp.getCases(), switchOp.getCases().size());
// Move over blocks.
for (auto [src, dest] :
@@ -497,10 +497,10 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
size_t idx = it.index();
Value val = it.value();
if (tensorIndices.contains(idx)) {
- result.push_back(rewriter
- .create<bufferization::ToTensorOp>(
- val.getLoc(), oldBbArgs[idx].getType(), val)
- .getResult());
+ result.push_back(
+ bufferization::ToTensorOp::create(rewriter, val.getLoc(),
+ oldBbArgs[idx].getType(), val)
+ .getResult());
} else {
result.push_back(val);
}
@@ -767,8 +767,8 @@ struct ForOpInterface
}
// Construct a new scf.for op with memref instead of tensor values.
- auto newForOp = rewriter.create<scf::ForOp>(
- forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
+ auto newForOp = scf::ForOp::create(
+ rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), castedInitArgs);
newForOp->setAttrs(forOp->getAttrs());
Block *loopBody = newForOp.getBody();
@@ -1003,8 +1003,8 @@ struct WhileOpInterface
// Construct a new scf.while op with memref instead of tensor values.
ValueRange argsRangeBefore(castedInitArgs);
TypeRange argsTypesBefore(argsRangeBefore);
- auto newWhileOp = rewriter.create<scf::WhileOp>(
- whileOp.getLoc(), argsTypesAfter, castedInitArgs);
+ auto newWhileOp = scf::WhileOp::create(rewriter, whileOp.getLoc(),
+ argsTypesAfter, castedInitArgs);
// Add before/after regions to the new op.
SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(),
@@ -1263,8 +1263,8 @@ struct ForallOpInterface
forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
BlockArgument bbArg = std::get<0>(it);
Value buffer = std::get<1>(it);
- Value bufferAsTensor = rewriter.create<ToTensorOp>(
- forallOp.getLoc(), bbArg.getType(), buffer);
+ Value bufferAsTensor = ToTensorOp::create(rewriter, forallOp.getLoc(),
+ bbArg.getType(), buffer);
bbArg.replaceAllUsesWith(bufferAsTensor);
}
@@ -1272,8 +1272,8 @@ struct ForallOpInterface
// introduced terminator.
rewriter.setInsertionPoint(forallOp);
ForallOp newForallOp;
- newForallOp = rewriter.create<ForallOp>(
- forallOp.getLoc(), forallOp.getMixedLowerBound(),
+ newForallOp = ForallOp::create(
+ rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
/*outputs=*/ValueRange(), forallOp.getMapping());
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index 3e93dc8..bee7780 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -50,19 +50,19 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
SmallVector<Value> initArgs;
initArgs.push_back(forOp.getLowerBound());
llvm::append_range(initArgs, forOp.getInitArgs());
- auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs,
- forOp->getAttrs());
+ auto whileOp = WhileOp::create(rewriter, forOp.getLoc(), lcvTypes, initArgs,
+ forOp->getAttrs());
// 'before' region contains the loop condition and forwarding of iteration
// arguments to the 'after' region.
auto *beforeBlock = rewriter.createBlock(
&whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs);
rewriter.setInsertionPointToStart(whileOp.getBeforeBody());
- auto cmpOp = rewriter.create<arith::CmpIOp>(
- whileOp.getLoc(), arith::CmpIPredicate::slt,
+ auto cmpOp = arith::CmpIOp::create(
+ rewriter, whileOp.getLoc(), arith::CmpIPredicate::slt,
beforeBlock->getArgument(0), forOp.getUpperBound());
- rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(),
- beforeBlock->getArguments());
+ scf::ConditionOp::create(rewriter, whileOp.getLoc(), cmpOp.getResult(),
+ beforeBlock->getArguments());
// Inline for-loop body into an executeRegion operation in the "after"
// region. The return type of the execRegionOp does not contain the
@@ -72,8 +72,9 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
// Add induction variable incrementation
rewriter.setInsertionPointToEnd(afterBlock);
- auto ivIncOp = rewriter.create<arith::AddIOp>(
- whileOp.getLoc(), afterBlock->getArgument(0), forOp.getStep());
+ auto ivIncOp =
+ arith::AddIOp::create(rewriter, whileOp.getLoc(),
+ afterBlock->getArgument(0), forOp.getStep());
// Rewrite uses of the for-loop block arguments to the new while-loop
// "after" arguments
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
index 44e6840..b95604f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
@@ -40,7 +40,7 @@ LogicalResult mlir::scf::forallToParallelLoop(RewriterBase &rewriter,
SmallVector<Value> steps = forallOp.getStep(rewriter);
// Create empty scf.parallel op.
- auto parallelOp = rewriter.create<scf::ParallelOp>(loc, lbs, ubs, steps);
+ auto parallelOp = scf::ParallelOp::create(rewriter, loc, lbs, ubs, steps);
rewriter.eraseBlock(&parallelOp.getRegion().front());
rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
parallelOp.getRegion().begin());
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index bcecef5..1130538 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -19,12 +19,10 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/MapVector.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/MathExtras.h"
#define DEBUG_TYPE "scf-loop-pipelining"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
using namespace mlir;
using namespace mlir::scf;
@@ -100,7 +98,7 @@ public:
bool LoopPipelinerInternal::initializeLoopInfo(
ForOp op, const PipeliningOption &options) {
- LDBG("Start initializeLoopInfo");
+ LDBG() << "Start initializeLoopInfo";
forOp = op;
ub = forOp.getUpperBound();
lb = forOp.getLowerBound();
@@ -109,7 +107,7 @@ bool LoopPipelinerInternal::initializeLoopInfo(
std::vector<std::pair<Operation *, unsigned>> schedule;
options.getScheduleFn(forOp, schedule);
if (schedule.empty()) {
- LDBG("--empty schedule -> BAIL");
+ LDBG() << "--empty schedule -> BAIL";
return false;
}
@@ -126,7 +124,7 @@ bool LoopPipelinerInternal::initializeLoopInfo(
auto stepCst = getConstantIntValue(step);
if (!upperBoundCst || !lowerBoundCst || !stepCst) {
if (!options.supportDynamicLoops) {
- LDBG("--dynamic loop not supported -> BAIL");
+ LDBG() << "--dynamic loop not supported -> BAIL";
return false;
}
} else {
@@ -134,21 +132,21 @@ bool LoopPipelinerInternal::initializeLoopInfo(
int64_t lbImm = lowerBoundCst.value();
int64_t stepImm = stepCst.value();
if (stepImm <= 0) {
- LDBG("--invalid loop step -> BAIL");
+ LDBG() << "--invalid loop step -> BAIL";
return false;
}
int64_t numIteration = llvm::divideCeilSigned(ubImm - lbImm, stepImm);
if (numIteration >= maxStage) {
dynamicLoop = false;
} else if (!options.supportDynamicLoops) {
- LDBG("--fewer loop iterations than pipeline stages -> BAIL");
+ LDBG() << "--fewer loop iterations than pipeline stages -> BAIL";
return false;
}
}
peelEpilogue = options.peelEpilogue;
predicateFn = options.predicateFn;
if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) {
- LDBG("--no epilogue or predicate set -> BAIL");
+ LDBG() << "--no epilogue or predicate set -> BAIL";
return false;
}
@@ -156,13 +154,13 @@ bool LoopPipelinerInternal::initializeLoopInfo(
for (Operation &op : forOp.getBody()->without_terminator()) {
if (!stages.contains(&op)) {
op.emitOpError("not assigned a pipeline stage");
- LDBG("--op not assigned a pipeline stage: " << op << " -> BAIL");
+ LDBG() << "--op not assigned a pipeline stage: " << op << " -> BAIL";
return false;
}
}
if (!verifySchedule()) {
- LDBG("--invalid schedule: " << op << " -> BAIL");
+ LDBG() << "--invalid schedule: " << op << " -> BAIL";
return false;
}
@@ -173,15 +171,16 @@ bool LoopPipelinerInternal::initializeLoopInfo(
(void)stageNum;
if (op == forOp.getBody()->getTerminator()) {
op->emitError("terminator should not be assigned a stage");
- LDBG("--terminator should not be assigned stage: " << *op << " -> BAIL");
+ LDBG() << "--terminator should not be assigned stage: " << *op
+ << " -> BAIL";
return false;
}
if (op->getBlock() != forOp.getBody()) {
op->emitOpError("the owning Block of all operations assigned a stage "
"should be the loop body block");
- LDBG("--the owning Block of all operations assigned a stage "
- "should be the loop body block: "
- << *op << " -> BAIL");
+ LDBG() << "--the owning Block of all operations assigned a stage "
+ "should be the loop body block: "
+ << *op << " -> BAIL";
return false;
}
}
@@ -196,8 +195,8 @@ bool LoopPipelinerInternal::initializeLoopInfo(
return !def ||
(!stages.contains(def) && forOp->isAncestor(def));
})) {
- LDBG("--only support loop carried dependency with a distance of 1 or "
- "defined outside of the loop -> BAIL");
+ LDBG() << "--only support loop carried dependency with a distance of 1 or "
+ "defined outside of the loop -> BAIL";
return false;
}
annotateFn = options.annotateFn;
@@ -279,25 +278,25 @@ LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
if (dynamicLoop) {
Type t = ub.getType();
// pred = ub > lb + (i * step)
- Value iv = rewriter.create<arith::AddIOp>(
- loc, lb,
- rewriter.create<arith::MulIOp>(
- loc, step,
- rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(t, i))));
- predicates[i] = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, iv, ub);
+ Value iv = arith::AddIOp::create(
+ rewriter, loc, lb,
+ arith::MulIOp::create(
+ rewriter, loc, step,
+ arith::ConstantOp::create(rewriter, loc,
+ rewriter.getIntegerAttr(t, i))));
+ predicates[i] = arith::CmpIOp::create(rewriter, loc,
+ arith::CmpIPredicate::slt, iv, ub);
}
// special handling for induction variable as the increment is implicit.
// iv = lb + i * step
Type t = lb.getType();
- Value iv = rewriter.create<arith::AddIOp>(
- loc, lb,
- rewriter.create<arith::MulIOp>(
- loc, step,
- rewriter.create<arith::ConstantOp>(loc,
- rewriter.getIntegerAttr(t, i))));
+ Value iv = arith::AddIOp::create(
+ rewriter, loc, lb,
+ arith::MulIOp::create(
+ rewriter, loc, step,
+ arith::ConstantOp::create(rewriter, loc,
+ rewriter.getIntegerAttr(t, i))));
setValueMapping(forOp.getInductionVar(), iv, i);
for (Operation *op : opOrder) {
if (stages[op] > i)
@@ -332,8 +331,8 @@ LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
Value prevValue = valueMapping
[forOp.getRegionIterArgs()[operand.getOperandNumber()]]
[i - stages[op]];
- source = rewriter.create<arith::SelectOp>(
- loc, predicates[predicateIdx], source, prevValue);
+ source = arith::SelectOp::create(
+ rewriter, loc, predicates[predicateIdx], source, prevValue);
}
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
source, i - stages[op] + 1);
@@ -444,15 +443,15 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop(
Type t = ub.getType();
Location loc = forOp.getLoc();
// newUb = ub - maxStage * step
- Value maxStageValue = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(t, maxStage));
+ Value maxStageValue = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIntegerAttr(t, maxStage));
Value maxStageByStep =
- rewriter.create<arith::MulIOp>(loc, step, maxStageValue);
- newUb = rewriter.create<arith::SubIOp>(loc, ub, maxStageByStep);
+ arith::MulIOp::create(rewriter, loc, step, maxStageValue);
+ newUb = arith::SubIOp::create(rewriter, loc, ub, maxStageByStep);
}
auto newForOp =
- rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb,
- forOp.getStep(), newLoopArg);
+ scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), newUb,
+ forOp.getStep(), newLoopArg);
// When there are no iter args, the loop body terminator will be created.
// Since we always create it below, remove the terminator if it was created.
if (!newForOp.getBody()->empty())
@@ -483,16 +482,17 @@ LogicalResult LoopPipelinerInternal::createKernel(
Type t = ub.getType();
for (unsigned i = 0; i < maxStage; i++) {
// c = ub - (maxStage - i) * step
- Value c = rewriter.create<arith::SubIOp>(
- loc, ub,
- rewriter.create<arith::MulIOp>(
- loc, step,
- rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(t, int64_t(maxStage - i)))));
-
- Value pred = rewriter.create<arith::CmpIOp>(
- newForOp.getLoc(), arith::CmpIPredicate::slt,
- newForOp.getInductionVar(), c);
+ Value c = arith::SubIOp::create(
+ rewriter, loc, ub,
+ arith::MulIOp::create(
+ rewriter, loc, step,
+ arith::ConstantOp::create(
+ rewriter, loc,
+ rewriter.getIntegerAttr(t, int64_t(maxStage - i)))));
+
+ Value pred = arith::CmpIOp::create(rewriter, newForOp.getLoc(),
+ arith::CmpIPredicate::slt,
+ newForOp.getInductionVar(), c);
predicates[i] = pred;
}
}
@@ -515,13 +515,13 @@ LogicalResult LoopPipelinerInternal::createKernel(
// offset = (maxStage - stages[op]) * step
Type t = step.getType();
- Value offset = rewriter.create<arith::MulIOp>(
- forOp.getLoc(), step,
- rewriter.create<arith::ConstantOp>(
- forOp.getLoc(),
+ Value offset = arith::MulIOp::create(
+ rewriter, forOp.getLoc(), step,
+ arith::ConstantOp::create(
+ rewriter, forOp.getLoc(),
rewriter.getIntegerAttr(t, maxStage - stages[op])));
- Value iv = rewriter.create<arith::AddIOp>(
- forOp.getLoc(), newForOp.getInductionVar(), offset);
+ Value iv = arith::AddIOp::create(rewriter, forOp.getLoc(),
+ newForOp.getInductionVar(), offset);
nestedNewOp->setOperand(operand->getOperandNumber(), iv);
rewriter.setInsertionPointAfter(newOp);
continue;
@@ -594,8 +594,8 @@ LogicalResult LoopPipelinerInternal::createKernel(
auto defStage = stages.find(def);
if (defStage != stages.end() && defStage->second < maxStage) {
Value pred = predicates[defStage->second];
- source = rewriter.create<arith::SelectOp>(
- pred.getLoc(), pred, source,
+ source = arith::SelectOp::create(
+ rewriter, pred.getLoc(), pred, source,
newForOp.getBody()
->getArguments()[yieldOperand.getOperandNumber() + 1]);
}
@@ -638,7 +638,7 @@ LogicalResult LoopPipelinerInternal::createKernel(
maxStage - defStage->second + 1);
}
}
- rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
+ scf::YieldOp::create(rewriter, forOp.getLoc(), yieldOperands);
return success();
}
@@ -652,8 +652,8 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
// removed by dead code if not used.
auto createConst = [&](int v) {
- return rewriter.create<arith::ConstantOp>(loc,
- rewriter.getIntegerAttr(t, v));
+ return arith::ConstantOp::create(rewriter, loc,
+ rewriter.getIntegerAttr(t, v));
};
// total_iterations = cdiv(range_diff, step);
@@ -661,42 +661,44 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
// - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
Value zero = createConst(0);
Value one = createConst(1);
- Value stepLessZero = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, step, zero);
- Value stepDecr =
- rewriter.create<arith::SelectOp>(loc, stepLessZero, one, createConst(-1));
+ Value stepLessZero = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::slt, step, zero);
+ Value stepDecr = arith::SelectOp::create(rewriter, loc, stepLessZero, one,
+ createConst(-1));
- Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb);
- Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step);
+ Value rangeDiff = arith::SubIOp::create(rewriter, loc, ub, lb);
+ Value rangeIncrStep = arith::AddIOp::create(rewriter, loc, rangeDiff, step);
Value rangeDecr =
- rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
- Value totalIterations = rewriter.create<arith::DivSIOp>(loc, rangeDecr, step);
+ arith::AddIOp::create(rewriter, loc, rangeIncrStep, stepDecr);
+ Value totalIterations =
+ arith::DivSIOp::create(rewriter, loc, rangeDecr, step);
// If total_iters < max_stage, start the epilogue at zero to match the
// ramp-up in the prologue.
// start_iter = max(0, total_iters - max_stage)
- Value iterI = rewriter.create<arith::SubIOp>(loc, totalIterations,
- createConst(maxStage));
- iterI = rewriter.create<arith::MaxSIOp>(loc, zero, iterI);
+ Value iterI = arith::SubIOp::create(rewriter, loc, totalIterations,
+ createConst(maxStage));
+ iterI = arith::MaxSIOp::create(rewriter, loc, zero, iterI);
// Capture predicates for dynamic loops.
SmallVector<Value> predicates(maxStage + 1);
for (int64_t i = 1; i <= maxStage; i++) {
// newLastIter = lb + step * iterI
- Value newlastIter = rewriter.create<arith::AddIOp>(
- loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
+ Value newlastIter = arith::AddIOp::create(
+ rewriter, loc, lb, arith::MulIOp::create(rewriter, loc, step, iterI));
setValueMapping(forOp.getInductionVar(), newlastIter, i);
// increment to next iterI
- iterI = rewriter.create<arith::AddIOp>(loc, iterI, one);
+ iterI = arith::AddIOp::create(rewriter, loc, iterI, one);
if (dynamicLoop) {
// Disable stages when `i` is greater than total_iters.
// pred = total_iters >= i
- predicates[i] = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, totalIterations, createConst(i));
+ predicates[i] =
+ arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge,
+ totalIterations, createConst(i));
}
}
@@ -758,8 +760,8 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
unsigned nextVersion = currentVersion + 1;
Value pred = predicates[currentVersion];
Value prevValue = valueMapping[mapVal][currentVersion];
- auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(),
- prevValue);
+ auto selOp = arith::SelectOp::create(rewriter, loc, pred,
+ pair.value(), prevValue);
returnValues[ri] = selOp;
if (nextVersion <= maxStage)
setValueMapping(mapVal, selOp, nextVersion);
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index d17cd47..4752c08 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -63,13 +63,13 @@ static void specializeParallelLoopForUnrolling(ParallelOp op) {
Value cond;
for (auto bound : llvm::zip(op.getUpperBound(), constantIndices)) {
Value constant =
- b.create<arith::ConstantIndexOp>(op.getLoc(), std::get<1>(bound));
- Value cmp = b.create<arith::CmpIOp>(op.getLoc(), arith::CmpIPredicate::eq,
- std::get<0>(bound), constant);
- cond = cond ? b.create<arith::AndIOp>(op.getLoc(), cond, cmp) : cmp;
+ arith::ConstantIndexOp::create(b, op.getLoc(), std::get<1>(bound));
+ Value cmp = arith::CmpIOp::create(b, op.getLoc(), arith::CmpIPredicate::eq,
+ std::get<0>(bound), constant);
+ cond = cond ? arith::AndIOp::create(b, op.getLoc(), cond, cmp) : cmp;
map.map(std::get<0>(bound), constant);
}
- auto ifOp = b.create<scf::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true);
+ auto ifOp = scf::IfOp::create(b, op.getLoc(), cond, /*withElseRegion=*/true);
ifOp.getThenBodyBuilder().clone(*op.getOperation(), map);
ifOp.getElseBodyBuilder().clone(*op.getOperation());
op.erase();
@@ -94,11 +94,11 @@ static void specializeForLoopForUnrolling(ForOp op) {
OpBuilder b(op);
IRMapping map;
- Value constant = b.create<arith::ConstantIndexOp>(op.getLoc(), minConstant);
- Value cond = b.create<arith::CmpIOp>(op.getLoc(), arith::CmpIPredicate::eq,
- bound, constant);
+ Value constant = arith::ConstantIndexOp::create(b, op.getLoc(), minConstant);
+ Value cond = arith::CmpIOp::create(b, op.getLoc(), arith::CmpIPredicate::eq,
+ bound, constant);
map.map(bound, constant);
- auto ifOp = b.create<scf::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true);
+ auto ifOp = scf::IfOp::create(b, op.getLoc(), cond, /*withElseRegion=*/true);
ifOp.getThenBodyBuilder().clone(*op.getOperation(), map);
ifOp.getElseBodyBuilder().clone(*op.getOperation());
op.erase();
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index ad12673..694cd85 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -190,8 +190,8 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
IRRewriter b(builder);
b.setInsertionPoint(secondPloop);
- auto newSecondPloop = b.create<ParallelOp>(
- secondPloop.getLoc(), secondPloop.getLowerBound(),
+ auto newSecondPloop = ParallelOp::create(
+ b, secondPloop.getLoc(), secondPloop.getLowerBound(),
secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
Block *newBlock = newSecondPloop.getBody();
@@ -212,7 +212,7 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
- auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
+ auto newReduceOp = scf::ReduceOp::create(b, term2.getLoc(), newReduceArgs);
for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
term1.getReductions(), term2.getReductions()))) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
index 66f7bc2..081f5fb 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
@@ -58,28 +58,28 @@ std::pair<ParallelOp, ParallelOp>
mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes,
bool noMinMaxBounds) {
OpBuilder b(op);
- auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0);
+ auto zero = arith::ConstantIndexOp::create(b, op.getLoc(), 0);
SmallVector<Value, 2> tileSizeConstants;
tileSizeConstants.reserve(op.getUpperBound().size());
for (size_t i = 0, end = op.getUpperBound().size(); i != end; ++i) {
if (i < tileSizes.size())
tileSizeConstants.push_back(
- b.create<arith::ConstantIndexOp>(op.getLoc(), tileSizes[i]));
+ arith::ConstantIndexOp::create(b, op.getLoc(), tileSizes[i]));
else
// Just pick 1 for the remaining dimensions.
tileSizeConstants.push_back(
- b.create<arith::ConstantIndexOp>(op.getLoc(), 1));
+ arith::ConstantIndexOp::create(b, op.getLoc(), 1));
}
// Create the outer loop with adjusted steps.
SmallVector<Value, 2> newSteps;
newSteps.reserve(op.getStep().size());
for (auto step : llvm::zip(op.getStep(), tileSizeConstants)) {
- newSteps.push_back(b.create<arith::MulIOp>(op.getLoc(), std::get<0>(step),
- std::get<1>(step)));
+ newSteps.push_back(arith::MulIOp::create(b, op.getLoc(), std::get<0>(step),
+ std::get<1>(step)));
}
- auto outerLoop = b.create<ParallelOp>(op.getLoc(), op.getLowerBound(),
- op.getUpperBound(), newSteps);
+ auto outerLoop = ParallelOp::create(b, op.getLoc(), op.getLowerBound(),
+ op.getUpperBound(), newSteps);
b.setInsertionPointToStart(outerLoop.getBody());
// Compute min(size, dim - offset) to avoid out-of-bounds accesses.
@@ -100,11 +100,10 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes,
op.getStep(), tileSizeConstants)) {
// Collect the statically known loop bounds
auto lowerBoundConstant =
- dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp());
+ lowerBound.getDefiningOp<arith::ConstantIndexOp>();
auto upperBoundConstant =
- dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp());
- auto stepConstant =
- dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp());
+ upperBound.getDefiningOp<arith::ConstantIndexOp>();
+ auto stepConstant = step.getDefiningOp<arith::ConstantIndexOp>();
auto tileSize =
cast<arith::ConstantIndexOp>(tileSizeConstant.getDefiningOp()).value();
// If the loop bounds and the loop step are constant and if the number of
@@ -130,45 +129,45 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes,
// Otherwise, we dynamically compute the bound for
// each iteration of the outer loop.
newBounds.push_back(
- b.create<affine::AffineMinOp>(op.getLoc(), b.getIndexType(), minMap,
- ValueRange{newStep, upperBound, iv}));
+ affine::AffineMinOp::create(b, op.getLoc(), b.getIndexType(), minMap,
+ ValueRange{newStep, upperBound, iv}));
}
- auto innerLoop = b.create<ParallelOp>(
- op.getLoc(), SmallVector<Value, 2>(newBounds.size(), zero), newBounds,
+ auto innerLoop = ParallelOp::create(
+ b, op.getLoc(), SmallVector<Value, 2>(newBounds.size(), zero), newBounds,
op.getStep());
if (noMinMaxBounds && needInboundCheck) {
b.setInsertionPointToStart(innerLoop.getBody());
// Insert in-bound check
Value inbound =
- b.create<arith::ConstantIntOp>(op.getLoc(), b.getIntegerType(1), 1);
+ arith::ConstantIntOp::create(b, op.getLoc(), b.getIntegerType(1), 1);
for (auto [outerUpperBound, outerIV, innerIV, innerStep] :
llvm::zip(outerLoop.getUpperBound(), outerLoop.getInductionVars(),
innerLoop.getInductionVars(), innerLoop.getStep())) {
// %in_bound = %in_bound &&
// (%inner_iv * %inner_step + %outer_iv < %outer_upper_bound)
- Value index = b.create<arith::AddIOp>(
- op.getLoc(), b.create<arith::MulIOp>(op.getLoc(), innerIV, innerStep),
- outerIV);
- Value dimInbound = b.create<arith::CmpIOp>(
- op.getLoc(), arith::CmpIPredicate::ult, index, outerUpperBound);
- inbound = b.create<arith::AndIOp>(op.getLoc(), inbound, dimInbound);
+ Value index = arith::AddIOp::create(
+ b, op.getLoc(),
+ arith::MulIOp::create(b, op.getLoc(), innerIV, innerStep), outerIV);
+ Value dimInbound = arith::CmpIOp::create(
+ b, op.getLoc(), arith::CmpIPredicate::ult, index, outerUpperBound);
+ inbound = arith::AndIOp::create(b, op.getLoc(), inbound, dimInbound);
}
- auto ifInbound = b.create<IfOp>(op.getLoc(),
- /*resultTypes*/ ArrayRef<Type>{}, inbound,
- /*hasElseRegion*/ false);
+ auto ifInbound = IfOp::create(b, op.getLoc(),
+ /*resultTypes*/ ArrayRef<Type>{}, inbound,
+ /*hasElseRegion*/ false);
ifInbound.getThenRegion().takeBody(op.getRegion());
Block &thenBlock = ifInbound.getThenRegion().front();
// Replace the scf.reduce terminator with an scf.yield terminator.
Operation *reduceOp = thenBlock.getTerminator();
b.setInsertionPointToEnd(&thenBlock);
- b.create<scf::YieldOp>(reduceOp->getLoc());
+ scf::YieldOp::create(b, reduceOp->getLoc());
reduceOp->erase();
b.setInsertionPointToStart(innerLoop.getBody());
for (const auto &ivs : llvm::enumerate(llvm::zip(
innerLoop.getInductionVars(), outerLoop.getInductionVars()))) {
- auto newIndex = b.create<arith::AddIOp>(
- op.getLoc(), std::get<0>(ivs.value()), std::get<1>(ivs.value()));
+ auto newIndex = arith::AddIOp::create(
+ b, op.getLoc(), std::get<0>(ivs.value()), std::get<1>(ivs.value()));
thenBlock.getArgument(ivs.index())
.replaceAllUsesExcept(newIndex, newIndex);
}
@@ -179,8 +178,8 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes,
for (auto ivs : llvm::zip(innerLoop.getInductionVars(),
outerLoop.getInductionVars())) {
Value innerIndex = std::get<0>(ivs);
- auto newIndex = b.create<arith::AddIOp>(op.getLoc(), std::get<0>(ivs),
- std::get<1>(ivs));
+ auto newIndex = arith::AddIOp::create(b, op.getLoc(), std::get<0>(ivs),
+ std::get<1>(ivs));
innerIndex.replaceAllUsesExcept(newIndex, newIndex);
}
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 0932624..1b07b77 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -112,11 +112,11 @@ public:
// We can not do clone as the number of result types after conversion
// might be different.
- ForOp newOp = rewriter.create<ForOp>(
- op.getLoc(), llvm::getSingleElement(adaptor.getLowerBound()),
- llvm::getSingleElement(adaptor.getUpperBound()),
- llvm::getSingleElement(adaptor.getStep()),
- flattenValues(adaptor.getInitArgs()));
+ ForOp newOp = ForOp::create(rewriter, op.getLoc(),
+ llvm::getSingleElement(adaptor.getLowerBound()),
+ llvm::getSingleElement(adaptor.getUpperBound()),
+ llvm::getSingleElement(adaptor.getStep()),
+ flattenValues(adaptor.getInitArgs()));
// Reserve whatever attributes in the original op.
newOp->setAttrs(op->getAttrs());
@@ -142,9 +142,9 @@ public:
ConversionPatternRewriter &rewriter,
TypeRange dstTypes) const {
- IfOp newOp = rewriter.create<IfOp>(
- op.getLoc(), dstTypes, llvm::getSingleElement(adaptor.getCondition()),
- true);
+ IfOp newOp =
+ IfOp::create(rewriter, op.getLoc(), dstTypes,
+ llvm::getSingleElement(adaptor.getCondition()), true);
newOp->setAttrs(op->getAttrs());
// We do not need the empty blocks created by rewriter.
@@ -171,8 +171,8 @@ public:
std::optional<WhileOp> convertSourceOp(WhileOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
TypeRange dstTypes) const {
- auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes,
- flattenValues(adaptor.getOperands()));
+ auto newOp = WhileOp::create(rewriter, op.getLoc(), dstTypes,
+ flattenValues(adaptor.getOperands()));
for (auto i : {0u, 1u}) {
if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter)))
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 484b03d..c0e47ee 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -447,9 +447,9 @@ static LogicalResult generateLoopNestUsingForOp(
SmallVector<Value> ivs;
for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
auto loop =
- rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
- [](OpBuilder &bodyBuilder, Location bodyLoc,
- Value iv, ValueRange /*iterArgs*/) {});
+ scf::ForOp::create(rewriter, loc, lb, ub, step, destinationTensors,
+ [](OpBuilder &bodyBuilder, Location bodyLoc,
+ Value iv, ValueRange /*iterArgs*/) {});
loops.push_back(loop);
ivs.push_back(loop.getInductionVar());
rewriter.setInsertionPointToEnd(loop.getBody());
@@ -476,12 +476,12 @@ static LogicalResult generateLoopNestUsingForOp(
resultSizes)) {
SmallVector<OpFoldResult> resultStride(resultOffset.size(),
rewriter.getIndexAttr(1));
- auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
- loc, tiledValue, destinationTensor, resultOffset, resultSize,
+ auto insertSlice = tensor::InsertSliceOp::create(
+ rewriter, loc, tiledValue, destinationTensor, resultOffset, resultSize,
resultStride);
yieldedValues.push_back(insertSlice);
}
- rewriter.create<scf::YieldOp>(loc, yieldedValues);
+ scf::YieldOp::create(rewriter, loc, yieldedValues);
// Add the scf.yield operations for all the outer loops.
for (auto [outerLoop, innerLoop] :
@@ -489,7 +489,7 @@ static LogicalResult generateLoopNestUsingForOp(
MutableArrayRef(loops).drop_front())) {
rewriter.setInsertionPointToEnd(
cast<scf::ForOp>(outerLoop.getOperation()).getBody());
- rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
+ scf::YieldOp::create(rewriter, outerLoop.getLoc(), innerLoop->getResults());
}
return success();
}
@@ -530,14 +530,14 @@ static LogicalResult generateLoopNestUsingForallOp(
continue;
nonZeroNumThreads.push_back(nt);
}
- forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads,
- destinationTensors, mappingAttr);
+ forallOp = scf::ForallOp::create(rewriter, loc, nonZeroNumThreads,
+ destinationTensors, mappingAttr);
} else {
SmallVector<OpFoldResult> lbs, ubs, steps;
std::tie(lbs, ubs, steps) =
getLoopBounds(rewriter, loc, loopRanges, tileSizes);
- forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps,
- destinationTensors, mappingAttr);
+ forallOp = scf::ForallOp::create(rewriter, loc, lbs, ubs, steps,
+ destinationTensors, mappingAttr);
}
loops.push_back(forallOp);
@@ -558,9 +558,9 @@ static LogicalResult generateLoopNestUsingForallOp(
SmallVector<OpFoldResult> resultStride(resultOffset.size(),
rewriter.getIndexAttr(1));
- rewriter.create<tensor::ParallelInsertSliceOp>(
- loc, tiledValue, destinationTensor, resultOffset, resultSize,
- resultStride);
+ tensor::ParallelInsertSliceOp::create(rewriter, loc, tiledValue,
+ destinationTensor, resultOffset,
+ resultSize, resultStride);
}
return success();
}
@@ -795,9 +795,9 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
auto inits = llvm::to_vector(loopOp.getInitArgs());
inits.append(newInitOperands.begin(), newInitOperands.end());
- auto newLoop = rewriter.create<scf::ForOp>(
- loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
- inits, [](OpBuilder &, Location, Value, ValueRange) {});
+ auto newLoop = scf::ForOp::create(
+ rewriter, loc, loopOp.getLowerBound(), loopOp.getUpperBound(),
+ loopOp.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {});
// Move the loop body to the new op.
Block *loopBody = loopOp.getBody();
@@ -826,9 +826,9 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
resultSizes)) {
SmallVector<OpFoldResult> resultStride(resultOffset.size(),
rewriter.getIndexAttr(1));
- Value insert = rewriter.create<tensor::InsertSliceOp>(
- yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
- resultStride);
+ Value insert = tensor::InsertSliceOp::create(
+ rewriter, yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset,
+ resultSize, resultStride);
newYieldValues.push_back(insert);
}
@@ -848,8 +848,8 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
rewriter.setInsertionPoint(loopOp);
auto inits = llvm::to_vector(loopOp.getOutputs());
inits.append(newInitOperands.begin(), newInitOperands.end());
- auto newLoop = rewriter.create<scf::ForallOp>(
- loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
+ auto newLoop = scf::ForallOp::create(
+ rewriter, loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
loopOp.getMixedStep(), inits, loopOp.getMapping(),
[](OpBuilder &, Location, ValueRange) {});
@@ -881,9 +881,9 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
SmallVector<OpFoldResult> resultStride(resultOffset.size(),
rewriter.getIndexAttr(1));
- rewriter.create<tensor::ParallelInsertSliceOp>(
- terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
- resultStride);
+ tensor::ParallelInsertSliceOp::create(rewriter, terminator.getLoc(),
+ tiledValue, iterArg, resultOffset,
+ resultSize, resultStride);
}
rewriter.replaceOp(loopOp,
@@ -932,9 +932,9 @@ static LogicalResult addInitOperandsToLoopNest(
// Create a new loop with the new init values for this loop.
SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
newInits.append(newInitValues.begin(), newInitValues.end());
- auto newLoop = rewriter.create<scf::ForOp>(
- forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
- forLoop.getStep(), newInits,
+ auto newLoop = scf::ForOp::create(
+ rewriter, forLoop.getLoc(), forLoop.getLowerBound(),
+ forLoop.getUpperBound(), forLoop.getStep(), newInits,
[&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
// Merge the body of the new loop with the body of the old loops.
@@ -1416,8 +1416,8 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
rewriter.setInsertionPoint(tiledDestStyleOp);
for (const auto &&[index, newRegionArg] :
llvm::enumerate(newRegionIterArgs)) {
- auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
- loc, newRegionArg, offsetList[index], sizesList[index],
+ auto destSlice = tensor::ExtractSliceOp::create(
+ rewriter, loc, newRegionArg, offsetList[index], sizesList[index],
SmallVector<OpFoldResult>(offsetList[index].size(),
rewriter.getIndexAttr(1)));
generatedSlices.push_back(destSlice);
@@ -2089,8 +2089,8 @@ cloneAsInsertSlice<tensor::InsertSliceOp>(RewriterBase &rewriter,
template <>
tensor::InsertSliceOp cloneAsInsertSlice<tensor::ParallelInsertSliceOp>(
RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp) {
- return rewriter.create<tensor::InsertSliceOp>(
- insertSliceOp->getLoc(), insertSliceOp.getSource(),
+ return tensor::InsertSliceOp::create(
+ rewriter, insertSliceOp->getLoc(), insertSliceOp.getSource(),
insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
}
@@ -2311,8 +2311,9 @@ mlir::scf::tileAndFuseConsumerOfSlices(
rewriter.setInsertionPoint(tiledDestStyleOp);
for (const auto &&[index, newRegionArg] :
llvm::enumerate(newRegionIterArgs)) {
- auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
- loc, newRegionArg, resultOffsets[index], resultSizes[index],
+ auto destSlice = tensor::ExtractSliceOp::create(
+ rewriter, loc, newRegionArg, resultOffsets[index],
+ resultSizes[index],
SmallVector<OpFoldResult>(resultOffsets[index].size(),
rewriter.getIndexAttr(1)));
// Make a copy of index to avoid a capturing structured binding, which
@@ -2388,8 +2389,8 @@ mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
Value strideVal =
getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
- auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
- strideVal, ValueRange{});
+ auto loop = scf::ForOp::create(rewriter, op.getLoc(), offsetVal, sizeVal,
+ strideVal, ValueRange{});
loops.push_back(loop);
ivs.push_back(loop.getInductionVar());
rewriter.setInsertionPoint(loop.getBody()->getTerminator());
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index 7e9a4d7..ec1044a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -189,7 +189,7 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
// dummy builder instead.
auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {};
auto newLoop =
- rewriter.create<scf::ForOp>(loc, lb, ub, step, newArgs, emptyBuilder);
+ scf::ForOp::create(rewriter, loc, lb, ub, step, newArgs, emptyBuilder);
Block *newBody = newLoop.getBody();
@@ -236,18 +236,18 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
rewriter.setInsertionPointAfter(newLoop);
Value one;
if (isa<IndexType>(step.getType())) {
- one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ one = arith::ConstantIndexOp::create(rewriter, loc, 1);
} else {
- one = rewriter.create<arith::ConstantIntOp>(loc, step.getType(), 1);
+ one = arith::ConstantIntOp::create(rewriter, loc, step.getType(), 1);
}
- Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
- Value len = rewriter.create<arith::SubIOp>(loc, ub, lb);
- len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
- len = rewriter.create<arith::DivSIOp>(loc, len, step);
- len = rewriter.create<arith::SubIOp>(loc, len, one);
- Value res = rewriter.create<arith::MulIOp>(loc, len, step);
- res = rewriter.create<arith::AddIOp>(loc, lb, res);
+ Value stepDec = arith::SubIOp::create(rewriter, loc, step, one);
+ Value len = arith::SubIOp::create(rewriter, loc, ub, lb);
+ len = arith::AddIOp::create(rewriter, loc, len, stepDec);
+ len = arith::DivSIOp::create(rewriter, loc, len, step);
+ len = arith::SubIOp::create(rewriter, loc, len, one);
+ Value res = arith::MulIOp::create(rewriter, loc, len, step);
+ res = arith::AddIOp::create(rewriter, loc, lb, res);
// Reconstruct `scf.while` results, inserting final induction var value
// into proper place.
diff --git a/mlir/lib/Dialect/SCF/Transforms/WrapInZeroTripCheck.cpp b/mlir/lib/Dialect/SCF/Transforms/WrapInZeroTripCheck.cpp
index f829208..db504fe 100644
--- a/mlir/lib/Dialect/SCF/Transforms/WrapInZeroTripCheck.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/WrapInZeroTripCheck.cpp
@@ -96,8 +96,8 @@ FailureOr<scf::WhileOp> mlir::scf::wrapWhileLoopInZeroTripCheck(
condOp.getArgs(), [&](Value arg) { return mapper.lookupOrDefault(arg); });
// Create rotated while loop.
- auto newLoopOp = rewriter.create<scf::WhileOp>(
- whileOp.getLoc(), whileOp.getResultTypes(), clonedCondArgs,
+ auto newLoopOp = scf::WhileOp::create(
+ rewriter, whileOp.getLoc(), whileOp.getResultTypes(), clonedCondArgs,
[&](OpBuilder &builder, Location loc, ValueRange args) {
// Rotate and move the loop body into before block.
auto newBlock = builder.getBlock();
@@ -109,21 +109,21 @@ FailureOr<scf::WhileOp> mlir::scf::wrapWhileLoopInZeroTripCheck(
},
[&](OpBuilder &builder, Location loc, ValueRange args) {
// Pass through values.
- builder.create<scf::YieldOp>(loc, args);
+ scf::YieldOp::create(builder, loc, args);
});
// Create zero-trip-check and move the while loop in.
- auto ifOp = rewriter.create<scf::IfOp>(
- whileOp.getLoc(), clonedCondition,
+ auto ifOp = scf::IfOp::create(
+ rewriter, whileOp.getLoc(), clonedCondition,
[&](OpBuilder &builder, Location loc) {
// Then runs the while loop.
rewriter.moveOpBefore(newLoopOp, builder.getInsertionBlock(),
builder.getInsertionPoint());
- builder.create<scf::YieldOp>(loc, newLoopOp.getResults());
+ scf::YieldOp::create(builder, loc, newLoopOp.getResults());
},
[&](OpBuilder &builder, Location loc) {
// Else returns the results from precondition.
- builder.create<scf::YieldOp>(loc, clonedCondArgs);
+ scf::YieldOp::create(builder, loc, clonedCondArgs);
});
rewriter.replaceOp(whileOp, ifOp);
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 062268a..5731795 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -24,14 +24,12 @@
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include <cstdint>
using namespace mlir;
#define DEBUG_TYPE "scf-utils"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
@@ -149,7 +147,7 @@ FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes,
originalTerminator->getOperandTypes());
auto outlinedFunc =
- rewriter.create<func::FuncOp>(loc, funcName, outlinedFuncType);
+ func::FuncOp::create(rewriter, loc, funcName, outlinedFuncType);
Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
// Merge blocks while replacing the original block operands.
@@ -164,8 +162,8 @@ FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
outlinedFuncBlockArgs.take_front(numOriginalBlockArguments));
// Explicitly set up a new ReturnOp terminator.
rewriter.setInsertionPointToEnd(outlinedFuncBody);
- rewriter.create<func::ReturnOp>(loc, originalTerminator->getResultTypes(),
- originalTerminator->getOperands());
+ func::ReturnOp::create(rewriter, loc, originalTerminator->getResultTypes(),
+ originalTerminator->getOperands());
}
// Reconstruct the block that was deleted and add a
@@ -181,7 +179,7 @@ FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
SmallVector<Value> callValues;
llvm::append_range(callValues, newBlock->getArguments());
llvm::append_range(callValues, outlinedValues);
- auto call = rewriter.create<func::CallOp>(loc, outlinedFunc, callValues);
+ auto call = func::CallOp::create(rewriter, loc, outlinedFunc, callValues);
if (callOp)
*callOp = call;
@@ -270,12 +268,12 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
assert(dividend.getType().isIntOrIndex() &&
"expected integer or index-typed value");
- Value divisorMinusOneCst = builder.create<arith::ConstantOp>(
- loc, builder.getIntegerAttr(dividend.getType(), divisor - 1));
- Value divisorCst = builder.create<arith::ConstantOp>(
- loc, builder.getIntegerAttr(dividend.getType(), divisor));
- Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOneCst);
- return builder.create<arith::DivUIOp>(loc, sum, divisorCst);
+ Value divisorMinusOneCst = arith::ConstantOp::create(
+ builder, loc, builder.getIntegerAttr(dividend.getType(), divisor - 1));
+ Value divisorCst = arith::ConstantOp::create(
+ builder, loc, builder.getIntegerAttr(dividend.getType(), divisor));
+ Value sum = arith::AddIOp::create(builder, loc, dividend, divisorMinusOneCst);
+ return arith::DivUIOp::create(builder, loc, sum, divisorCst);
}
// Build the IR that performs ceil division of a positive value by another
@@ -286,11 +284,11 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
Value divisor) {
assert(dividend.getType().isIntOrIndex() &&
"expected integer or index-typed value");
- Value cstOne = builder.create<arith::ConstantOp>(
- loc, builder.getOneAttr(dividend.getType()));
- Value divisorMinusOne = builder.create<arith::SubIOp>(loc, divisor, cstOne);
- Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOne);
- return builder.create<arith::DivUIOp>(loc, sum, divisor);
+ Value cstOne = arith::ConstantOp::create(
+ builder, loc, builder.getOneAttr(dividend.getType()));
+ Value divisorMinusOne = arith::SubIOp::create(builder, loc, divisor, cstOne);
+ Value sum = arith::AddIOp::create(builder, loc, dividend, divisorMinusOne);
+ return arith::DivUIOp::create(builder, loc, sum, divisor);
}
/// Returns the trip count of `forOp` if its' low bound, high bound and step are
@@ -400,18 +398,20 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
// Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
if (generateEpilogueLoop)
- upperBoundUnrolled = boundsBuilder.create<arith::ConstantOp>(
- loc, boundsBuilder.getIntegerAttr(forOp.getUpperBound().getType(),
- upperBoundUnrolledCst));
+ upperBoundUnrolled = arith::ConstantOp::create(
+ boundsBuilder, loc,
+ boundsBuilder.getIntegerAttr(forOp.getUpperBound().getType(),
+ upperBoundUnrolledCst));
else
upperBoundUnrolled = forOp.getUpperBound();
// Create constant for 'stepUnrolled'.
- stepUnrolled = stepCst == stepUnrolledCst
- ? step
- : boundsBuilder.create<arith::ConstantOp>(
- loc, boundsBuilder.getIntegerAttr(
- step.getType(), stepUnrolledCst));
+ stepUnrolled =
+ stepCst == stepUnrolledCst
+ ? step
+ : arith::ConstantOp::create(boundsBuilder, loc,
+ boundsBuilder.getIntegerAttr(
+ step.getType(), stepUnrolledCst));
} else {
// Dynamic loop bounds computation.
// TODO: Add dynamic asserts for negative lb/ub/step, or
@@ -419,22 +419,23 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
auto lowerBound = forOp.getLowerBound();
auto upperBound = forOp.getUpperBound();
Value diff =
- boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
+ arith::SubIOp::create(boundsBuilder, loc, upperBound, lowerBound);
Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
- Value unrollFactorCst = boundsBuilder.create<arith::ConstantOp>(
- loc, boundsBuilder.getIntegerAttr(tripCount.getType(), unrollFactor));
+ Value unrollFactorCst = arith::ConstantOp::create(
+ boundsBuilder, loc,
+ boundsBuilder.getIntegerAttr(tripCount.getType(), unrollFactor));
Value tripCountRem =
- boundsBuilder.create<arith::RemSIOp>(loc, tripCount, unrollFactorCst);
+ arith::RemSIOp::create(boundsBuilder, loc, tripCount, unrollFactorCst);
// Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
Value tripCountEvenMultiple =
- boundsBuilder.create<arith::SubIOp>(loc, tripCount, tripCountRem);
+ arith::SubIOp::create(boundsBuilder, loc, tripCount, tripCountRem);
// Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step
- upperBoundUnrolled = boundsBuilder.create<arith::AddIOp>(
- loc, lowerBound,
- boundsBuilder.create<arith::MulIOp>(loc, tripCountEvenMultiple, step));
+ upperBoundUnrolled = arith::AddIOp::create(
+ boundsBuilder, loc, lowerBound,
+ arith::MulIOp::create(boundsBuilder, loc, tripCountEvenMultiple, step));
// Scale 'step' by 'unrollFactor'.
stepUnrolled =
- boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst);
+ arith::MulIOp::create(boundsBuilder, loc, step, unrollFactorCst);
}
UnrolledLoopInfo resultLoops;
@@ -470,11 +471,11 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
forOp.getBody(), forOp.getInductionVar(), unrollFactor,
[&](unsigned i, Value iv, OpBuilder b) {
// iv' = iv + step * i;
- auto stride = b.create<arith::MulIOp>(
- loc, step,
- b.create<arith::ConstantOp>(loc,
- b.getIntegerAttr(iv.getType(), i)));
- return b.create<arith::AddIOp>(loc, iv, stride);
+ auto stride = arith::MulIOp::create(
+ b, loc, step,
+ arith::ConstantOp::create(b, loc,
+ b.getIntegerAttr(iv.getType(), i)));
+ return arith::AddIOp::create(b, loc, iv, stride);
},
annotateFn, iterArgs, yieldedValues);
// Promote the loop body up if this has turned into a single iteration loop.
@@ -522,13 +523,13 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
// If any control operand of any inner loop of `forOp` is defined within
// `forOp`, no unroll jam.
if (!areInnerBoundsInvariant(forOp)) {
- LDBG("failed to unroll and jam: inner bounds are not invariant");
+ LDBG() << "failed to unroll and jam: inner bounds are not invariant";
return failure();
}
// Currently, for operations with results are not supported.
if (forOp->getNumResults() > 0) {
- LDBG("failed to unroll and jam: unsupported loop with results");
+ LDBG() << "failed to unroll and jam: unsupported loop with results";
return failure();
}
@@ -537,16 +538,17 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
if (!tripCount.has_value()) {
// If the trip count is dynamic, do not unroll & jam.
- LDBG("failed to unroll and jam: trip count could not be determined");
+ LDBG() << "failed to unroll and jam: trip count could not be determined";
return failure();
}
if (unrollJamFactor > *tripCount) {
- LDBG("unroll and jam factor is greater than trip count, set factor to trip "
- "count");
+ LDBG() << "unroll and jam factor is greater than trip count, set factor to "
+ "trip "
+ "count";
unrollJamFactor = *tripCount;
} else if (*tripCount % unrollJamFactor != 0) {
- LDBG("failed to unroll and jam: unsupported trip count that is not a "
- "multiple of unroll jam factor");
+ LDBG() << "failed to unroll and jam: unsupported trip count that is not a "
+ "multiple of unroll jam factor";
return failure();
}
@@ -777,13 +779,13 @@ void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
if (!isStepOne) {
Value origStepValue =
getValueOrCreateConstantIntOp(rewriter, loc, origStep);
- scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStepValue);
+ scaled = arith::MulIOp::create(rewriter, loc, normalizedIv, origStepValue);
preserve.insert(scaled.getDefiningOp());
}
denormalizedIv = scaled;
if (!isZeroBased) {
Value origLbValue = getValueOrCreateConstantIntOp(rewriter, loc, origLb);
- denormalizedIv = rewriter.create<arith::AddIOp>(loc, scaled, origLbValue);
+ denormalizedIv = arith::AddIOp::create(rewriter, loc, scaled, origLbValue);
preserve.insert(denormalizedIv.getDefiningOp());
}
@@ -819,15 +821,14 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
if (vOne && vOne.value() == 1)
continue;
if (productOf)
- productOf =
- rewriter.create<arith::MulIOp>(loc, productOf.value(), v).getResult();
+ productOf = arith::MulIOp::create(rewriter, loc, productOf.value(), v)
+ .getResult();
else
productOf = v;
}
if (!productOf) {
- productOf = rewriter
- .create<arith::ConstantOp>(
- loc, rewriter.getOneAttr(getType(values.front())))
+ productOf = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getOneAttr(getType(values.front())))
.getResult();
}
return productOf.value();
@@ -846,9 +847,8 @@ delinearizeInductionVariable(RewriterBase &rewriter, Location loc,
Value linearizedIv, ArrayRef<Value> ubs) {
if (linearizedIv.getType().isIndex()) {
- Operation *delinearizedOp =
- rewriter.create<affine::AffineDelinearizeIndexOp>(loc, linearizedIv,
- ubs);
+ Operation *delinearizedOp = affine::AffineDelinearizeIndexOp::create(
+ rewriter, loc, linearizedIv, ubs);
auto resultVals = llvm::map_to_vector(
delinearizedOp->getResults(), [](OpResult r) -> Value { return r; });
return {resultVals, SmallPtrSet<Operation *, 2>{delinearizedOp}};
@@ -870,8 +870,8 @@ delinearizeInductionVariable(RewriterBase &rewriter, Location loc,
if (!isUbOne.test(index)) {
break;
}
- delinearizedIvs[index] = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getZeroAttr(ub.getType()));
+ delinearizedIvs[index] = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getZeroAttr(ub.getType()));
numLeadingOneUbs++;
}
@@ -879,17 +879,17 @@ delinearizeInductionVariable(RewriterBase &rewriter, Location loc,
for (unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) {
unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1;
if (i != numLeadingOneUbs && !isUbOne.test(idx + 1)) {
- previous = rewriter.create<arith::DivSIOp>(loc, previous, ubs[idx + 1]);
+ previous = arith::DivSIOp::create(rewriter, loc, previous, ubs[idx + 1]);
preservedUsers.insert(previous.getDefiningOp());
}
Value iv = previous;
if (i != e - 1) {
if (!isUbOne.test(idx)) {
- iv = rewriter.create<arith::RemSIOp>(loc, previous, ubs[idx]);
+ iv = arith::RemSIOp::create(rewriter, loc, previous, ubs[idx]);
preservedUsers.insert(iv.getDefiningOp());
} else {
- iv = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getZeroAttr(ubs[idx].getType()));
+ iv = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getZeroAttr(ubs[idx].getType()));
}
}
delinearizedIvs[idx] = iv;
@@ -1089,13 +1089,13 @@ void mlir::collapseParallelLoops(
// Combine iteration spaces.
SmallVector<Value, 3> lowerBounds, upperBounds, steps;
- auto cst0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto cst1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto cst0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ auto cst1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
for (auto &sortedDimension : sortedDimensions) {
- Value newUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value newUpperBound = arith::ConstantIndexOp::create(rewriter, loc, 1);
for (auto idx : sortedDimension) {
- newUpperBound = rewriter.create<arith::MulIOp>(
- loc, newUpperBound, normalizedUpperBounds[idx]);
+ newUpperBound = arith::MulIOp::create(rewriter, loc, newUpperBound,
+ normalizedUpperBounds[idx]);
}
lowerBounds.push_back(cst0);
steps.push_back(cst1);
@@ -1108,8 +1108,8 @@ void mlir::collapseParallelLoops(
// value. The remainders then determine based on that range, which iteration
// of the original induction value this represents. This is a normalized value
// that is un-normalized already by the previous logic.
- auto newPloop = rewriter.create<scf::ParallelOp>(
- loc, lowerBounds, upperBounds, steps,
+ auto newPloop = scf::ParallelOp::create(
+ rewriter, loc, lowerBounds, upperBounds, steps,
[&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
Value previous = ploopIVs[i];
@@ -1119,15 +1119,15 @@ void mlir::collapseParallelLoops(
unsigned idx = combinedDimensions[i][j];
// Determine the current induction value's current loop iteration
- Value iv = insideBuilder.create<arith::RemSIOp>(
- loc, previous, normalizedUpperBounds[idx]);
+ Value iv = arith::RemSIOp::create(insideBuilder, loc, previous,
+ normalizedUpperBounds[idx]);
replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
loops.getRegion());
// Remove the effect of the current induction value to prepare for
// the next value.
- previous = insideBuilder.create<arith::DivSIOp>(
- loc, previous, normalizedUpperBounds[idx]);
+ previous = arith::DivSIOp::create(insideBuilder, loc, previous,
+ normalizedUpperBounds[idx]);
}
// The final induction value is just the remaining value.
@@ -1237,7 +1237,7 @@ static Loops stripmineSink(scf::ForOp forOp, Value factor,
auto iv = forOp.getInductionVar();
OpBuilder b(forOp);
- forOp.setStep(b.create<arith::MulIOp>(forOp.getLoc(), originalStep, factor));
+ forOp.setStep(arith::MulIOp::create(b, forOp.getLoc(), originalStep, factor));
Loops innerLoops;
for (auto t : targets) {
@@ -1247,12 +1247,12 @@ static Loops stripmineSink(scf::ForOp forOp, Value factor,
// Insert newForOp before the terminator of `t`.
auto b = OpBuilder::atBlockTerminator((t.getBody()));
- Value stepped = b.create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep());
+ Value stepped = arith::AddIOp::create(b, t.getLoc(), iv, forOp.getStep());
Value ub =
- b.create<arith::MinSIOp>(t.getLoc(), forOp.getUpperBound(), stepped);
+ arith::MinSIOp::create(b, t.getLoc(), forOp.getUpperBound(), stepped);
// Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
- auto newForOp = b.create<scf::ForOp>(t.getLoc(), iv, ub, originalStep);
+ auto newForOp = scf::ForOp::create(b, t.getLoc(), iv, ub, originalStep);
newForOp.getBody()->getOperations().splice(
newForOp.getBody()->getOperations().begin(),
t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
@@ -1339,8 +1339,8 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
auto forOp = forOps[i];
OpBuilder builder(forOp);
auto loc = forOp.getLoc();
- Value diff = builder.create<arith::SubIOp>(loc, forOp.getUpperBound(),
- forOp.getLowerBound());
+ Value diff = arith::SubIOp::create(builder, loc, forOp.getUpperBound(),
+ forOp.getLowerBound());
Value numIterations = ceilDivPositive(builder, loc, diff, forOp.getStep());
Value iterationsPerBlock =
ceilDivPositive(builder, loc, numIterations, sizes[i]);
@@ -1372,9 +1372,10 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
// Create a new scf.forall op after the source loop.
rewriter.setInsertionPointAfter(source);
- scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
- source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
- source.getMixedStep(), fusedOuts, source.getMapping());
+ scf::ForallOp fusedLoop = scf::ForallOp::create(
+ rewriter, source.getLoc(), source.getMixedLowerBound(),
+ source.getMixedUpperBound(), source.getMixedStep(), fusedOuts,
+ source.getMapping());
// Map control operands.
IRMapping mapping;
@@ -1425,8 +1426,8 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
// Create a new scf.for op after the source loop (with scf.yield terminator
// (without arguments) only in case its init_args is empty).
rewriter.setInsertionPointAfter(source);
- scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
- source.getLoc(), source.getLowerBound(), source.getUpperBound(),
+ scf::ForOp fusedLoop = scf::ForOp::create(
+ rewriter, source.getLoc(), source.getLowerBound(), source.getUpperBound(),
source.getStep(), fusedInitArgs);
// Map original induction variables and operands to those of the fused loop.
@@ -1452,7 +1453,7 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
for (Value operand : source.getBody()->getTerminator()->getOperands())
yieldResults.push_back(mapping.lookupOrDefault(operand));
if (!yieldResults.empty())
- rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
+ scf::YieldOp::create(rewriter, source.getLoc(), yieldResults);
// Replace old loops by substituting their uses by results of the fused loop.
rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
@@ -1483,8 +1484,8 @@ FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
// Use the normalized builder since the lower bounds are always 0 and the
// steps are always 1.
- auto normalizedForallOp = rewriter.create<scf::ForallOp>(
- loc, newUbs, forallOp.getOutputs(), forallOp.getMapping(),
+ auto normalizedForallOp = scf::ForallOp::create(
+ rewriter, loc, newUbs, forallOp.getOutputs(), forallOp.getMapping(),
[](OpBuilder &, Location, ValueRange) {});
rewriter.inlineRegionBefore(forallOp.getBodyRegion(),
diff --git a/mlir/lib/Dialect/SMT/IR/SMTDialect.cpp b/mlir/lib/Dialect/SMT/IR/SMTDialect.cpp
index 66eed86..48c0b1e 100644
--- a/mlir/lib/Dialect/SMT/IR/SMTDialect.cpp
+++ b/mlir/lib/Dialect/SMT/IR/SMTDialect.cpp
@@ -30,14 +30,14 @@ Operation *SMTDialect::materializeConstant(OpBuilder &builder, Attribute value,
if (auto attrValue = dyn_cast<BitVectorAttr>(value)) {
assert(bvType == attrValue.getType() &&
"attribute and desired result types have to match");
- return builder.create<BVConstantOp>(loc, attrValue);
+ return BVConstantOp::create(builder, loc, attrValue);
}
}
// BoolType constants can materialize into smt.constant
if (auto boolType = dyn_cast<BoolType>(type)) {
if (auto attrValue = dyn_cast<BoolAttr>(value))
- return builder.create<BoolConstantOp>(loc, attrValue);
+ return BoolConstantOp::create(builder, loc, attrValue);
}
return nullptr;
diff --git a/mlir/lib/Dialect/SMT/IR/SMTOps.cpp b/mlir/lib/Dialect/SMT/IR/SMTOps.cpp
index 8977a3a..c517ef2 100644
--- a/mlir/lib/Dialect/SMT/IR/SMTOps.cpp
+++ b/mlir/lib/Dialect/SMT/IR/SMTOps.cpp
@@ -405,7 +405,7 @@ static void buildQuantifier(
SmallVector<Location>(boundVarTypes.size(), odsState.location));
Value returnVal =
bodyBuilder(odsBuilder, odsState.location, block->getArguments());
- odsBuilder.create<smt::YieldOp>(odsState.location, returnVal);
+ smt::YieldOp::create(odsBuilder, odsState.location, returnVal);
}
if (patternBuilder) {
Region *region = odsState.addRegion();
@@ -416,7 +416,7 @@ static void buildQuantifier(
SmallVector<Location>(boundVarTypes.size(), odsState.location));
ValueRange returnVals =
patternBuilder(odsBuilder, odsState.location, block->getArguments());
- odsBuilder.create<smt::YieldOp>(odsState.location, returnVals);
+ smt::YieldOp::create(odsBuilder, odsState.location, returnVals);
}
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index c9a8e97..fcf1526 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -92,11 +92,13 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface {
/// as necessary.
void handleTerminator(Operation *op, Block *newDest) const final {
if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
- OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest);
+ auto builder = OpBuilder(op);
+ spirv::BranchOp::create(builder, op->getLoc(), newDest);
op->erase();
} else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
- OpBuilder(op).create<spirv::BranchOp>(retValOp->getLoc(), newDest,
- retValOp->getOperands());
+ auto builder = OpBuilder(op);
+ spirv::BranchOp::create(builder, retValOp->getLoc(), newDest,
+ retValOp->getOperands());
op->erase();
}
}
@@ -665,19 +667,17 @@ static ParseResult parseStructMemberDecorations(
// Parse member decoration value if it exists.
if (succeeded(parser.parseOptionalEqual())) {
- auto memberDecorationValue =
- parseAndVerifyInteger<uint32_t>(dialect, parser);
-
- if (!memberDecorationValue)
+ Attribute memberDecorationValue;
+ if (failed(parser.parseAttribute(memberDecorationValue)))
return failure();
memberDecorationInfo.emplace_back(
- static_cast<uint32_t>(memberTypes.size() - 1), 1,
- memberDecoration.value(), memberDecorationValue.value());
+ static_cast<uint32_t>(memberTypes.size() - 1),
+ memberDecoration.value(), memberDecorationValue);
} else {
memberDecorationInfo.emplace_back(
- static_cast<uint32_t>(memberTypes.size() - 1), 0,
- memberDecoration.value(), 0);
+ static_cast<uint32_t>(memberTypes.size() - 1),
+ memberDecoration.value(), UnitAttr::get(dialect.getContext()));
}
return success();
};
@@ -693,7 +693,9 @@ static ParseResult parseStructMemberDecorations(
// `!spirv.struct<` (id `,`)?
// `(`
// (spirv-type (`[` struct-member-decoration `]`)?)*
-// `)>`
+// `)`
+// (`,` struct-decoration)?
+// `>`
static Type parseStructType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
// TODO: This function is quite lengthy. Break it down into smaller chunks.
@@ -767,17 +769,48 @@ static Type parseStructType(SPIRVDialect const &dialect,
return Type();
}
- if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
+ if (failed(parser.parseRParen()))
+ return Type();
+
+ SmallVector<StructType::StructDecorationInfo, 1> structDecorationInfo;
+
+ auto parseStructDecoration = [&]() {
+ std::optional<spirv::Decoration> decoration =
+ parseAndVerify<spirv::Decoration>(dialect, parser);
+ if (!decoration)
+ return failure();
+
+ // Parse decoration value if it exists.
+ if (succeeded(parser.parseOptionalEqual())) {
+ Attribute decorationValue;
+ if (failed(parser.parseAttribute(decorationValue)))
+ return failure();
+
+ structDecorationInfo.emplace_back(decoration.value(), decorationValue);
+ } else {
+ structDecorationInfo.emplace_back(decoration.value(),
+ UnitAttr::get(dialect.getContext()));
+ }
+ return success();
+ };
+
+ while (succeeded(parser.parseOptionalComma()))
+ if (failed(parseStructDecoration()))
+ return Type();
+
+ if (failed(parser.parseGreater()))
return Type();
if (!identifier.empty()) {
if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
- memberDecorationInfo)))
+ memberDecorationInfo,
+ structDecorationInfo)))
return Type();
return idStructTy;
}
- return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
+ return StructType::get(memberTypes, offsetInfo, memberDecorationInfo,
+ structDecorationInfo);
}
// spirv-type ::= array-type
@@ -882,8 +915,9 @@ static void print(StructType type, DialectAsmPrinter &os) {
}
auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
os << stringifyDecoration(decoration.decoration);
- if (decoration.hasValue) {
- os << "=" << decoration.decorationValue;
+ if (decoration.hasValue()) {
+ os << "=";
+ os.printAttributeWithoutType(decoration.decorationValue);
}
};
llvm::interleaveComma(decorations, os, eachFn);
@@ -892,7 +926,23 @@ static void print(StructType type, DialectAsmPrinter &os) {
};
llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
printMember);
- os << ")>";
+ os << ")";
+
+ SmallVector<spirv::StructType::StructDecorationInfo, 1> decorations;
+ type.getStructDecorations(decorations);
+ if (!decorations.empty()) {
+ os << ", ";
+ auto eachFn = [&os](spirv::StructType::StructDecorationInfo decoration) {
+ os << stringifyDecoration(decoration.decoration);
+ if (decoration.hasValue()) {
+ os << "=";
+ os.printAttributeWithoutType(decoration.decorationValue);
+ }
+ };
+ llvm::interleaveComma(decorations, os, eachFn);
+ }
+
+ os << ">";
}
static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 46739bc..ddb3426 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -835,12 +835,14 @@ void SampledImageType::getCapabilities(
/// - for literal structs:
/// - a list of member types;
/// - a list of member offset info;
-/// - a list of member decoration info.
+/// - a list of member decoration info;
+/// - a list of struct decoration info.
///
/// Identified structures only have a mutable component consisting of:
/// - a list of member types;
/// - a list of member offset info;
-/// - a list of member decoration info.
+/// - a list of member decoration info;
+/// - a list of struct decoration info.
struct spirv::detail::StructTypeStorage : public TypeStorage {
/// Construct a storage object for an identified struct type. A struct type
/// associated with such storage must call StructType::trySetBody(...) later
@@ -848,6 +850,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
StructTypeStorage(StringRef identifier)
: memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
+ numStructDecorations(0), structDecorationsInfo(nullptr),
identifier(identifier) {}
/// Construct a storage object for a literal struct type. A struct type
@@ -855,10 +858,14 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
StructTypeStorage(
unsigned numMembers, Type const *memberTypes,
StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
- StructType::MemberDecorationInfo const *memberDecorationsInfo)
+ StructType::MemberDecorationInfo const *memberDecorationsInfo,
+ unsigned numStructDecorations,
+ StructType::StructDecorationInfo const *structDecorationsInfo)
: memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
numMembers(numMembers), numMemberDecorations(numMemberDecorations),
- memberDecorationsInfo(memberDecorationsInfo) {}
+ memberDecorationsInfo(memberDecorationsInfo),
+ numStructDecorations(numStructDecorations),
+ structDecorationsInfo(structDecorationsInfo) {}
/// A storage key is divided into 2 parts:
/// - for identified structs:
@@ -867,16 +874,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
/// - an ArrayRef<Type> for member types;
/// - an ArrayRef<StructType::OffsetInfo> for member offset info;
/// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
+ /// info;
+ /// - an ArrayRef<StructType::StructDecorationInfo> for struct decoration
/// info.
///
/// An identified struct type is uniqued only by the first part (field 0)
/// of the key.
///
- /// A literal struct type is uniqued only by the second part (fields 1, 2, and
- /// 3) of the key. The identifier field (field 0) must be empty.
+ /// A literal struct type is uniqued only by the second part (fields 1, 2, 3
+ /// and 4) of the key. The identifier field (field 0) must be empty.
using KeyTy =
std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
- ArrayRef<StructType::MemberDecorationInfo>>;
+ ArrayRef<StructType::MemberDecorationInfo>,
+ ArrayRef<StructType::StructDecorationInfo>>;
/// For identified structs, return true if the given key contains the same
/// identifier.
@@ -890,7 +900,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
}
return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
- getMemberDecorationsInfo());
+ getMemberDecorationsInfo(), getStructDecorationsInfo());
}
/// If the given key contains a non-empty identifier, this method constructs
@@ -937,9 +947,17 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
}
- return new (allocator.allocate<StructTypeStorage>())
- StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
- numMemberDecorations, memberDecorationList);
+ const StructType::StructDecorationInfo *structDecorationList = nullptr;
+ unsigned numStructDecorations = 0;
+ if (!std::get<4>(key).empty()) {
+ auto keyStructDecorations = std::get<4>(key);
+ numStructDecorations = keyStructDecorations.size();
+ structDecorationList = allocator.copyInto(keyStructDecorations).data();
+ }
+
+ return new (allocator.allocate<StructTypeStorage>()) StructTypeStorage(
+ keyTypes.size(), typesList, offsetInfoList, numMemberDecorations,
+ memberDecorationList, numStructDecorations, structDecorationList);
}
ArrayRef<Type> getMemberTypes() const {
@@ -961,6 +979,13 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
return {};
}
+ ArrayRef<StructType::StructDecorationInfo> getStructDecorationsInfo() const {
+ if (structDecorationsInfo)
+ return ArrayRef<StructType::StructDecorationInfo>(structDecorationsInfo,
+ numStructDecorations);
+ return {};
+ }
+
StringRef getIdentifier() const { return identifier; }
bool isIdentified() const { return !identifier.empty(); }
@@ -973,17 +998,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
/// - If called for an identified struct whose body was set before (through a
/// call to this method) but with different contents from the passed
/// arguments.
- LogicalResult mutate(
- TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
- ArrayRef<StructType::OffsetInfo> structOffsetInfo,
- ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
+ LogicalResult
+ mutate(TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
+ ArrayRef<StructType::OffsetInfo> structOffsetInfo,
+ ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo,
+ ArrayRef<StructType::StructDecorationInfo> structDecorationInfo) {
if (!isIdentified())
return failure();
if (memberTypesAndIsBodySet.getInt() &&
(getMemberTypes() != structMemberTypes ||
getOffsetInfo() != structOffsetInfo ||
- getMemberDecorationsInfo() != structMemberDecorationInfo))
+ getMemberDecorationsInfo() != structMemberDecorationInfo ||
+ getStructDecorationsInfo() != structDecorationInfo))
return failure();
memberTypesAndIsBodySet.setInt(true);
@@ -1007,6 +1034,11 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
allocator.copyInto(structMemberDecorationInfo).data();
}
+ if (!structDecorationInfo.empty()) {
+ numStructDecorations = structDecorationInfo.size();
+ structDecorationsInfo = allocator.copyInto(structDecorationInfo).data();
+ }
+
return success();
}
@@ -1015,21 +1047,30 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
unsigned numMembers;
unsigned numMemberDecorations;
StructType::MemberDecorationInfo const *memberDecorationsInfo;
+ unsigned numStructDecorations;
+ StructType::StructDecorationInfo const *structDecorationsInfo;
StringRef identifier;
};
StructType
StructType::get(ArrayRef<Type> memberTypes,
ArrayRef<StructType::OffsetInfo> offsetInfo,
- ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
+ ArrayRef<StructType::MemberDecorationInfo> memberDecorations,
+ ArrayRef<StructType::StructDecorationInfo> structDecorations) {
assert(!memberTypes.empty() && "Struct needs at least one member type");
// Sort the decorations.
- SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
+ SmallVector<StructType::MemberDecorationInfo, 4> sortedMemberDecorations(
memberDecorations);
- llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
+ llvm::array_pod_sort(sortedMemberDecorations.begin(),
+ sortedMemberDecorations.end());
+ SmallVector<StructType::StructDecorationInfo, 1> sortedStructDecorations(
+ structDecorations);
+ llvm::array_pod_sort(sortedStructDecorations.begin(),
+ sortedStructDecorations.end());
+
return Base::get(memberTypes.vec().front().getContext(),
/*identifier=*/StringRef(), memberTypes, offsetInfo,
- sortedDecorations);
+ sortedMemberDecorations, sortedStructDecorations);
}
StructType StructType::getIdentified(MLIRContext *context,
@@ -1039,18 +1080,21 @@ StructType StructType::getIdentified(MLIRContext *context,
return Base::get(context, identifier, ArrayRef<Type>(),
ArrayRef<StructType::OffsetInfo>(),
- ArrayRef<StructType::MemberDecorationInfo>());
+ ArrayRef<StructType::MemberDecorationInfo>(),
+ ArrayRef<StructType::StructDecorationInfo>());
}
StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
StructType newStructType = Base::get(
context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
- ArrayRef<StructType::MemberDecorationInfo>());
+ ArrayRef<StructType::MemberDecorationInfo>(),
+ ArrayRef<StructType::StructDecorationInfo>());
// Set an empty body in case this is a identified struct.
if (newStructType.isIdentified() &&
failed(newStructType.trySetBody(
ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
- ArrayRef<StructType::MemberDecorationInfo>())))
+ ArrayRef<StructType::MemberDecorationInfo>(),
+ ArrayRef<StructType::StructDecorationInfo>())))
return StructType();
return newStructType;
@@ -1074,6 +1118,15 @@ TypeRange StructType::getElementTypes() const {
bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
+bool StructType::hasDecoration(spirv::Decoration decoration) const {
+ for (StructType::StructDecorationInfo info :
+ getImpl()->getStructDecorationsInfo())
+ if (info.decoration == decoration)
+ return true;
+
+ return false;
+}
+
uint64_t StructType::getMemberOffset(unsigned index) const {
assert(getNumElements() > index && "member index out of range");
return getImpl()->offsetInfo[index];
@@ -1105,11 +1158,21 @@ void StructType::getMemberDecorations(
}
}
+void StructType::getStructDecorations(
+ SmallVectorImpl<StructType::StructDecorationInfo> &structDecorations)
+ const {
+ structDecorations.clear();
+ auto implDecorations = getImpl()->getStructDecorationsInfo();
+ structDecorations.append(implDecorations.begin(), implDecorations.end());
+}
+
LogicalResult
StructType::trySetBody(ArrayRef<Type> memberTypes,
ArrayRef<OffsetInfo> offsetInfo,
- ArrayRef<MemberDecorationInfo> memberDecorations) {
- return Base::mutate(memberTypes, offsetInfo, memberDecorations);
+ ArrayRef<MemberDecorationInfo> memberDecorations,
+ ArrayRef<StructDecorationInfo> structDecorations) {
+ return Base::mutate(memberTypes, offsetInfo, memberDecorations,
+ structDecorations);
}
void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
@@ -1131,6 +1194,11 @@ llvm::hash_code spirv::hash_value(
memberDecorationInfo.decoration);
}
+llvm::hash_code spirv::hash_value(
+ const StructType::StructDecorationInfo &structDecorationInfo) {
+ return llvm::hash_value(structDecorationInfo.decoration);
+}
+
//===----------------------------------------------------------------------===//
// MatrixType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 81365b4..3911ec0 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -58,7 +58,17 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
spirv::PointerType::get(spirv::StructType::get(varType), *storageClass);
}
auto varPtrType = cast<spirv::PointerType>(varType);
- auto varPointeeType = cast<spirv::StructType>(varPtrType.getPointeeType());
+ Type pointeeType = varPtrType.getPointeeType();
+
+ // Images are an opaque type and so we can just return a pointer to an image.
+ // Note that currently only sampled images are supported in the SPIR-V
+ // lowering.
+ if (isa<spirv::SampledImageType>(pointeeType))
+ return spirv::GlobalVariableOp::create(builder, funcOp.getLoc(), varType,
+ varName, abiInfo.getDescriptorSet(),
+ abiInfo.getBinding());
+
+ auto varPointeeType = cast<spirv::StructType>(pointeeType);
// Set the offset information.
varPointeeType =
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 35ec019..8f4c4cc 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -182,6 +182,14 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
return bitWidth / 8;
}
+ // Handle 8-bit floats.
+ if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
+ auto bitWidth = type.getIntOrFloatBitWidth();
+ if (bitWidth == 8)
+ return bitWidth / 8;
+ return std::nullopt;
+ }
+
if (auto complexType = dyn_cast<ComplexType>(type)) {
auto elementSize = getTypeNumBytes(options, complexType.getElementType());
if (!elementSize)
@@ -318,6 +326,44 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
type.getSignedness());
}
+/// Converts 8-bit float types to integer types with the same bit width.
+/// Returns a nullptr for unsupported 8-bit float types.
+static Type convert8BitFloatType(const SPIRVConversionOptions &options,
+ FloatType type) {
+ if (!options.emulateUnsupportedFloatTypes)
+ return nullptr;
+ // F8 types are converted to integer types with the same bit width.
+ if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+ Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+ Float8E8M0FNUType>(type))
+ return IntegerType::get(type.getContext(), type.getWidth());
+ LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type: " << type << "\n");
+ return nullptr;
+}
+
+/// Returns a type with the same shape but with any 8-bit float element type
+/// converted to the same bit width integer type. This is a noop when the
+/// element type is not the 8-bit float type or emulation flag is set to false.
+static ShapedType
+convertShaped8BitFloatType(ShapedType type,
+ const SPIRVConversionOptions &options) {
+ if (!options.emulateUnsupportedFloatTypes)
+ return type;
+ Type srcElementType = type.getElementType();
+ Type convertedElementType = nullptr;
+ // F8 types are converted to integer types with the same bit width.
+ if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+ Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+ Float8E8M0FNUType>(srcElementType))
+ convertedElementType = IntegerType::get(
+ type.getContext(), srcElementType.getIntOrFloatBitWidth());
+
+ if (!convertedElementType)
+ return type;
+
+ return type.clone(convertedElementType);
+}
+
/// Returns a type with the same shape but with any index element type converted
/// to the matching integer type. This is a noop when the element type is not
/// the index type.
@@ -337,6 +383,7 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
const SPIRVConversionOptions &options, VectorType type,
std::optional<spirv::StorageClass> storageClass = {}) {
type = cast<VectorType>(convertIndexElementType(type, options));
+ type = cast<VectorType>(convertShaped8BitFloatType(type, options));
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
if (!scalarType) {
// If this is not a spec allowed scalar type, try to handle sub-byte integer
@@ -433,6 +480,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
}
type = cast<TensorType>(convertIndexElementType(type, options));
+ type = cast<TensorType>(convertShaped8BitFloatType(type, options));
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
if (!scalarType) {
LLVM_DEBUG(llvm::dbgs()
@@ -596,6 +644,10 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
} else if (auto indexType = dyn_cast<IndexType>(elementType)) {
type = cast<MemRefType>(convertIndexElementType(type, options));
arrayElemType = type.getElementType();
+ } else if (auto floatType = dyn_cast<FloatType>(elementType)) {
+ // Hnadle 8 bit float types.
+ type = cast<MemRefType>(convertShaped8BitFloatType(type, options));
+ arrayElemType = type.getElementType();
} else {
LLVM_DEBUG(
llvm::dbgs()
@@ -1444,6 +1496,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
addConversion([this](FloatType floatType) -> std::optional<Type> {
if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
return convertScalarType(this->targetEnv, this->options, scalarType);
+ if (floatType.getWidth() == 8)
+ return convert8BitFloatType(this->options, floatType);
return Type();
});
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index 6a9b951..a53d0a7 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -174,6 +174,21 @@ void UpdateVCEPass::runOnOperation() {
if (walkResult.wasInterrupted())
return signalPassFailure();
+ // Update min version requirement for capabilities after deducing them.
+ for (spirv::Capability cap : deducedCapabilities) {
+ if (std::optional<spirv::Version> minVersion = spirv::getMinVersion(cap)) {
+ deducedVersion = std::max(deducedVersion, *minVersion);
+ if (deducedVersion > allowedVersion) {
+ module.emitError("Capability '")
+ << spirv::stringifyCapability(cap) << "' requires min version "
+ << spirv::stringifyVersion(deducedVersion)
+ << " but target environment allows up to "
+ << spirv::stringifyVersion(allowedVersion);
+ return signalPassFailure();
+ }
+ }
+ }
+
// TODO: verify that the deduced version is consistent with
// SPIR-V ops' maximal version requirements.
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 7805599..5ba8289 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -150,17 +150,17 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
- return builder.create<ub::PoisonOp>(loc, type, poison);
+ return ub::PoisonOp::create(builder, loc, type, poison);
if (llvm::isa<ShapeType>(type) || isExtentTensorType(type))
- return builder.create<ConstShapeOp>(
- loc, type, llvm::cast<DenseIntElementsAttr>(value));
+ return ConstShapeOp::create(builder, loc, type,
+ llvm::cast<DenseIntElementsAttr>(value));
if (llvm::isa<SizeType>(type))
- return builder.create<ConstSizeOp>(loc, type,
- llvm::cast<IntegerAttr>(value));
+ return ConstSizeOp::create(builder, loc, type,
+ llvm::cast<IntegerAttr>(value));
if (llvm::isa<WitnessType>(type))
- return builder.create<ConstWitnessOp>(loc, type,
- llvm::cast<BoolAttr>(value));
+ return ConstWitnessOp::create(builder, loc, type,
+ llvm::cast<BoolAttr>(value));
return arith::ConstantOp::materialize(builder, value, type, loc);
}
@@ -315,8 +315,8 @@ struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
auto newYieldOp =
rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
rewriter.setInsertionPoint(op);
- auto newOp = rewriter.create<AssumingOp>(
- op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
+ auto newOp = AssumingOp::create(
+ rewriter, op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
newOp.getDoRegion().takeBody(op.getDoRegion());
// Use the new results to replace the previously used ones.
@@ -384,7 +384,7 @@ void AssumingOp::build(
// Build body.
SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
- builder.create<AssumingYieldOp>(result.location, yieldValues);
+ AssumingYieldOp::create(builder, result.location, yieldValues);
SmallVector<Type, 2> assumingTypes;
for (Value v : yieldValues)
@@ -735,13 +735,13 @@ struct BroadcastForwardSingleOperandPattern
if (replacement.getType() != op.getType()) {
auto loc = op.getLoc();
if (llvm::isa<ShapeType>(op.getType())) {
- replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
+ replacement = FromExtentTensorOp::create(rewriter, loc, replacement);
} else {
assert(!llvm::isa<ShapeType>(op.getType()) &&
!llvm::isa<ShapeType>(replacement.getType()) &&
"expect extent tensor cast");
replacement =
- rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
+ tensor::CastOp::create(rewriter, loc, op.getType(), replacement);
}
}
@@ -779,9 +779,9 @@ struct BroadcastFoldConstantOperandsPattern
auto foldedConstantOperandsTy = RankedTensorType::get(
{static_cast<int64_t>(foldedConstantShape.size())},
rewriter.getIndexType());
- newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
- op.getLoc(), foldedConstantOperandsTy,
- rewriter.getIndexTensorAttr(foldedConstantShape)));
+ newShapeOperands.push_back(
+ ConstShapeOp::create(rewriter, op.getLoc(), foldedConstantOperandsTy,
+ rewriter.getIndexTensorAttr(foldedConstantShape)));
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
newShapeOperands);
return success();
@@ -844,9 +844,9 @@ struct BroadcastConcretizeResultTypePattern
}
}
- auto newOp = rewriter.create<BroadcastOp>(
- op.getLoc(), getExtentTensorType(getContext(), maxRank),
- op.getShapes());
+ auto newOp = BroadcastOp::create(rewriter, op.getLoc(),
+ getExtentTensorType(getContext(), maxRank),
+ op.getShapes());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
return success();
}
@@ -1353,11 +1353,11 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
auto loc = result.location;
auto dimAttr = builder.getIndexAttr(dim);
if (llvm::isa<ShapeType>(shape.getType())) {
- Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
+ Value dim = ConstSizeOp::create(builder, loc, dimAttr);
build(builder, result, builder.getType<SizeType>(), shape, dim);
} else {
- Value dim =
- builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
+ Value dim = arith::ConstantOp::create(builder, loc, builder.getIndexType(),
+ dimAttr);
build(builder, result, builder.getIndexType(), shape, dim);
}
}
@@ -1702,13 +1702,12 @@ struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
return failure();
Location loc = op.getLoc();
Value constShape =
- rewriter
- .create<ConstShapeOp>(loc,
- rewriter.getIndexTensorAttr(type.getShape()))
+ ConstShapeOp::create(rewriter, loc,
+ rewriter.getIndexTensorAttr(type.getShape()))
.getResult();
if (constShape.getType() != op.getResult().getType())
- constShape = rewriter.create<tensor::CastOp>(
- loc, op.getResult().getType(), constShape);
+ constShape = tensor::CastOp::create(rewriter, loc,
+ op.getResult().getType(), constShape);
rewriter.replaceOp(op, constShape);
return success();
}
@@ -1750,10 +1749,11 @@ struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
if (opTensorTy != shapeTensorTy) {
if (opTensorTy.getElementType() == shapeTensorTy.getElementType())
- shape = rewriter.create<tensor::CastOp>(op.getLoc(), opTensorTy, shape);
- else if (!isExtentTensorType(shapeTensorTy))
shape =
- rewriter.create<arith::IndexCastOp>(op.getLoc(), opTensorTy, shape);
+ tensor::CastOp::create(rewriter, op.getLoc(), opTensorTy, shape);
+ else if (!isExtentTensorType(shapeTensorTy))
+ shape = arith::IndexCastOp::create(rewriter, op.getLoc(), opTensorTy,
+ shape);
}
rewriter.replaceOp(op, shape);
diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index e405475..f6bc225 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -55,8 +55,8 @@ struct AssumingOpInterface
// Create new op and move over region.
TypeRange newResultTypes(yieldOp.getOperands());
- auto newOp = rewriter.create<shape::AssumingOp>(
- op->getLoc(), newResultTypes, assumingOp.getWitness());
+ auto newOp = shape::AssumingOp::create(
+ rewriter, op->getLoc(), newResultTypes, assumingOp.getWitness());
newOp.getDoRegion().takeBody(assumingOp.getRegion());
// Update all uses of the old op.
@@ -64,8 +64,9 @@ struct AssumingOpInterface
SmallVector<Value> newResults;
for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) {
if (isa<TensorType>(it.value())) {
- newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
- assumingOp.getLoc(), it.value(), newOp->getResult(it.index())));
+ newResults.push_back(bufferization::ToTensorOp::create(
+ rewriter, assumingOp.getLoc(), it.value(),
+ newOp->getResult(it.index())));
} else {
newResults.push_back(newOp->getResult(it.index()));
}
diff --git a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp
index 0fe1072..b636797 100644
--- a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp
@@ -66,7 +66,7 @@ createFuncFromCluster(OpBuilder &b, const SmallVector<Operation *, 8> &cluster,
cluster.empty()
? b.getFunctionType(shape.getType(), shape.getType())
: b.getFunctionType(ValueRange(inputs).getTypes(), shape.getType());
- shape::FuncOp fnOp = b.create<shape::FuncOp>(loc, fnName, fnType);
+ shape::FuncOp fnOp = shape::FuncOp::create(b, loc, fnName, fnType);
Block *block = fnOp.addEntryBlock();
b.setInsertionPointToEnd(block);
IRMapping bvm;
@@ -82,7 +82,7 @@ createFuncFromCluster(OpBuilder &b, const SmallVector<Operation *, 8> &cluster,
llvm::SmallVector<Value, 4> fnReturns;
fnReturns.push_back(bvm.lookupOrDefault(shape));
- b.create<shape::ReturnOp>(loc, fnReturns);
+ shape::ReturnOp::create(b, loc, fnReturns);
fnOp.setPrivate();
return std::make_pair(fnOp, inputs);
}
@@ -184,7 +184,7 @@ class TensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
LogicalResult matchAndRewrite(tensor::DimOp op,
PatternRewriter &rewriter) const override {
auto shapeOf =
- rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getSource());
+ shape::ShapeOfOp::create(rewriter, op.getLoc(), op.getSource());
rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
op.getIndex());
return success();
diff --git a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
index d83ceab..3c363f3 100644
--- a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
@@ -43,14 +43,14 @@ NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
->materializeConstant(rewriter, rewriter.getIndexAttr(1),
valueType, loc)
->getResult(0);
- ReduceOp reduce = rewriter.create<ReduceOp>(loc, op.getShape(), init);
+ ReduceOp reduce = ReduceOp::create(rewriter, loc, op.getShape(), init);
// Generate reduce operator.
Block *body = reduce.getBody();
OpBuilder b = OpBuilder::atBlockEnd(body);
- Value product = b.create<MulOp>(loc, valueType, body->getArgument(1),
- body->getArgument(2));
- b.create<shape::YieldOp>(loc, product);
+ Value product = MulOp::create(b, loc, valueType, body->getArgument(1),
+ body->getArgument(2));
+ shape::YieldOp::create(b, loc, product);
rewriter.replaceOp(op, reduce.getResult());
return success();
diff --git a/mlir/lib/Dialect/Mesh/CMakeLists.txt b/mlir/lib/Dialect/Shard/CMakeLists.txt
index fa8842f..fa8842f 100644
--- a/mlir/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shard/CMakeLists.txt
diff --git a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt b/mlir/lib/Dialect/Shard/IR/CMakeLists.txt
index 3fea4d6..70c604988 100644
--- a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shard/IR/CMakeLists.txt
@@ -1,11 +1,11 @@
-add_mlir_dialect_library(MLIRMeshDialect
- MeshOps.cpp
+add_mlir_dialect_library(MLIRShardDialect
+ ShardOps.cpp
ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shard
DEPENDS
- MLIRMeshIncGen
+ MLIRShardIncGen
LINK_LIBS PUBLIC
MLIRArithDialect
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
index 28608cb..08fccfa 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
@@ -1,4 +1,4 @@
-//===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===//
+//===- ShardOps.cpp - Shard Dialect Operations ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,10 +6,10 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -37,13 +37,12 @@
#include <optional>
#include <utility>
-#define DEBUG_TYPE "mesh-ops"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+#define DEBUG_TYPE "shard-ops"
using namespace mlir;
-using namespace mlir::mesh;
+using namespace mlir::shard;
-#include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
+#include "mlir/Dialect/Shard/IR/ShardDialect.cpp.inc"
namespace {
@@ -74,11 +73,10 @@ static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
return lhs.value() * rhs.value();
}
-SmallVector<Value> mlir::mesh::getMixedAsValues(OpBuilder b,
- const Location &loc,
- llvm::ArrayRef<int64_t> statics,
- ValueRange dynamics,
- Type type) {
+SmallVector<Value>
+mlir::shard::getMixedAsValues(OpBuilder b, const Location &loc,
+ llvm::ArrayRef<int64_t> statics,
+ ValueRange dynamics, Type type) {
SmallVector<Value> values;
auto dyn = dynamics.begin();
Type i64 = b.getI64Type();
@@ -102,7 +100,7 @@ SmallVector<Value> mlir::mesh::getMixedAsValues(OpBuilder b,
//===----------------------------------------------------------------------===//
namespace {
-struct MeshInlinerInterface : public DialectInlinerInterface {
+struct ShardInlinerinterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
// Currently no restrictions are encoded for inlining.
bool isLegalToInline(Operation *, Operation *, bool) const final {
@@ -118,44 +116,45 @@ struct MeshInlinerInterface : public DialectInlinerInterface {
} // namespace
//===----------------------------------------------------------------------===//
-// Mesh dialect
+// Shard dialect
//===----------------------------------------------------------------------===//
-void MeshDialect::initialize() {
+void ShardDialect::initialize() {
addOperations<
#define GET_OP_LIST
-#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
+#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
>();
addAttributes<
#define GET_ATTRDEF_LIST
-#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
+#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
>();
addTypes<
#define GET_TYPEDEF_LIST
-#include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
+#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
>();
- addInterface<MeshInlinerInterface>();
+ addInterface<ShardInlinerinterface>();
}
-Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
- Type type, Location loc) {
+Operation *ShardDialect::materializeConstant(OpBuilder &builder,
+ Attribute value, Type type,
+ Location loc) {
return arith::ConstantOp::materialize(builder, value, type, loc);
}
//===----------------------------------------------------------------------===//
-// Mesh utilities
+// Shard utilities
//===----------------------------------------------------------------------===//
-static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
- FlatSymbolRefAttr meshSymbol,
+static FailureOr<GridOp> getGridAndVerify(Operation *op,
+ FlatSymbolRefAttr gridSymbol,
SymbolTableCollection &symbolTable) {
- mesh::MeshOp mesh = getMeshOrNull(op, meshSymbol, symbolTable);
- if (!mesh) {
- return op->emitError() << "Undefined required mesh symbol \""
- << meshSymbol.getValue() << "\".";
+ shard::GridOp grid = getGridOrNull(op, gridSymbol, symbolTable);
+ if (!grid) {
+ return op->emitError() << "Undefined required grid symbol \""
+ << gridSymbol.getValue() << "\".";
}
- return mesh;
+ return grid;
}
template <typename It>
@@ -175,20 +174,20 @@ bool isUnique(It begin, It end) {
return true;
}
-static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
- MeshOp mesh) {
- SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
+static LogicalResult verifyGridAxes(Location loc, ArrayRef<GridAxis> axes,
+ GridOp grid) {
+ SmallVector<GridAxis> sorted = llvm::to_vector(axes);
llvm::sort(sorted);
if (!isUnique(sorted.begin(), sorted.end())) {
- return emitError(loc) << "Mesh axes contains duplicate elements.";
+ return emitError(loc) << "Grid axes contains duplicate elements.";
}
- MeshAxis rank = mesh.getRank();
+ GridAxis rank = grid.getRank();
for (auto axis : axes) {
if (axis >= rank || axis < 0) {
return emitError(loc)
- << "0-based mesh axis index " << axis
- << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
+ << "0-based grid axis index " << axis
+ << " is out of bounds. The referenced grid \"" << grid.getSymName()
<< "\" is of rank " << rank << ".";
}
}
@@ -197,22 +196,22 @@ static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
}
template <typename Op>
-static FailureOr<MeshOp>
-getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
- auto mesh =
- ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable);
- if (failed(mesh)) {
+static FailureOr<GridOp>
+getGridAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
+ auto grid =
+ ::getGridAndVerify(op.getOperation(), op.getGridAttr(), symbolTable);
+ if (failed(grid)) {
return failure();
}
- if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) {
+ if (failed(verifyGridAxes(op.getLoc(), op.getGridAxes(), grid.value()))) {
return failure();
}
- return mesh;
+ return grid;
}
-template <typename InShape, typename MeshShape, typename SplitAxes,
+template <typename InShape, typename GridShape, typename SplitAxes,
typename OutShape>
-static void shardShape(const InShape &inShape, const MeshShape &meshShape,
+static void shardShape(const InShape &inShape, const GridShape &gridShape,
const SplitAxes &splitAxes, OutShape &outShape,
ArrayRef<int64_t> shardedDimsOffsets = {},
ArrayRef<int64_t> haloSizes = {}) {
@@ -226,7 +225,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
llvm::adl_begin(outShape));
if (!shardedDimsOffsets.empty()) {
- auto isDynShape = ShapedType::isDynamicShape(meshShape);
+ auto isDynShape = ShapedType::isDynamicShape(gridShape);
uint64_t pos = 1;
for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
if (!innerSplitAxes.empty()) {
@@ -238,7 +237,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
// non-uniform offs in shardedDimsOffsets.
uint64_t numShards = 0;
for (auto i : innerSplitAxes.asArrayRef()) {
- numShards += meshShape[i];
+ numShards += gridShape[i];
}
for (size_t i = 1; i < numShards; ++i) {
if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
@@ -256,7 +255,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
outShape[tensorAxis] = shardDimension(
inShape[tensorAxis],
- collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
+ collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), gridShape));
}
if (!haloSizes.empty()) {
@@ -279,25 +278,25 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
}
}
-ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
- MeshSharding sharding) {
+ShapedType shard::shardShapedType(ShapedType shape, GridOp grid,
+ Sharding sharding) {
using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
SmallVector<Dim> resShapeArr(shape.getShape().size());
- shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
+ shardShape(shape.getShape(), grid.getShape(), sharding.getSplitAxes(),
resShapeArr, sharding.getStaticShardedDimsOffsets(),
sharding.getStaticHaloSizes());
return shape.clone(resShapeArr);
}
-Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
+Type shard::shardType(Type type, GridOp grid, Sharding sharding) {
RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
if (rankedTensorType && !rankedTensorType.getShape().empty()) {
- return shardShapedType(rankedTensorType, mesh, sharding);
+ return shardShapedType(rankedTensorType, grid, sharding);
}
return type;
}
-static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
+static void maybeInsertTargetShardingAnnotationImpl(Sharding sharding,
Value &operandValue,
Operation *operandOp,
OpBuilder &builder,
@@ -336,9 +335,9 @@ static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
}
-void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
- OpResult result,
- OpBuilder &builder) {
+void mlir::shard::maybeInsertTargetShardingAnnotation(Sharding sharding,
+ OpResult result,
+ OpBuilder &builder) {
ShardOp newShardOp;
SmallVector<std::pair<Value, Operation *>> uses;
for (auto &use : result.getUses()) {
@@ -350,9 +349,9 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
}
}
-void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
- OpOperand &operand,
- OpBuilder &builder) {
+void mlir::shard::maybeInsertSourceShardingAnnotation(Sharding sharding,
+ OpOperand &operand,
+ OpBuilder &builder) {
OpBuilder::InsertionGuard insertionGuard(builder);
Value operandValue = operand.get();
Operation *operandSrcOp = operandValue.getDefiningOp();
@@ -404,18 +403,18 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
}
//===----------------------------------------------------------------------===//
-// mesh.mesh op
+// shard.grid op
//===----------------------------------------------------------------------===//
-LogicalResult MeshOp::verify() {
+LogicalResult GridOp::verify() {
int64_t rank = getRank();
if (rank <= 0)
- return emitOpError("rank of mesh is expected to be a positive integer");
+ return emitOpError("rank of grid is expected to be a positive integer");
for (int64_t dimSize : getShape()) {
if (dimSize < 0 && ShapedType::isStatic(dimSize))
- return emitOpError("dimension size of a mesh is expected to be "
+ return emitOpError("dimension size of a grid is expected to be "
"non-negative or dynamic");
}
@@ -423,21 +422,21 @@ LogicalResult MeshOp::verify() {
}
//===----------------------------------------------------------------------===//
-// mesh.mesh_shape op
+// shard.grid_shape op
//===----------------------------------------------------------------------===//
LogicalResult
-MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
- if (failed(mesh)) {
+GridShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
+ if (failed(grid)) {
return failure();
}
- if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
+ if (failed(verifyGridAxes(getLoc(), getAxes(), grid.value()))) {
return failure();
}
size_t expectedResultsCount =
- getAxes().empty() ? mesh->getRank() : getAxes().size();
+ getAxes().empty() ? grid->getRank() : getAxes().size();
if (getResult().size() != expectedResultsCount) {
return emitError() << "Unexpected number of results " << getResult().size()
<< ". Expected " << expectedResultsCount << ".";
@@ -446,53 +445,53 @@ MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}
-void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- MeshOp mesh) {
- build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
+void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ GridOp grid) {
+ build(odsBuilder, odsState, grid, SmallVector<GridAxis>());
}
-void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- MeshOp mesh, ArrayRef<MeshAxis> axes) {
+void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ GridOp grid, ArrayRef<GridAxis> axes) {
build(odsBuilder, odsState,
- SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(),
+ SmallVector<Type>(axes.empty() ? grid.getRank() : axes.size(),
odsBuilder.getIndexType()),
- mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes));
+ grid.getSymName(), GridAxesAttr::get(odsBuilder.getContext(), axes));
}
-void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef mesh, ArrayRef<MeshAxis> axes) {
+void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ StringRef grid, ArrayRef<GridAxis> axes) {
assert(!axes.empty());
build(odsBuilder, odsState,
- SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
- MeshAxesAttr::get(odsBuilder.getContext(), axes));
+ SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), grid,
+ GridAxesAttr::get(odsBuilder.getContext(), axes));
}
-void MeshShapeOp::getAsmResultNames(
+void GridShapeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
- setNameFn(getResults()[0], "mesh_shape");
+ setNameFn(getResults()[0], "grid_shape");
}
//===----------------------------------------------------------------------===//
-// mesh.sharding
+// shard.sharding
//===----------------------------------------------------------------------===//
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
- FlatSymbolRefAttr mesh,
- ArrayRef<MeshAxesAttr> split_axes,
+ FlatSymbolRefAttr grid,
+ ArrayRef<GridAxesAttr> split_axes,
ArrayRef<int64_t> static_halos,
ArrayRef<int64_t> static_offsets) {
return build(
- b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
+ b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes),
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
}
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
- llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes,
+ llvm::StringRef grid, ArrayRef<GridAxesAttr> split_axes,
ArrayRef<int64_t> static_halos,
ArrayRef<int64_t> static_offsets) {
- return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh),
- MeshAxesArrayAttr::get(b.getContext(), split_axes),
+ return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), grid),
+ GridAxesArrayAttr::get(b.getContext(), split_axes),
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets),
{});
@@ -500,7 +499,7 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
void ShardingOp::build(
::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
- FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
+ FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> split_axes,
::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes,
::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) {
mlir::SmallVector<int64_t> staticHalos, staticDims;
@@ -508,16 +507,16 @@ void ShardingOp::build(
dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
return build(
- b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
+ b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes),
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos,
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
}
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
- mlir::mesh::MeshSharding from) {
+ mlir::shard::Sharding from) {
- build(b, odsState, ShardingType::get(b.getContext()), from.getMeshAttr(),
- MeshAxesArrayAttr::get(b.getContext(), from.getSplitAxes()),
+ build(b, odsState, ShardingType::get(b.getContext()), from.getGridAttr(),
+ GridAxesArrayAttr::get(b.getContext(), from.getSplitAxes()),
from.getStaticShardedDimsOffsets().empty()
? DenseI64ArrayAttr()
: b.getDenseI64ArrayAttr(from.getStaticShardedDimsOffsets()),
@@ -529,21 +528,21 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
}
LogicalResult ShardingOp::verify() {
- llvm::SmallSet<MeshAxis, 4> visitedAxes;
+ llvm::SmallSet<GridAxis, 4> visitedAxes;
- auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult {
- for (MeshAxis axis : axesArray) {
+ auto checkGridAxis = [&](ArrayRef<GridAxis> axesArray) -> LogicalResult {
+ for (GridAxis axis : axesArray) {
if (axis < 0)
- return emitError() << "mesh axis is expected to be non-negative";
+ return emitError() << "grid axis is expected to be non-negative";
if (!visitedAxes.insert(axis).second)
- return emitError() << "mesh axis duplicated";
+ return emitError() << "grid axis duplicated";
}
return success();
};
for (auto subAxes : getSplitAxes().getAxes()) {
- ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef();
- if (failed(checkMeshAxis(subAxesArray)))
+ ArrayRef<GridAxis> subAxesArray = subAxes.asArrayRef();
+ if (failed(checkGridAxis(subAxesArray)))
return failure();
}
@@ -572,26 +571,26 @@ void ShardingOp::getAsmResultNames(
}
LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
- if (failed(mesh)) {
+ auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
+ if (failed(grid)) {
return failure();
}
- if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
+ if (mlir::ShapedType::isDynamicShape(grid->getShape()) &&
getStaticShardedDimsOffsets().size() > 0) {
return emitError() << "sharded dims offsets are not allowed for "
- "devices meshes with dynamic shape.";
+ "device grids with dynamic shape.";
}
auto shardedDimsOffsets = getStaticShardedDimsOffsets();
if (!shardedDimsOffsets.empty()) {
- auto meshShape = mesh.value().getShape();
- assert(ShapedType::isStaticShape(meshShape));
+ auto gridShape = grid.value().getShape();
+ assert(ShapedType::isStaticShape(gridShape));
uint64_t pos = 0;
for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) {
if (!innerSplitAxes.empty()) {
int64_t numShards = 0, off = 0;
for (auto i : innerSplitAxes.asArrayRef()) {
- numShards += meshShape[i];
+ numShards += gridShape[i];
}
for (int64_t i = 0; i <= numShards; ++i) {
if (shardedDimsOffsets.size() <= pos + i) {
@@ -684,11 +683,11 @@ void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
}
//===----------------------------------------------------------------------===//
-// MeshSharding
+// Sharding
//===----------------------------------------------------------------------===//
-bool MeshSharding::equalSplitAxes(const MeshSharding &rhs) const {
- if (getMesh() != rhs.getMesh()) {
+bool Sharding::equalSplitAxes(const Sharding &rhs) const {
+ if (getGrid() != rhs.getGrid()) {
return false;
}
@@ -701,16 +700,16 @@ bool MeshSharding::equalSplitAxes(const MeshSharding &rhs) const {
}
return llvm::all_of(llvm::drop_begin(getSplitAxes(), minSize),
- std::mem_fn(&MeshAxesAttr::empty)) &&
+ std::mem_fn(&GridAxesAttr::empty)) &&
llvm::all_of(llvm::drop_begin(rhs.getSplitAxes(), minSize),
- std::mem_fn(&MeshAxesAttr::empty));
+ std::mem_fn(&GridAxesAttr::empty));
}
-bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
+bool Sharding::equalHaloAndShardSizes(const Sharding &rhs) const {
return equalShardSizes(rhs) && equalHaloSizes(rhs);
}
-bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const {
+bool Sharding::equalShardSizes(const Sharding &rhs) const {
if (rhs.getStaticShardedDimsOffsets().size() !=
getStaticShardedDimsOffsets().size() ||
!llvm::equal(getStaticShardedDimsOffsets(),
@@ -726,7 +725,7 @@ bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const {
return true;
}
-bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const {
+bool Sharding::equalHaloSizes(const Sharding &rhs) const {
if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
!llvm::equal(getStaticHaloSizes(), rhs.getStaticHaloSizes())) {
return false;
@@ -738,45 +737,43 @@ bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const {
return true;
}
-bool MeshSharding::operator==(Value rhs) const {
+bool Sharding::operator==(Value rhs) const {
return equalSplitAxes(rhs) && equalHaloAndShardSizes(rhs);
}
-bool MeshSharding::operator!=(Value rhs) const { return !(*this == rhs); }
+bool Sharding::operator!=(Value rhs) const { return !(*this == rhs); }
-bool MeshSharding::operator==(const MeshSharding &rhs) const {
+bool Sharding::operator==(const Sharding &rhs) const {
return equalSplitAxes(rhs) && equalHaloAndShardSizes(rhs);
}
-bool MeshSharding::operator!=(const MeshSharding &rhs) const {
- return !(*this == rhs);
-}
+bool Sharding::operator!=(const Sharding &rhs) const { return !(*this == rhs); }
-MeshSharding::MeshSharding(::mlir::FlatSymbolRefAttr mesh_) : mesh(mesh_) {}
+Sharding::Sharding(::mlir::FlatSymbolRefAttr grid_) : grid(grid_) {}
-MeshSharding::MeshSharding(Value rhs) {
- auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
+Sharding::Sharding(Value rhs) {
+ auto shardingOp = rhs.getDefiningOp<ShardingOp>();
assert(shardingOp && "expected sharding op");
auto splitAxes = shardingOp.getSplitAxes().getAxes();
// If splitAxes are empty, use "empty" constructor.
if (splitAxes.empty()) {
- *this = MeshSharding(shardingOp.getMeshAttr());
+ *this = Sharding(shardingOp.getGridAttr());
return;
}
*this =
- get(shardingOp.getMeshAttr(), splitAxes, shardingOp.getStaticHaloSizes(),
+ get(shardingOp.getGridAttr(), splitAxes, shardingOp.getStaticHaloSizes(),
shardingOp.getStaticShardedDimsOffsets(),
SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
}
-MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
- ArrayRef<MeshAxesAttr> split_axes_,
- ArrayRef<int64_t> static_halo_sizes_,
- ArrayRef<int64_t> static_sharded_dims_offsets_,
- ArrayRef<Value> dynamic_halo_sizes_,
- ArrayRef<Value> dynamic_sharded_dims_offsets_) {
- MeshSharding res(mesh_);
+Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid_,
+ ArrayRef<GridAxesAttr> split_axes_,
+ ArrayRef<int64_t> static_halo_sizes_,
+ ArrayRef<int64_t> static_sharded_dims_offsets_,
+ ArrayRef<Value> dynamic_halo_sizes_,
+ ArrayRef<Value> dynamic_sharded_dims_offsets_) {
+ Sharding res(grid_);
if (split_axes_.empty()) {
return res;
}
@@ -784,7 +781,7 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
res.split_axes.resize(split_axes_.size());
for (auto [i, axis] : llvm::enumerate(split_axes_)) {
res.split_axes[i] =
- MeshAxesAttr::get(mesh_.getContext(), axis.asArrayRef());
+ GridAxesAttr::get(grid_.getContext(), axis.asArrayRef());
}
auto clone = [](const auto src, auto &dst) {
@@ -801,7 +798,7 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
}
//===----------------------------------------------------------------------===//
-// mesh.shard_shape
+// shard.shard_shape
//===----------------------------------------------------------------------===//
void ShardShapeOp::getAsmResultNames(
@@ -820,7 +817,7 @@ void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
}
//===----------------------------------------------------------------------===//
-// mesh.shard op
+// shard.shard op
//===----------------------------------------------------------------------===//
void ShardOp::getAsmResultNames(
@@ -850,10 +847,10 @@ public:
if (!otherOp || !otherOp->isBeforeInBlock(op)) {
return failure();
}
- // Create a MeshSharding object for the current and the other ShardOp
+ // Create a Sharding object for the current and the other ShardOp
// If the two are equal replace current op with the other op.
- MeshSharding currentSharding(op.getSharding());
- MeshSharding otherSharding(otherOp.getSharding());
+ Sharding currentSharding(op.getSharding());
+ Sharding otherSharding(otherOp.getSharding());
if (currentSharding == otherSharding) {
b.replaceAllUsesWith(op.getResult(), otherOp.getResult());
b.eraseOp(op.getOperation());
@@ -876,21 +873,21 @@ void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
}
//===----------------------------------------------------------------------===//
-// mesh.process_multi_index op
+// shard.process_multi_index op
//===----------------------------------------------------------------------===//
LogicalResult
ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
- if (failed(mesh)) {
+ auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
+ if (failed(grid)) {
return failure();
}
- if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
+ if (failed(verifyGridAxes(getLoc(), getAxes(), grid.value()))) {
return failure();
}
size_t expectedResultsCount =
- getAxes().empty() ? mesh->getRank() : getAxes().size();
+ getAxes().empty() ? grid->getRank() : getAxes().size();
if (getResult().size() != expectedResultsCount) {
return emitError() << "Unexpected number of results " << getResult().size()
<< ". Expected " << expectedResultsCount << ".";
@@ -900,17 +897,17 @@ ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
}
void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- MeshOp mesh) {
+ GridOp grid) {
build(odsBuilder, odsState,
- SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
- mesh.getSymName(), ArrayRef<MeshAxis>());
+ SmallVector<Type>(grid.getRank(), odsBuilder.getIndexType()),
+ grid.getSymName(), ArrayRef<GridAxis>());
}
void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef mesh, ArrayRef<MeshAxis> axes) {
+ StringRef grid, ArrayRef<GridAxis> axes) {
build(odsBuilder, odsState,
- SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
- MeshAxesAttr::get(odsBuilder.getContext(), axes));
+ SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), grid,
+ GridAxesAttr::get(odsBuilder.getContext(), axes));
}
void ProcessMultiIndexOp::getAsmResultNames(
@@ -919,21 +916,21 @@ void ProcessMultiIndexOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.process_linear_index op
+// shard.process_linear_index op
//===----------------------------------------------------------------------===//
LogicalResult
ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
- if (failed(mesh)) {
+ auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
+ if (failed(grid)) {
return failure();
}
return success();
}
void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
- OperationState &odsState, MeshOp mesh) {
- build(odsBuilder, odsState, mesh.getSymName());
+ OperationState &odsState, GridOp grid) {
+ build(odsBuilder, odsState, grid.getSymName());
}
void ProcessLinearIndexOp::getAsmResultNames(
@@ -942,13 +939,13 @@ void ProcessLinearIndexOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.neighbors_linear_indices op
+// shard.neighbors_linear_indices op
//===----------------------------------------------------------------------===//
LogicalResult
NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
- if (failed(mesh)) {
+ auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
+ if (failed(grid)) {
return failure();
}
return success();
@@ -967,12 +964,12 @@ void NeighborsLinearIndicesOp::getAsmResultNames(
namespace {
template <typename Op>
-struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
+struct EmptyGridAxesCanonicalizationPattern : OpRewritePattern<Op> {
using OpRewritePattern<Op>::OpRewritePattern;
LogicalResult matchAndRewrite(Op op,
PatternRewriter &rewriter) const override {
- auto meshAxes = op.getMeshAxes();
- if (!meshAxes.empty()) {
+ auto gridAxes = op.getGridAxes();
+ if (!gridAxes.empty()) {
return failure();
}
if (op.getInput().getType() != op.getResult().getType()) {
@@ -990,24 +987,24 @@ struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
ArrayRef<int64_t> device,
Operation::operand_range deviceDynamic,
- ArrayRef<MeshAxis> meshAxes,
- ArrayRef<int64_t> meshShape) {
- if (device.size() != meshAxes.size()) {
+ ArrayRef<GridAxis> gridAxes,
+ ArrayRef<int64_t> gridShape) {
+ if (device.size() != gridAxes.size()) {
return emitError(loc) << "In-group device \"" << deviceName
<< "\" has unexpected multi-index size "
- << device.size() << ". Expected " << meshAxes.size()
+ << device.size() << ". Expected " << gridAxes.size()
<< ".";
}
for (size_t i = 0; i < device.size(); ++i) {
if (ShapedType::isStatic(device[i]) &&
- ShapedType::isStatic(meshShape[meshAxes[i]]) &&
- meshShape[meshAxes[i]] <= device[i]) {
+ ShapedType::isStatic(gridShape[gridAxes[i]]) &&
+ gridShape[gridAxes[i]] <= device[i]) {
return emitError(loc)
<< "Out of bounds coordinate " << i << " for in-group device \""
<< deviceName << "\"."
<< " Got " << device[i] << ", but expected value in the range [0, "
- << (meshShape[meshAxes[i]] - 1) << "].";
+ << (gridShape[gridAxes[i]] - 1) << "].";
}
}
return success();
@@ -1043,7 +1040,7 @@ static LogicalResult verifyDimensionCompatibility(Location loc,
static LogicalResult verifyGatherOperandAndResultShape(
Value operand, Value result, int64_t gatherAxis,
- ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
+ ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) {
auto resultRank = cast<ShapedType>(result.getType()).getRank();
if (gatherAxis < 0 || gatherAxis >= resultRank) {
return emitError(result.getLoc())
@@ -1054,7 +1051,7 @@ static LogicalResult verifyGatherOperandAndResultShape(
ShapedType operandType = cast<ShapedType>(operand.getType());
ShapedType resultType = cast<ShapedType>(result.getType());
auto deviceGroupSize =
- DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
+ DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape));
for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
@@ -1070,7 +1067,7 @@ static LogicalResult verifyGatherOperandAndResultShape(
static LogicalResult verifyAllToAllOperandAndResultShape(
Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
- ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
+ ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) {
ShapedType operandType = cast<ShapedType>(operand.getType());
ShapedType resultType = cast<ShapedType>(result.getType());
for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
@@ -1088,7 +1085,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
}
auto deviceGroupSize =
- DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
+ DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape));
auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
DimensionSize expectedResultConcatDimSize =
@@ -1115,7 +1112,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
static LogicalResult verifyScatterOrSliceOperandAndResultShape(
Value operand, Value result, int64_t tensorAxis,
- ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
+ ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) {
ShapedType operandType = cast<ShapedType>(operand.getType());
ShapedType resultType = cast<ShapedType>(result.getType());
for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
@@ -1129,7 +1126,7 @@ static LogicalResult verifyScatterOrSliceOperandAndResultShape(
}
auto deviceGroupSize =
- DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
+ DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape));
auto operandScatterDimSize =
DimensionSize(operandType.getDimSize(tensorAxis));
if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
@@ -1151,8 +1148,8 @@ static LogicalResult verifyScatterOrSliceOperandAndResultShape(
return success();
}
-static RankedTensorType sliceResultType(Type operandType, MeshOp mesh,
- ArrayRef<MeshAxis> meshAxes,
+static RankedTensorType sliceResultType(Type operandType, GridOp grid,
+ ArrayRef<GridAxis> gridAxes,
int64_t sliceAxis) {
RankedTensorType operandRankedTensorType =
cast<RankedTensorType>(operandType);
@@ -1163,29 +1160,29 @@ static RankedTensorType sliceResultType(Type operandType, MeshOp mesh,
resultShape[sliceAxis] =
operandSliceAxisSize /
- DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
+ DimensionSize(collectiveProcessGroupSize(gridAxes, grid));
return operandRankedTensorType.clone(resultShape);
}
//===----------------------------------------------------------------------===//
-// mesh.all_gather op
+// shard.all_gather op
//===----------------------------------------------------------------------===//
LogicalResult
AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
- if (failed(mesh)) {
+ auto grid = getGridAndVerifyAxes(*this, symbolTable);
+ if (failed(grid)) {
return failure();
}
auto gatherAxis = getGatherAxis().getSExtValue();
return verifyGatherOperandAndResultShape(getOperand(), getResult(),
- gatherAxis, getMeshAxes(),
- mesh.value().getShape());
+ gatherAxis, getGridAxes(),
+ grid.value().getShape());
}
void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
+ patterns.add<EmptyGridAxesCanonicalizationPattern<AllGatherOp>>(context);
}
void AllGatherOp::getAsmResultNames(
@@ -1194,23 +1191,23 @@ void AllGatherOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.all_reduce op
+// shard.all_reduce op
//===----------------------------------------------------------------------===//
LogicalResult
AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- return getMeshAndVerifyAxes(*this, symbolTable);
+ return getGridAndVerifyAxes(*this, symbolTable);
}
void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
+ patterns.add<EmptyGridAxesCanonicalizationPattern<AllReduceOp>>(context);
}
void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- Value input, StringRef mesh,
- ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) {
- build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input,
+ Value input, StringRef grid,
+ ArrayRef<GridAxis> gridAxes, ReductionKind reduction) {
+ build(odsBuilder, odsState, input.getType(), grid, gridAxes, input,
reduction);
}
@@ -1220,36 +1217,36 @@ void AllReduceOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.all_slice op
+// shard.all_slice op
//===----------------------------------------------------------------------===//
LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
- if (failed(mesh)) {
+ auto grid = getGridAndVerifyAxes(*this, symbolTable);
+ if (failed(grid)) {
return failure();
}
return verifyScatterOrSliceOperandAndResultShape(
- getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
- mesh.value().getShape());
+ getOperand(), getResult(), getSliceAxis().getSExtValue(), getGridAxes(),
+ grid.value().getShape());
}
void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
+ patterns.add<EmptyGridAxesCanonicalizationPattern<AllSliceOp>>(context);
}
void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
+ Value input, GridOp grid, ArrayRef<GridAxis> gridAxes,
int64_t sliceAxis) {
- Type resultType = sliceResultType(input.getType(), mesh, meshAxes, sliceAxis);
- build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
+ Type resultType = sliceResultType(input.getType(), grid, gridAxes, sliceAxis);
+ build(odsBuilder, odsState, resultType, input, grid.getSymName(), gridAxes,
sliceAxis);
}
void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- Type resultType, Value input, StringRef mesh,
- ArrayRef<MeshAxis> meshAxes, int64_t sliceAxis) {
- build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
+ Type resultType, Value input, StringRef grid,
+ ArrayRef<GridAxis> gridAxes, int64_t sliceAxis) {
+ build(odsBuilder, odsState, resultType, grid, gridAxes, input,
APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
}
@@ -1259,23 +1256,23 @@ void AllSliceOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.all_to_all op
+// shard.all_to_all op
//===----------------------------------------------------------------------===//
LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
- if (failed(mesh)) {
+ auto grid = getGridAndVerifyAxes(*this, symbolTable);
+ if (failed(grid)) {
return failure();
}
return verifyAllToAllOperandAndResultShape(
getOperand(), getResult(), getSplitAxis().getSExtValue(),
- getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
+ getConcatAxis().getSExtValue(), getGridAxes(), grid.value().getShape());
}
void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
+ patterns.add<EmptyGridAxesCanonicalizationPattern<AllToAllOp>>(context);
}
void AllToAllOp::getAsmResultNames(
@@ -1284,18 +1281,18 @@ void AllToAllOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.broadcast op
+// shard.broadcast op
//===----------------------------------------------------------------------===//
LogicalResult
BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
- if (failed(mesh)) {
+ auto grid = getGridAndVerifyAxes(*this, symbolTable);
+ if (failed(grid)) {
return failure();
}
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
- getRootDynamic(), getMeshAxes(),
- mesh.value().getShape()))) {
+ getRootDynamic(), getGridAxes(),
+ grid.value().getShape()))) {
return failure();
}
@@ -1304,7 +1301,7 @@ BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
+ patterns.add<EmptyGridAxesCanonicalizationPattern<BroadcastOp>>(context);
}
void BroadcastOp::getAsmResultNames(
@@ -1313,29 +1310,29 @@ void BroadcastOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.gather op
+// shard.gather op
//===----------------------------------------------------------------------===//
LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
- if (failed(mesh)) {
+ auto grid = getGridAndVerifyAxes(*this, symbolTable);
+ if (failed(grid)) {
return failure();
}
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
- getRootDynamic(), getMeshAxes(),
- mesh.value().getShape()))) {
+ getRootDynamic(), getGridAxes(),
+ grid.value().getShape()))) {
return failure();
}
auto gatherAxis = getGatherAxis().getSExtValue();
return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
- getMeshAxes(),
- mesh.value().getShape());
+ getGridAxes(),
+ grid.value().getShape());
}
void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
+ patterns.add<EmptyGridAxesCanonicalizationPattern<GatherOp>>(context);
}
void GatherOp::getAsmResultNames(
@@ -1344,18 +1341,18 @@ void GatherOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.recv op
+// shard.recv op
//===----------------------------------------------------------------------===//
LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
- if (failed(mesh)) {
+ auto grid = getGridAndVerifyAxes(*this, symbolTable);
+ if (failed(grid)) {
return failure();
}
if (getSource() &&
failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
getSource().value(), getSourceDynamic(),
- getMeshAxes(), mesh.value().getShape()))) {
+ getGridAxes(), grid.value().getShape()))) {
return failure();
}
return success();
@@ -1363,7 +1360,7 @@ LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
+ patterns.add<EmptyGridAxesCanonicalizationPattern<RecvOp>>(context);
}
void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
@@ -1371,17 +1368,17 @@ void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
}
//===----------------------------------------------------------------------===//
-// mesh.reduce op
+// shard.reduce op
//===----------------------------------------------------------------------===//
LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
- if (failed(mesh)) {
+ auto grid = getGridAndVerifyAxes(*this, symbolTable);
+ if (failed(grid)) {
return failure();
}
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
- getRootDynamic(), getMeshAxes(),
- mesh.value().getShape()))) {
+ getRootDynamic(), getGridAxes(),
+ grid.value().getShape()))) {
return failure();
}
@@ -1390,7 +1387,7 @@ LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
+ patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceOp>>(context);
}
void ReduceOp::getAsmResultNames(
@@ -1399,24 +1396,24 @@ void ReduceOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.reduce_scatter op
+// shard.reduce_scatter op
//===----------------------------------------------------------------------===//
LogicalResult
ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
- if (failed(mesh)) {
+ auto grid = getGridAndVerifyAxes(*this, symbolTable);
+ if (failed(grid)) {
return failure();
}
return verifyScatterOrSliceOperandAndResultShape(
- getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
- mesh.value().getShape());
+ getOperand(), getResult(), getScatterAxis().getSExtValue(), getGridAxes(),
+ grid.value().getShape());
}
void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
+ patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceScatterOp>>(context);
}
void ReduceScatterOp::getAsmResultNames(
@@ -1425,29 +1422,29 @@ void ReduceScatterOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.scatter op
+// shard.scatter op
//===----------------------------------------------------------------------===//
LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
- if (failed(mesh)) {
+ auto grid = getGridAndVerifyAxes(*this, symbolTable);
+ if (failed(grid)) {
return failure();
}
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
- getRootDynamic(), getMeshAxes(),
- mesh.value().getShape()))) {
+ getRootDynamic(), getGridAxes(),
+ grid.value().getShape()))) {
return failure();
}
auto scatterAxis = getScatterAxis().getSExtValue();
return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
- scatterAxis, getMeshAxes(),
- mesh.value().getShape());
+ scatterAxis, getGridAxes(),
+ grid.value().getShape());
}
void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
+ patterns.add<EmptyGridAxesCanonicalizationPattern<ScatterOp>>(context);
}
void ScatterOp::getAsmResultNames(
@@ -1456,17 +1453,17 @@ void ScatterOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.send op
+// shard.send op
//===----------------------------------------------------------------------===//
LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
- if (failed(mesh)) {
+ auto grid = getGridAndVerifyAxes(*this, symbolTable);
+ if (failed(grid)) {
return failure();
}
if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
getDestination(), getDestinationDynamic(),
- getMeshAxes(), mesh.value().getShape()))) {
+ getGridAxes(), grid.value().getShape()))) {
return failure();
}
return success();
@@ -1474,7 +1471,7 @@ LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
+ patterns.add<EmptyGridAxesCanonicalizationPattern<SendOp>>(context);
}
void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
@@ -1482,20 +1479,20 @@ void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
}
//===----------------------------------------------------------------------===//
-// mesh.shift op
+// shard.shift op
//===----------------------------------------------------------------------===//
LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
- if (failed(mesh)) {
+ auto grid = getGridAndVerifyAxes(*this, symbolTable);
+ if (failed(grid)) {
return failure();
}
- auto meshAxes = getMeshAxes();
+ auto gridAxes = getGridAxes();
auto shiftAxis = getShiftAxis().getZExtValue();
- if (!llvm::is_contained(meshAxes, shiftAxis)) {
+ if (!llvm::is_contained(gridAxes, shiftAxis)) {
return emitError() << "Invalid shift axis " << shiftAxis
- << ". It must be one of the grouping mesh axes.";
+ << ". It must be one of the grouping grid axes.";
}
return success();
@@ -1504,7 +1501,7 @@ LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
// TODO: remove op when offset is 0 or if it is a rotate with and
- // offset % shift_axis_mesh_dim_size == 0.
+ // offset % shift_axis_grid_dim_size == 0.
}
void ShiftOp::getAsmResultNames(
@@ -1513,13 +1510,13 @@ void ShiftOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.update_halo op
+// shard.update_halo op
//===----------------------------------------------------------------------===//
LogicalResult
UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
- if (failed(mesh)) {
+ auto grid = getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
+ if (failed(grid)) {
return failure();
}
@@ -1531,12 +1528,12 @@ UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
-#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
+#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
#define GET_ATTRDEF_CLASSES
-#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
+#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
#define GET_TYPEDEF_CLASSES
-#include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
+#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
-#include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
+#include "mlir/Dialect/Shard/IR/ShardEnums.cpp.inc"
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/lib/Dialect/Shard/Interfaces/CMakeLists.txt
index afe76b5..01e8e56 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shard/Interfaces/CMakeLists.txt
@@ -2,7 +2,7 @@ add_mlir_library(MLIRShardingInterface
ShardingInterface.cpp
ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shard
DEPENDS
MLIRShardingInterfaceIncGen
@@ -10,7 +10,7 @@ add_mlir_library(MLIRShardingInterface
LINK_LIBS PUBLIC
MLIRDialectUtils
MLIRIR
- MLIRMeshDialect
+ MLIRShardDialect
MLIRTensorDialect
MLIRSupport
)
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp
index 6b3d49e..d4e7618 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp
@@ -6,10 +6,10 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Support/LLVM.h"
@@ -24,9 +24,9 @@
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
using namespace mlir;
-using namespace mlir::mesh;
+using namespace mlir::shard;
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.cpp.inc"
//===----------------------------------------------------------------------===//
// common util functions
@@ -93,40 +93,39 @@ checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
}
template <typename T>
-SmallVector<MeshAxesAttr>
+SmallVector<GridAxesAttr>
fromArrayOfVector(MLIRContext *ctxt, const SmallVector<SmallVector<T>> &vec) {
- SmallVector<MeshAxesAttr> res;
+ SmallVector<GridAxesAttr> res;
for (const auto &v : vec) {
- res.emplace_back(MeshAxesAttr::get(ctxt, v));
+ res.emplace_back(GridAxesAttr::get(ctxt, v));
}
return res;
}
//===----------------------------------------------------------------------===//
-// mesh::getMeshSharding
+// shard::getSharding
//===----------------------------------------------------------------------===//
-FailureOr<std::pair<bool, MeshSharding>>
-mesh::getMeshSharding(OpResult result) {
+FailureOr<std::pair<bool, Sharding>> shard::getSharding(OpResult result) {
Value val = cast<Value>(result);
bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
- auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
+ auto shardOp = llvm::dyn_cast<shard::ShardOp>(user);
if (!shardOp)
return false;
return !shardOp.getAnnotateForUsers();
});
if (anyShardedForDef) {
- // expected to have exact one use if it has a use of `mesh.shard` without
+ // expected to have exact one use if it has a use of `shard.shard` without
// unit attr annotate_for_users
if (!val.hasOneUse())
return failure();
- auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
- return std::make_pair(false, MeshSharding(shardOp.getSharding()));
+ auto shardOp = llvm::cast<shard::ShardOp>(*val.getUsers().begin());
+ return std::make_pair(false, Sharding(shardOp.getSharding()));
}
bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
- auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
+ auto shardOp = llvm::dyn_cast<shard::ShardOp>(user);
if (!shardOp)
return false;
return shardOp.getAnnotateForUsers();
@@ -138,24 +137,23 @@ mesh::getMeshSharding(OpResult result) {
if (shardOp)
shardOps.push_back(shardOp);
}
- MeshSharding shardForDef = shardOps[0].getSharding();
+ Sharding shardForDef = shardOps[0].getSharding();
for (size_t i = 1; i < shardOps.size(); ++i) {
- // TODO: Deduce a reasonable mesh sharding attr for def when they are
+ // TODO: Deduce a reasonable grid sharding attr for def when they are
// different
assert(shardForDef == shardOps[i].getSharding() &&
- "only support all shard ops have the same mesh sharding attr");
+ "only support all shard ops have the same grid sharding attr");
}
return std::make_pair(true, shardForDef);
}
return failure();
}
-FailureOr<std::pair<bool, MeshSharding>>
-mesh::getMeshSharding(OpOperand &opOperand) {
+FailureOr<std::pair<bool, Sharding>> shard::getSharding(OpOperand &opOperand) {
Value val = opOperand.get();
if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
return std::make_pair(shardOp.getAnnotateForUsers(),
- MeshSharding(shardOp.getSharding()));
+ Sharding(shardOp.getSharding()));
return failure();
}
@@ -164,7 +162,7 @@ mesh::getMeshSharding(OpOperand &opOperand) {
// ShardingInterface::verifyShardingInterfaceImpl
//===----------------------------------------------------------------------===//
-LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
+LogicalResult shard::ShardingInterface::verifyShardingInterfaceImpl() {
Operation *op = getOperation();
// check operands and results type
@@ -201,7 +199,7 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
// ShardingInterface::printLoopTypesAndIndexingMaps
//===----------------------------------------------------------------------===//
-void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
+void shard::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
os << "print loop types and indexing maps for: \n";
getOperation()->print(os);
os << "\n";
@@ -222,15 +220,15 @@ void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
namespace {
-// Update the given `shardingOption` according to `meshAxes` and `loopIdx`
+// Update the given `shardingOption` according to `gridAxes` and `loopIdx`
static LogicalResult fillShardingOption(Operation *op,
ShardingOption &shardingOption,
- FlatSymbolRefAttr mesh,
- ArrayRef<MeshAxis> meshAxes,
+ FlatSymbolRefAttr grid,
+ ArrayRef<GridAxis> gridAxes,
unsigned loopIdx) {
- if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) ||
+ if ((shardingOption.grid && grid && shardingOption.grid != grid) ||
(!shardingOption.shardingArray[loopIdx].empty() &&
- shardingOption.shardingArray[loopIdx] != meshAxes)) {
+ shardingOption.shardingArray[loopIdx] != gridAxes)) {
LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator "
<< loopIdx << "\n");
return failure();
@@ -239,28 +237,28 @@ static LogicalResult fillShardingOption(Operation *op,
if (i == loopIdx)
continue;
- for (MeshAxis axis : meshAxes) {
+ for (GridAxis axis : gridAxes) {
if (llvm::is_contained(shardingOption.shardingArray[i], axis)) {
- LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes "
+ LLVM_DEBUG(DBGS() << "sharding option conflicts because grid axes "
<< axis << " duplicate");
return failure();
}
}
}
- if (mesh)
- shardingOption.mesh = mesh;
+ if (grid)
+ shardingOption.grid = grid;
if (shardingOption.shardingArray[loopIdx].empty())
- shardingOption.shardingArray[loopIdx].append(meshAxes.begin(),
- meshAxes.end());
+ shardingOption.shardingArray[loopIdx].append(gridAxes.begin(),
+ gridAxes.end());
return success();
}
} // namespace
FailureOr<ShardingOption>
-mesh::detail::defaultGetShardingOption(Operation *op,
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings) {
+shard::detail::defaultGetShardingOption(Operation *op,
+ ArrayRef<Sharding> operandShardings,
+ ArrayRef<Sharding> resultShardings) {
ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
ShardingOption shardingOption;
@@ -276,25 +274,25 @@ mesh::detail::defaultGetShardingOption(Operation *op,
// 1. Fill sharding option based on op results
for (auto shardingIt : llvm::enumerate(resultShardings)) {
- MeshSharding shardAttr = shardingIt.value();
+ Sharding shardAttr = shardingIt.value();
if (!shardAttr)
continue;
AffineMap map = maps[numOperands + shardingIt.index()];
anyShardingInResultsOrOperands = true;
if (shardAttr.getSplitAxes().empty() || map.getResults().empty()) {
- shardingOption.mesh = shardAttr.getMeshAttr();
+ shardingOption.grid = shardAttr.getGridAttr();
} else {
// Handle the split axes: calculate the corresponding loop index for each
// split axes sub-array, and then store the sub-array to
// shardingOption[index]
for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
AffineExpr expr = std::get<0>(it);
- ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
+ ArrayRef<GridAxis> axes = std::get<1>(it).asArrayRef();
auto dim = cast<AffineDimExpr>(expr);
unsigned index = dim.getPosition();
visitedLoopIndices.insert(index);
if (failed(fillShardingOption(op, shardingOption,
- shardAttr.getMeshAttr(), axes, index)))
+ shardAttr.getGridAttr(), axes, index)))
return failure();
}
}
@@ -302,7 +300,7 @@ mesh::detail::defaultGetShardingOption(Operation *op,
// 2. Fill sharding option based on operands
for (auto shardingIt : llvm::enumerate(operandShardings)) {
- MeshSharding shardAttr = shardingIt.value();
+ Sharding shardAttr = shardingIt.value();
if (!shardAttr)
continue;
@@ -316,7 +314,7 @@ mesh::detail::defaultGetShardingOption(Operation *op,
// then the operands with multiple loop indices.
for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
AffineExpr expr = std::get<0>(it);
- ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
+ ArrayRef<GridAxis> axes = std::get<1>(it).asArrayRef();
FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
checkOperandAffineExpr(expr, numDims);
if (failed(loopIndices))
@@ -329,7 +327,7 @@ mesh::detail::defaultGetShardingOption(Operation *op,
unsigned loopIdx = *loopIndices->begin();
visitedLoopIndices.insert(loopIdx);
if (failed(fillShardingOption(op, shardingOption,
- shardAttr.getMeshAttr(), axes, loopIdx)))
+ shardAttr.getGridAttr(), axes, loopIdx)))
return failure();
}
// If multiple loop indices correspond to a dimension of an operand, it is
@@ -361,11 +359,11 @@ mesh::detail::defaultGetShardingOption(Operation *op,
}
// Get the sharding attributed for the given result and sharding option.
-MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption,
- AffineMap map,
- ArrayRef<utils::IteratorType> loopTypes) {
+static Sharding getSharding(OpResult result,
+ const ShardingOption &shardingOption, AffineMap map,
+ ArrayRef<utils::IteratorType> loopTypes) {
auto resultType = cast<RankedTensorType>(result.getType());
- SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
+ SmallVector<SmallVector<GridAxis>> splitAxes(resultType.getRank());
// process the split axes
for (auto it : llvm::enumerate(map.getResults())) {
@@ -379,25 +377,25 @@ MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption,
}
removeTrailingEmptySubArray(splitAxes);
- return MeshSharding::get(shardingOption.mesh,
- fromArrayOfVector(result.getContext(), splitAxes));
+ return Sharding::get(shardingOption.grid,
+ fromArrayOfVector(result.getContext(), splitAxes));
}
-static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
- const ShardingOption &shardingOption,
- AffineMap map) {
+static FailureOr<Sharding> getSharding(OpOperand &opOperand,
+ const ShardingOption &shardingOption,
+ AffineMap map) {
Value operandValue = opOperand.get();
auto operandType = dyn_cast<RankedTensorType>(operandValue.getType());
if (!operandType) {
if (operandValue.getType().isIntOrIndexOrFloat())
- return MeshSharding();
+ return Sharding();
return failure();
}
// 0d tensors cannot be sharded and must get replicated
if (operandType.getRank() == 0) {
- return MeshSharding(shardingOption.mesh);
+ return Sharding(shardingOption.grid);
}
- SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
+ SmallVector<SmallVector<GridAxis>> splitAxes(operandType.getRank());
unsigned numDims = map.getNumDims();
for (auto it : llvm::enumerate(map.getResults())) {
int64_t idx = it.index();
@@ -422,15 +420,14 @@ static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
}
removeTrailingEmptySubArray(splitAxes);
- return MeshSharding::get(
- shardingOption.mesh,
+ return Sharding::get(
+ shardingOption.grid,
fromArrayOfVector(opOperand.get().getContext(), splitAxes));
}
-FailureOr<std::vector<MeshSharding>>
-mesh::detail::defaultGetShardingAnnotations(
+FailureOr<std::vector<Sharding>> shard::detail::defaultGetShardingAnnotations(
Operation *op, const ShardingOption &shardingOption) {
- std::vector<MeshSharding> res;
+ std::vector<Sharding> res;
ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
SmallVector<utils::IteratorType> loopTypes =
@@ -439,7 +436,7 @@ mesh::detail::defaultGetShardingAnnotations(
unsigned numOperands = op->getNumOperands();
for (OpOperand &opOperand : op->getOpOperands()) {
- FailureOr<MeshSharding> shardingAttr = getSharding(
+ FailureOr<Sharding> shardingAttr = ::getSharding(
opOperand, shardingOption, maps[opOperand.getOperandNumber()]);
if (failed(shardingAttr))
return failure();
@@ -447,9 +444,9 @@ mesh::detail::defaultGetShardingAnnotations(
}
for (OpResult result : op->getResults()) {
- res.push_back(getSharding(result, shardingOption,
- maps[numOperands + result.getResultNumber()],
- loopTypes));
+ res.push_back(::getSharding(result, shardingOption,
+ maps[numOperands + result.getResultNumber()],
+ loopTypes));
}
return res;
@@ -459,26 +456,25 @@ mesh::detail::defaultGetShardingAnnotations(
// detail::defaultAddShardingAnnotations
//===----------------------------------------------------------------------===//
-// To add a `mesh.shard` op for the given result, based on the details provided
+// To add a `shard.shard` op for the given result, based on the details provided
// in `shardingOption`, `map`, and `loopTypes`.
static LogicalResult addShardOp(OpBuilder &b, OpResult result,
const ShardingOption &shardingOption,
AffineMap map,
ArrayRef<utils::IteratorType> loopTypes) {
- MeshSharding sharding = getSharding(result, shardingOption, map, loopTypes);
+ Sharding sharding = getSharding(result, shardingOption, map, loopTypes);
maybeInsertTargetShardingAnnotation(sharding, result, b);
return success();
}
-// To add a `mesh.shard` op for the given operand, based on the details provided
-// in `shardingOption`, `map`, and `loopTypes`.
+// To add a `shard.shard` op for the given operand, based on the details
+// provided in `shardingOption`, `map`, and `loopTypes`.
static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
const ShardingOption &shardingOption,
AffineMap map) {
- FailureOr<MeshSharding> sharding =
- getSharding(opOperand, shardingOption, map);
+ FailureOr<Sharding> sharding = getSharding(opOperand, shardingOption, map);
if (failed(sharding)) {
return failure();
}
@@ -488,9 +484,9 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
return success();
}
-LogicalResult mesh::detail::defaultAddShardingAnnotations(
+LogicalResult shard::detail::defaultAddShardingAnnotations(
Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {
- assert(!shardingOption.empty && shardingOption.mesh);
+ assert(!shardingOption.empty && shardingOption.grid);
ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
SmallVector<utils::IteratorType> loopTypes =
@@ -498,7 +494,7 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations(
SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
unsigned numOperands = op->getNumOperands();
- // 1. add mesh.shard ops for all op results
+ // 1. add shard.shard ops for all op results
for (OpResult result : op->getResults()) {
if (failed(addShardOp(b, result, shardingOption,
maps[numOperands + result.getResultNumber()],
@@ -506,7 +502,7 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations(
return failure();
}
- // 2. add mesh.shard ops for all operands
+ // 2. add shard.shard ops for all operands
for (OpOperand &opOperand : op->getOpOperands()) {
if (failed(addShardOp(b, opOperand, shardingOption,
maps[opOperand.getOperandNumber()])))
@@ -517,9 +513,8 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations(
}
#ifndef NDEBUG
-static bool
-isValueCompatibleWithFullReplicationSharding(Value value,
- MeshSharding sharding) {
+static bool isValueCompatibleWithFullReplicationSharding(Value value,
+ Sharding sharding) {
if (isa<RankedTensorType>(value.getType())) {
return isFullReplication(sharding);
}
@@ -527,60 +522,59 @@ isValueCompatibleWithFullReplicationSharding(Value value,
return !sharding;
}
-template <typename ValueRange, typename MeshShardingRage>
+template <typename ValueRange, typename ShardingRage>
static bool
areValuesCompatibleWithFullReplicationShardings(ValueRange &&values,
- MeshShardingRage &&shardings) {
+ ShardingRage &&shardings) {
if (std::size(values) != std::size(shardings)) {
return false;
}
- return llvm::all_of(
- llvm::zip_equal(std::forward<ValueRange>(values),
- std::forward<MeshShardingRage>(shardings)),
- [](auto valueAndSharding) {
- return isValueCompatibleWithFullReplicationSharding(
- std::get<0>(valueAndSharding), std::get<1>(valueAndSharding));
- });
+ return llvm::all_of(llvm::zip_equal(std::forward<ValueRange>(values),
+ std::forward<ShardingRage>(shardings)),
+ [](auto valueAndSharding) {
+ return isValueCompatibleWithFullReplicationSharding(
+ std::get<0>(valueAndSharding),
+ std::get<1>(valueAndSharding));
+ });
}
#endif // NDEBUG
-void mesh::spmdizeFullyReplicatedOperation(
- Operation &op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTable, OpBuilder &builder) {
- assert(spmdizedOperands.size() == operandShardings.size());
+void shard::partitionFullyReplicatedOperation(
+ Operation &op, ArrayRef<Value> partitionedOperands,
+ ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
+ IRMapping &partitionMap, SymbolTableCollection &symbolTable,
+ OpBuilder &builder) {
+ assert(partitionedOperands.size() == operandShardings.size());
assert(areValuesCompatibleWithFullReplicationShardings(op.getOperands(),
operandShardings));
assert(areValuesCompatibleWithFullReplicationShardings(op.getResults(),
resultShardings));
// `clone` will populate the mapping of old to new results.
- builder.clone(op, spmdizationMap);
+ builder.clone(op, partitionMap);
}
-static void updateMeshAxisAssignmentForLoopIterators(
- ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
- SmallVector<std::optional<SmallVector<MeshAxis>>>
- &meshAxesAssignmentForLoopIterators) {
+static void updateGridAxisAssignmentForLoopIterators(
+ ArrayRef<GridAxis> gridAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
+ SmallVector<std::optional<SmallVector<GridAxis>>>
+ &gridAxesAssignmentForLoopIterators) {
AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr);
unsigned loopIteratorIdx = affineDimExpr.getPosition();
- if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {
- assert(llvm::equal(meshAxesAssignmentForTensorAxis,
- *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));
+ if (gridAxesAssignmentForLoopIterators[loopIteratorIdx]) {
+ assert(llvm::equal(gridAxesAssignmentForTensorAxis,
+ *gridAxesAssignmentForLoopIterators[loopIteratorIdx]));
} else {
- meshAxesAssignmentForLoopIterators[loopIteratorIdx] =
- llvm::to_vector(meshAxesAssignmentForTensorAxis);
+ gridAxesAssignmentForLoopIterators[loopIteratorIdx] =
+ llvm::to_vector(gridAxesAssignmentForTensorAxis);
}
}
-ShardingArray mesh::getMeshAxisAssignmentForLoopIterators(
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings,
+ShardingArray shard::getGridAxisAssignmentForLoopIterators(
+ ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
ArrayRef<utils::IteratorType> loopIteratorTypes,
ArrayRef<AffineMap> indexingMaps) {
- SmallVector<std::optional<SmallVector<MeshAxis>>>
- meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
- std::vector<MeshSharding> operatorAndResultShardings;
+ SmallVector<std::optional<SmallVector<GridAxis>>>
+ gridAxisAssignmentForLoopIterators(loopIteratorTypes.size());
+ std::vector<Sharding> operatorAndResultShardings;
operatorAndResultShardings.reserve(operandShardings.size() +
resultShardings.size());
llvm::append_range(operatorAndResultShardings, operandShardings);
@@ -589,69 +583,69 @@ ShardingArray mesh::getMeshAxisAssignmentForLoopIterators(
if (!sharding) {
continue;
}
- for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] :
+ for (auto [gridAxesAssignmentForTensorAxis, indexingExpr] :
llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
- updateMeshAxisAssignmentForLoopIterators(
- meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
- meshAxisAssignmentForLoopIterators);
+ updateGridAxisAssignmentForLoopIterators(
+ gridAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
+ gridAxisAssignmentForLoopIterators);
}
// Missing trailing split axes means replication on those tensor dimensions.
for (unsigned i = sharding.getSplitAxes().size();
i < affineMap.getNumResults(); ++i) {
- updateMeshAxisAssignmentForLoopIterators(
- {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators);
+ updateGridAxisAssignmentForLoopIterators(
+ {}, affineMap.getResults()[i], gridAxisAssignmentForLoopIterators);
}
}
ShardingArray res;
- llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res),
- [](std::optional<SmallVector<MeshAxis>> &axes) {
+ llvm::transform(gridAxisAssignmentForLoopIterators, std::back_inserter(res),
+ [](std::optional<SmallVector<GridAxis>> &axes) {
if (!axes) {
- return SmallVector<MeshAxis>();
+ return SmallVector<GridAxis>();
};
return std::move(*axes);
});
return res;
}
-bool mesh::isAtLeastOneReductionIteratorSharded(
+bool shard::isAtLeastOneReductionIteratorSharded(
ArrayRef<utils::IteratorType> loopIteratorTypes,
- ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
- for (auto [loopIteratorType, meshAxisAssignment] :
- llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
+ ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators) {
+ for (auto [loopIteratorType, gridAxisAssignment] :
+ llvm::zip_equal(loopIteratorTypes, gridAxisAssignmentForLoopIterators)) {
if (loopIteratorType == utils::IteratorType::reduction &&
- !meshAxisAssignment.empty()) {
+ !gridAxisAssignment.empty()) {
return true;
}
}
return false;
}
-SmallVector<MeshAxis> mesh::getReductionMeshAxes(
+SmallVector<GridAxis> shard::getReductionGridAxes(
ArrayRef<utils::IteratorType> loopIteratorTypes,
- ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
- SmallVector<MeshAxis> meshAxes;
- for (auto [loopIteratorType, meshAxisAssignment] :
- llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
+ ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators) {
+ SmallVector<GridAxis> gridAxes;
+ for (auto [loopIteratorType, gridAxisAssignment] :
+ llvm::zip_equal(loopIteratorTypes, gridAxisAssignmentForLoopIterators)) {
if (loopIteratorType == utils::IteratorType::reduction) {
- llvm::append_range(meshAxes, meshAxisAssignment);
+ llvm::append_range(gridAxes, gridAxisAssignment);
}
}
- return meshAxes;
+ return gridAxes;
}
-void mesh::spmdizeTriviallyShardableOperation(
- Operation &op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTable, OpBuilder &builder) {
+void shard::partitionTriviallyShardableOperation(
+ Operation &op, ArrayRef<Value> partitionedOperands,
+ ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
+ IRMapping &partitionMap, SymbolTableCollection &symbolTable,
+ OpBuilder &builder) {
// `clone` will populate the mapping of old to new results.
- Operation *newOp = builder.clone(op, spmdizationMap);
+ Operation *newOp = builder.clone(op, partitionMap);
// Set the result types to the sharded counterparts.
for (auto [oldResult, newResult, sharding] :
llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
newResult.setType(shardType(
newResult.getType(),
- getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding));
+ getGridOrNull(&op, sharding.getGridAttr(), symbolTable), sharding));
}
}
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt
index 381bc9a..a884764 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt
@@ -1,14 +1,14 @@
-add_mlir_dialect_library(MLIRMeshTransforms
+add_mlir_dialect_library(MLIRShardTransforms
Simplifications.cpp
ShardingPropagation.cpp
- Spmdization.cpp
+ Partition.cpp
Transforms.cpp
ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shard
DEPENDS
- MLIRMeshPassIncGen
+ MLIRShardPassIncGen
MLIRShardingInterface
LINK_LIBS PUBLIC
@@ -21,7 +21,7 @@ add_mlir_dialect_library(MLIRMeshTransforms
MLIRFuncDialect
MLIRFunctionInterfaces
MLIRIR
- MLIRMeshDialect
+ MLIRShardDialect
MLIRPass
MLIRSupport
MLIRTensorDialect
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index c6e76ec..3e3d476 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -1,4 +1,4 @@
-//===- Spmdization.cpp --------------------------------------------- C++ --===//
+//===- Partition.cpp --------------------------------------------- C++ --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,11 +6,11 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
+#include "mlir/Dialect/Shard/Transforms/Partition.h"
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -33,7 +33,7 @@
#include <optional>
#include <tuple>
-namespace mlir::mesh {
+namespace mlir::shard {
template <typename SourceAxes, typename TargetAxes>
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
@@ -43,52 +43,49 @@ static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
});
}
-static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx,
- MeshSharding sourceSharding,
- int64_t splitTensorAxis,
- MeshAxis splitMeshAxis) {
- SmallVector<MeshAxesAttr> targetShardingSplitAxes =
+static Sharding targetShardingInSplitLastAxis(MLIRContext *ctx,
+ Sharding sourceSharding,
+ int64_t splitTensorAxis,
+ GridAxis splitGridAxis) {
+ SmallVector<GridAxesAttr> targetShardingSplitAxes =
llvm::to_vector(sourceSharding.getSplitAxes());
while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
splitTensorAxis) {
- targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
+ targetShardingSplitAxes.push_back(GridAxesAttr::get(ctx, {}));
}
auto targetSplitAxes =
llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
- targetSplitAxes.push_back(splitMeshAxis);
+ targetSplitAxes.push_back(splitGridAxis);
targetShardingSplitAxes[splitTensorAxis] =
- MeshAxesAttr::get(ctx, targetSplitAxes);
- return MeshSharding::get(sourceSharding.getMeshAttr(),
- targetShardingSplitAxes);
+ GridAxesAttr::get(ctx, targetSplitAxes);
+ return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes);
}
-// Split a replicated tensor along a mesh axis.
+// Split a replicated tensor along a grid axis.
// E.g. [[0, 1]] -> [[0, 1, 2]].
-// Returns the spmdized target value with its sharding.
-static std::tuple<TypedValue<ShapedType>, MeshSharding>
+// Returns the partitioned target value with its sharding.
+static std::tuple<TypedValue<ShapedType>, Sharding>
splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
- MeshSharding sourceSharding,
- TypedValue<ShapedType> sourceShard, MeshOp mesh,
- int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
+ Sharding sourceSharding,
+ TypedValue<ShapedType> sourceShard, GridOp grid,
+ int64_t splitTensorAxis, GridAxis splitGridAxis) {
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
- builder
- .create<AllSliceOp>(sourceShard, mesh,
- ArrayRef<MeshAxis>(splitMeshAxis),
- splitTensorAxis)
+ AllSliceOp::create(builder, sourceShard, grid,
+ ArrayRef<GridAxis>(splitGridAxis), splitTensorAxis)
.getResult());
- MeshSharding targetSharding = targetShardingInSplitLastAxis(
- builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
+ Sharding targetSharding = targetShardingInSplitLastAxis(
+ builder.getContext(), sourceSharding, splitTensorAxis, splitGridAxis);
return {targetShard, targetSharding};
}
// Detect if the resharding is of type e.g.
// [[0, 1]] -> [[0, 1, 2]].
-// If detected, returns the corresponding tensor axis mesh axis pair.
+// If detected, returns the corresponding tensor axis grid axis pair.
// Does not detect insertions like
// [[0, 1]] -> [[0, 2, 1]].
-static std::optional<std::tuple<int64_t, MeshAxis>>
-detectSplitLastAxisInResharding(MeshSharding sourceSharding,
- MeshSharding targetSharding) {
+static std::optional<std::tuple<int64_t, GridAxis>>
+detectSplitLastAxisInResharding(Sharding sourceSharding,
+ Sharding targetSharding) {
for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
++tensorAxis) {
if (sourceSharding.getSplitAxes().size() > tensorAxis) {
@@ -118,16 +115,15 @@ detectSplitLastAxisInResharding(MeshSharding sourceSharding,
return std::nullopt;
}
-static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
-trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
- MeshSharding sourceSharding,
- MeshSharding targetSharding,
+static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
+trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
+ Sharding sourceSharding, Sharding targetSharding,
TypedValue<ShapedType> sourceShard) {
if (auto detectRes =
detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
- auto [tensorAxis, meshAxis] = detectRes.value();
- return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh,
- tensorAxis, meshAxis);
+ auto [tensorAxis, gridAxis] = detectRes.value();
+ return splitLastAxisInResharding(builder, sourceSharding, sourceShard, grid,
+ tensorAxis, gridAxis);
}
return std::nullopt;
@@ -135,10 +131,10 @@ trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
// Detect if the resharding is of type e.g.
// [[0, 1, 2]] -> [[0, 1]].
-// If detected, returns the corresponding tensor axis mesh axis pair.
-static std::optional<std::tuple<int64_t, MeshAxis>>
-detectUnsplitLastAxisInResharding(MeshSharding sourceSharding,
- MeshSharding targetSharding) {
+// If detected, returns the corresponding tensor axis grid axis pair.
+static std::optional<std::tuple<int64_t, GridAxis>>
+detectUnsplitLastAxisInResharding(Sharding sourceSharding,
+ Sharding targetSharding) {
for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
++tensorAxis) {
if (targetSharding.getSplitAxes().size() > tensorAxis) {
@@ -165,10 +161,10 @@ detectUnsplitLastAxisInResharding(MeshSharding sourceSharding,
return std::nullopt;
}
-static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx,
- MeshSharding sourceSharding,
- int64_t splitTensorAxis) {
- SmallVector<MeshAxesAttr> targetShardingSplitAxes =
+static Sharding targetShardingInUnsplitLastAxis(MLIRContext *ctx,
+ Sharding sourceSharding,
+ int64_t splitTensorAxis) {
+ SmallVector<GridAxesAttr> targetShardingSplitAxes =
llvm::to_vector(sourceSharding.getSplitAxes());
assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
splitTensorAxis);
@@ -177,9 +173,8 @@ static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx,
targetSplitAxes.pop_back();
targetShardingSplitAxes[splitTensorAxis] =
- MeshAxesAttr::get(ctx, targetSplitAxes);
- return MeshSharding::get(sourceSharding.getMeshAttr(),
- targetShardingSplitAxes);
+ GridAxesAttr::get(ctx, targetSplitAxes);
+ return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes);
}
static ShapedType allGatherResultShapeInUnsplitLastAxis(
@@ -190,45 +185,42 @@ static ShapedType allGatherResultShapeInUnsplitLastAxis(
return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
}
-static std::tuple<TypedValue<ShapedType>, MeshSharding>
-unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
- MeshSharding sourceSharding,
- ShapedType sourceUnshardedShape,
- TypedValue<ShapedType> sourceShard, MeshOp mesh,
- int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
+static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding(
+ ImplicitLocOpBuilder &builder, Sharding sourceSharding,
+ ShapedType sourceUnshardedShape, TypedValue<ShapedType> sourceShard,
+ GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) {
MLIRContext *ctx = builder.getContext();
builder.setInsertionPointAfterValue(sourceShard);
- MeshSharding targetSharding =
+ Sharding targetSharding =
targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis);
ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
- sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
+ sourceShard.getType(), grid.getShape()[splitGridAxis], splitTensorAxis);
Value allGatherResult = AllGatherOp::create(
builder,
RankedTensorType::get(allGatherResultShape.getShape(),
allGatherResultShape.getElementType()),
- mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard,
+ grid.getSymName(), SmallVector<GridAxis>({splitGridAxis}), sourceShard,
APInt(64, splitTensorAxis));
ShapedType targetShape =
- shardShapedType(sourceUnshardedShape, mesh, targetSharding);
+ shardShapedType(sourceUnshardedShape, grid, targetSharding);
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
tensor::CastOp::create(builder, targetShape, allGatherResult)
.getResult());
return {targetShard, targetSharding};
}
-static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
-tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
- MeshSharding sourceSharding,
- MeshSharding targetSharding,
+static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
+tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
+ Sharding sourceSharding, Sharding targetSharding,
ShapedType sourceUnshardedShape,
TypedValue<ShapedType> sourceShard) {
if (auto detectRes =
detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
- auto [tensorAxis, meshAxis] = detectRes.value();
+ auto [tensorAxis, gridAxis] = detectRes.value();
return unsplitLastAxisInResharding(builder, sourceSharding,
- sourceUnshardedShape, sourceShard, mesh,
- tensorAxis, meshAxis);
+ sourceUnshardedShape, sourceShard, grid,
+ tensorAxis, gridAxis);
}
return std::nullopt;
@@ -238,10 +230,10 @@ tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
// [[0, 1], [2]] -> [[0], [1, 2]].
// Only moving the last axis counts.
// If detected, returns the corresponding (source_tensor_axis,
-// target_tensor_axis, mesh_axis) tuple.
-static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
-detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding,
- MeshSharding targetSharding) {
+// target_tensor_axis, grid_axis) tuple.
+static std::optional<std::tuple<int64_t, int64_t, GridAxis>>
+detectMoveLastSplitAxisInResharding(Sharding sourceSharding,
+ Sharding targetSharding) {
for (size_t sourceTensorAxis = 0;
sourceTensorAxis < sourceSharding.getSplitAxes().size();
++sourceTensorAxis) {
@@ -281,33 +273,32 @@ detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding,
return std::nullopt;
}
-static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx,
- MeshSharding sourceSharding,
- int64_t sourceTensorAxis,
- int64_t targetTensorAxis) {
- SmallVector<MeshAxesAttr> targetShardingSplitAxes =
+static Sharding targetShardingInMoveLastAxis(MLIRContext *ctx,
+ Sharding sourceSharding,
+ int64_t sourceTensorAxis,
+ int64_t targetTensorAxis) {
+ SmallVector<GridAxesAttr> targetShardingSplitAxes =
llvm::to_vector(sourceSharding.getSplitAxes());
while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
targetTensorAxis) {
- targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
+ targetShardingSplitAxes.push_back(GridAxesAttr::get(ctx, {}));
}
auto sourceSplitAxes =
llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
assert(!sourceSplitAxes.empty());
- auto meshAxis = sourceSplitAxes.back();
+ auto gridAxis = sourceSplitAxes.back();
sourceSplitAxes.pop_back();
targetShardingSplitAxes[sourceTensorAxis] =
- MeshAxesAttr::get(ctx, sourceSplitAxes);
+ GridAxesAttr::get(ctx, sourceSplitAxes);
auto targetSplitAxes =
llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
- targetSplitAxes.push_back(meshAxis);
+ targetSplitAxes.push_back(gridAxis);
targetShardingSplitAxes[targetTensorAxis] =
- MeshAxesAttr::get(ctx, targetSplitAxes);
+ GridAxesAttr::get(ctx, targetSplitAxes);
- return MeshSharding::get(sourceSharding.getMeshAttr(),
- targetShardingSplitAxes);
+ return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes);
}
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
@@ -322,46 +313,46 @@ static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
}
-static std::tuple<TypedValue<ShapedType>, MeshSharding>
-moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
- MeshSharding sourceSharding,
+static std::tuple<TypedValue<ShapedType>, Sharding>
+moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
+ Sharding sourceSharding,
ShapedType sourceUnshardedShape,
TypedValue<ShapedType> sourceShard,
int64_t sourceTensorAxis,
- int64_t targetTensorAxis, MeshAxis meshAxis) {
+ int64_t targetTensorAxis, GridAxis gridAxis) {
MLIRContext *ctx = builder.getContext();
builder.setInsertionPointAfterValue(sourceShard);
- MeshSharding targetSharding = targetShardingInMoveLastAxis(
+ Sharding targetSharding = targetShardingInMoveLastAxis(
ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
- sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
+ sourceShard.getType(), grid.getShape()[gridAxis], sourceTensorAxis,
targetTensorAxis);
Value allToAllResult = AllToAllOp::create(
builder,
RankedTensorType::get(allToAllResultShape.getShape(),
allToAllResultShape.getElementType()),
- mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard,
+ grid.getSymName(), SmallVector<GridAxis>({gridAxis}), sourceShard,
APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
ShapedType targetShape =
- shardShapedType(sourceUnshardedShape, mesh, targetSharding);
+ shardShapedType(sourceUnshardedShape, grid, targetSharding);
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
tensor::CastOp::create(builder, targetShape, allToAllResult).getResult());
return {targetShard, targetSharding};
}
-static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
-tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
- MeshSharding sourceSharding,
- MeshSharding targetSharding,
+static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
+tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
+ Sharding sourceSharding,
+ Sharding targetSharding,
ShapedType sourceUnshardedShape,
TypedValue<ShapedType> sourceShard) {
if (auto detectRes =
detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
- auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
+ auto [sourceTensorAxis, targetTensorAxis, gridAxis] = detectRes.value();
return moveLastSplitAxisInResharding(
- builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
- sourceTensorAxis, targetTensorAxis, meshAxis);
+ builder, grid, sourceSharding, sourceUnshardedShape, sourceShard,
+ sourceTensorAxis, targetTensorAxis, gridAxis);
}
return std::nullopt;
@@ -371,10 +362,9 @@ tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
// needed. A changed halo sizes requires copying the "core" of the source tensor
// into the "core" of the destination tensor followed by an update halo
// operation.
-static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
-tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
- MeshSharding sourceSharding,
- MeshSharding targetSharding,
+static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
+tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
+ Sharding sourceSharding, Sharding targetSharding,
ShapedType sourceUnshardedShape,
TypedValue<ShapedType> sourceShard) {
// Currently handles only cases where halo sizes differ but everything else
@@ -392,7 +382,7 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
assert(((srcHaloSizes.empty() || ShapedType::isStaticShape(srcHaloSizes)) &&
ShapedType::isStaticShape(tgtHaloSizes) &&
sourceShard.getType().hasStaticShape()) &&
- "dynamic shapes/halos are not supported yet for mesh-spmdization");
+ "dynamic shapes/halos are not supported yet for shard-partition");
auto rank = sourceShard.getType().getRank();
auto splitAxes = sourceSharding.getSplitAxes();
SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
@@ -428,56 +418,55 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
// Finally update the halo.
auto updateHaloResult =
- builder
- .create<UpdateHaloOp>(
- sourceShard.getLoc(),
- RankedTensorType::get(outShape,
- sourceShard.getType().getElementType()),
- initOprnd, mesh.getSymName(),
- MeshAxesArrayAttr::get(builder.getContext(),
- sourceSharding.getSplitAxes()),
- targetSharding.getDynamicHaloSizes(),
- targetSharding.getStaticHaloSizes())
+ UpdateHaloOp::create(
+ builder, sourceShard.getLoc(),
+ RankedTensorType::get(outShape,
+ sourceShard.getType().getElementType()),
+ initOprnd, grid.getSymName(),
+ GridAxesArrayAttr::get(builder.getContext(),
+ sourceSharding.getSplitAxes()),
+ targetSharding.getDynamicHaloSizes(),
+ targetSharding.getStaticHaloSizes())
.getResult();
return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult),
targetSharding);
}
-// Handles only resharding on a 1D mesh.
+// Handles only resharding on a 1D shard.
// Currently the sharded tensor axes must be exactly divisible by the single
-// mesh axis size.
+// grid axis size.
static TypedValue<ShapedType>
-reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
- MeshSharding sourceSharding, MeshSharding targetSharding,
+reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid,
+ Sharding sourceSharding, Sharding targetSharding,
TypedValue<ShapedType> sourceUnshardedValue,
TypedValue<ShapedType> sourceShard) {
assert(sourceShard.getType() ==
- shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
+ shardShapedType(sourceUnshardedValue.getType(), grid, sourceSharding));
[[maybe_unused]] ShapedType targetShardType =
- shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding);
+ shardShapedType(sourceUnshardedValue.getType(), grid, targetSharding);
assert(sourceShard.getType().getRank() == targetShardType.getRank());
- assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported.");
+ assert(grid.getRank() == 1 && "Only 1D grides are currently supported.");
if (sourceSharding == targetSharding) {
return sourceShard;
}
TypedValue<ShapedType> targetShard;
- MeshSharding actualTargetSharding;
+ Sharding actualTargetSharding;
if (sourceSharding.getStaticShardedDimsOffsets().empty() &&
targetSharding.getStaticShardedDimsOffsets().empty() &&
sourceSharding.getStaticHaloSizes().empty() &&
targetSharding.getStaticHaloSizes().empty()) {
if (auto tryRes = tryMoveLastSplitAxisInResharding(
- builder, mesh, sourceSharding, targetSharding,
+ builder, grid, sourceSharding, targetSharding,
sourceUnshardedValue.getType(), sourceShard)) {
std::tie(targetShard, actualTargetSharding) = tryRes.value();
} else if (auto tryRes =
- trySplitLastAxisInResharding(builder, mesh, sourceSharding,
+ trySplitLastAxisInResharding(builder, grid, sourceSharding,
targetSharding, sourceShard)) {
std::tie(targetShard, actualTargetSharding) = tryRes.value();
} else if (auto tryRes = tryUnsplitLastAxisInResharding(
- builder, mesh, sourceSharding, targetSharding,
+ builder, grid, sourceSharding, targetSharding,
sourceUnshardedValue.getType(), sourceShard)) {
std::tie(targetShard, actualTargetSharding) = tryRes.value();
}
@@ -488,9 +477,8 @@ reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
return targetShard;
}
-TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
- MeshSharding sourceSharding,
- MeshSharding targetSharding,
+TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, GridOp grid,
+ Sharding sourceSharding, Sharding targetSharding,
TypedValue<ShapedType> sourceUnshardedValue,
TypedValue<ShapedType> sourceShard) {
// If source and destination sharding are the same, no need to do anything.
@@ -500,28 +488,28 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
}
// Tries to handle the case where the resharding is needed because the halo
- // sizes are different. Supports arbitrary mesh dimensionality.
+ // sizes are different. Supports arbitrary grid dimensionality.
if (auto tryRes = tryUpdateHaloInResharding(
- builder, mesh, sourceSharding, targetSharding,
+ builder, grid, sourceSharding, targetSharding,
sourceUnshardedValue.getType(), sourceShard)) {
return std::get<0>(tryRes.value()); // targetShard
}
- // Resort to handling only 1D meshes since the general case is complicated if
+ // Resort to handling only 1D grids since the general case is complicated if
// it needs to be communication efficient in terms of minimizing the data
// transfered between devices.
- return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
+ return reshardOn1DGrid(builder, grid, sourceSharding, targetSharding,
sourceUnshardedValue, sourceShard);
}
-TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
+TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
ShardOp target,
TypedValue<ShapedType> sourceShardValue) {
assert(source.getResult() == target.getSrc());
auto sourceSharding = source.getSharding();
auto targetSharding = target.getSharding();
ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
- return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
+ return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding,
cast<TypedValue<ShapedType>>(source.getSrc()),
sourceShardValue);
}
@@ -530,21 +518,21 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
ShardOp target,
TypedValue<ShapedType> sourceShardValue,
SymbolTableCollection &symbolTableCollection) {
- MeshOp srcMesh = getMesh(source, symbolTableCollection);
- assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection));
- return reshard(builder, srcMesh, source, target, sourceShardValue);
+ GridOp srcGrid = getGrid(source, symbolTableCollection);
+ assert(srcGrid && srcGrid == getGrid(target, symbolTableCollection));
+ return reshard(builder, srcGrid, source, target, sourceShardValue);
}
void reshardingRegisterDependentDialects(DialectRegistry &registry) {
- registry.insert<mesh::MeshDialect, tensor::TensorDialect>();
+ registry.insert<shard::ShardDialect, tensor::TensorDialect>();
}
-#define GEN_PASS_DEF_SPMDIZATION
-#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+#define GEN_PASS_DEF_PARTITION
+#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
using UnshardedToShardedValueMap = DenseMap<Value, Value>;
-// Get the types of block arguments for an spmdized block.
+// Get the types of block arguments for an partitioned block.
// Reads the sharding annotations of the arguments to deduce the sharded types.
// Types that are not ranked tensors are left unchanged.
SmallVector<Type>
@@ -563,35 +551,36 @@ shardedBlockArgumentTypes(Block &block,
Operation *useOp = *rankedTensorArg.getUsers().begin();
ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
assert(shardOp);
- MeshOp mesh = getMesh(shardOp, symbolTableCollection);
- return cast<Type>(shardShapedType(rankedTensorArg.getType(), mesh,
+ GridOp grid = getGrid(shardOp, symbolTableCollection);
+ return cast<Type>(shardShapedType(rankedTensorArg.getType(), grid,
shardOp.getSharding()));
});
return res;
}
-static LogicalResult spmdizeOperation(
- Operation &op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTableCollection, OpBuilder &builder) {
+static LogicalResult
+partitionOperation(Operation &op, ArrayRef<Value> partitionedOperands,
+ ArrayRef<Sharding> operandShardings,
+ ArrayRef<Sharding> resultShardings, IRMapping &partitionMap,
+ SymbolTableCollection &symbolTableCollection,
+ OpBuilder &builder) {
ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
if (!shardingInterface) {
// If there is no sharding interface we are conservative and assume that
// the op should be fully replicated no all devices.
- spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings,
- resultShardings, spmdizationMap,
- symbolTableCollection, builder);
+ partitionFullyReplicatedOperation(op, partitionedOperands, operandShardings,
+ resultShardings, partitionMap,
+ symbolTableCollection, builder);
} else {
- if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings,
- resultShardings, spmdizationMap,
- symbolTableCollection, builder))) {
+ if (failed(shardingInterface.partition(
+ partitionedOperands, operandShardings, resultShardings,
+ partitionMap, symbolTableCollection, builder))) {
return failure();
}
}
- assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) {
- return spmdizationMap.contains(result);
+ assert(llvm::all_of(op.getResults(), [&partitionMap](OpResult result) {
+ return partitionMap.contains(result);
}));
return success();
@@ -599,88 +588,87 @@ static LogicalResult spmdizeOperation(
// Retrieve the sharding annotations for the operands of the given operation.
// If the type is not a ranked tensor it is not require to have an annotation.
-static std::vector<MeshSharding> getOperandShardings(Operation &op) {
- std::vector<MeshSharding> res;
+static std::vector<Sharding> getOperandShardings(Operation &op) {
+ std::vector<Sharding> res;
res.reserve(op.getNumOperands());
llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
TypedValue<RankedTensorType> rankedTensor =
dyn_cast<TypedValue<RankedTensorType>>(operand);
if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
- return MeshSharding();
+ return Sharding();
}
Operation *definingOp = operand.getDefiningOp();
assert(definingOp);
ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
- return MeshSharding(shardOp.getSharding());
+ return Sharding(shardOp.getSharding());
});
return res;
}
// Retrieve the sharding annotations for the results of the given operation.
// If the type is not a ranked tensor it is not require to have an annotation.
-static std::vector<MeshSharding> getResultShardings(Operation &op) {
- std::vector<MeshSharding> res;
+static std::vector<Sharding> getResultShardings(Operation &op) {
+ std::vector<Sharding> res;
res.reserve(op.getNumResults());
llvm::transform(
op.getResults(), std::back_inserter(res), [&op](OpResult result) {
if (!result.hasOneUse() || result.use_empty()) {
- return MeshSharding();
+ return Sharding();
}
TypedValue<RankedTensorType> rankedTensor =
dyn_cast<TypedValue<RankedTensorType>>(result);
if (!rankedTensor) {
- return MeshSharding();
+ return Sharding();
}
Operation *userOp = *result.getUsers().begin();
ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
if (shardOp) {
- return MeshSharding(shardOp.getSharding());
+ return Sharding(shardOp.getSharding());
}
if (rankedTensor.getType().getRank() == 0) {
// This is a 0d tensor result without explicit sharding.
- // Find mesh symbol from operands, if any.
- // Shardings without mesh are not always fully supported yet.
+ // Find grid symbol from operands, if any.
+ // Shardings without grid are not always fully supported yet.
for (auto operand : op.getOperands()) {
if (auto sharding = operand.getDefiningOp<ShardingOp>()) {
- return MeshSharding(sharding.getMeshAttr());
+ return Sharding(sharding.getGridAttr());
}
}
}
- return MeshSharding();
+ return Sharding();
});
return res;
}
static LogicalResult
-spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTableCollection,
- OpBuilder &builder) {
- Value targetSpmdValue;
+partitionOperation(ShardOp shardOp, IRMapping &partitionMap,
+ SymbolTableCollection &symbolTableCollection,
+ OpBuilder &builder) {
+ Value targetPartitionValue;
// Check if 2 shard ops are chained. If not there is no need for resharding
// as the source and target shared the same sharding.
- ShardOp srcShardOp =
- dyn_cast_or_null<ShardOp>(shardOp.getSrc().getDefiningOp());
+ ShardOp srcShardOp = shardOp.getSrc().getDefiningOp<ShardOp>();
if (!srcShardOp) {
- targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc());
+ targetPartitionValue = partitionMap.lookup(shardOp.getSrc());
} else {
// Insert resharding.
- TypedValue<ShapedType> srcSpmdValue =
- cast<TypedValue<ShapedType>>(spmdizationMap.lookup(srcShardOp));
- targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
- symbolTableCollection);
+ TypedValue<ShapedType> srcPartitionValue =
+ cast<TypedValue<ShapedType>>(partitionMap.lookup(srcShardOp));
+ targetPartitionValue = reshard(builder, srcShardOp, shardOp,
+ srcPartitionValue, symbolTableCollection);
}
- assert(!spmdizationMap.contains(shardOp.getResult()));
- spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
+ assert(!partitionMap.contains(shardOp.getResult()));
+ partitionMap.map(shardOp.getResult(), targetPartitionValue);
return success();
}
static LogicalResult
-spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTableCollection,
- OpBuilder &builder) {
+partitionOperation(Operation &op, IRMapping &partitionMap,
+ SymbolTableCollection &symbolTableCollection,
+ OpBuilder &builder) {
if (isa<ShardingOp>(op)) {
return success();
}
@@ -690,30 +678,31 @@ spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
return op.emitError("expected a shard op as source of get_sharding");
}
auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp());
- spmdizationMap.map(op.getResult(0), newSharding->getResult(0));
+ partitionMap.map(op.getResult(0), newSharding->getResult(0));
return success();
}
ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
if (shardOp) {
- return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection,
- builder);
+ return partitionOperation(shardOp, partitionMap, symbolTableCollection,
+ builder);
}
- SmallVector<Value> spmdizedOperands;
- llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands),
- [&spmdizationMap](Value operand) {
- assert(spmdizationMap.contains(operand));
- return spmdizationMap.lookup(operand);
+ SmallVector<Value> partitionedOperands;
+ llvm::transform(op.getOperands(), std::back_inserter(partitionedOperands),
+ [&partitionMap](Value operand) {
+ assert(partitionMap.contains(operand));
+ return partitionMap.lookup(operand);
});
- return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op),
- getResultShardings(op), spmdizationMap,
- symbolTableCollection, builder);
+ return partitionOperation(op, partitionedOperands, getOperandShardings(op),
+ getResultShardings(op), partitionMap,
+ symbolTableCollection, builder);
}
-static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTableCollection,
- OpBuilder &builder) {
+static LogicalResult
+partitionBlock(Block &block, IRMapping &partitionMap,
+ SymbolTableCollection &symbolTableCollection,
+ OpBuilder &builder) {
SmallVector<Location> argLocations;
llvm::transform(block.getArguments(), std::back_inserter(argLocations),
@@ -721,16 +710,16 @@ static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
Block *newBlock = builder.createBlock(
block.getParent(), {},
shardedBlockArgumentTypes(block, symbolTableCollection), argLocations);
- for (auto [unshardedBlockArg, spmdizedBlockArg] :
+ for (auto [unshardedBlockArg, partitionedBlockArg] :
llvm::zip(block.getArguments(), newBlock->getArguments())) {
- spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg);
+ partitionMap.map(unshardedBlockArg, partitionedBlockArg);
}
OpBuilder::InsertionGuard insertionGuard(builder);
builder.setInsertionPointToEnd(newBlock);
for (Operation &op : block.getOperations()) {
- if (failed(spmdizeOperation(op, spmdizationMap, symbolTableCollection,
- builder))) {
+ if (failed(partitionOperation(op, partitionMap, symbolTableCollection,
+ builder))) {
return failure();
}
}
@@ -739,8 +728,8 @@ static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
}
static LogicalResult
-spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTableCollection) {
+partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap,
+ SymbolTableCollection &symbolTableCollection) {
OpBuilder builder(op.getFunctionBody());
// Snapshot the original blocks to not mess up the iteration when adding new
@@ -754,8 +743,8 @@ spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
}
for (Block *block : originalBlocks) {
- if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
- builder))) {
+ if (failed(partitionBlock(*block, partitionMap, symbolTableCollection,
+ builder))) {
return failure();
}
}
@@ -788,22 +777,22 @@ spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
namespace {
-struct Spmdization : public impl::SpmdizationBase<Spmdization> {
+struct Partition : public impl::PartitionBase<Partition> {
void runOnOperation() override {
- IRMapping spmdizationMap;
+ IRMapping partitionMap;
SymbolTableCollection symbolTableCollection;
- if (failed(spmdizeFuncOp(getOperation(), spmdizationMap,
- symbolTableCollection))) {
+ if (failed(partitionFuncOp(getOperation(), partitionMap,
+ symbolTableCollection))) {
return signalPassFailure();
}
}
void getDependentDialects(DialectRegistry &registry) const override {
reshardingRegisterDependentDialects(registry);
- registry.insert<mesh::MeshDialect>();
+ registry.insert<shard::ShardDialect>();
}
};
} // namespace
-} // namespace mlir::mesh
+} // namespace mlir::shard
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
index 09c754d..a647128c 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
@@ -6,11 +6,11 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Mesh/Transforms/Passes.h"
+#include "mlir/Dialect/Shard/Transforms/Passes.h"
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "llvm/ADT/STLExtras.h"
@@ -21,17 +21,17 @@
#include <vector>
namespace mlir {
-namespace mesh {
+namespace shard {
#define GEN_PASS_DEF_SHARDINGPROPAGATION
-#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
-} // namespace mesh
+#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
+} // namespace shard
} // namespace mlir
#define DEBUG_TYPE "sharding-propagation"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
using namespace mlir;
-using namespace mlir::mesh;
+using namespace mlir::shard;
enum class ReshardingRquirementKind {
NO_RESHARDING = 0,
@@ -68,7 +68,7 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
[[maybe_unused]] static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
const ShardingOption &v) {
- return stream << "{empty = " << v.empty << ", mesh" << v.mesh
+ return stream << "{empty = " << v.empty << ", grid" << v.grid
<< ", shardingArray = " << v.shardingArray << "}";
}
@@ -105,15 +105,15 @@ operator<<(llvm::raw_ostream &stream, ReshardingRquirementKind v) {
// specific shardings. For example, mustShardings = [shard0, None] and
// optionalShardings = [None, shard1], the result will be [[shard0, shard1],
// [shard0, None]]
-static SmallVector<std::vector<MeshSharding>>
-getOrderedPossibleShardingAttrs(ArrayRef<MeshSharding> mustShardings,
- ArrayRef<MeshSharding> optionalShardings) {
- SmallVector<std::vector<MeshSharding>> allShardingAttrs;
- std::vector<MeshSharding> curShardingAttrs;
+static SmallVector<std::vector<Sharding>>
+getOrderedPossibleShardingAttrs(ArrayRef<Sharding> mustShardings,
+ ArrayRef<Sharding> optionalShardings) {
+ SmallVector<std::vector<Sharding>> allShardingAttrs;
+ std::vector<Sharding> curShardingAttrs;
std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) {
if (i == mustShardings.size()) {
- allShardingAttrs.push_back(std::vector<MeshSharding>(curShardingAttrs));
+ allShardingAttrs.push_back(std::vector<Sharding>(curShardingAttrs));
return;
}
@@ -147,14 +147,14 @@ getOrderedPossibleShardingAttrs(ArrayRef<MeshSharding> mustShardings,
// 1. No resharding is required (all existing annotations are compatible).
// 2. No resharding for operands/results that have annotation specifically
// targeting this operation. This means
-// * operands that are the result of `mesh.shard` ops marked with
+// * operands that are the result of `shard.shard` ops marked with
// `annotate_for_users`.
-// * results that are annotated with `mesh.shard` ops without
+// * results that are annotated with `shard.shard` ops without
// `annotate_for_users`.
// 3. All other cases. Resharding is required for operands/results with
// annotation targeting explicitly this operation.
ReshardingRquirementKind getReshardingRquirementKind(
- Operation *op, const std::vector<MeshSharding> &operandAndResultShardings) {
+ Operation *op, const std::vector<Sharding> &operandAndResultShardings) {
ReshardingRquirementKind res = ReshardingRquirementKind::NO_RESHARDING;
size_t operandsCount = op->getOperands().size();
@@ -167,7 +167,7 @@ ReshardingRquirementKind getReshardingRquirementKind(
for (auto [operand, sharding] :
llvm::zip_equal(op->getOperands(), operandShardings)) {
- ShardOp shardOp = llvm::dyn_cast_or_null<ShardOp>(operand.getDefiningOp());
+ ShardOp shardOp = operand.getDefiningOp<ShardOp>();
if (!shardOp) {
continue;
}
@@ -213,14 +213,13 @@ ReshardingRquirementKind getReshardingRquirementKind(
// 3. Resharding of existing explicit sharding annotations for this op.
static FailureOr<ShardingOption> selectShardingOption(
ShardingInterface shardingOp,
- ArrayRef<std::vector<MeshSharding>> possibleOperandShardingAttrs,
- ArrayRef<std::vector<MeshSharding>> possibleResultShardingAttrs) {
+ ArrayRef<std::vector<Sharding>> possibleOperandShardingAttrs,
+ ArrayRef<std::vector<Sharding>> possibleResultShardingAttrs) {
SmallVector<std::tuple<ShardingOption, ReshardingRquirementKind>>
shardingOptionsAndReshardingRequirements;
- for (ArrayRef<MeshSharding> resultShardings : possibleResultShardingAttrs) {
- for (ArrayRef<MeshSharding> operandShardings :
- possibleOperandShardingAttrs) {
+ for (ArrayRef<Sharding> resultShardings : possibleResultShardingAttrs) {
+ for (ArrayRef<Sharding> operandShardings : possibleOperandShardingAttrs) {
FailureOr<ShardingOption> shardingOption =
shardingOp.getShardingOption(operandShardings, resultShardings);
if (failed(shardingOption) || shardingOption->empty) {
@@ -231,7 +230,7 @@ static FailureOr<ShardingOption> selectShardingOption(
// They may be missing some annotations.
// Whatever is returned by getShardingAnnotations is exactly what the op
// needs.
- FailureOr<std::vector<MeshSharding>> operandAndResultShardings =
+ FailureOr<std::vector<Sharding>> operandAndResultShardings =
shardingOp.getShardingAnnotations(*shardingOption);
if (failed(operandAndResultShardings)) {
return failure();
@@ -276,13 +275,13 @@ static FailureOr<ShardingOption> selectShardingOption(
// For each operation that implements the ShardingInterface, infer the sharding
// option of the operation from its operands and/or results using the
// `getShardingOption` method. If the inferred sharding option is not empty, add
-// a `mesh.shard` operation for all remaining operands and results that do not
+// a `shard.shard` operation for all remaining operands and results that do not
// have sharding annotations.
static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
if (op->hasTrait<OpTrait::IsTerminator>() ||
(op->hasTrait<OpTrait::ConstantLike>() && !shardingOp) ||
- llvm::isa<mesh::ShardOp, mesh::ShardingOp, mesh::GetShardingOp>(op))
+ llvm::isa<shard::ShardOp, shard::ShardingOp, shard::GetShardingOp>(op))
return success();
if (!shardingOp) {
@@ -290,14 +289,13 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
return failure();
}
- // collect MeshSharding from results
- std::vector<MeshSharding> allowConflictsResultShardings;
+ // collect Sharding from results
+ std::vector<Sharding> allowConflictsResultShardings;
allowConflictsResultShardings.resize(op->getNumResults());
- std::vector<MeshSharding> resultMustShardings;
+ std::vector<Sharding> resultMustShardings;
resultMustShardings.resize(op->getNumResults());
for (OpResult result : op->getResults()) {
- FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
- getMeshSharding(result);
+ FailureOr<std::pair<bool, Sharding>> maybeShardAttr = getSharding(result);
if (failed(maybeShardAttr))
continue;
if (!maybeShardAttr->first)
@@ -307,14 +305,14 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
maybeShardAttr->second;
}
- // collect MeshSharding from operands
- std::vector<MeshSharding> allowConflictsOperandShardings;
+ // collect Sharding from operands
+ std::vector<Sharding> allowConflictsOperandShardings;
allowConflictsOperandShardings.resize(op->getNumOperands());
- std::vector<MeshSharding> operandMustShardings;
+ std::vector<Sharding> operandMustShardings;
operandMustShardings.resize(op->getNumOperands());
for (OpOperand &opOperand : op->getOpOperands()) {
- FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
- getMeshSharding(opOperand);
+ FailureOr<std::pair<bool, Sharding>> maybeShardAttr =
+ getSharding(opOperand);
if (failed(maybeShardAttr))
continue;
@@ -327,10 +325,10 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
}
// try to get the sharding option
- SmallVector<std::vector<MeshSharding>> possibleOperandShardingAttrs =
+ SmallVector<std::vector<Sharding>> possibleOperandShardingAttrs =
getOrderedPossibleShardingAttrs(operandMustShardings,
allowConflictsOperandShardings);
- SmallVector<std::vector<MeshSharding>> possibleResultShardingAttrs =
+ SmallVector<std::vector<Sharding>> possibleResultShardingAttrs =
getOrderedPossibleShardingAttrs(resultMustShardings,
allowConflictsResultShardings);
FailureOr<ShardingOption> shardingOption = selectShardingOption(
@@ -358,7 +356,7 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
// ShardingPropagation
//===----------------------------------------------------------------------===//
struct ShardingPropagation
- : public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
+ : public shard::impl::ShardingPropagationBase<ShardingPropagation> {
using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase;
@@ -376,8 +374,7 @@ struct ShardingPropagation
LLVM_DEBUG(
DBGS() << "print all the ops' iterator types and indexing maps in the "
"block.\n";
- for (Operation &op
- : block.getOperations()) {
+ for (Operation &op : block.getOperations()) {
if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
});
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Shard/Transforms/Simplifications.cpp
index 1315502..a17671e 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Simplifications.cpp
@@ -1,4 +1,4 @@
-//===- Simplifications.cpp - Mesh Simplifications ---------------*- C++ -*-===//
+//===- Simplifications.cpp - Shard Simplifications -_------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,10 +6,10 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
#include "TransformsDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
@@ -18,7 +18,7 @@
#include <numeric>
namespace mlir {
-namespace mesh {
+namespace shard {
void populateSimplificationPatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
@@ -52,53 +52,53 @@ namespace {
// DialectFoldInterface, because it needs a SymbolTableCollection to cache the
// symbol tables.
// We can't use DialectFoldInterface since the cache may be invalidated by some
-// pass changing the referenced MeshOp ops.
-struct MeshShapeFolder
- : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> {
+// pass changing the referenced GridOp ops.
+struct GridShapeFolder
+ : OpRewritePatternWithSymbolTableCollection<GridShapeOp> {
using OpRewritePatternWithSymbolTableCollection::
OpRewritePatternWithSymbolTableCollection;
- LogicalResult matchAndRewrite(MeshShapeOp op,
+ LogicalResult matchAndRewrite(GridShapeOp op,
PatternRewriter &rewriter) const override {
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
- MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
- op.getOperation(), op.getMeshAttr());
- if (!mesh) {
+ GridOp grid = symbolTableCollection.lookupNearestSymbolFrom<shard::GridOp>(
+ op.getOperation(), op.getGridAttr());
+ if (!grid) {
return failure();
}
- ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
- SmallVector<MeshAxis> opAxesIota;
- if (opMeshAxes.empty()) {
- opAxesIota.resize(mesh.getRank());
+ ArrayRef<GridAxis> opGridAxes = op.getAxes();
+ SmallVector<GridAxis> opAxesIota;
+ if (opGridAxes.empty()) {
+ opAxesIota.resize(grid.getRank());
std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
- opMeshAxes = opAxesIota;
+ opGridAxes = opAxesIota;
}
- if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
- return ShapedType::isDynamic(mesh.getShape()[axis]);
+ if (llvm::all_of(opGridAxes, [&grid](GridAxis axis) {
+ return ShapedType::isDynamic(grid.getShape()[axis]);
})) {
- // All mesh dimensions are dynamic. Nothing to fold.
+ // All grid dimensions are dynamic. Nothing to fold.
return failure();
}
SmallVector<Value> newResults(op->getResults().size());
- SmallVector<MeshAxis> newShapeOpMeshAxes;
+ SmallVector<GridAxis> newShapeOpGridAxes;
SmallVector<size_t> newToOldResultsIndexMap;
- for (size_t i = 0; i < opMeshAxes.size(); ++i) {
- auto meshAxisSize = mesh.getShape()[opMeshAxes[i]];
- if (ShapedType::isDynamic(meshAxisSize)) {
+ for (size_t i = 0; i < opGridAxes.size(); ++i) {
+ auto gridAxisSize = grid.getShape()[opGridAxes[i]];
+ if (ShapedType::isDynamic(gridAxisSize)) {
newToOldResultsIndexMap.push_back(i);
- newShapeOpMeshAxes.push_back(opMeshAxes[i]);
+ newShapeOpGridAxes.push_back(opGridAxes[i]);
} else {
- // Fold static mesh axes.
+ // Fold static grid axes.
newResults[i] = arith::ConstantOp::create(
- builder, builder.getIndexAttr(meshAxisSize));
+ builder, builder.getIndexAttr(gridAxisSize));
}
}
- // Leave only the dynamic mesh axes to be queried.
- if (!newShapeOpMeshAxes.empty()) {
- MeshShapeOp newShapeOp =
- MeshShapeOp::create(builder, mesh.getSymName(), newShapeOpMeshAxes);
+ // Leave only the dynamic grid axes to be queried.
+ if (!newShapeOpGridAxes.empty()) {
+ GridShapeOp newShapeOp =
+ GridShapeOp::create(builder, grid.getSymName(), newShapeOpGridAxes);
for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
}
@@ -113,8 +113,8 @@ struct MeshShapeFolder
void populateFoldingPatterns(RewritePatternSet &patterns,
SymbolTableCollection &symbolTableCollection) {
- patterns.add<MeshShapeFolder>(symbolTableCollection, patterns.getContext());
+ patterns.add<GridShapeFolder>(symbolTableCollection, patterns.getContext());
}
-} // namespace mesh
+} // namespace shard
} // namespace mlir
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
index 1bde1af..772e66f 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/Shard/Transforms/Transforms.h"
#include "TransformsDetail.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Utils.h"
@@ -14,8 +14,8 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -28,12 +28,12 @@
#include <iterator>
#include <numeric>
-namespace mlir::mesh {
+namespace mlir::shard {
namespace {
-/// Lower `mesh.process_multi_index` into expression using
-/// `mesh.process_linear_index` and `mesh.mesh_shape`.
+/// Lower `shard.process_multi_index` into expression using
+/// `shard.process_linear_index` and `shard.grid_shape`.
struct ProcessMultiIndexOpLowering
: OpRewritePatternWithSymbolTableCollection<ProcessMultiIndexOp> {
using OpRewritePatternWithSymbolTableCollection::
@@ -41,30 +41,30 @@ struct ProcessMultiIndexOpLowering
LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
PatternRewriter &rewriter) const override {
- MeshOp mesh = getMesh(op, symbolTableCollection);
- if (!mesh) {
+ GridOp grid = getGrid(op, symbolTableCollection);
+ if (!grid) {
return failure();
}
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
builder.setInsertionPointAfter(op.getOperation());
- Value linearIndex = ProcessLinearIndexOp::create(builder, mesh);
- ValueRange meshShape = MeshShapeOp::create(builder, mesh).getResults();
+ Value linearIndex = ProcessLinearIndexOp::create(builder, grid);
+ ValueRange gridShape = GridShapeOp::create(builder, grid).getResults();
SmallVector<Value> completeMultiIndex =
affine::AffineDelinearizeIndexOp::create(builder, linearIndex,
- meshShape)
+ gridShape)
.getMultiIndex();
SmallVector<Value> multiIndex;
- ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
- SmallVector<MeshAxis> opAxesIota;
- if (opMeshAxes.empty()) {
- opAxesIota.resize(mesh.getRank());
+ ArrayRef<GridAxis> opGridAxes = op.getAxes();
+ SmallVector<GridAxis> opAxesIota;
+ if (opGridAxes.empty()) {
+ opAxesIota.resize(grid.getRank());
std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
- opMeshAxes = opAxesIota;
+ opGridAxes = opAxesIota;
}
- llvm::transform(opMeshAxes, std::back_inserter(multiIndex),
- [&completeMultiIndex](MeshAxis meshAxis) {
- return completeMultiIndex[meshAxis];
+ llvm::transform(opGridAxes, std::back_inserter(multiIndex),
+ [&completeMultiIndex](GridAxis gridAxis) {
+ return completeMultiIndex[gridAxis];
});
rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
return success();
@@ -86,15 +86,15 @@ struct AllSliceOpLowering
// axis.
// The slice axis is split into equisized parts with count
// the number of processes in the collective process group induced by
- // the mesh axes.
+ // the grid axes.
// The part for each process is determined by the corresponding
// linear-index in the process group.
//
// There are no collectives that require communication.
// Each process operates on its local tensor.
- MeshOp mesh = getMesh(op, symbolTableCollection);
- if (!mesh) {
+ GridOp grid = getGrid(op, symbolTableCollection);
+ if (!grid) {
return failure();
}
@@ -104,15 +104,15 @@ struct AllSliceOpLowering
Value zero = arith::ConstantOp::create(builder, builder.getIndexAttr(0));
Operation::result_range processInGroupMultiIndex =
- ProcessMultiIndexOp::create(builder, mesh.getSymName(),
- op.getMeshAxes())
+ ProcessMultiIndexOp::create(builder, grid.getSymName(),
+ op.getGridAxes())
.getResults();
Operation::result_range processGroupShape =
- MeshShapeOp::create(builder, mesh.getSymName(), op.getMeshAxes())
+ GridShapeOp::create(builder, grid.getSymName(), op.getGridAxes())
.getResult();
Value processGroupSize =
- createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder);
+ createCollectiveProcessGroupSize(grid, op.getGridAxes(), builder);
int64_t sliceAxis = op.getSliceAxis().getSExtValue();
Value operandSliceAxisSize =
@@ -125,7 +125,7 @@ struct AllSliceOpLowering
cf::AssertOp::create(builder, isTargetShapeExactlyDivisible,
"Slicing a tensor with axis size that is "
"not exactly divisible by the "
- "mesh process group size is not supported.");
+ "grid process group size is not supported.");
Value resultSliceAxisSize =
arith::DivUIOp::create(builder, operandSliceAxisSize, processGroupSize);
OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
@@ -172,7 +172,7 @@ void populateProcessMultiIndexOpLoweringPatterns(
}
void registerProcessMultiIndexOpLoweringDialects(DialectRegistry &registry) {
- registry.insert<affine::AffineDialect, mesh::MeshDialect>();
+ registry.insert<affine::AffineDialect, shard::ShardDialect>();
}
void populateAllSliceOpLoweringPatterns(
@@ -183,7 +183,7 @@ void populateAllSliceOpLoweringPatterns(
void registerAllSliceOpLoweringDialects(DialectRegistry &registry) {
registry.insert<affine::AffineDialect, arith::ArithDialect,
- cf::ControlFlowDialect, mesh::MeshDialect,
+ cf::ControlFlowDialect, shard::ShardDialect,
tensor::TensorDialect>();
}
@@ -199,21 +199,21 @@ void registerAllOpLoweringDialects(DialectRegistry &registry) {
}
TypedValue<IndexType>
-createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
+createCollectiveProcessGroupSize(GridOp grid, ArrayRef<GridAxis> axes,
ImplicitLocOpBuilder &builder) {
- Operation::result_range meshShape =
- mesh::MeshShapeOp::create(builder, mesh, axes).getResults();
+ Operation::result_range gridShape =
+ GridShapeOp::create(builder, grid, axes).getResults();
return cast<TypedValue<IndexType>>(arith::createProduct(
- builder, builder.getLoc(), llvm::to_vector_of<Value>(meshShape),
+ builder, builder.getLoc(), llvm::to_vector_of<Value>(gridShape),
builder.getIndexType()));
}
TypedValue<IndexType>
-createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex,
- ArrayRef<MeshAxis> meshAxes,
+createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
+ ArrayRef<GridAxis> gridAxes,
ImplicitLocOpBuilder &builder) {
Operation::result_range processGroupShape =
- MeshShapeOp::create(builder, mesh, meshAxes).getResult();
+ GridShapeOp::create(builder, grid, gridAxes).getResult();
OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
@@ -225,11 +225,11 @@ createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex,
return cast<TypedValue<IndexType>>(res);
}
-TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
- ArrayRef<MeshAxis> meshAxes,
+TypedValue<IndexType> createProcessLinearIndex(StringRef grid,
+ ArrayRef<GridAxis> gridAxes,
ImplicitLocOpBuilder &builder) {
return createProcessLinearIndex(
- mesh, ProcessMultiIndexOp::create(builder, mesh, meshAxes).getResults(),
- meshAxes, builder);
+ grid, ProcessMultiIndexOp::create(builder, grid, gridAxes).getResults(),
+ gridAxes, builder);
}
-} // namespace mlir::mesh
+} // namespace mlir::shard
diff --git a/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h b/mlir/lib/Dialect/Shard/Transforms/TransformsDetail.h
index 3e3f584..60c9828 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h
+++ b/mlir/lib/Dialect/Shard/Transforms/TransformsDetail.h
@@ -6,14 +6,14 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H
-#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H
+#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMSDETAIL_H
+#define MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMSDETAIL_H
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
namespace mlir {
-namespace mesh {
+namespace shard {
template <typename Op>
struct OpRewritePatternWithSymbolTableCollection : OpRewritePattern<Op> {
@@ -29,7 +29,7 @@ protected:
SymbolTableCollection &symbolTableCollection;
};
-} // namespace mesh
+} // namespace shard
} // namespace mlir
-#endif // MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H
+#endif // MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMSDETAIL_H
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 0262319..3b4140e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -931,10 +931,9 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm,
ny, args.drop_back(nTrailingP), createPartitionFunc);
- Value p = builder
- .create<func::CallOp>(loc, partitionFunc,
- TypeRange{IndexType::get(context)},
- args.drop_back(nTrailingP))
+ Value p = func::CallOp::create(builder, loc, partitionFunc,
+ TypeRange{IndexType::get(context)},
+ args.drop_back(nTrailingP))
.getResult(0);
Value lenLow = arith::SubIOp::create(builder, loc, p, lo);
@@ -1028,9 +1027,8 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix,
xPerm, ny, operands, createBinarySearchFunc);
- Value p = builder
- .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
- operands)
+ Value p = func::CallOp::create(builder, loc, searchFunc,
+ TypeRange{c1.getType()}, operands)
.getResult(0);
// Move the value at data[i] to a temporary location.
@@ -1317,7 +1315,7 @@ public:
Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1);
Value newSize = arith::AddIOp::create(rewriter, loc, size, n);
- auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
+ auto nValue = n.getDefiningOp<arith::ConstantIndexOp>();
bool nIsOne = (nValue && nValue.value() == 1);
if (!op.getInbounds()) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index a317abd..0bd1d34 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -98,10 +98,10 @@ static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc,
Value numT = constantIndex(builder, loc, numThreads);
gpu::KernelDim3 gridSize = {one, one, one};
gpu::KernelDim3 blckSize = {numT, one, one};
- return builder
- .create<gpu::LaunchFuncOp>(loc, gpuFunc, gridSize, blckSize,
- /*dynSharedMemSz*/ none, args,
- builder.getType<gpu::AsyncTokenType>(), tokens)
+ return gpu::LaunchFuncOp::create(builder, loc, gpuFunc, gridSize, blckSize,
+ /*dynSharedMemSz*/ none, args,
+ builder.getType<gpu::AsyncTokenType>(),
+ tokens)
.getAsyncToken();
}
@@ -1168,7 +1168,7 @@ struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
ForallRewriter(MLIRContext *context, unsigned nT)
- : OpRewritePattern(context), numThreads(nT){};
+ : OpRewritePattern(context), numThreads(nT) {};
LogicalResult matchAndRewrite(scf::ParallelOp forallOp,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index dfb1274..9cd4896 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -443,8 +443,8 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
ValueRange inputs, Location loc) -> Value {
- return builder
- .create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs)
+ return UnrealizedConversionCastOp::create(builder, loc, TypeRange(spTp),
+ inputs)
.getResult(0);
});
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 70795e2..7a26cd3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -412,13 +412,13 @@ static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
if (memTp.getRank() > 1)
return mem;
// Truncate linear memrefs to given size.
- return builder
- .create<memref::SubViewOp>(
- loc, MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()),
- mem, ValueRange{}, ValueRange{sz}, ValueRange{},
- ArrayRef<int64_t>{0}, // static offset
- ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
- ArrayRef<int64_t>{1}) // static stride
+ return memref::SubViewOp::create(
+ builder, loc,
+ MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()),
+ mem, ValueRange{}, ValueRange{sz}, ValueRange{},
+ ArrayRef<int64_t>{0}, // static offset
+ ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
+ ArrayRef<int64_t>{1}) // static stride
.getResult();
}
@@ -449,7 +449,7 @@ class SparseInsertGenerator
public:
SparseInsertGenerator(TensorType rtp, TypeRange retTypes, ValueRange params,
bool genCall)
- : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp){};
+ : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp) {};
/// Generates code along an insertion path without the need for a "cursor".
/// This current insertion strategy comes at the expense of some testing
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index b444ac5..79f4e7f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -904,9 +904,8 @@ public:
dstTp->withoutDimToLvl(),
!srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity());
SmallVector<Value> dynSizes;
- Value buffer = rewriter
- .create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(),
- nnz, Attribute())
+ Value buffer = AllocTensorOp::create(rewriter, loc, bufferTp, dynSizes,
+ Value(), nnz, Attribute())
.getResult();
// Convert src coordinates to dst coordinates by first collapsing it to 1D
@@ -1013,9 +1012,8 @@ public:
!srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
Value buffer =
- rewriter
- .create<AllocTensorOp>(loc, bufferTp, dstDynSizes, Value(),
- /*sizeHint=*/nnz, Attribute())
+ AllocTensorOp::create(rewriter, loc, bufferTp, dstDynSizes, Value(),
+ /*sizeHint=*/nnz, Attribute())
.getResult();
// Implement the sparse2sparse reshape as follows:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 0e96b59..869d27a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -115,8 +115,7 @@ public:
bufferization::BufferizationState bufferizationState;
- if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()),
- updatedOptions,
+ if (failed(bufferization::bufferizeModuleOp(getOperation(), updatedOptions,
bufferizationState)))
return failure();
diff --git a/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp
index 0421a6c..0784615 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h"
-#include "mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h"
+#include "mlir/Dialect/Tensor/Extensions/ShardingExtensions.h"
using namespace mlir;
diff --git a/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt
index dba5933..8f0b7da 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt
@@ -1,10 +1,10 @@
set(LLVM_OPTIONAL_SOURCES
AllExtensions.cpp
- MeshShardingExtensions.cpp
+ ShardingExtensions.cpp
)
-add_mlir_extension_library(MLIRTensorMeshShardingExtensions
- MeshShardingExtensions.cpp
+add_mlir_extension_library(MLIRTensorShardingExtensions
+ ShardingExtensions.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Extensions
@@ -22,5 +22,5 @@ add_mlir_extension_library(MLIRTensorAllExtensions
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Extensions
LINK_LIBS PUBLIC
- MLIRTensorMeshShardingExtensions
+ MLIRTensorShardingExtensions
) \ No newline at end of file
diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp
index 7e4a5ac..ca7287c 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp
@@ -6,15 +6,15 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/DialectRegistry.h"
using namespace mlir;
using namespace mlir::tensor;
-using namespace mlir::mesh;
+using namespace mlir::shard;
namespace {
@@ -40,20 +40,20 @@ struct CreatorOpShardingInterface
{AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
}
- LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings,
- IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTable,
- OpBuilder &builder) const {
+ LogicalResult partition(Operation *op, ArrayRef<Value> partitionedOperands,
+ ArrayRef<Sharding> operandShardings,
+ ArrayRef<Sharding> resultShardings,
+ IRMapping &partitionMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder) const {
assert(resultShardings.size() == 1);
auto resType = cast<RankedTensorType>(op->getResult(0).getType());
- mlir::mesh::MeshOp mesh;
+ mlir::shard::GridOp grid;
ShapedType shardType;
if (resType.getRank() > 0) {
- mesh = mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable);
+ grid = shard::getGrid(op, resultShardings[0].getGridAttr(), symbolTable);
shardType =
- cast<ShapedType>(mesh::shardType(resType, mesh, resultShardings[0]));
+ cast<ShapedType>(shard::shardType(resType, grid, resultShardings[0]));
} else {
shardType = resType;
}
@@ -67,7 +67,7 @@ struct CreatorOpShardingInterface
auto oldType = cast<ShapedType>(resType);
assert(oldType.getRank() == shardType.getRank());
int currOldOprndNum = -1;
- mesh::ShardShapeOp shapeForDevice;
+ shard::ShardShapeOp shapeForDevice;
ValueRange device;
Operation *newSharding = nullptr;
for (auto i = 0; i < oldType.getRank(); ++i) {
@@ -76,23 +76,23 @@ struct CreatorOpShardingInterface
newSharding =
ShardingOp::create(builder, op->getLoc(), resultShardings[0]);
device =
- mesh::ProcessMultiIndexOp::create(builder, op->getLoc(), mesh)
+ shard::ProcessMultiIndexOp::create(builder, op->getLoc(), grid)
.getResults();
- shapeForDevice = mesh::ShardShapeOp::create(
- builder, op->getLoc(), oldType.getShape(), spmdizedOperands,
+ shapeForDevice = shard::ShardShapeOp::create(
+ builder, op->getLoc(), oldType.getShape(), partitionedOperands,
newSharding->getResult(0), device);
}
newOperands.emplace_back(shapeForDevice.getResult()[i]);
} else if (oldType.isDynamicDim(i)) {
assert(shardType.isDynamicDim(i));
- newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]);
+ newOperands.emplace_back(partitionedOperands[++currOldOprndNum]);
}
}
newOp = OpTy::create(builder, op->getLoc(), shardType, newOperands);
- spmdizationMap.map(op->getResult(0), newOp->getResult(0));
+ partitionMap.map(op->getResult(0), newOp->getResult(0));
} else {
// `clone` will populate the mapping of old to new results.
- newOp = builder.clone(*op, spmdizationMap);
+ newOp = builder.clone(*op, partitionMap);
}
newOp->getResult(0).setType(shardType);
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index bc11e56..c3356c1 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -784,8 +784,8 @@ struct PadOpInterface
auto toValue = [&](OpFoldResult ofr) {
if (auto value = dyn_cast<Value>(ofr))
return value;
- return rewriter
- .create<arith::ConstantIndexOp>(loc, *getConstantIntValue(ofr))
+ return arith::ConstantIndexOp::create(rewriter, loc,
+ *getConstantIntValue(ofr))
.getResult();
};
@@ -919,9 +919,8 @@ struct ReshapeOpInterface
auto memrefType = MemRefType::get(
srcType.getShape(), srcType.getElementType(), AffineMap(),
cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace());
- srcBuffer = rewriter
- .create<bufferization::ToBufferOp>(
- op->getLoc(), memrefType, *tensorAlloc)
+ srcBuffer = bufferization::ToBufferOp::create(rewriter, op->getLoc(),
+ memrefType, *tensorAlloc)
.getResult();
}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
index 43d9d70..9fd27d3 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
@@ -130,8 +130,7 @@ FailureOr<Value> tensor::buildIndependentOp(OpBuilder &b,
// Create a tensor::ExtractSliceOp.
SmallVector<OpFoldResult> offsets(newSizes.size(), b.getIndexAttr(0));
SmallVector<OpFoldResult> strides(newSizes.size(), b.getIndexAttr(1));
- return b
- .create<ExtractSliceOp>(loc, newEmptyOp, offsets, emptyOp.getMixedSizes(),
- strides)
+ return ExtractSliceOp::create(b, loc, newEmptyOp, offsets,
+ emptyOp.getMixedSizes(), strides)
.getResult();
}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index e0af2f7..2ec23e1 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -385,10 +385,9 @@ struct BubbleUpExpandShapeThroughExtractSlice
return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
});
OpFoldResult collapsedOffset =
- rewriter
- .create<affine::AffineLinearizeIndexOp>(loc, offsetVals,
- reassocGroupSizes,
- /*disjoint=*/true)
+ affine::AffineLinearizeIndexOp::create(rewriter, loc, offsetVals,
+ reassocGroupSizes,
+ /*disjoint=*/true)
.getResult();
collapsedOffsets.push_back(collapsedOffset);
collapsedSizes.push_back(collapsedSize);
diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt
index b1fac8c..c6a438d 100644
--- a/mlir/lib/Dialect/Tosa/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt
@@ -36,7 +36,7 @@ add_mlir_dialect_library(MLIRTosaShardingInterfaceImpl
LINK_LIBS PUBLIC
MLIRIR
- MLIRMeshDialect
+ MLIRShardDialect
MLIRShardingInterface
MLIRSupport
MLIRTosaDialect
diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
index d3a5f44..45994a7 100644
--- a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
@@ -7,9 +7,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/DialectRegistry.h"
@@ -19,7 +19,7 @@
using namespace mlir;
using namespace mlir::tosa;
-using namespace mlir::mesh;
+using namespace mlir::shard;
namespace {
@@ -87,15 +87,15 @@ struct NegateOpSharding
return maps;
}
- LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings,
- IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTable,
- OpBuilder &builder) const {
- spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
- resultShardings, spmdizationMap,
- symbolTable, builder);
+ LogicalResult partition(Operation *op, ArrayRef<Value> partitiondOperands,
+ ArrayRef<Sharding> operandShardings,
+ ArrayRef<Sharding> resultShardings,
+ IRMapping &partitionMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder) const {
+ partitionTriviallyShardableOperation(*op, partitiondOperands,
+ operandShardings, resultShardings,
+ partitionMap, symbolTable, builder);
return success();
}
};
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 606626d..6d2cbb5 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -554,7 +554,7 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
Value input = op.getInput();
// Check the input to the CLAMP op is itself a CLAMP.
- auto clampOp = dyn_cast_if_present<tosa::ClampOp>(input.getDefiningOp());
+ auto clampOp = input.getDefiningOp<tosa::ClampOp>();
if (!clampOp)
return failure();
@@ -707,9 +707,8 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
auto size_op =
getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
replaceWithSlice =
- rewriter
- .create<tosa::SliceOp>(sliceOp.getLoc(), sliceOp.getType(),
- input, start_op, size_op)
+ tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
+ input, start_op, size_op)
.getResult();
break;
}
@@ -1302,9 +1301,11 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
auto intVal = operand.getSplatValue<APInt>();
auto bitwidth = outETy.getIntOrFloatBitWidth();
- if (trunc) {
+ // i1 types are boolean in TOSA
+ if (outETy.isInteger(1)) {
+ intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
+ } else if (trunc) {
intVal = intVal.trunc(bitwidth);
- // i1 types are boolean in TOSA
} else if (unsignIn || inIntType.isInteger(1)) {
intVal = intVal.zext(bitwidth);
} else {
@@ -1634,7 +1635,7 @@ OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
for (Value operand : getOperands()) {
concatOperands.emplace_back(operand);
- auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
+ auto producer = operand.getDefiningOp<ConcatOp>();
if (!producer)
continue;
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 648e508a9..3cafb19 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -13,8 +13,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
@@ -166,7 +166,7 @@ void TosaDialect::initialize() {
>();
addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
declarePromisedInterfaces<
- mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
+ shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
@@ -3647,6 +3647,22 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
return std::nullopt;
}
+static void printInitializationList(OpAsmPrinter &parser,
+ Block::BlockArgListType blocksArgs,
+ ValueRange initializers,
+ StringRef prefix = "") {
+ assert(blocksArgs.size() == initializers.size() &&
+ "expected same length of arguments and initializers");
+ if (initializers.empty())
+ return;
+
+ parser << prefix << '(';
+ llvm::interleaveComma(
+ llvm::zip(blocksArgs, initializers), parser,
+ [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
+ parser << ")";
+}
+
// parse and print of IfOp refer to the implementation of SCF dialect.
ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
// Create the regions for 'then'.
@@ -3654,16 +3670,64 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
Region *thenRegion = result.addRegion();
Region *elseRegion = result.addRegion();
- auto &builder = parser.getBuilder();
OpAsmParser::UnresolvedOperand cond;
- // Create a i1 tensor type for the boolean condition.
- Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
- if (parser.parseOperand(cond) ||
- parser.resolveOperand(cond, i1Type, result.operands))
+
+ if (parser.parseOperand(cond))
return failure();
- // Parse optional results type list.
- if (parser.parseOptionalArrowTypeList(result.types))
+
+ SmallVector<OpAsmParser::Argument, 4> regionArgs;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
+
+ // Parse the optional block arguments
+ OptionalParseResult listResult =
+ parser.parseOptionalAssignmentList(regionArgs, operands);
+ if (listResult.has_value() && failed(listResult.value()))
return failure();
+
+ // Parse a colon.
+ if (failed(parser.parseColon()))
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected type for condition operand");
+
+ // Parse the type of the condition operand
+ Type condType;
+ if (failed(parser.parseType(condType)))
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected type for condition operand");
+
+ // Resolve operand with provided type
+ if (failed(parser.resolveOperand(cond, condType, result.operands)))
+ return failure();
+
+ // Parse optional block arg types
+ if (listResult.has_value()) {
+ FunctionType functionType;
+
+ if (failed(parser.parseType(functionType)))
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected list of types for block arguments "
+ << "followed by arrow type and list of return types";
+
+ result.addTypes(functionType.getResults());
+
+ if (functionType.getNumInputs() != operands.size()) {
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected as many input types as operands "
+ << "(expected " << operands.size() << " got "
+ << functionType.getNumInputs() << ")";
+ }
+
+ // Resolve input operands.
+ if (failed(parser.resolveOperands(operands, functionType.getInputs(),
+ parser.getCurrentLocation(),
+ result.operands)))
+ return failure();
+ } else {
+ // Parse optional results type list.
+ if (parser.parseOptionalArrowTypeList(result.types))
+ return failure();
+ }
+
// Parse the 'then' region.
if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
return failure();
@@ -3681,26 +3745,28 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
}
void IfOp::print(OpAsmPrinter &p) {
- bool printBlockTerminators = false;
-
p << " " << getCondition();
- if (!getResults().empty()) {
- p << " -> (" << getResultTypes() << ")";
- // Print yield explicitly if the op defines values.
- printBlockTerminators = true;
+
+ printInitializationList(p, getThenGraph().front().getArguments(),
+ getInputList(), " ");
+ p << " : ";
+ p << getCondition().getType();
+
+ if (!getInputList().empty()) {
+ p << " (";
+ llvm::interleaveComma(getInputList().getTypes(), p);
+ p << ")";
}
- p << ' ';
- p.printRegion(getThenGraph(),
- /*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/printBlockTerminators);
+ p.printArrowTypeList(getResultTypes());
+ p << " ";
+
+ p.printRegion(getThenGraph());
// Print the 'else' regions if it exists and has a block.
auto &elseRegion = getElseGraph();
if (!elseRegion.empty()) {
p << " else ";
- p.printRegion(elseRegion,
- /*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/printBlockTerminators);
+ p.printRegion(elseRegion);
}
p.printOptionalAttrDict((*this)->getAttrs());
@@ -3909,22 +3975,6 @@ ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
parser.parseOptionalAttrDictWithKeyword(result.attributes));
}
-static void printInitializationList(OpAsmPrinter &parser,
- Block::BlockArgListType blocksArgs,
- ValueRange initializers,
- StringRef prefix = "") {
- assert(blocksArgs.size() == initializers.size() &&
- "expected same length of arguments and initializers");
- if (initializers.empty())
- return;
-
- parser << prefix << '(';
- llvm::interleaveComma(
- llvm::zip(blocksArgs, initializers), parser,
- [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
- parser << ")";
-}
-
void WhileOp::print(OpAsmPrinter &parser) {
printInitializationList(parser, getCondGraph().front().getArguments(),
getInputList(), " ");
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index 9474299..0bec0da 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -81,9 +81,8 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
dyn_cast<RankedTensorType>(input.getType()).getElementType());
auto revisedInputShapeValue =
getTosaConstShape(rewriter, op.getLoc(), revisedInputShape);
- input = rewriter
- .create<tosa::ReshapeOp>(op.getLoc(), inputType, input,
- revisedInputShapeValue)
+ input = tosa::ReshapeOp::create(rewriter, op.getLoc(), inputType, input,
+ revisedInputShapeValue)
.getResult();
Type resultETy = resultType.getElementType();
@@ -162,9 +161,8 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
shiftType, rewriter.getIntegerAttr(shiftElementType, 0));
Value constZero =
tosa::ConstOp::create(rewriter, op.getLoc(), shiftType, shiftZeroAttr);
- Value mulValue = rewriter
- .create<tosa::MulOp>(op.getLoc(), mulShapeType, input,
- weight, constZero)
+ Value mulValue = tosa::MulOp::create(rewriter, op.getLoc(), mulShapeType,
+ input, weight, constZero)
.getResult();
// Reshape output to [N, H, W, C * M].
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 88b0f36..9543fa1 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -464,9 +464,12 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
CheckCondition condition = CheckCondition::invalid;
const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
+ if (failed(maybeProfDef) && failed(maybeExtDef))
+ return success();
- if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
- !maybeProfDef.value().size() && !maybeExtDef.value().size()) {
+ const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) ||
+ (succeeded(maybeExtDef) && !maybeExtDef->empty());
+ if (!hasEntry) {
std::string message;
llvm::raw_string_ostream os(message);
os << "illegal: operation operand/result data types did not align with any "
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 32b5fb6..8ec7765 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1248,16 +1248,14 @@ bool checkErrorIfCondIf(Operation *op) {
// })
//
// Simplified:
- // %0 = tosa.cond_if %arg2 {
- // tosa.yield %arg0
+ // %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) {
+ // ^bb0(%arg3, %arg4):
+ // tosa.yield %arg3
// } else {
- // tosa.yield %arg1
+ // ^bb0(%arg3, %arg4):
+ // tosa.yield %arg4
// }
- //
- // Unfortunately, the simplified syntax does not encapsulate values
- // used in then/else regions (see 'simplified' example above), so it
- // must be rewritten to use the generic syntax in order to be conformant
- // to the specification.
+
return failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else"));
}
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index 4662836..14a4fdf 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -16,15 +16,13 @@
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/iterator.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/InterleavedRange.h"
#define DEBUG_TYPE "transform-dialect"
-#define DEBUG_TYPE_FULL "transform-dialect-full"
#define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << (X))
-#define FULL_LDBG(X) DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, (DBGS() << (X)))
+#define FULL_LDBG() LDBG(4)
using namespace mlir;
@@ -486,24 +484,20 @@ void transform::TransformState::recordOpHandleInvalidationOne(
newlyInvalidated.count(otherHandle))
return;
- FULL_LDBG("--recordOpHandleInvalidationOne\n");
- DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
- (DBGS() << "--ancestors: "
- << llvm::interleaved(llvm::make_pointee_range(potentialAncestors))
- << "\n");
- });
+ FULL_LDBG() << "--recordOpHandleInvalidationOne";
+ FULL_LDBG() << "--ancestors: "
+ << llvm::interleaved(
+ llvm::make_pointee_range(potentialAncestors));
Operation *owner = consumingHandle.getOwner();
unsigned operandNo = consumingHandle.getOperandNumber();
for (Operation *ancestor : potentialAncestors) {
// clang-format off
- DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
- { (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); });
- DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
- { (DBGS() << "----of payload with name: "
- << payloadOp->getName().getIdentifier() << "\n"); });
- DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
- { (DBGS() << "----of payload: " << *payloadOp << "\n"); });
+ FULL_LDBG() << "----handle one ancestor: " << *ancestor;;
+
+ FULL_LDBG() << "----of payload with name: "
+ << payloadOp->getName().getIdentifier();
+ FULL_LDBG() << "----of payload: " << *payloadOp;
// clang-format on
if (!ancestor->isAncestor(payloadOp))
continue;
@@ -609,10 +603,8 @@ void transform::TransformState::recordOpHandleInvalidation(
transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
if (potentialAncestors.empty()) {
- DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
- (DBGS() << "----recording invalidation for empty handle: " << handle.get()
- << "\n");
- });
+ FULL_LDBG() << "----recording invalidation for empty handle: "
+ << handle.get();
Operation *owner = handle.getOwner();
unsigned operandNo = handle.getOperandNumber();
@@ -709,7 +701,7 @@ void transform::TransformState::recordValueHandleInvalidation(
LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
transform::TransformOpInterface transform,
transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
- FULL_LDBG("--Start checkAndRecordHandleInvalidation\n");
+ FULL_LDBG() << "--Start checkAndRecordHandleInvalidation";
auto memoryEffectsIface =
cast<MemoryEffectOpInterface>(transform.getOperation());
SmallVector<MemoryEffects::EffectInstance> effects;
@@ -717,9 +709,7 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
transform::TransformMappingResource::get(), effects);
for (OpOperand &target : transform->getOpOperands()) {
- DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
- (DBGS() << "----iterate on handle: " << target.get() << "\n");
- });
+ FULL_LDBG() << "----iterate on handle: " << target.get();
// If the operand uses an invalidated handle, report it. If the operation
// allows handles to point to repeated payload operations, only report
// pre-existing invalidation errors. Otherwise, also report invalidations
@@ -727,14 +717,14 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
auto it = invalidatedHandles.find(target.get());
auto nit = newlyInvalidated.find(target.get());
if (it != invalidatedHandles.end()) {
- FULL_LDBG("--End checkAndRecordHandleInvalidation, found already "
- "invalidated -> FAILURE\n");
+ FULL_LDBG() << "--End checkAndRecordHandleInvalidation, found already "
+ "invalidated -> FAILURE";
return it->getSecond()(transform->getLoc()), failure();
}
if (!transform.allowsRepeatedHandleOperands() &&
nit != newlyInvalidated.end()) {
- FULL_LDBG("--End checkAndRecordHandleInvalidation, found newly "
- "invalidated (by this op) -> FAILURE\n");
+ FULL_LDBG() << "--End checkAndRecordHandleInvalidation, found newly "
+ "invalidated (by this op) -> FAILURE";
return nit->getSecond()(transform->getLoc()), failure();
}
@@ -745,27 +735,28 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
effect.getValue() == target.get();
};
if (llvm::any_of(effects, consumesTarget)) {
- FULL_LDBG("----found consume effect\n");
+ FULL_LDBG() << "----found consume effect";
if (llvm::isa<transform::TransformHandleTypeInterface>(
target.get().getType())) {
- FULL_LDBG("----recordOpHandleInvalidation\n");
+ FULL_LDBG() << "----recordOpHandleInvalidation";
SmallVector<Operation *> payloadOps =
llvm::to_vector(getPayloadOps(target.get()));
recordOpHandleInvalidation(target, payloadOps, nullptr,
newlyInvalidated);
} else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
target.get().getType())) {
- FULL_LDBG("----recordValueHandleInvalidation\n");
+ FULL_LDBG() << "----recordValueHandleInvalidation";
recordValueHandleInvalidation(target, newlyInvalidated);
} else {
- FULL_LDBG("----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
+ FULL_LDBG()
+ << "----not a TransformHandle -> SKIP AND DROP ON THE FLOOR";
}
} else {
- FULL_LDBG("----no consume effect -> SKIP\n");
+ FULL_LDBG() << "----no consume effect -> SKIP";
}
}
- FULL_LDBG("--End checkAndRecordHandleInvalidation -> SUCCESS\n");
+ FULL_LDBG() << "--End checkAndRecordHandleInvalidation -> SUCCESS";
return success();
}
@@ -818,18 +809,14 @@ void transform::TransformState::compactOpHandles() {
DiagnosedSilenceableFailure
transform::TransformState::applyTransform(TransformOpInterface transform) {
- LLVM_DEBUG({
- DBGS() << "applying: ";
- transform->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
- llvm::dbgs() << "\n";
- });
- DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
- DBGS() << "Top-level payload before application:\n"
- << *getTopLevel() << "\n");
+ LDBG() << "applying: "
+ << OpWithFlags(transform, OpPrintingFlags().skipRegions());
+ FULL_LDBG() << "Top-level payload before application:\n" << *getTopLevel();
auto printOnFailureRAII = llvm::make_scope_exit([this] {
(void)this;
- LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print(
- llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()););
+ LDBG() << "Failing Top-level payload:\n"
+ << OpWithFlags(getTopLevel(),
+ OpPrintingFlags().printGenericOpForm());
});
// Set current transform op.
@@ -837,47 +824,45 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
// Expensive checks to detect invalid transform IR.
if (options.getExpensiveChecksEnabled()) {
- FULL_LDBG("ExpensiveChecksEnabled\n");
+ FULL_LDBG() << "ExpensiveChecksEnabled";
if (failed(checkAndRecordHandleInvalidation(transform)))
return DiagnosedSilenceableFailure::definiteFailure();
for (OpOperand &operand : transform->getOpOperands()) {
- DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
- (DBGS() << "iterate on handle: " << operand.get() << "\n");
- });
+ FULL_LDBG() << "iterate on handle: " << operand.get();
if (!isHandleConsumed(operand.get(), transform)) {
- FULL_LDBG("--handle not consumed -> SKIP\n");
+ FULL_LDBG() << "--handle not consumed -> SKIP";
continue;
}
if (transform.allowsRepeatedHandleOperands()) {
- FULL_LDBG("--op allows repeated handles -> SKIP\n");
+ FULL_LDBG() << "--op allows repeated handles -> SKIP";
continue;
}
- FULL_LDBG("--handle is consumed\n");
+ FULL_LDBG() << "--handle is consumed";
Type operandType = operand.get().getType();
if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
- FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n");
+ FULL_LDBG() << "--checkRepeatedConsumptionInOperand for Operation*";
DiagnosedSilenceableFailure check =
checkRepeatedConsumptionInOperand<Operation *>(
getPayloadOpsView(operand.get()), transform,
operand.getOperandNumber());
if (!check.succeeded()) {
- FULL_LDBG("----FAILED\n");
+ FULL_LDBG() << "----FAILED";
return check;
}
} else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
- FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n");
+ FULL_LDBG() << "--checkRepeatedConsumptionInOperand For Value";
DiagnosedSilenceableFailure check =
checkRepeatedConsumptionInOperand<Value>(
getPayloadValuesView(operand.get()), transform,
operand.getOperandNumber());
if (!check.succeeded()) {
- FULL_LDBG("----FAILED\n");
+ FULL_LDBG() << "----FAILED";
return check;
}
} else {
- FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
+ FULL_LDBG() << "--not a TransformHandle -> SKIP AND DROP ON THE FLOOR";
}
}
}
@@ -999,8 +984,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
printOnFailureRAII.release();
DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
- DBGS() << "Top-level payload:\n";
- getTopLevel()->print(llvm::dbgs());
+ LDBG() << "Top-level payload:\n" << *getTopLevel();
});
return result;
}
@@ -1277,7 +1261,7 @@ void transform::TrackingListener::notifyMatchFailure(
LLVM_DEBUG({
Diagnostic diag(loc, DiagnosticSeverity::Remark);
reasonCallback(diag);
- DBGS() << "Match Failure : " << diag.str() << "\n";
+ LDBG() << "Match Failure : " << diag.str();
});
}
diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
index d464230..0248896 100644
--- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
@@ -26,6 +26,7 @@ add_mlir_dialect_library(MLIRVectorDialect
MLIRMemRefDialect
MLIRSideEffectInterfaces
MLIRTensorDialect
+ MLIRUBDialect
MLIRValueBoundsOpInterface
MLIRVectorInterfaces
)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8c97aed..86fbb76 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -372,9 +372,8 @@ SmallVector<Value> vector::getAsValues(OpBuilder &builder, Location loc,
llvm::transform(foldResults, std::back_inserter(values),
[&](OpFoldResult foldResult) {
if (auto attr = dyn_cast<Attribute>(foldResult))
- return builder
- .create<arith::ConstantIndexOp>(
- loc, cast<IntegerAttr>(attr).getInt())
+ return arith::ConstantIndexOp::create(
+ builder, loc, cast<IntegerAttr>(attr).getInt())
.getResult();
return cast<Value>(foldResult);
@@ -1259,63 +1258,6 @@ void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
CanonicalizeContractAdd<arith::AddFOp>>(context);
}
-//===----------------------------------------------------------------------===//
-// ExtractElementOp
-//===----------------------------------------------------------------------===//
-
-void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
- SetIntRangeFn setResultRanges) {
- setResultRanges(getResult(), argRanges.front());
-}
-
-void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
- Value source) {
- result.addOperands({source});
- result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType());
-}
-
-LogicalResult vector::ExtractElementOp::verify() {
- VectorType vectorType = getSourceVectorType();
- if (vectorType.getRank() == 0) {
- if (getPosition())
- return emitOpError("expected position to be empty with 0-D vector");
- return success();
- }
- if (vectorType.getRank() != 1)
- return emitOpError("unexpected >1 vector rank");
- if (!getPosition())
- return emitOpError("expected position for 1-D vector");
- return success();
-}
-
-OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
- // Skip the 0-D vector here now.
- if (!adaptor.getPosition())
- return {};
-
- // Fold extractelement (splat X) -> X.
- if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
- return splat.getInput();
-
- // Fold extractelement(broadcast(X)) -> X.
- if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
- if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
- return broadcast.getSource();
-
- auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
- auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
- if (!pos || !src)
- return {};
-
- auto srcElements = src.getValues<Attribute>();
-
- uint64_t posIdx = pos.getInt();
- if (posIdx >= srcElements.size())
- return {};
-
- return srcElements[posIdx];
-}
-
// Returns `true` if `index` is either within [0, maxIndex) or equal to
// `poisonValue`.
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue,
@@ -2591,8 +2533,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
llvm::enumerate(fromElements.getElements())) {
// Check that the element is from a vector.extract operation.
- auto extractOp =
- dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
+ auto extractOp = element.getDefiningOp<vector::ExtractOp>();
if (!extractOp) {
return rewriter.notifyMatchFailure(fromElements,
"element not from vector.extract");
@@ -3186,60 +3127,6 @@ void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
//===----------------------------------------------------------------------===//
-// InsertElementOp
-//===----------------------------------------------------------------------===//
-
-void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
- SetIntRangeFn setResultRanges) {
- setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
-}
-
-void InsertElementOp::build(OpBuilder &builder, OperationState &result,
- Value source, Value dest) {
- build(builder, result, source, dest, {});
-}
-
-LogicalResult InsertElementOp::verify() {
- auto dstVectorType = getDestVectorType();
- if (dstVectorType.getRank() == 0) {
- if (getPosition())
- return emitOpError("expected position to be empty with 0-D vector");
- return success();
- }
- if (dstVectorType.getRank() != 1)
- return emitOpError("unexpected >1 vector rank");
- if (!getPosition())
- return emitOpError("expected position for 1-D vector");
- return success();
-}
-
-OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
- // Skip the 0-D vector here.
- if (!adaptor.getPosition())
- return {};
-
- auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
- auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
- auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
- if (!src || !dst || !pos)
- return {};
-
- if (src.getType() != getDestVectorType().getElementType())
- return {};
-
- auto dstElements = dst.getValues<Attribute>();
-
- SmallVector<Attribute> results(dstElements);
-
- uint64_t posIdx = pos.getInt();
- if (posIdx >= results.size())
- return {};
- results[posIdx] = src;
-
- return DenseElementsAttr::get(getDestVectorType(), results);
-}
-
-//===----------------------------------------------------------------------===//
// InsertOp
//===----------------------------------------------------------------------===//
@@ -6429,6 +6316,11 @@ std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
return llvm::to_vector<4>(getResultVectorType().getShape());
}
+void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ setResultRanges(getResult(), argRanges.front());
+}
+
namespace {
// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
@@ -7311,6 +7203,23 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
}
//===----------------------------------------------------------------------===//
+// StepOp
+//===----------------------------------------------------------------------===//
+
+void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ auto resultType = cast<VectorType>(getType());
+ if (resultType.isScalable()) {
+ return;
+ }
+ unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType);
+ APInt zero(bitwidth, 0);
+ APInt high(bitwidth, resultType.getDimSize(0) - 1);
+ ConstantIntRanges result = {zero, high, zero, high};
+ setResultRanges(getResult(), result);
+}
+
+//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index cb8e566..dedc3b3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -28,7 +28,10 @@ using namespace mlir;
using namespace mlir::vector;
namespace {
-/// Progressive lowering of BroadcastOp.
+
+/// Convert a vector.broadcast with a vector operand to a lower rank
+/// vector.broadcast. vector.broadcast with a scalar operand is expected to be
+/// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly.
class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
public:
using OpRewritePattern::OpRewritePattern;
@@ -40,20 +43,23 @@ public:
VectorType srcType = dyn_cast<VectorType>(op.getSourceType());
Type eltType = dstType.getElementType();
- // Scalar to any vector can use splat.
- if (!srcType) {
- rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
- return success();
- }
+ // A broadcast from a scalar is considered to be in the lowered form.
+ if (!srcType)
+ return rewriter.notifyMatchFailure(
+ op, "broadcast from scalar already in lowered form");
// Determine rank of source and destination.
int64_t srcRank = srcType.getRank();
int64_t dstRank = dstType.getRank();
- // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
+ // Here we are broadcasting to a rank-1 vector. Ensure that the source is a
+ // scalar.
if (srcRank <= 1 && dstRank == 1) {
- Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource());
- rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
+ SmallVector<int64_t> fullRankPosition(srcRank, 0);
+ Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(),
+ fullRankPosition);
+ assert(!isa<VectorType>(ext.getType()) && "expected scalar");
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, dstType, ext);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 2484670..e062f55 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -248,11 +248,10 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
scf::YieldOp::create(b, loc, result);
};
- result =
- rewriter
- .create<scf::IfOp>(loc, condition, /*thenBuilder=*/loadBuilder,
+ result = scf::IfOp::create(rewriter, loc, condition,
+ /*thenBuilder=*/loadBuilder,
/*elseBuilder=*/passThruBuilder)
- .getResult(0);
+ .getResult(0);
}
rewriter.replaceOp(op, result);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index e910932..2cf8f0b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -142,8 +142,8 @@ struct TransferReadPermutationLowering
// Transpose result of transfer_read.
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
- return rewriter
- .create<vector::TransposeOp>(op.getLoc(), newRead, transposePerm)
+ return vector::TransposeOp::create(rewriter, op.getLoc(), newRead,
+ transposePerm)
.getResult();
}
};
@@ -371,8 +371,8 @@ struct TransferOpReduceRank
rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
newInBoundsAttr);
- return rewriter
- .create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
+ return vector::BroadcastOp::create(rewriter, op.getLoc(), originalVecType,
+ newRead)
.getVector();
}
};
@@ -468,7 +468,7 @@ struct TransferReadToVectorLoadLowering
read, "vector type is not rank 1, can't create masked load, needs "
"VectorToSCF");
- Value fill = vector::SplatOp::create(
+ Value fill = vector::BroadcastOp::create(
rewriter, read.getLoc(), unbroadcastedVectorType, read.getPadding());
res = vector::MaskedLoadOp::create(
rewriter, read.getLoc(), unbroadcastedVectorType, read.getBase(),
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 58e94ea..bb0f339 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -451,10 +451,9 @@ struct WarpOpTransferWrite : public WarpDistributionPattern {
}
SmallVector<Value> delinearized;
if (map.getNumResults() > 1) {
- delinearized = rewriter
- .create<mlir::affine::AffineDelinearizeIndexOp>(
- newWarpOp.getLoc(), newWarpOp.getLaneid(),
- delinearizedIdSizes)
+ delinearized = mlir::affine::AffineDelinearizeIndexOp::create(
+ rewriter, newWarpOp.getLoc(), newWarpOp.getLaneid(),
+ delinearizedIdSizes)
.getResults();
} else {
// If there is only one map result, we can elide the delinearization
@@ -1538,19 +1537,18 @@ struct WarpOpInsertScalar : public WarpDistributionPattern {
arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
newWarpOp.getLaneid(), insertingLane);
Value newResult =
- rewriter
- .create<scf::IfOp>(
- loc, isInsertingLane,
- /*thenBuilder=*/
- [&](OpBuilder &builder, Location loc) {
- Value newInsert = vector::InsertOp::create(
- builder, loc, newSource, distributedVec, newPos);
- scf::YieldOp::create(builder, loc, newInsert);
- },
- /*elseBuilder=*/
- [&](OpBuilder &builder, Location loc) {
- scf::YieldOp::create(builder, loc, distributedVec);
- })
+ scf::IfOp::create(
+ rewriter, loc, isInsertingLane,
+ /*thenBuilder=*/
+ [&](OpBuilder &builder, Location loc) {
+ Value newInsert = vector::InsertOp::create(
+ builder, loc, newSource, distributedVec, newPos);
+ scf::YieldOp::create(builder, loc, newInsert);
+ },
+ /*elseBuilder=*/
+ [&](OpBuilder &builder, Location loc) {
+ scf::YieldOp::create(builder, loc, distributedVec);
+ })
.getResult(0);
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
return success();
@@ -1661,10 +1659,9 @@ struct WarpOpInsert : public WarpDistributionPattern {
auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
scf::YieldOp::create(builder, loc, distributedDest);
};
- newResult = rewriter
- .create<scf::IfOp>(loc, isInsertingLane,
- /*thenBuilder=*/insertingBuilder,
- /*elseBuilder=*/nonInsertingBuilder)
+ newResult = scf::IfOp::create(rewriter, loc, isInsertingLane,
+ /*thenBuilder=*/insertingBuilder,
+ /*elseBuilder=*/nonInsertingBuilder)
.getResult(0);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 73388a5..9889d7f2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -466,9 +466,9 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
newOp = mlir::vector::maskOperation(rewriter, newOp, newMask);
}
- return rewriter
- .create<vector::BroadcastOp>(loc, contractOp->getResultTypes()[0],
- newOp->getResults()[0])
+ return vector::BroadcastOp::create(rewriter, loc,
+ contractOp->getResultTypes()[0],
+ newOp->getResults()[0])
.getResult();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index e6bb96f..f78e579 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -32,7 +32,7 @@
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include <cstdint>
@@ -41,9 +41,6 @@
using namespace mlir;
#define DEBUG_TYPE "vector-narrow-type-emulation"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
-#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
using VectorValue = TypedValue<VectorType>;
using MemRefValue = TypedValue<MemRefType>;
@@ -135,17 +132,16 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
return vector::CreateMaskOp::create(rewriter, loc, newMaskType,
newMaskOperands);
})
- .Case<vector::ConstantMaskOp>(
- [&](auto constantMaskOp) -> std::optional<Operation *> {
- // Take the shape of mask, compress its trailing dimension:
- SmallVector<int64_t> maskDimSizes(
- constantMaskOp.getMaskDimSizes());
- int64_t &maskIndex = maskDimSizes.back();
- maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
- numSrcElemsPerDest);
- return vector::ConstantMaskOp::create(
- rewriter, loc, newMaskType, maskDimSizes);
- })
+ .Case<vector::ConstantMaskOp>([&](auto constantMaskOp)
+ -> std::optional<Operation *> {
+ // Take the shape of mask, compress its trailing dimension:
+ SmallVector<int64_t> maskDimSizes(constantMaskOp.getMaskDimSizes());
+ int64_t &maskIndex = maskDimSizes.back();
+ maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
+ numSrcElemsPerDest);
+ return vector::ConstantMaskOp::create(rewriter, loc, newMaskType,
+ maskDimSizes);
+ })
.Case<arith::ConstantOp>([&](auto constantOp)
-> std::optional<Operation *> {
// TODO: Support multiple dimensions.
@@ -232,9 +228,8 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
auto resultVectorType =
VectorType::get({numElemsToExtract}, vectorType.getElementType());
- return rewriter
- .create<vector::ExtractStridedSliceOp>(loc, resultVectorType, src,
- offsets, sizes, strides)
+ return vector::ExtractStridedSliceOp::create(rewriter, loc, resultVectorType,
+ src, offsets, sizes, strides)
->getResult(0);
}
@@ -1526,11 +1521,11 @@ BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
"requires -D non-scalable vector type");
int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
- LDBG("sourceVectorType: " << sourceVectorType);
+ LDBG() << "sourceVectorType: " << sourceVectorType;
int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
int64_t mostMinorTargetDim = targetVectorType.getShape().back();
- LDBG("targetVectorType: " << targetVectorType);
+ LDBG() << "targetVectorType: " << targetVectorType;
int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
(void)mostMinorSourceDim;
@@ -1555,7 +1550,7 @@ BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
VectorType targetVectorType)
: enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
- LDBG("\n" << enumerator.sourceElementRanges);
+ LDBG() << "\n" << enumerator.sourceElementRanges;
}
/// Verify that the precondition type meets the common preconditions for any
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 72352d7..cbb9d4b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -303,7 +303,7 @@ public:
// Extract/insert on a lower ranked extract strided slice op.
Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
rewriter.getZeroAttr(elemType));
- Value res = SplatOp::create(rewriter, loc, dstType, zero);
+ Value res = BroadcastOp::create(rewriter, loc, dstType, zero);
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
off += stride, ++idx) {
Value one = ExtractOp::create(rewriter, loc, op.getVector(), off);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 2676d25..c707f38 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -25,12 +25,10 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "vector-transfer-opt"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-
using namespace mlir;
/// Return the ancestor op in the region or nullptr if the region is not
@@ -88,8 +86,7 @@ bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
/// transfer_write is dead if all reads that can be reached from the potentially
/// dead transfer_write are dominated by the overwriting transfer_write.
void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
- LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
- << "\n");
+ LDBG() << "Candidate for dead store: " << *write.getOperation();
llvm::SmallVector<Operation *, 8> blockingAccesses;
Operation *firstOverwriteCandidate = nullptr;
Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getBase()));
@@ -150,13 +147,12 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
!isReachable(writeAncestor, accessAncestor))
continue;
if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
- LLVM_DEBUG(DBGS() << "Store may not be dead due to op: "
- << *accessAncestor << "\n");
+ LDBG() << "Store may not be dead due to op: " << *accessAncestor;
return;
}
}
- LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
- << " overwritten by: " << *firstOverwriteCandidate << "\n");
+ LDBG() << "Found dead store: " << *write.getOperation()
+ << " overwritten by: " << *firstOverwriteCandidate;
opToErase.push_back(write.getOperation());
}
@@ -174,8 +170,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
if (read.hasOutOfBoundsDim())
return;
- LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
- << "\n");
+ LDBG() << "Candidate for Forwarding: " << *read.getOperation();
SmallVector<Operation *, 8> blockingWrites;
vector::TransferWriteOp lastwrite = nullptr;
Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getBase()));
@@ -230,14 +225,13 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
continue;
if (!postDominators.postDominates(lastwrite, write)) {
- LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
- << *write << "\n");
+ LDBG() << "Fail to do write to read forwarding due to op: " << *write;
return;
}
}
- LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
- << " to: " << *read.getOperation() << "\n");
+ LDBG() << "Forward value from " << *lastwrite.getOperation()
+ << " to: " << *read.getOperation();
read.replaceAllUsesWith(lastwrite.getVector());
opToErase.push_back(read.getOperation());
}
@@ -330,8 +324,8 @@ createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc,
}
reducedOperands.push_back(operand);
}
- return rewriter
- .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands)
+ return vector::CreateMaskOp::create(rewriter, loc, reducedType,
+ reducedOperands)
.getResult();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index 05b0074..5e12dc4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -348,24 +348,23 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
Location loc = xferOp.getLoc();
Value zero = arith::ConstantIndexOp::create(b, loc, 0);
Value memref = xferOp.getBase();
- return b
- .create<scf::IfOp>(
- loc, inBoundsCond,
- [&](OpBuilder &b, Location loc) {
- Value res =
- castToCompatibleMemRefType(b, memref, compatibleMemRefType);
- scf::ValueVector viewAndIndices{res};
- llvm::append_range(viewAndIndices, xferOp.getIndices());
- scf::YieldOp::create(b, loc, viewAndIndices);
- },
- [&](OpBuilder &b, Location loc) {
- Value casted =
- castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
- scf::ValueVector viewAndIndices{casted};
- viewAndIndices.insert(viewAndIndices.end(),
- xferOp.getTransferRank(), zero);
- scf::YieldOp::create(b, loc, viewAndIndices);
- })
+ return scf::IfOp::create(
+ b, loc, inBoundsCond,
+ [&](OpBuilder &b, Location loc) {
+ Value res =
+ castToCompatibleMemRefType(b, memref, compatibleMemRefType);
+ scf::ValueVector viewAndIndices{res};
+ llvm::append_range(viewAndIndices, xferOp.getIndices());
+ scf::YieldOp::create(b, loc, viewAndIndices);
+ },
+ [&](OpBuilder &b, Location loc) {
+ Value casted =
+ castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
+ scf::ValueVector viewAndIndices{casted};
+ viewAndIndices.insert(viewAndIndices.end(),
+ xferOp.getTransferRank(), zero);
+ scf::YieldOp::create(b, loc, viewAndIndices);
+ })
->getResults();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 73ca327..2269a40 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -410,9 +410,8 @@ FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp,
oldMaskType.getScalableDims().drop_front(unusedDimsBitVector.count());
VectorType maskOpType =
VectorType::get(newShape, rewriter.getI1Type(), newShapeScalableDims);
- mask = rewriter
- .create<vector::ShapeCastOp>(contractOp.getLoc(), maskOpType,
- maskingOp.getMask())
+ mask = vector::ShapeCastOp::create(rewriter, contractOp.getLoc(),
+ maskOpType, maskingOp.getMask())
.getResult();
}
@@ -940,7 +939,7 @@ public:
Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
rewriter.getZeroAttr(elemType));
- Value res = SplatOp::create(rewriter, loc, castDstType, zero);
+ Value res = BroadcastOp::create(rewriter, loc, castDstType, zero);
SmallVector<int64_t> sliceShape = {castDstLastDim};
SmallVector<int64_t> strides = {1};
@@ -966,6 +965,45 @@ private:
std::function<bool(BitCastOp)> controlFn;
};
+static bool haveSameShapeAndScaling(Type t, Type u) {
+ auto tVec = dyn_cast<VectorType>(t);
+ auto uVec = dyn_cast<VectorType>(u);
+ if (!tVec) {
+ return !uVec;
+ }
+ if (!uVec) {
+ return false;
+ }
+ return tVec.getShape() == uVec.getShape() &&
+ tVec.getScalableDims() == uVec.getScalableDims();
+}
+
+/// If `type` is shaped, clone it with `newElementType`. Otherwise,
+/// return `newElementType`.
+static Type cloneOrReplace(Type type, Type newElementType) {
+ if (auto shapedType = dyn_cast<ShapedType>(type)) {
+ return shapedType.clone(newElementType);
+ }
+ return newElementType;
+}
+
+/// If `value` is the result of a splat or broadcast operation, return the input
+/// of the splat/broadcast operation.
+static Value getBroadcastLikeSource(Value value) {
+
+ Operation *op = value.getDefiningOp();
+ if (!op)
+ return {};
+
+ if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
+ return broadcast.getSource();
+
+ if (auto splat = dyn_cast<vector::SplatOp>(op))
+ return splat.getInput();
+
+ return {};
+}
+
/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
///
/// Example:
@@ -989,16 +1027,14 @@ struct ReorderElementwiseOpsOnBroadcast final
PatternRewriter &rewriter) const override {
if (op->getNumResults() != 1)
return failure();
- if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
+ auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
+ if (!resultType)
return failure();
if (!OpTrait::hasElementwiseMappableTraits(op))
return rewriter.notifyMatchFailure(
op, "Op doesn't have ElementwiseMappableTraits");
if (op->getNumOperands() == 0)
return failure();
- if (op->getResults()[0].getType() != op->getOperand(0).getType())
- return rewriter.notifyMatchFailure(op,
- "result and operand type mismatch");
if (isa<vector::FMAOp>(op)) {
return rewriter.notifyMatchFailure(
op,
@@ -1006,45 +1042,71 @@ struct ReorderElementwiseOpsOnBroadcast final
"might be a scalar");
}
- // Get the type of the lhs operand
- auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
- if (!lhsBcastOrSplat ||
- !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
+ Type resultElemType = resultType.getElementType();
+
+ // Get the type of the first non-constant operand
+ Value splatSource;
+ for (Value operand : op->getOperands()) {
+ Operation *definingOp = operand.getDefiningOp();
+ if (!definingOp)
+ return failure();
+ if (definingOp->hasTrait<OpTrait::ConstantLike>())
+ continue;
+ splatSource = getBroadcastLikeSource(operand);
+ break;
+ }
+ if (!splatSource)
return failure();
- auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
+ Type unbroadcastResultType =
+ cloneOrReplace(splatSource.getType(), resultElemType);
- // Make sure that all operands are broadcast from identical types:
+ // Make sure that all operands are broadcast from identically-shaped types:
// * scalar (`vector.broadcast` + `vector.splat`), or
// * vector (`vector.broadcast`).
// Otherwise the re-ordering wouldn't be safe.
- if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
- auto bcast = val.getDefiningOp<vector::BroadcastOp>();
- if (bcast)
- return (bcast.getOperand().getType() == lhsBcastOrSplatType);
- auto splat = val.getDefiningOp<vector::SplatOp>();
- if (splat)
- return (splat.getOperand().getType() == lhsBcastOrSplatType);
- return false;
+ if (!llvm::all_of(op->getOperands(), [splatSource](Value val) {
+ if (auto source = getBroadcastLikeSource(val))
+ return haveSameShapeAndScaling(source.getType(),
+ splatSource.getType());
+ SplatElementsAttr splatConst;
+ return matchPattern(val, m_Constant(&splatConst));
})) {
- return failure();
+ return rewriter.notifyMatchFailure(
+ op,
+ "not all operands are constants or broadcasts from the same type");
}
// Collect the source values before broadcasting
SmallVector<Value> srcValues;
srcValues.reserve(op->getNumOperands());
for (Value operand : op->getOperands()) {
- srcValues.push_back(operand.getDefiningOp()->getOperand(0));
+ SplatElementsAttr splatConst;
+ if (matchPattern(operand, m_Constant(&splatConst))) {
+ Attribute newConst;
+ Type elementType = getElementTypeOrSelf(operand.getType());
+ Type newType = cloneOrReplace(unbroadcastResultType, elementType);
+ if (auto newTypeShaped = dyn_cast<ShapedType>(newType)) {
+ newConst = splatConst.resizeSplat(newTypeShaped);
+ } else {
+ newConst = splatConst.getSplatValue<Attribute>();
+ }
+ Operation *newConstOp =
+ operand.getDefiningOp()->getDialect()->materializeConstant(
+ rewriter, newConst, newType, operand.getLoc());
+ srcValues.push_back(newConstOp->getResult(0));
+ } else {
+ srcValues.push_back(operand.getDefiningOp()->getOperand(0));
+ }
}
// Create the "elementwise" Op
Operation *elementwiseOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
- lhsBcastOrSplatType, op->getAttrs());
+ unbroadcastResultType, op->getAttrs());
// Replace the original Op with the elementwise Op
- auto vectorType = op->getResultTypes()[0];
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
- op, vectorType, elementwiseOp->getResults());
+ op, resultType, elementwiseOp->getResults());
return success();
}
@@ -1240,15 +1302,17 @@ public:
return rewriter.notifyMatchFailure(
op, "only 1-element vectors are supported");
- Operation *splat = op.getValueToStore().getDefiningOp();
- if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
- return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast");
+ Value toStore = op.getValueToStore();
+ Value source = getBroadcastLikeSource(toStore);
+ if (!source)
+ return rewriter.notifyMatchFailure(
+ op, "value to store is not from a broadcast");
// Checking for single use so we can remove splat.
+ Operation *splat = toStore.getDefiningOp();
if (!splat->hasOneUse())
return rewriter.notifyMatchFailure(op, "expected single op use");
- Value source = splat->getOperand(0);
Value base = op.getBase();
ValueRange indices = op.getIndices();
@@ -1298,13 +1362,13 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
// Add in an offset if requested.
if (off) {
Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
- Value ov = vector::SplatOp::create(rewriter, loc, indices.getType(), o);
+ Value ov = vector::BroadcastOp::create(rewriter, loc, indices.getType(), o);
indices = arith::AddIOp::create(rewriter, loc, ov, indices);
}
// Construct the vector comparison.
Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
Value bounds =
- vector::SplatOp::create(rewriter, loc, indices.getType(), bound);
+ vector::BroadcastOp::create(rewriter, loc, indices.getType(), bound);
return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
indices, bounds);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index fceba65..501abec 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -16,13 +16,11 @@
#include "mlir/Interfaces/VectorInterfaces.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/InterleavedRange.h"
#include <optional>
#define DEBUG_TYPE "vector-unroll"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
using namespace mlir;
using namespace mlir::vector;
@@ -90,10 +88,9 @@ static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
/// std::nullopt if the op shouldn't be or cannot be unrolled.
static std::optional<SmallVector<int64_t>>
getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
- LDBG("");
- LDBG("Get unroll shape for op " << op->getName().getStringRef());
+ LDBG() << "Get unroll shape for op " << op->getName().getStringRef();
if (options.filterConstraint && failed(options.filterConstraint(op))) {
- LDBG("--no filter constraint -> BAIL");
+ LDBG() << "--no filter constraint -> BAIL";
return std::nullopt;
}
assert(options.nativeShape &&
@@ -101,33 +98,33 @@ getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
"shape call back function to be set");
auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
if (!unrollableVectorOp) {
- LDBG("--not an unrollable op -> BAIL");
+ LDBG() << "--not an unrollable op -> BAIL";
return std::nullopt;
}
auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
if (!maybeUnrollShape) {
- LDBG("--could not get shape of op " << *op << " -> BAIL");
+ LDBG() << "--could not get shape of op " << *op << " -> BAIL";
return std::nullopt;
}
- LDBG("--vector op shape: " << llvm::interleaved(*maybeUnrollShape));
+ LDBG() << "--vector op shape: " << llvm::interleaved(*maybeUnrollShape);
std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
if (!targetShape) {
- LDBG("--no unrolling target shape defined " << *op << "-> SKIP");
+ LDBG() << "--no unrolling target shape defined " << *op << "-> SKIP";
return std::nullopt;
}
- LDBG("--target shape: " << llvm::interleaved(*targetShape));
+ LDBG() << "--target shape: " << llvm::interleaved(*targetShape);
auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
if (!maybeShapeRatio) {
- LDBG("--could not compute integral shape ratio -> BAIL");
+ LDBG() << "--could not compute integral shape ratio -> BAIL";
return std::nullopt;
}
if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
- LDBG("--no unrolling needed -> SKIP");
+ LDBG() << "--no unrolling needed -> SKIP";
return std::nullopt;
}
- LDBG("--found an integral shape ratio to unroll to -> SUCCESS");
+ LDBG() << "--found an integral shape ratio to unroll to -> SUCCESS";
return targetShape;
}
@@ -169,7 +166,7 @@ struct UnrollTransferReadPattern
auto sourceVectorType = readOp.getVectorType();
SmallVector<int64_t> strides(targetShape->size(), 1);
Location loc = readOp.getLoc();
- ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
+ ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
// Prepare the result vector;
Value result =
@@ -225,6 +222,14 @@ struct UnrollTransferWritePattern
SmallVector<int64_t> strides(targetShape->size(), 1);
Location loc = writeOp.getLoc();
ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
+ // Bail-out if rank(source) != rank(target). The main limitation here is the
+ // fact that `ExtractStridedSlice` requires the rank for the input and
+ // output to match. If needed, we can relax this later.
+ if (originalSize.size() != targetShape->size())
+ return rewriter.notifyMatchFailure(
+ writeOp,
+ "expected source input vector rank to match target shape rank");
+
SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
writeOp.getIndices().end());
SmallVector<int64_t> loopOrder =
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index c045063..10ed2bc 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -27,13 +27,11 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseSet.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/InterleavedRange.h"
#define DEBUG_TYPE "vector-utils"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
-
using namespace mlir;
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
@@ -369,14 +367,14 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
LogicalResult
vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
ArrayRef<int64_t> inputVectorSizes) {
- LDBG("Iteration space static sizes:" << llvm::interleaved(shape));
+ LDBG() << "Iteration space static sizes:" << llvm::interleaved(shape);
if (inputVectorSizes.size() != shape.size()) {
- LDBG("Input vector sizes don't match the number of loops");
+ LDBG() << "Input vector sizes don't match the number of loops";
return failure();
}
if (ShapedType::isDynamicShape(inputVectorSizes)) {
- LDBG("Input vector sizes can't have dynamic dimensions");
+ LDBG() << "Input vector sizes can't have dynamic dimensions";
return failure();
}
if (!llvm::all_of(llvm::zip(shape, inputVectorSizes),
@@ -386,8 +384,9 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
return ShapedType::isDynamic(staticSize) ||
staticSize <= inputSize;
})) {
- LDBG("Input vector sizes must be greater than or equal to iteration space "
- "static sizes");
+ LDBG() << "Input vector sizes must be greater than or equal to iteration "
+ "space "
+ "static sizes";
return failure();
}
return success();
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 704deea..33450f3 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -110,6 +110,34 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
return success();
}
+static LogicalResult
+isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy,
+ int64_t chunkSize,
+ function_ref<InFlightDiagnostic()> emitError) {
+
+ if (!valueTy)
+ return emitError() << "Expecting a vector type result.";
+
+ auto maskShape = getShapeOf(maskTy);
+ auto valueShape = getShapeOf(valueTy);
+
+ // a valid shape for SIMT case
+ if (valueTy.getRank() == 1) {
+ if (valueTy.getNumElements() != chunkSize)
+ return emitError() << "value elements must match chunk size " << chunkSize
+ << " for SIMT code.";
+ return success();
+ }
+
+ llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
+ if (chunkSize > 1)
+ expectedMaskShape.pop_back();
+ if (expectedMaskShape != maskShape)
+ return emitError() << "Mask should match value except the chunk size dim.";
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_CreateNdDescOp
//===----------------------------------------------------------------------===//
@@ -644,9 +672,14 @@ LogicalResult CreateDescOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult PrefetchOp::verify() {
auto tdescTy = getTensorDescType();
- if (!tdescTy.isScattered())
+
+ if (tdescTy && !tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
+ if (!tdescTy && getRankOf(getSource()) > 1)
+ return emitOpError(
+ "Expecting the source is a 1D memref or pointer (uint64_t).");
+
if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -659,6 +692,13 @@ LogicalResult PrefetchOp::verify() {
return success();
}
+void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint);
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_LoadGatherOp
//===----------------------------------------------------------------------===//
@@ -667,6 +707,13 @@ LogicalResult LoadGatherOp::verify() {
auto maskTy = getMaskType();
auto valueTy = getValueType();
+ if (tdescTy && !tdescTy.isScattered())
+ return emitOpError("Expects a scattered TensorDesc.");
+
+ if (!tdescTy && getRankOf(getSource()) > 1)
+ return emitOpError(
+ "Expecting the source is a 1D memref or pointer (uint64_t).");
+
if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -676,8 +723,27 @@ LogicalResult LoadGatherOp::verify() {
if (!isReadHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
- return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
- [&]() { return emitOpError(); });
+ if (tdescTy)
+ return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
+ [&]() { return emitOpError(); });
+ auto srcTy = getSourceType();
+ uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
+ auto memTy = dyn_cast<MemRefType>(srcTy);
+
+ if (memTy && (valueTy.getElementType() != memTy.getElementType()))
+ return emitError() << "Value should have the same element type as MemRef.";
+
+ return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
+ [&]() { return emitOpError(); });
+}
+
+void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
+ Type valueType, Value source, Value mask,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
+ l1_hint, l2_hint, l3_hint);
}
//===----------------------------------------------------------------------===//
@@ -688,6 +754,13 @@ LogicalResult StoreScatterOp::verify() {
auto maskTy = getMaskType();
auto valueTy = getValueType();
+ if (tdescTy && !tdescTy.isScattered())
+ return emitOpError("Expects a scattered TensorDesc.\n");
+
+ if (!tdescTy && getRankOf(getDest()) > 1)
+ return emitOpError(
+ "Expecting the dest is a 1D memref or pointer (uint64_t).");
+
if (!isWriteHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -697,8 +770,28 @@ LogicalResult StoreScatterOp::verify() {
if (!isWriteHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
- return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
- [&]() { return emitOpError(); });
+ if (tdescTy)
+ return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
+ [&]() { return emitOpError(); });
+
+ auto destTy = getDestType();
+ uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
+ auto memTy = dyn_cast<MemRefType>(destTy);
+
+ if (memTy && (valueTy.getElementType() != memTy.getElementType()))
+ return emitError() << "Value should have the same element type as MemRef.";
+
+ return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
+ [&]() { return emitOpError(); });
+}
+
+void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
+ Value value, Value dest, Value mask,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
+ l2_hint, l3_hint);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 4656f11..d82c541 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -17,6 +17,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
namespace mlir {
namespace xegpu {
@@ -26,8 +27,6 @@ namespace xegpu {
} // namespace mlir
#define DEBUG_TYPE "xegpu-blocking"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
using namespace mlir;
@@ -53,7 +52,7 @@ resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
// We only interest in the case where all inputs and outputs have the
// identical VectorTypes
if (!hasIdenticalVectorTypes(inputs) || !hasIdenticalVectorTypes(outputs)) {
- LDBG("skip unrealized conversion cast op not emulating pack/unpack.");
+ LDBG() << "skip unrealized conversion cast op not emulating pack/unpack.";
return;
}
@@ -149,7 +148,7 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
if (auto type = dyn_cast<ShapedType>(value.getType()))
return llvm::to_vector(type.getShape());
}
- LDBG("failed to getTileShape for: " << value);
+ LDBG() << "failed to getTileShape for: " << value;
return std::nullopt;
}
@@ -214,7 +213,7 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
return layout && layout.isWgLayout();
});
if (hasWgLayoutOperands || hasWgLayoutResults) {
- LDBG("skip unrolling for op with workgroup level layout: " << *op);
+ LDBG() << "skip unrolling for op with workgroup level layout: " << *op;
return false;
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index a6208b4..c793b71 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -17,7 +17,7 @@
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "llvm/ADT/STLExtras.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
namespace mlir {
namespace xegpu {
@@ -27,8 +27,6 @@ namespace xegpu {
} // namespace mlir
#define DEBUG_TYPE "xegpu-unroll"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
using namespace mlir;
@@ -44,11 +42,10 @@ protected:
/// Return the target shape for the given `op`. Return std::nullopt if the
/// op shouldn't be or cannot be unrolled.
std::optional<SmallVector<int64_t>> getTargetShape(Operation *op) const {
- LDBG("");
- LDBG("Get unroll shape for: " << *op);
+ LDBG() << "Get unroll shape for: " << *op;
if (options.filterConstraint && failed(options.filterConstraint(op))) {
- LDBG("--no filter constraint -> BAIL");
+ LDBG() << "--no filter constraint -> BAIL";
return std::nullopt;
}
@@ -484,7 +481,8 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- if (!tdescTy.isScattered())
+ // TODO: handle the unstructure source case (!tdesTy)
+ if (!tdescTy || op.getOffsets())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -546,7 +544,8 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- if (!tdescTy.isScattered())
+ // TODO: handle the unstructure source case (!tdesTy)
+ if (!tdescTy || op.getOffsets())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -575,7 +574,8 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- if (!tdescTy.isScattered())
+ // TODO: handle the unstructure source case (!tdesTy)
+ if (!tdescTy || op.getOffsets())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 229a289..850f70c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -207,7 +207,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
// Subtract startOfRange from the original subgroup id to get the adjusted
// sg id
Value startOfRangeVal =
- rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
+ arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
adjustedSgId =
rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
}
@@ -431,8 +431,8 @@ struct WgToSgVectorBroadcastOp
SmallVector<Value> newBroadcastOps;
for (auto operand : adaptor.getOperands().front()) {
- auto newBroadcast = rewriter.create<vector::BroadcastOp>(
- op.getLoc(), newResultType, operand);
+ auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
+ newResultType, operand);
xegpu::setLayoutAttr(newBroadcast->getResult(0),
layout.dropSgLayoutAndData());
newBroadcastOps.push_back(newBroadcast.getResult());
@@ -563,8 +563,8 @@ struct WgToSgConvertLayoutOp
if (input && target) {
// keep the ConvertLayoutOp for rest fields, e.g., inst_data.
for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
- auto newOp = rewriter.create<xegpu::ConvertLayoutOp>(
- op.getLoc(), src.getType(), src, input, target);
+ auto newOp = xegpu::ConvertLayoutOp::create(
+ rewriter, op.getLoc(), src.getType(), src, input, target);
newOps[i] = newOp;
}
}
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 0652202..e55a666 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -8,7 +8,6 @@
#include <cmath>
#include <cstdint>
-#include <limits>
#include <utility>
#include "AffineExprDetail.h"
@@ -16,7 +15,6 @@
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/IntegerSet.h"
-#include "mlir/Support/TypeID.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/MathExtras.h"
#include <numeric>
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index f95ad29..de52fbd 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -40,7 +40,7 @@
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/Regex.h"
@@ -2070,9 +2070,8 @@ static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op,
return failure();
});
if (failed(verify(op))) {
- LLVM_DEBUG(llvm::dbgs()
- << DEBUG_TYPE << ": '" << op->getName()
- << "' failed to verify and will be printed in generic form\n");
+ LDBG() << op->getName()
+ << "' failed to verify and will be printed in generic form";
printerFlags.printGenericOpForm();
}
diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp
index 3e33795..776b5c6 100644
--- a/mlir/lib/IR/Diagnostics.cpp
+++ b/mlir/lib/IR/Diagnostics.cpp
@@ -821,15 +821,7 @@ SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
for (unsigned i = 0, e = mgr.getNumBuffers(); i != e; ++i)
(void)impl->computeExpectedDiags(out, mgr, mgr.getMemoryBuffer(i + 1));
- // Register a handler to verify the diagnostics.
- setHandler([&](Diagnostic &diag) {
- // Process the main diagnostics.
- process(diag);
-
- // Process each of the notes.
- for (auto &note : diag.getNotes())
- process(note);
- });
+ registerInContext(ctx);
}
SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
@@ -862,6 +854,17 @@ LogicalResult SourceMgrDiagnosticVerifierHandler::verify() {
return impl->status;
}
+void SourceMgrDiagnosticVerifierHandler::registerInContext(MLIRContext *ctx) {
+ ctx->getDiagEngine().registerHandler([&](Diagnostic &diag) {
+ // Process the main diagnostics.
+ process(diag);
+
+ // Process each of the notes.
+ for (auto &note : diag.getNotes())
+ process(note);
+ });
+}
+
/// Process a single diagnostic.
void SourceMgrDiagnosticVerifierHandler::process(Diagnostic &diag) {
return process(diag.getLocation(), diag.str(), diag.getSeverity());
diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp
index f897546..23e70c6 100644
--- a/mlir/lib/IR/Location.cpp
+++ b/mlir/lib/IR/Location.cpp
@@ -18,13 +18,9 @@
#include "llvm/ADT/PointerIntPair.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
-#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/TrailingObjects.h"
#include <cassert>
-#include <iterator>
-#include <memory>
-#include <optional>
#include <tuple>
#include <utility>
diff --git a/mlir/lib/IR/PDL/PDLPatternMatch.cpp b/mlir/lib/IR/PDL/PDLPatternMatch.cpp
index 28b39dd..62a71aa 100644
--- a/mlir/lib/IR/PDL/PDLPatternMatch.cpp
+++ b/mlir/lib/IR/PDL/PDLPatternMatch.cpp
@@ -7,10 +7,7 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/IR/IRMapping.h"
-#include "mlir/IR/Iterators.h"
#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/RegionKindInterface.h"
#include "llvm/Support/InterleavedRange.h"
using namespace mlir;
diff --git a/mlir/lib/IR/PatternLoggingListener.cpp b/mlir/lib/IR/PatternLoggingListener.cpp
index ce2123a..0db13ab 100644
--- a/mlir/lib/IR/PatternLoggingListener.cpp
+++ b/mlir/lib/IR/PatternLoggingListener.cpp
@@ -1,50 +1,48 @@
#include "mlir/IR/PatternMatch.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "pattern-logging-listener"
-#define DBGS() (llvm::dbgs() << "[" << DEBUG_TYPE << "] ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
using namespace mlir;
void RewriterBase::PatternLoggingListener::notifyOperationInserted(
Operation *op, InsertPoint previous) {
- LDBG(patternName << " | notifyOperationInserted"
- << " | " << op->getName());
+ LDBG() << patternName << " | notifyOperationInserted"
+ << " | " << op->getName();
ForwardingListener::notifyOperationInserted(op, previous);
}
void RewriterBase::PatternLoggingListener::notifyOperationModified(
Operation *op) {
- LDBG(patternName << " | notifyOperationModified"
- << " | " << op->getName());
+ LDBG() << patternName << " | notifyOperationModified"
+ << " | " << op->getName();
ForwardingListener::notifyOperationModified(op);
}
void RewriterBase::PatternLoggingListener::notifyOperationReplaced(
Operation *op, Operation *newOp) {
- LDBG(patternName << " | notifyOperationReplaced (with op)"
- << " | " << op->getName() << " | " << newOp->getName());
+ LDBG() << patternName << " | notifyOperationReplaced (with op)"
+ << " | " << op->getName() << " | " << newOp->getName();
ForwardingListener::notifyOperationReplaced(op, newOp);
}
void RewriterBase::PatternLoggingListener::notifyOperationReplaced(
Operation *op, ValueRange replacement) {
- LDBG(patternName << " | notifyOperationReplaced (with values)"
- << " | " << op->getName());
+ LDBG() << patternName << " | notifyOperationReplaced (with values)"
+ << " | " << op->getName();
ForwardingListener::notifyOperationReplaced(op, replacement);
}
void RewriterBase::PatternLoggingListener::notifyOperationErased(
Operation *op) {
- LDBG(patternName << " | notifyOperationErased"
- << " | " << op->getName());
+ LDBG() << patternName << " | notifyOperationErased"
+ << " | " << op->getName();
ForwardingListener::notifyOperationErased(op);
}
void RewriterBase::PatternLoggingListener::notifyPatternBegin(
const Pattern &pattern, Operation *op) {
- LDBG(patternName << " | notifyPatternBegin"
- << " | " << op->getName());
+ LDBG() << patternName << " | notifyPatternBegin"
+ << " | " << op->getName();
ForwardingListener::notifyPatternBegin(pattern, op);
}
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 1e60848..9332f55 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -7,8 +7,6 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/PatternMatch.h"
-#include "mlir/Config/mlir-config.h"
-#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Iterators.h"
#include "mlir/IR/RegionKindInterface.h"
#include "llvm/ADT/SmallPtrSet.h"
@@ -158,6 +156,11 @@ void RewriterBase::eraseOp(Operation *op) {
assert(op->use_empty() && "expected 'op' to have no uses");
auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
+ // If the current insertion point is before the erased operation, we adjust
+ // the insertion point to be after the operation.
+ if (getInsertionPoint() == op->getIterator())
+ setInsertionPointAfter(op);
+
// Fast path: If no listener is attached, the op can be dropped in one go.
if (!rewriteListener) {
op->erase();
@@ -322,6 +325,11 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
moveOpBefore(&source->front(), dest, before);
}
+ // If the current insertion point is within the source block, adjust the
+ // insertion point to the destination block.
+ if (getInsertionBlock() == source)
+ setInsertionPoint(dest, getInsertionPoint());
+
// Erase the source block.
assert(source->empty() && "expected 'source' to be empty");
eraseBlock(source);
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 07c311b..87b4799 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -10,7 +10,6 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/SetVector.h"
-#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringSwitch.h"
#include <optional>
diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp
index 7b3a946..fa550e4 100644
--- a/mlir/lib/IR/Value.cpp
+++ b/mlir/lib/IR/Value.cpp
@@ -8,9 +8,7 @@
#include "mlir/IR/Value.h"
#include "mlir/IR/Block.h"
-#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
-#include "llvm/ADT/SmallPtrSet.h"
using namespace mlir;
using namespace mlir::detail;
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index e9b5e92..310680b 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -17,14 +17,32 @@
using namespace mlir;
+static std::pair<int64_t, int64_t>
+getLineAndColStart(const llvm::SourceMgr &sourceMgr) {
+ unsigned lastFileID = sourceMgr.getNumBuffers();
+ if (lastFileID == 1)
+ return {0, 0};
+
+ auto bufferID = sourceMgr.getMainFileID();
+ const llvm::MemoryBuffer *main = sourceMgr.getMemoryBuffer(bufferID);
+ const llvm::MemoryBuffer *last = sourceMgr.getMemoryBuffer(lastFileID);
+ // Exclude same start.
+ if (main->getBufferStart() < last->getBufferStart() &&
+ main->getBufferEnd() >= last->getBufferEnd()) {
+ return sourceMgr.getLineAndColumn(
+ llvm::SMLoc::getFromPointer(last->getBufferStart()), bufferID);
+ }
+ return {0, 0};
+}
+
LogicalResult mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
Block *block, const ParserConfig &config,
LocationAttr *sourceFileLoc) {
const auto *sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
if (sourceFileLoc) {
- *sourceFileLoc = FileLineColLoc::get(config.getContext(),
- sourceBuf->getBufferIdentifier(),
- /*line=*/0, /*column=*/0);
+ auto [line, column] = getLineAndColStart(sourceMgr);
+ *sourceFileLoc = FileLineColLoc::get(
+ config.getContext(), sourceBuf->getBufferIdentifier(), line, column);
}
if (isBytecode(*sourceBuf))
return readBytecodeFile(*sourceBuf, block, config);
@@ -37,9 +55,9 @@ mlir::parseSourceFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
const auto *sourceBuf =
sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID());
if (sourceFileLoc) {
- *sourceFileLoc = FileLineColLoc::get(config.getContext(),
- sourceBuf->getBufferIdentifier(),
- /*line=*/0, /*column=*/0);
+ auto [line, column] = getLineAndColStart(*sourceMgr);
+ *sourceFileLoc = FileLineColLoc::get(
+ config.getContext(), sourceBuf->getBufferIdentifier(), line, column);
}
if (isBytecode(*sourceBuf))
return readBytecodeFile(sourceMgr, block, config);
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 0db9808..7094c8e 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -901,7 +901,7 @@ LogicalResult PassManager::run(Operation *op) {
if (failed(initialize(context, impl->initializationGeneration + 1)))
return failure();
initializationKey = newInitKey;
- pipelineKey = pipelineInitializationKey;
+ pipelineInitializationKey = pipelineKey;
}
// Construct a top level analysis manager for the pipeline.
diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp
index 7c294f0..bc766d4 100644
--- a/mlir/lib/Pass/PassRegistry.cpp
+++ b/mlir/lib/Pass/PassRegistry.cpp
@@ -10,7 +10,6 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
-#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Format.h"
diff --git a/mlir/lib/Query/Matcher/MatchersInternal.cpp b/mlir/lib/Query/Matcher/MatchersInternal.cpp
index 01f412a..21524f0 100644
--- a/mlir/lib/Query/Matcher/MatchersInternal.cpp
+++ b/mlir/lib/Query/Matcher/MatchersInternal.cpp
@@ -7,7 +7,6 @@
//===----------------------------------------------------------------------===//
#include "mlir/Query/Matcher/MatchersInternal.h"
-#include "llvm/ADT/SetVector.h"
namespace mlir::query::matcher {
diff --git a/mlir/lib/RegisterAllDialects.cpp b/mlir/lib/RegisterAllDialects.cpp
new file mode 100644
index 0000000..7a345ed
--- /dev/null
+++ b/mlir/lib/RegisterAllDialects.cpp
@@ -0,0 +1,207 @@
+//===- RegisterAllDialects.cpp - MLIR Dialects Registration -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a helper to trigger the registration of all dialects and
+// passes to the system.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/InitAllDialects.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/Async/IR/Async.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/GPU/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
+#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
+#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
+#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
+#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
+#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
+#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
+#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/SMT/IR/SMTDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
+#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/Transforms/RuntimeOpVerification.h"
+#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
+#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/Interfaces/CastInterfaces.h"
+#include "mlir/Target/LLVM/NVVM/Target.h"
+#include "mlir/Target/LLVM/ROCDL/Target.h"
+#include "mlir/Target/SPIRV/Target.h"
+
+/// Add all the MLIR dialects to the provided registry.
+void mlir::registerAllDialects(DialectRegistry &registry) {
+ // clang-format off
+ registry.insert<acc::OpenACCDialect,
+ affine::AffineDialect,
+ amdgpu::AMDGPUDialect,
+ amx::AMXDialect,
+ arith::ArithDialect,
+ arm_neon::ArmNeonDialect,
+ arm_sme::ArmSMEDialect,
+ arm_sve::ArmSVEDialect,
+ async::AsyncDialect,
+ bufferization::BufferizationDialect,
+ cf::ControlFlowDialect,
+ complex::ComplexDialect,
+ DLTIDialect,
+ emitc::EmitCDialect,
+ func::FuncDialect,
+ gpu::GPUDialect,
+ index::IndexDialect,
+ irdl::IRDLDialect,
+ linalg::LinalgDialect,
+ LLVM::LLVMDialect,
+ math::MathDialect,
+ memref::MemRefDialect,
+ shard::ShardDialect,
+ ml_program::MLProgramDialect,
+ mpi::MPIDialect,
+ nvgpu::NVGPUDialect,
+ NVVM::NVVMDialect,
+ omp::OpenMPDialect,
+ pdl::PDLDialect,
+ pdl_interp::PDLInterpDialect,
+ ptr::PtrDialect,
+ quant::QuantDialect,
+ ROCDL::ROCDLDialect,
+ scf::SCFDialect,
+ shape::ShapeDialect,
+ smt::SMTDialect,
+ sparse_tensor::SparseTensorDialect,
+ spirv::SPIRVDialect,
+ tensor::TensorDialect,
+ tosa::TosaDialect,
+ transform::TransformDialect,
+ ub::UBDialect,
+ vector::VectorDialect,
+ x86vector::X86VectorDialect,
+ xegpu::XeGPUDialect,
+ xevm::XeVMDialect>();
+ // clang-format on
+
+ // Register all external models.
+ affine::registerValueBoundsOpInterfaceExternalModels(registry);
+ arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
+ arith::registerBufferizableOpInterfaceExternalModels(registry);
+ arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
+ arith::registerShardingInterfaceExternalModels(registry);
+ arith::registerValueBoundsOpInterfaceExternalModels(registry);
+ bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
+ registry);
+ builtin::registerCastOpInterfaceExternalModels(registry);
+ cf::registerBufferizableOpInterfaceExternalModels(registry);
+ cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
+ gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
+ gpu::registerValueBoundsOpInterfaceExternalModels(registry);
+ LLVM::registerInlinerInterface(registry);
+ NVVM::registerInlinerInterface(registry);
+ linalg::registerAllDialectInterfaceImplementations(registry);
+ linalg::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
+ memref::registerAllocationOpInterfaceExternalModels(registry);
+ memref::registerBufferViewFlowOpInterfaceExternalModels(registry);
+ memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
+ memref::registerValueBoundsOpInterfaceExternalModels(registry);
+ memref::registerMemorySlotExternalModels(registry);
+ ml_program::registerBufferizableOpInterfaceExternalModels(registry);
+ scf::registerBufferDeallocationOpInterfaceExternalModels(registry);
+ scf::registerBufferizableOpInterfaceExternalModels(registry);
+ scf::registerValueBoundsOpInterfaceExternalModels(registry);
+ shape::registerBufferizableOpInterfaceExternalModels(registry);
+ sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry);
+ tensor::registerBufferizableOpInterfaceExternalModels(registry);
+ tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry);
+ tensor::registerInferTypeOpInterfaceExternalModels(registry);
+ tensor::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
+ tensor::registerSubsetOpInterfaceExternalModels(registry);
+ tensor::registerTilingInterfaceExternalModels(registry);
+ tensor::registerValueBoundsOpInterfaceExternalModels(registry);
+ tosa::registerShardingInterfaceExternalModels(registry);
+ vector::registerBufferizableOpInterfaceExternalModels(registry);
+ vector::registerSubsetOpInterfaceExternalModels(registry);
+ vector::registerValueBoundsOpInterfaceExternalModels(registry);
+ NVVM::registerNVVMTargetInterfaceExternalModels(registry);
+ ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
+ spirv::registerSPIRVTargetInterfaceExternalModels(registry);
+}
+
+/// Append all the MLIR dialects to the registry contained in the given context.
+void mlir::registerAllDialects(MLIRContext &context) {
+ DialectRegistry registry;
+ registerAllDialects(registry);
+ context.appendDialectRegistry(registry);
+}
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
new file mode 100644
index 0000000..8f7c67c
--- /dev/null
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -0,0 +1,115 @@
+//===- RegisterAllExtensions.cpp - MLIR Extension Registration --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a helper to trigger the registration of all dialect
+// extensions to the system.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/InitAllExtensions.h"
+
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
+#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
+#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
+#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
+#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h"
+#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
+#include "mlir/Conversion/GPUCommon/GPUToLLVM.h"
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
+#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
+#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
+#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
+#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
+#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
+#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
+#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
+#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
+#include "mlir/Dialect/AMX/Transforms.h"
+#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
+#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
+#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
+#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
+#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
+#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
+#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
+#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
+#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h"
+#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
+#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
+#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
+#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h"
+#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h"
+#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
+#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h"
+#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
+#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
+#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
+#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
+#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
+#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
+
+/// This function may be called to register all MLIR dialect extensions with the
+/// provided registry.
+/// If you're building a compiler, you generally shouldn't use this: you would
+/// individually register the specific extensions that are useful for the
+/// pipelines and transformations you are using.
+void mlir::registerAllExtensions(DialectRegistry &registry) {
+ // Register all conversions to LLVM extensions.
+ registerConvertArithToEmitCInterface(registry);
+ arith::registerConvertArithToLLVMInterface(registry);
+ registerConvertComplexToLLVMInterface(registry);
+ cf::registerConvertControlFlowToLLVMInterface(registry);
+ func::registerAllExtensions(registry);
+ tensor::registerAllExtensions(registry);
+ registerConvertFuncToEmitCInterface(registry);
+ registerConvertFuncToLLVMInterface(registry);
+ index::registerConvertIndexToLLVMInterface(registry);
+ registerConvertMathToLLVMInterface(registry);
+ mpi::registerConvertMPIToLLVMInterface(registry);
+ registerConvertMemRefToEmitCInterface(registry);
+ registerConvertMemRefToLLVMInterface(registry);
+ registerConvertNVVMToLLVMInterface(registry);
+ registerConvertOpenMPToLLVMInterface(registry);
+ registerConvertSCFToEmitCInterface(registry);
+ ub::registerConvertUBToLLVMInterface(registry);
+ registerConvertAMXToLLVMInterface(registry);
+ gpu::registerConvertGpuToLLVMInterface(registry);
+ NVVM::registerConvertGpuToNVVMInterface(registry);
+ vector::registerConvertVectorToLLVMInterface(registry);
+ registerConvertXeVMToLLVMInterface(registry);
+
+ // Register all transform dialect extensions.
+ affine::registerTransformDialectExtension(registry);
+ bufferization::registerTransformDialectExtension(registry);
+ dlti::registerTransformDialectExtension(registry);
+ func::registerTransformDialectExtension(registry);
+ gpu::registerTransformDialectExtension(registry);
+ linalg::registerTransformDialectExtension(registry);
+ memref::registerTransformDialectExtension(registry);
+ nvgpu::registerTransformDialectExtension(registry);
+ scf::registerTransformDialectExtension(registry);
+ sparse_tensor::registerTransformDialectExtension(registry);
+ tensor::registerTransformDialectExtension(registry);
+ transform::registerDebugExtension(registry);
+ transform::registerIRDLExtension(registry);
+ transform::registerLoopExtension(registry);
+ transform::registerPDLExtension(registry);
+ transform::registerTuneExtension(registry);
+ vector::registerTransformDialectExtension(registry);
+ arm_neon::registerTransformDialectExtension(registry);
+ arm_sve::registerTransformDialectExtension(registry);
+
+ // Translation extensions need to be registered by calling
+ // `registerAllToLLVMIRTranslations` (see All.h).
+}
diff --git a/mlir/lib/RegisterAllPasses.cpp b/mlir/lib/RegisterAllPasses.cpp
new file mode 100644
index 0000000..1ed3a37
--- /dev/null
+++ b/mlir/lib/RegisterAllPasses.cpp
@@ -0,0 +1,99 @@
+//===- RegisterAllPasses.cpp - MLIR Registration ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a helper to trigger the registration of all passes to the
+// system.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/InitAllPasses.h"
+
+#include "mlir/Conversion/Passes.h"
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+#include "mlir/Dialect/Affine/Passes.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
+#include "mlir/Dialect/Async/Passes.h"
+#include "mlir/Dialect/Bufferization/Pipelines/Passes.h"
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
+#include "mlir/Dialect/EmitC/Transforms/Passes.h"
+#include "mlir/Dialect/Func/Transforms/Passes.h"
+#include "mlir/Dialect/GPU/Pipelines/Passes.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+#include "mlir/Dialect/Quant/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
+#include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Dialect/Shard/Transforms/Passes.h"
+#include "mlir/Dialect/SparseTensor/Pipelines/Passes.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Transform/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Transforms/Passes.h"
+
+// This function may be called to register the MLIR passes with the
+// global registry.
+// If you're building a compiler, you likely don't need this: you would build a
+// pipeline programmatically without the need to register with the global
+// registry, since it would already be calling the creation routine of the
+// individual passes.
+// The global registry is interesting to interact with the command-line tools.
+void mlir::registerAllPasses() {
+ // General passes
+ registerTransformsPasses();
+
+ // Conversion passes
+ registerConversionPasses();
+
+ // Dialect passes
+ acc::registerOpenACCPasses();
+ affine::registerAffinePasses();
+ amdgpu::registerAMDGPUPasses();
+ registerAsyncPasses();
+ arith::registerArithPasses();
+ bufferization::registerBufferizationPasses();
+ func::registerFuncPasses();
+ registerGPUPasses();
+ registerLinalgPasses();
+ registerNVGPUPasses();
+ registerSparseTensorPasses();
+ LLVM::registerLLVMPasses();
+ math::registerMathPasses();
+ memref::registerMemRefPasses();
+ shard::registerShardPasses();
+ ml_program::registerMLProgramPasses();
+ quant::registerQuantPasses();
+ registerSCFPasses();
+ registerShapePasses();
+ spirv::registerSPIRVPasses();
+ tensor::registerTensorPasses();
+ tosa::registerTosaOptPasses();
+ transform::registerTransformPasses();
+ vector::registerVectorPasses();
+ arm_sme::registerArmSMEPasses();
+ arm_sve::registerArmSVEPasses();
+ emitc::registerEmitCPasses();
+ xegpu::registerXeGPUPasses();
+
+ // Dialect pipelines
+ bufferization::registerBufferizationPipelines();
+ sparse_tensor::registerSparseTensorPipelines();
+ tosa::registerTosaToLinalgPipelines();
+ gpu::registerGPUToNVVMPipeline();
+}
diff --git a/mlir/lib/Support/ToolUtilities.cpp b/mlir/lib/Support/ToolUtilities.cpp
index 748f928..2cf33eb 100644
--- a/mlir/lib/Support/ToolUtilities.cpp
+++ b/mlir/lib/Support/ToolUtilities.cpp
@@ -14,6 +14,8 @@
#include "mlir/Support/LLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
+#include <string>
+#include <utility>
using namespace mlir;
@@ -22,18 +24,18 @@ mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer,
ChunkBufferHandler processChunkBuffer,
raw_ostream &os, llvm::StringRef inputSplitMarker,
llvm::StringRef outputSplitMarker) {
+ llvm::MemoryBufferRef originalBufferRef = originalBuffer->getMemBufferRef();
// If splitting is disabled, we process the full input buffer.
if (inputSplitMarker.empty())
- return processChunkBuffer(std::move(originalBuffer), os);
+ return processChunkBuffer(std::move(originalBuffer), originalBufferRef, os);
const int inputSplitMarkerLen = inputSplitMarker.size();
- auto *origMemBuffer = originalBuffer.get();
SmallVector<StringRef, 8> rawSourceBuffers;
const int checkLen = 2;
// Split dropping the last checkLen chars to enable flagging near misses.
- origMemBuffer->getBuffer().split(rawSourceBuffers,
- inputSplitMarker.drop_back(checkLen));
+ originalBufferRef.getBuffer().split(rawSourceBuffers,
+ inputSplitMarker.drop_back(checkLen));
if (rawSourceBuffers.empty())
return success();
@@ -79,11 +81,17 @@ mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer,
auto interleaveFn = [&](StringRef subBuffer) {
auto splitLoc = SMLoc::getFromPointer(subBuffer.data());
unsigned splitLine = fileSourceMgr.getLineAndColumn(splitLoc).first;
- auto subMemBuffer = llvm::MemoryBuffer::getMemBufferCopy(
- subBuffer, Twine("within split at ") +
- origMemBuffer->getBufferIdentifier() + ":" +
- Twine(splitLine) + " offset ");
- if (failed(processChunkBuffer(std::move(subMemBuffer), os)))
+ std::string name((Twine("within split at ") +
+ originalBufferRef.getBufferIdentifier() + ":" +
+ Twine(splitLine) + " offset ")
+ .str());
+ // Use MemoryBufferRef to avoid copying the buffer & keep at same location
+ // relative to the original buffer.
+ auto subMemBuffer =
+ llvm::MemoryBuffer::getMemBuffer(llvm::MemoryBufferRef(subBuffer, name),
+ /*RequiresNullTerminator=*/false);
+ if (failed(
+ processChunkBuffer(std::move(subMemBuffer), originalBufferRef, os)))
hadFailure = true;
};
llvm::interleave(sourceBuffers, os, interleaveFn,
@@ -92,3 +100,16 @@ mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer,
// If any fails, then return a failure of the tool.
return failure(hadFailure);
}
+
+LogicalResult
+mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer,
+ NoSourceChunkBufferHandler processChunkBuffer,
+ raw_ostream &os, llvm::StringRef inputSplitMarker,
+ llvm::StringRef outputSplitMarker) {
+ auto process = [&](std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
+ const llvm::MemoryBufferRef &, raw_ostream &os) {
+ return processChunkBuffer(std::move(chunkBuffer), os);
+ };
+ return splitAndProcessBuffer(std::move(originalBuffer), process, os,
+ inputSplitMarker, outputSplitMarker);
+}
diff --git a/mlir/lib/Support/TypeID.cpp b/mlir/lib/Support/TypeID.cpp
index 01ad910..304253c 100644
--- a/mlir/lib/Support/TypeID.cpp
+++ b/mlir/lib/Support/TypeID.cpp
@@ -27,9 +27,6 @@ namespace {
struct ImplicitTypeIDRegistry {
/// Lookup or insert a TypeID for the given type name.
TypeID lookupOrInsert(StringRef typeName) {
- LLVM_DEBUG(llvm::dbgs() << "ImplicitTypeIDRegistry::lookupOrInsert("
- << typeName << ")\n");
-
// Perform a heuristic check to see if this type is in an anonymous
// namespace. String equality is not valid for anonymous types, so we try to
// abort whenever we see them.
diff --git a/mlir/lib/TableGen/Successor.cpp b/mlir/lib/TableGen/Successor.cpp
index ce0aafb..cd0677d 100644
--- a/mlir/lib/TableGen/Successor.cpp
+++ b/mlir/lib/TableGen/Successor.cpp
@@ -12,7 +12,6 @@
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Successor.h"
-#include "llvm/ADT/TypeSwitch.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp
index 4f74056..b31377e 100644
--- a/mlir/lib/TableGen/Type.cpp
+++ b/mlir/lib/TableGen/Type.cpp
@@ -12,7 +12,6 @@
#include "mlir/TableGen/Type.h"
#include "mlir/TableGen/Dialect.h"
-#include "llvm/ADT/Twine.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/TableGen/Record.h"
diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
index 2108ffd..7dae03e 100644
--- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp
+++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
@@ -9,8 +9,6 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Dialect.h"
#include "mlir/Target/Cpp/CppEmitter.h"
#include "mlir/Tools/mlir-translate/Translation.h"
#include "llvm/Support/CommandLine.h"
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index a393d88..dcd2e11 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -17,15 +17,12 @@
#include "mlir/Support/IndentedOstream.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/Cpp/CppEmitter.h"
-#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/ADT/StringExtras.h"
-#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include <stack>
-#include <utility>
#define DEBUG_TYPE "translate-to-cpp"
@@ -903,8 +900,7 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
// inlined, and as such should be wrapped in parentheses in order to guarantee
// its precedence and associativity.
auto requiresParentheses = [&](Value value) {
- auto expressionOp =
- dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
+ auto expressionOp = value.getDefiningOp<ExpressionOp>();
if (!expressionOp)
return false;
return shouldBeInlined(expressionOp);
@@ -1545,7 +1541,7 @@ LogicalResult CppEmitter::emitOperand(Value value) {
return success();
}
- auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
+ auto expressionOp = value.getDefiningOp<ExpressionOp>();
if (expressionOp && shouldBeInlined(expressionOp))
return emitExpression(expressionOp);
diff --git a/mlir/lib/Target/LLVM/CMakeLists.txt b/mlir/lib/Target/LLVM/CMakeLists.txt
index 83fbf7a..f6e44c6 100644
--- a/mlir/lib/Target/LLVM/CMakeLists.txt
+++ b/mlir/lib/Target/LLVM/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_library(MLIRTargetLLVM
intrinsics_gen
LINK_COMPONENTS
+ BitWriter
Core
IPO
IRReader
@@ -59,7 +60,7 @@ if ("NVPTX" IN_LIST LLVM_TARGETS_TO_BUILD)
# See: https://gitlab.kitware.com/cmake/cmake/-/issues/24858
# TODO: Bump the MLIR CMake version to 3.26.4 and switch to
# ${CUDAToolkit_LIBRARY_ROOT}
- if(NOT DEFINED ${CUDAToolkit_LIBRARY_ROOT})
+ if(NOT DEFINED CUDAToolkit_LIBRARY_ROOT)
get_filename_component(MLIR_CUDAToolkit_ROOT ${CUDAToolkit_BIN_DIR}
DIRECTORY ABSOLUTE)
else()
diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt
index af22a7f..9ea5c683 100644
--- a/mlir/lib/Target/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt
@@ -60,6 +60,7 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration
MLIRROCDLToLLVMIRTranslation
MLIRSPIRVToLLVMIRTranslation
MLIRVCIXToLLVMIRTranslation
+ MLIRXeVMToLLVMIRTranslation
)
add_mlir_translation_library(MLIRTargetLLVMIRImport
diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
index 75170bf..8e6f5c7 100644
--- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
@@ -12,7 +12,6 @@
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/BuiltinOps.h"
#include "mlir/Target/LLVMIR/Dialect/All.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Tools/mlir-translate/Translation.h"
diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
index f030fa7..86c731a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
@@ -10,3 +10,4 @@ add_subdirectory(OpenMP)
add_subdirectory(ROCDL)
add_subdirectory(SPIRV)
add_subdirectory(VCIX)
+add_subdirectory(XeVM)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index ff34a08..0f675a0 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -13,6 +13,7 @@
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
@@ -136,46 +137,6 @@ convertOperandBundles(OperandRangeRange bundleOperands,
return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation);
}
-static LogicalResult
-convertParameterAndResultAttrs(mlir::Location loc, ArrayAttr argAttrsArray,
- ArrayAttr resAttrsArray, llvm::CallBase *call,
- LLVM::ModuleTranslation &moduleTranslation) {
- if (argAttrsArray) {
- for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) {
- if (auto argAttrs = cast<DictionaryAttr>(argAttrsAttr);
- !argAttrs.empty()) {
- FailureOr<llvm::AttrBuilder> attrBuilder =
- moduleTranslation.convertParameterAttrs(loc, argAttrs);
- if (failed(attrBuilder))
- return failure();
- call->addParamAttrs(argIdx, *attrBuilder);
- }
- }
- }
-
- if (resAttrsArray && resAttrsArray.size() > 0) {
- if (resAttrsArray.size() != 1)
- return mlir::emitError(loc, "llvm.func cannot have multiple results");
- if (auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]);
- !resAttrs.empty()) {
- FailureOr<llvm::AttrBuilder> attrBuilder =
- moduleTranslation.convertParameterAttrs(loc, resAttrs);
- if (failed(attrBuilder))
- return failure();
- call->addRetAttrs(*attrBuilder);
- }
- }
- return success();
-}
-
-static LogicalResult
-convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call,
- LLVM::ModuleTranslation &moduleTranslation) {
- return convertParameterAndResultAttrs(
- callOp.getLoc(), callOp.getArgAttrsAttr(), callOp.getResAttrsAttr(), call,
- moduleTranslation);
-}
-
/// Builder for LLVM_CallIntrinsicOp
static LogicalResult
convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
@@ -243,9 +204,7 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
convertOperandBundles(op.getOpBundleOperands(), op.getOpBundleTags(),
moduleTranslation));
- if (failed(convertParameterAndResultAttrs(op.getLoc(), op.getArgAttrsAttr(),
- op.getResAttrsAttr(), inst,
- moduleTranslation)))
+ if (failed(moduleTranslation.convertArgAndResultAttrs(op, inst)))
return failure();
if (op.getNumResults() == 1)
@@ -455,7 +414,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
if (callOp.getInlineHintAttr())
call->addFnAttr(llvm::Attribute::InlineHint);
- if (failed(convertParameterAndResultAttrs(callOp, call, moduleTranslation)))
+ if (failed(moduleTranslation.convertArgAndResultAttrs(callOp, call)))
return failure();
if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) {
@@ -569,8 +528,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
operandsRef.drop_front(), opBundles);
}
result->setCallingConv(convertCConvToLLVM(invOp.getCConv()));
- if (failed(
- convertParameterAndResultAttrs(invOp, result, moduleTranslation)))
+ if (failed(moduleTranslation.convertArgAndResultAttrs(invOp, result)))
return failure();
moduleTranslation.mapBranch(invOp, result);
// InvokeOp can only have 0 or 1 result
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
index ad01a64..55e73e8 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
@@ -13,7 +13,6 @@
#include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Target/LLVMIR/ModuleImport.h"
-
#include "llvm/IR/ConstantRange.h"
using namespace mlir;
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
index d162afd..97c6b4e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
@@ -151,8 +151,7 @@ processDataOperands(llvm::IRBuilderBase &builder,
// Copyin operands are handled as `to` call.
llvm::SmallVector<mlir::Value> create, copyin;
for (mlir::Value dataOp : op.getDataClauseOperands()) {
- if (auto createOp =
- mlir::dyn_cast_or_null<acc::CreateOp>(dataOp.getDefiningOp())) {
+ if (auto createOp = dataOp.getDefiningOp<acc::CreateOp>()) {
create.push_back(createOp.getVarPtr());
} else if (auto copyinOp = mlir::dyn_cast_or_null<acc::CopyinOp>(
dataOp.getDefiningOp())) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index da39b19..49e1e55 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -16,15 +16,12 @@
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
-#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
-#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
@@ -39,7 +36,6 @@
#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
-#include <any>
#include <cstdint>
#include <iterator>
#include <numeric>
@@ -3541,8 +3537,7 @@ getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp,
}
static bool isDeclareTargetLink(mlir::Value value) {
- if (auto addressOfOp =
- llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.getDefiningOp())) {
+ if (auto addressOfOp = value.getDefiningOp<LLVM::AddressOfOp>()) {
auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName());
if (auto declareTargetGlobal =
@@ -3882,29 +3877,28 @@ static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
llvm::SmallVector<size_t> indices(indexAttr.size());
std::iota(indices.begin(), indices.end(), 0);
- llvm::sort(indices.begin(), indices.end(),
- [&](const size_t a, const size_t b) {
- auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
- auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
- for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
- int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
- int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
+ llvm::sort(indices, [&](const size_t a, const size_t b) {
+ auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
+ auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
+ for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
+ int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
+ int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
- if (aIndex == bIndex)
- continue;
+ if (aIndex == bIndex)
+ continue;
- if (aIndex < bIndex)
- return first;
+ if (aIndex < bIndex)
+ return first;
- if (aIndex > bIndex)
- return !first;
- }
+ if (aIndex > bIndex)
+ return !first;
+ }
- // Iterated the up until the end of the smallest member and
- // they were found to be equal up to that point, so select
- // the member with the lowest index count, so the "parent"
- return memberIndicesA.size() < memberIndicesB.size();
- });
+ // Iterated the up until the end of the smallest member and
+ // they were found to be equal up to that point, so select
+ // the member with the lowest index count, so the "parent"
+ return memberIndicesA.size() < memberIndicesB.size();
+ });
return llvm::cast<omp::MapInfoOp>(
mapInfo.getMembers()[indices.front()].getDefiningOp());
@@ -4502,8 +4496,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
ifCond = moduleTranslation.lookupValue(ifVar);
if (auto devId = dataOp.getDevice())
- if (auto constOp =
- dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
+ if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
@@ -4520,8 +4513,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
ifCond = moduleTranslation.lookupValue(ifVar);
if (auto devId = enterDataOp.getDevice())
- if (auto constOp =
- dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
+ if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
RTLFn =
@@ -4540,8 +4532,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
ifCond = moduleTranslation.lookupValue(ifVar);
if (auto devId = exitDataOp.getDevice())
- if (auto constOp =
- dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
+ if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
@@ -4560,8 +4551,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
ifCond = moduleTranslation.lookupValue(ifVar);
if (auto devId = updateDataOp.getDevice())
- if (auto constOp =
- dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
+ if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
@@ -5202,8 +5192,7 @@ static std::optional<int64_t> extractConstInteger(Value value) {
if (!value)
return std::nullopt;
- if (auto constOp =
- dyn_cast_if_present<LLVM::ConstantOp>(value.getDefiningOp()))
+ if (auto constOp = value.getDefiningOp<LLVM::ConstantOp>())
if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
return constAttr.getInt();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt
new file mode 100644
index 0000000..6308d7e
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt
@@ -0,0 +1,21 @@
+set(LLVM_OPTIONAL_SOURCES
+ XeVMToLLVMIRTranslation.cpp
+)
+
+add_mlir_translation_library(MLIRXeVMToLLVMIRTranslation
+ XeVMToLLVMIRTranslation.cpp
+
+ DEPENDS
+ MLIRXeVMConversionsIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRDialectUtils
+ MLIRIR
+ MLIRLLVMDialect
+ MLIRXeVMDialect
+ MLIRSupport
+ MLIRTargetLLVMIRExport
+)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
new file mode 100644
index 0000000..73b166d
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
@@ -0,0 +1,103 @@
+//===-- XeVMToLLVMIRTranslation.cpp - Translate XeVM to LLVM IR -*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the MLIR XeVM dialect and
+// LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Metadata.h"
+
+#include "llvm/IR/ConstantRange.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+namespace {
+/// Implementation of the dialect interface that converts operations belonging
+/// to the XeVM dialect to LLVM IR.
+class XeVMDialectLLVMIRTranslationInterface
+ : public LLVMTranslationDialectInterface {
+public:
+ using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
+
+ /// Attaches module-level metadata for functions marked as kernels.
+ LogicalResult
+ amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
+ LLVM::ModuleTranslation &moduleTranslation) const final {
+ StringRef attrName = attribute.getName().getValue();
+ if (attrName == mlir::xevm::XeVMDialect::getCacheControlsAttrName()) {
+ auto cacheControlsArray = dyn_cast<ArrayAttr>(attribute.getValue());
+ if (cacheControlsArray.size() != 2) {
+ return op->emitOpError(
+ "Expected both L1 and L3 cache control attributes!");
+ }
+ if (instructions.size() != 1) {
+ return op->emitOpError("Expecting a single instruction");
+ }
+ return handleDecorationCacheControl(instructions.front(),
+ cacheControlsArray.getValue());
+ }
+ auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
+ if (!func)
+ return failure();
+
+ return success();
+ }
+
+private:
+ static LogicalResult handleDecorationCacheControl(llvm::Instruction *inst,
+ ArrayRef<Attribute> attrs) {
+ SmallVector<llvm::Metadata *> decorations;
+ llvm::LLVMContext &ctx = inst->getContext();
+ llvm::Type *i32Ty = llvm::IntegerType::getInt32Ty(ctx);
+ llvm::transform(
+ attrs, std::back_inserter(decorations),
+ [&ctx, i32Ty](Attribute attr) -> llvm::Metadata * {
+ auto valuesArray = dyn_cast<ArrayAttr>(attr).getValue();
+ std::array<llvm::Metadata *, 4> metadata;
+ llvm::transform(
+ valuesArray, metadata.begin(), [i32Ty](Attribute valueAttr) {
+ return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(
+ i32Ty, cast<IntegerAttr>(valueAttr).getValue()));
+ });
+ return llvm::MDNode::get(ctx, metadata);
+ });
+ constexpr llvm::StringLiteral decorationCacheControlMDName =
+ "spirv.DecorationCacheControlINTEL";
+ inst->setMetadata(decorationCacheControlMDName,
+ llvm::MDNode::get(ctx, decorations));
+ return success();
+ }
+};
+} // namespace
+
+void mlir::registerXeVMDialectTranslation(::mlir::DialectRegistry &registry) {
+ registry.insert<xevm::XeVMDialect>();
+ registry.addExtension(+[](MLIRContext *ctx, xevm::XeVMDialect *dialect) {
+ dialect->addInterfaces<XeVMDialectLLVMIRTranslationInterface>();
+ });
+}
+
+void mlir::registerXeVMDialectTranslation(::mlir::MLIRContext &context) {
+ DialectRegistry registry;
+ registerXeVMDialectTranslation(registry);
+ context.appendDialectRegistry(registry);
+}
diff --git a/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp b/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp
index 580afdd..cb1f234 100644
--- a/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp
+++ b/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp
@@ -33,7 +33,9 @@ LogicalResult mlir::LLVMImportInterface::convertUnregisteredIntrinsic(
SmallVector<Value> mlirOperands;
SmallVector<NamedAttribute> mlirAttrs;
if (failed(moduleImport.convertIntrinsicArguments(
- llvmOperands, llvmOpBundles, false, {}, {}, mlirOperands, mlirAttrs)))
+ llvmOperands, llvmOpBundles, /*requiresOpBundles=*/false,
+ /*immArgPositions=*/{}, /*immArgAttrNames=*/{}, mlirOperands,
+ mlirAttrs)))
return failure();
Type resultType = moduleImport.convertType(inst->getType());
@@ -44,11 +46,7 @@ LogicalResult mlir::LLVMImportInterface::convertUnregisteredIntrinsic(
ValueRange{mlirOperands}, FastmathFlagsAttr{});
moduleImport.setFastmathFlagsAttr(inst, op);
-
- ArrayAttr argsAttr, resAttr;
- moduleImport.convertParameterAttributes(inst, argsAttr, resAttr, builder);
- op.setArgAttrsAttr(argsAttr);
- op.setResAttrsAttr(resAttr);
+ moduleImport.convertArgAndResultAttrs(inst, op);
// Update importer tracking of results.
unsigned numRes = op.getNumResults();
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 94db7f8..6325480 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -30,6 +30,7 @@
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Comdat.h"
#include "llvm/IR/Constants.h"
@@ -142,6 +143,7 @@ static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
// TODO: Implement the `convertInstruction` hooks in the
// `LLVMDialectLLVMIRImportInterface` and move the following include there.
#include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
+
return failure();
}
@@ -1062,6 +1064,18 @@ void ModuleImport::convertTargetTriple() {
builder.getStringAttr(llvmModule->getTargetTriple().str()));
}
+void ModuleImport::convertModuleLevelAsm() {
+ llvm::StringRef asmStr = llvmModule->getModuleInlineAsm();
+ llvm::SmallVector<mlir::Attribute> asmArrayAttr;
+
+ for (llvm::StringRef line : llvm::split(asmStr, '\n'))
+ if (!line.empty())
+ asmArrayAttr.push_back(builder.getStringAttr(line));
+
+ mlirModule->setAttr(LLVM::LLVMDialect::getModuleLevelAsmAttrName(),
+ builder.getArrayAttr(asmArrayAttr));
+}
+
LogicalResult ModuleImport::convertFunctions() {
for (llvm::Function &func : llvmModule->functions())
if (failed(processFunction(&func)))
@@ -1626,12 +1640,11 @@ FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) {
// Convert dso_local_equivalent.
if (auto *dsoLocalEquivalent = dyn_cast<llvm::DSOLocalEquivalent>(constant)) {
Type type = convertType(dsoLocalEquivalent->getType());
- return builder
- .create<DSOLocalEquivalentOp>(
- loc, type,
- FlatSymbolRefAttr::get(
- builder.getContext(),
- dsoLocalEquivalent->getGlobalValue()->getName()))
+ return DSOLocalEquivalentOp::create(
+ builder, loc, type,
+ FlatSymbolRefAttr::get(
+ builder.getContext(),
+ dsoLocalEquivalent->getGlobalValue()->getName()))
.getResult();
}
@@ -1736,9 +1749,9 @@ FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) {
FlatSymbolRefAttr::get(context, blockAddr->getFunction()->getName());
auto blockTag =
BlockTagAttr::get(context, blockAddr->getBasicBlock()->getNumber());
- return builder
- .create<BlockAddressOp>(loc, convertType(blockAddr->getType()),
- BlockAddressAttr::get(context, fnSym, blockTag))
+ return BlockAddressOp::create(
+ builder, loc, convertType(blockAddr->getType()),
+ BlockAddressAttr::get(context, fnSym, blockTag))
.getRes();
}
@@ -2228,17 +2241,16 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
if (!resultTy)
return failure();
ArrayAttr operandAttrs = convertAsmInlineOperandAttrs(*callInst);
- return builder
- .create<InlineAsmOp>(
- loc, resultTy, *operands,
- builder.getStringAttr(asmI->getAsmString()),
- builder.getStringAttr(asmI->getConstraintString()),
- asmI->hasSideEffects(), asmI->isAlignStack(),
- convertTailCallKindFromLLVM(callInst->getTailCallKind()),
- AsmDialectAttr::get(
- mlirModule.getContext(),
- convertAsmDialectFromLLVM(asmI->getDialect())),
- operandAttrs)
+ return InlineAsmOp::create(
+ builder, loc, resultTy, *operands,
+ builder.getStringAttr(asmI->getAsmString()),
+ builder.getStringAttr(asmI->getConstraintString()),
+ asmI->hasSideEffects(), asmI->isAlignStack(),
+ convertTailCallKindFromLLVM(callInst->getTailCallKind()),
+ AsmDialectAttr::get(
+ mlirModule.getContext(),
+ convertAsmDialectFromLLVM(asmI->getDialect())),
+ operandAttrs)
.getOperation();
}
bool isIncompatibleCall;
@@ -2268,7 +2280,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
// Handle parameter and result attributes unless it's an incompatible
// call.
if (!isIncompatibleCall)
- convertParameterAttributes(callInst, callOp, builder);
+ convertArgAndResultAttrs(callInst, callOp);
return callOp.getOperation();
}();
@@ -2365,7 +2377,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
// Handle parameter and result attributes unless it's an incompatible
// invoke.
if (!isIncompatibleInvoke)
- convertParameterAttributes(invokeInst, invokeOp, builder);
+ convertArgAndResultAttrs(invokeInst, invokeOp);
if (!invokeInst->getType()->isVoidTy())
mapValue(inst, invokeOp.getResults().front());
@@ -2731,11 +2743,10 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
}
DictionaryAttr
-ModuleImport::convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
- OpBuilder &builder) {
+ModuleImport::convertArgOrResultAttrSet(llvm::AttributeSet llvmAttrSet) {
SmallVector<NamedAttribute> paramAttrs;
for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) {
- auto llvmAttr = llvmParamAttrs.getAttribute(llvmKind);
+ auto llvmAttr = llvmAttrSet.getAttribute(llvmKind);
// Skip attributes that are not attached.
if (!llvmAttr.isValid())
continue;
@@ -2770,13 +2781,12 @@ ModuleImport::convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
return builder.getDictionaryAttr(paramAttrs);
}
-void ModuleImport::convertParameterAttributes(llvm::Function *func,
- LLVMFuncOp funcOp,
- OpBuilder &builder) {
+void ModuleImport::convertArgAndResultAttrs(llvm::Function *func,
+ LLVMFuncOp funcOp) {
auto llvmAttrs = func->getAttributes();
for (size_t i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
llvm::AttributeSet llvmArgAttrs = llvmAttrs.getParamAttrs(i);
- funcOp.setArgAttrs(i, convertParameterAttribute(llvmArgAttrs, builder));
+ funcOp.setArgAttrs(i, convertArgOrResultAttrSet(llvmArgAttrs));
}
// Convert the result attributes and attach them wrapped in an ArrayAttribute
// to the funcOp.
@@ -2784,17 +2794,23 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
if (!llvmResAttr.hasAttributes())
return;
funcOp.setResAttrsAttr(
- builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
+ builder.getArrayAttr({convertArgOrResultAttrSet(llvmResAttr)}));
}
-void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
- ArrayAttr &argsAttr,
- ArrayAttr &resAttr,
- OpBuilder &builder) {
+void ModuleImport::convertArgAndResultAttrs(
+ llvm::CallBase *call, ArgAndResultAttrsOpInterface attrsOp,
+ ArrayRef<unsigned> immArgPositions) {
+ // Compute the set of immediate argument positions.
+ llvm::SmallDenseSet<unsigned> immArgPositionsSet(immArgPositions.begin(),
+ immArgPositions.end());
+ // Convert the argument attributes and filter out immediate arguments.
llvm::AttributeList llvmAttrs = call->getAttributes();
SmallVector<llvm::AttributeSet> llvmArgAttrsSet;
bool anyArgAttrs = false;
for (size_t i = 0, e = call->arg_size(); i < e; ++i) {
+ // Skip immediate arguments.
+ if (immArgPositionsSet.contains(i))
+ continue;
llvmArgAttrsSet.emplace_back(llvmAttrs.getParamAttrs(i));
if (llvmArgAttrsSet.back().hasAttributes())
anyArgAttrs = true;
@@ -2808,24 +2824,16 @@ void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
if (anyArgAttrs) {
SmallVector<DictionaryAttr> argAttrs;
for (auto &llvmArgAttrs : llvmArgAttrsSet)
- argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder));
- argsAttr = getArrayAttr(argAttrs);
+ argAttrs.emplace_back(convertArgOrResultAttrSet(llvmArgAttrs));
+ attrsOp.setArgAttrsAttr(getArrayAttr(argAttrs));
}
+ // Convert the result attributes.
llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs();
if (!llvmResAttr.hasAttributes())
return;
- DictionaryAttr resAttrs = convertParameterAttribute(llvmResAttr, builder);
- resAttr = getArrayAttr({resAttrs});
-}
-
-void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
- CallOpInterface callOp,
- OpBuilder &builder) {
- ArrayAttr argsAttr, resAttr;
- convertParameterAttributes(call, argsAttr, resAttr, builder);
- callOp.setArgAttrsAttr(argsAttr);
- callOp.setResAttrsAttr(resAttr);
+ DictionaryAttr resAttrs = convertArgOrResultAttrSet(llvmResAttr);
+ attrsOp.setResAttrsAttr(getArrayAttr({resAttrs}));
}
template <typename Op>
@@ -2893,7 +2901,7 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) {
builder, loc, func->getName(), functionType,
convertLinkageFromLLVM(func->getLinkage()), dsoLocal, cconv);
- convertParameterAttributes(func, funcOp, builder);
+ convertArgAndResultAttrs(func, funcOp);
if (FlatSymbolRefAttr personality = getPersonalityAsAttr(func))
funcOp.setPersonalityAttr(personality);
@@ -3200,5 +3208,6 @@ OwningOpRef<ModuleOp> mlir::translateLLVMIRToModule(
if (failed(moduleImport.convertIFuncs()))
return {};
moduleImport.convertTargetTriple();
+ moduleImport.convertModuleLevelAsm();
return module;
}
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index b997e55..b3a06e2 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1758,6 +1758,48 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
return attrBuilder;
}
+LogicalResult ModuleTranslation::convertArgAndResultAttrs(
+ ArgAndResultAttrsOpInterface attrsOp, llvm::CallBase *call,
+ ArrayRef<unsigned> immArgPositions) {
+ // Convert the argument attributes.
+ if (ArrayAttr argAttrsArray = attrsOp.getArgAttrsAttr()) {
+ unsigned argAttrIdx = 0;
+ llvm::SmallDenseSet<unsigned> immArgPositionsSet(immArgPositions.begin(),
+ immArgPositions.end());
+ for (unsigned argIdx : llvm::seq<unsigned>(call->arg_size())) {
+ if (argAttrIdx >= argAttrsArray.size())
+ break;
+ // Skip immediate arguments (they have no entries in argAttrsArray).
+ if (immArgPositionsSet.contains(argIdx))
+ continue;
+ // Skip empty argument attributes.
+ auto argAttrs = cast<DictionaryAttr>(argAttrsArray[argAttrIdx++]);
+ if (argAttrs.empty())
+ continue;
+ // Convert and add attributes to the call instruction.
+ FailureOr<llvm::AttrBuilder> attrBuilder =
+ convertParameterAttrs(attrsOp->getLoc(), argAttrs);
+ if (failed(attrBuilder))
+ return failure();
+ call->addParamAttrs(argIdx, *attrBuilder);
+ }
+ }
+
+ // Convert the result attributes.
+ if (ArrayAttr resAttrsArray = attrsOp.getResAttrsAttr()) {
+ if (!resAttrsArray.empty()) {
+ auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]);
+ FailureOr<llvm::AttrBuilder> attrBuilder =
+ convertParameterAttrs(attrsOp->getLoc(), resAttrs);
+ if (failed(attrBuilder))
+ return failure();
+ call->addRetAttrs(*attrBuilder);
+ }
+ }
+
+ return success();
+}
+
FailureOr<llvm::AttrBuilder>
ModuleTranslation::convertParameterAttrs(Location loc,
DictionaryAttr paramAttrs) {
@@ -2276,6 +2318,25 @@ prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
llvmModule->setTargetTriple(
llvm::Triple(cast<StringAttr>(targetTripleAttr).getValue()));
+ if (auto asmAttr = m->getDiscardableAttr(
+ LLVM::LLVMDialect::getModuleLevelAsmAttrName())) {
+ auto asmArrayAttr = dyn_cast<ArrayAttr>(asmAttr);
+ if (!asmArrayAttr) {
+ m->emitError("expected an array attribute for a module level asm");
+ return nullptr;
+ }
+
+ for (Attribute elt : asmArrayAttr) {
+ auto asmStrAttr = dyn_cast<StringAttr>(elt);
+ if (!asmStrAttr) {
+ m->emitError(
+ "expected a string attribute for each entry of a module level asm");
+ return nullptr;
+ }
+ llvmModule->appendModuleInlineAsm(asmStrAttr.getValue());
+ }
+ }
+
return llvmModule;
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 88799a5..88931b5 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -347,10 +347,6 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
return emitError(unknownLoc, "OpDecoration with ")
<< decorationName << "needs a single target <id>";
}
- // Block decoration does not affect spirv.struct type, but is still stored
- // for verification.
- // TODO: Update StructType to contain this information since
- // it is needed for many validation rules.
decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
break;
case spirv::Decoration::Location:
@@ -993,7 +989,8 @@ spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
if (failed(structType.trySetBody(
deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
- deferredStructIt->memberDecorationsInfo)))
+ deferredStructIt->memberDecorationsInfo,
+ deferredStructIt->structDecorationsInfo)))
return failure();
deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
@@ -1188,13 +1185,14 @@ spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
}
offsetInfo[memberIndex] = memberDecoration.second[0];
} else {
+ auto intType = mlir::IntegerType::get(context, 32);
if (!memberDecoration.second.empty()) {
- memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
- memberDecoration.first,
- memberDecoration.second[0]);
+ memberDecorationsInfo.emplace_back(
+ memberIndex, memberDecoration.first,
+ IntegerAttr::get(intType, memberDecoration.second[0]));
} else {
- memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
- memberDecoration.first, 0);
+ memberDecorationsInfo.emplace_back(
+ memberIndex, memberDecoration.first, UnitAttr::get(context));
}
}
}
@@ -1202,24 +1200,37 @@ spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
}
}
+ SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo;
+ if (decorations.count(operands[0])) {
+ NamedAttrList &allDecorations = decorations[operands[0]];
+ for (NamedAttribute &decorationAttr : allDecorations) {
+ std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration(
+ llvm::convertToCamelFromSnakeCase(decorationAttr.getName(), true));
+ assert(decoration.has_value());
+ structDecorationsInfo.emplace_back(decoration.value(),
+ decorationAttr.getValue());
+ }
+ }
+
uint32_t structID = operands[0];
std::string structIdentifier = nameMap.lookup(structID).str();
if (structIdentifier.empty()) {
assert(unresolvedMemberTypes.empty() &&
"didn't expect unresolved member types");
- typeMap[structID] =
- spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
+ typeMap[structID] = spirv::StructType::get(
+ memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo);
} else {
auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
typeMap[structID] = structTy;
if (!unresolvedMemberTypes.empty())
- deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
- memberTypes, offsetInfo,
- memberDecorationsInfo});
+ deferredStructTypesInfos.push_back(
+ {structTy, unresolvedMemberTypes, memberTypes, offsetInfo,
+ memberDecorationsInfo, structDecorationsInfo});
else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
- memberDecorationsInfo)))
+ memberDecorationsInfo,
+ structDecorationsInfo)))
return failure();
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 20482bd..db1cc3f 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -95,6 +95,7 @@ struct DeferredStructTypeInfo {
SmallVector<Type, 4> memberTypes;
SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
+ SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo;
};
/// A struct that collects the info needed to materialize/emit a
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 3400fcf..737f296 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -19,7 +19,6 @@
#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
-#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
@@ -319,6 +318,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
case spirv::Decoration::RestrictPointer:
case spirv::Decoration::NoContraction:
case spirv::Decoration::Constant:
+ case spirv::Decoration::Block:
// For unit attributes and decoration attributes, the args list
// has no values so we do nothing.
if (isa<UnitAttr, DecorationAttr>(attr))
@@ -406,8 +406,9 @@ LogicalResult Serializer::processMemberDecoration(
SmallVector<uint32_t, 4> args(
{structID, memberDecoration.memberIndex,
static_cast<uint32_t>(memberDecoration.decoration)});
- if (memberDecoration.hasValue) {
- args.push_back(memberDecoration.decorationValue);
+ if (memberDecoration.hasValue()) {
+ args.push_back(
+ cast<IntegerAttr>(memberDecoration.decorationValue).getInt());
}
encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args);
return success();
@@ -446,6 +447,19 @@ LogicalResult Serializer::processType(Location loc, Type type,
LogicalResult
Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
SetVector<StringRef> &serializationCtx) {
+
+ // Map unsigned integer types to singless integer types.
+ // This is needed otherwise the generated spirv assembly will contain
+ // twice a type declaration (like OpTypeInt 32 0) which is no permitted and
+ // such module fails validation. Indeed at MLIR level the two types are
+ // different and lookup in the cache below misses.
+ // Note: This conversion needs to happen here before the type is looked up in
+ // the cache.
+ if (type.isUnsignedInteger()) {
+ type = IntegerType::get(loc->getContext(), type.getIntOrFloatBitWidth(),
+ IntegerType::SignednessSemantics::Signless);
+ }
+
typeID = getTypeID(type);
if (typeID)
return success();
@@ -617,11 +631,16 @@ LogicalResult Serializer::prepareBasicType(
operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
operands.push_back(pointeeTypeID);
+ // TODO: Now struct decorations are supported this code may not be
+ // necessary. However, it is left to support backwards compatibility.
+ // Ideally, Block decorations should be inserted when converting to SPIR-V.
if (isInterfaceStructPtrType(ptrType)) {
- if (failed(emitDecoration(getTypeID(pointeeStruct),
- spirv::Decoration::Block)))
- return emitError(loc, "cannot decorate ")
- << pointeeStruct << " with Block decoration";
+ auto structType = cast<spirv::StructType>(ptrType.getPointeeType());
+ if (!structType.hasDecoration(spirv::Decoration::Block))
+ if (failed(emitDecoration(getTypeID(pointeeStruct),
+ spirv::Decoration::Block)))
+ return emitError(loc, "cannot decorate ")
+ << pointeeStruct << " with Block decoration";
}
return success();
@@ -666,10 +685,12 @@ LogicalResult Serializer::prepareBasicType(
}
operands.push_back(elementTypeID);
if (hasOffset) {
+ auto intType = IntegerType::get(structType.getContext(), 32);
// Decorate each struct member with an offset
spirv::StructType::MemberDecorationInfo offsetDecoration{
- elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
- static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
+ elementIndex, spirv::Decoration::Offset,
+ IntegerAttr::get(intType,
+ structType.getMemberOffset(elementIndex))};
if (failed(processMemberDecoration(resultID, offsetDecoration))) {
return emitError(loc, "cannot decorate ")
<< elementIndex << "-th member of " << structType
@@ -689,6 +710,20 @@ LogicalResult Serializer::prepareBasicType(
}
}
+ SmallVector<spirv::StructType::StructDecorationInfo, 1> structDecorations;
+ structType.getStructDecorations(structDecorations);
+
+ for (spirv::StructType::StructDecorationInfo &structDecoration :
+ structDecorations) {
+ if (failed(processDecorationAttr(loc, resultID,
+ structDecoration.decoration,
+ structDecoration.decorationValue))) {
+ return emitError(loc, "cannot decorate struct ")
+ << structType << " with "
+ << stringifyDecoration(structDecoration.decoration);
+ }
+ }
+
typeEnum = spirv::Opcode::OpTypeStruct;
if (structType.isIdentified())
@@ -923,6 +958,25 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
} else {
return 0;
}
+ } else if (isa<spirv::TensorArmType>(constType)) {
+ numberOfConstituents = shapedType.getNumElements();
+ operands.reserve(numberOfConstituents + 2);
+ for (int i = 0; i < numberOfConstituents; ++i) {
+ uint32_t elementID = 0;
+ if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
+ elementID =
+ elementType.isInteger(1)
+ ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[i])
+ : prepareConstantInt(loc, attr.getValues<IntegerAttr>()[i]);
+ }
+ if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
+ elementID = prepareConstantFp(loc, attr.getValues<FloatAttr>()[i]);
+ }
+ if (!elementID) {
+ return 0;
+ }
+ operands.push_back(elementID);
+ }
} else {
operands.reserve(numberOfConstituents + 2);
for (int i = 0; i < numberOfConstituents; ++i) {
diff --git a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
index 04f02f2..e2c987a 100644
--- a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
+++ b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
@@ -6,7 +6,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Tools/PDLL/AST/Context.h"
#include "mlir/Tools/PDLL/AST/Nodes.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
diff --git a/mlir/lib/Tools/PDLL/ODS/Operation.cpp b/mlir/lib/Tools/PDLL/ODS/Operation.cpp
index 7e708be..b836ece 100644
--- a/mlir/lib/Tools/PDLL/ODS/Operation.cpp
+++ b/mlir/lib/Tools/PDLL/ODS/Operation.cpp
@@ -7,8 +7,6 @@
//===----------------------------------------------------------------------===//
#include "mlir/Tools/PDLL/ODS/Operation.h"
-#include "mlir/Support/IndentedOstream.h"
-#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::pdll::ods;
diff --git a/mlir/lib/Tools/lsp-server-support/Protocol.cpp b/mlir/lib/Tools/lsp-server-support/Protocol.cpp
index 33cdd28..9828704 100644
--- a/mlir/lib/Tools/lsp-server-support/Protocol.cpp
+++ b/mlir/lib/Tools/lsp-server-support/Protocol.cpp
@@ -284,11 +284,11 @@ bool mlir::lsp::fromJSON(const llvm::json::Value &value,
if (codeAction->getObject("codeActionLiteralSupport"))
result.codeActionStructure = true;
}
- if (auto *window = textDocument->getObject("window")) {
- if (std::optional<bool> workDoneProgressSupport =
- window->getBoolean("workDoneProgress"))
- result.workDoneProgress = *workDoneProgressSupport;
- }
+ }
+ if (auto *window = o->getObject("window")) {
+ if (std::optional<bool> workDoneProgressSupport =
+ window->getBoolean("workDoneProgress"))
+ result.workDoneProgress = *workDoneProgressSupport;
}
return true;
}
diff --git a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp
index 2504123..9b937db 100644
--- a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp
+++ b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp
@@ -11,8 +11,6 @@
#include "Protocol.h"
#include "mlir/Tools/lsp-server-support/Logging.h"
#include "mlir/Tools/lsp-server-support/Transport.h"
-#include "llvm/ADT/FunctionExtras.h"
-#include "llvm/ADT/StringMap.h"
#include <optional>
#define DEBUG_TYPE "mlir-lsp-server"
diff --git a/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp b/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp
index b1bbf98..f1dc326 100644
--- a/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp
+++ b/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp
@@ -9,7 +9,6 @@
#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h"
#include "LSPServer.h"
#include "MLIRServer.h"
-#include "mlir/IR/Dialect.h"
#include "mlir/Tools/lsp-server-support/Logging.h"
#include "mlir/Tools/lsp-server-support/Transport.h"
#include "llvm/Support/CommandLine.h"
diff --git a/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp b/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp
index 4ba76fb..a56e9a1 100644
--- a/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp
+++ b/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp
@@ -11,14 +11,7 @@
//===----------------------------------------------------------------------===//
#include "Protocol.h"
-#include "llvm/ADT/Hashing.h"
-#include "llvm/ADT/SmallString.h"
-#include "llvm/Support/ErrorHandling.h"
-#include "llvm/Support/Format.h"
-#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/JSON.h"
-#include "llvm/Support/Path.h"
-#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::lsp;
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 8f78590..bdcdaa4 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -508,13 +508,20 @@ performActions(raw_ostream &os,
/// Parses the memory buffer. If successfully, run a series of passes against
/// it and print the result.
-static LogicalResult processBuffer(raw_ostream &os,
- std::unique_ptr<MemoryBuffer> ownedBuffer,
- const MlirOptMainConfig &config,
- DialectRegistry &registry,
- llvm::ThreadPoolInterface *threadPool) {
+static LogicalResult
+processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
+ llvm::MemoryBufferRef sourceBuffer,
+ const MlirOptMainConfig &config, DialectRegistry &registry,
+ SourceMgrDiagnosticVerifierHandler *verifyHandler,
+ llvm::ThreadPoolInterface *threadPool) {
// Tell sourceMgr about this buffer, which is what the parser will pick up.
auto sourceMgr = std::make_shared<SourceMgr>();
+ // Add the original buffer to the source manager to use for determining
+ // locations.
+ sourceMgr->AddNewSourceBuffer(
+ llvm::MemoryBuffer::getMemBuffer(sourceBuffer,
+ /*RequiresNullTerminator=*/false),
+ SMLoc());
sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
// Create a context just for the current buffer. Disable threading on creation
@@ -522,6 +529,8 @@ static LogicalResult processBuffer(raw_ostream &os,
MLIRContext context(registry, MLIRContext::Threading::DISABLED);
if (threadPool)
context.setThreadPool(*threadPool);
+ if (verifyHandler)
+ verifyHandler->registerInContext(&context);
StringRef irdlFile = config.getIrdlFile();
if (!irdlFile.empty() && failed(loadIRDLDialects(irdlFile, context)))
@@ -545,17 +554,12 @@ static LogicalResult processBuffer(raw_ostream &os,
return performActions(os, sourceMgr, &context, config);
}
- SourceMgrDiagnosticVerifierHandler sourceMgrHandler(
- *sourceMgr, &context, config.verifyDiagnosticsLevel());
-
// Do any processing requested by command line flags. We don't care whether
// these actions succeed or fail, we only care what diagnostics they produce
// and whether they match our expectations.
(void)performActions(os, sourceMgr, &context, config);
- // Verify the diagnostic handler to make sure that each of the diagnostics
- // matched.
- return sourceMgrHandler.verify();
+ return success();
}
std::pair<std::string, std::string>
@@ -624,14 +628,31 @@ LogicalResult mlir::MlirOptMain(llvm::raw_ostream &outputStream,
if (threadPoolCtx.isMultithreadingEnabled())
threadPool = &threadPoolCtx.getThreadPool();
+ SourceMgr sourceMgr;
+ sourceMgr.AddNewSourceBuffer(
+ llvm::MemoryBuffer::getMemBuffer(buffer->getMemBufferRef(),
+ /*RequiresNullTerminator=*/false),
+ SMLoc());
+ // Note: this creates a verifier handler independent of the the flag set, as
+ // internally if the flag is not set, a new scoped diagnostic handler is
+ // created which would intercept the diagnostics and verify them.
+ SourceMgrDiagnosticVerifierHandler sourceMgrHandler(
+ sourceMgr, &threadPoolCtx, config.verifyDiagnosticsLevel());
auto chunkFn = [&](std::unique_ptr<MemoryBuffer> chunkBuffer,
- raw_ostream &os) {
- return processBuffer(os, std::move(chunkBuffer), config, registry,
- threadPool);
+ llvm::MemoryBufferRef sourceBuffer, raw_ostream &os) {
+ return processBuffer(
+ os, std::move(chunkBuffer), sourceBuffer, config, registry,
+ config.shouldVerifyDiagnostics() ? &sourceMgrHandler : nullptr,
+ threadPool);
};
- return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream,
- config.inputSplitMarker(),
- config.outputSplitMarker());
+ LogicalResult status = splitAndProcessBuffer(
+ llvm::MemoryBuffer::getMemBuffer(buffer->getMemBufferRef(),
+ /*RequiresNullTerminator=*/false),
+ chunkFn, outputStream, config.inputSplitMarker(),
+ config.outputSplitMarker());
+ if (config.shouldVerifyDiagnostics() && failed(sourceMgrHandler.verify()))
+ status = failure();
+ return status;
}
LogicalResult mlir::MlirOptMain(int argc, char **argv,
diff --git a/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp b/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp
index 97b8288..685e794 100644
--- a/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp
+++ b/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp
@@ -15,7 +15,6 @@
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/GenNameParser.h"
#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
diff --git a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
index f2a81cc..e1c8afb 100644
--- a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
+++ b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
@@ -8,9 +8,6 @@
#include "mlir/Tools/mlir-translate/MlirTranslateMain.h"
#include "mlir/IR/AsmState.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/Verifier.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/Timing.h"
@@ -138,6 +135,13 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
// Processes the memory buffer with a new MLIRContext.
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
raw_ostream &os) {
+ // Many of the translations expect a null-terminated buffer while splitting
+ // the buffer does not guarantee null-termination. Make a copy of the buffer
+ // to ensure null-termination.
+ if (!ownedBuffer->getBuffer().ends_with('\0')) {
+ ownedBuffer = llvm::MemoryBuffer::getMemBufferCopy(
+ ownedBuffer->getBuffer(), ownedBuffer->getBufferIdentifier());
+ }
// Temporary buffers for chained translation processing.
std::string dataIn;
std::string dataOut;
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 3a8088b..058039e 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -37,5 +37,4 @@ add_mlir_library(MLIRTransforms
MLIRSideEffectInterfaces
MLIRSupport
MLIRTransformUtils
- MLIRUBDialect
)
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 4d09c5f..09e5a02 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -19,7 +19,6 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/DenseMapInfo.h"
-#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/RecyclingAllocator.h"
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 4b0ac28..7a99fe8 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -13,7 +13,6 @@
#include "mlir/Transforms/Passes.h"
-#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/mlir/lib/Transforms/OpStats.cpp b/mlir/lib/Transforms/OpStats.cpp
index 0dc3fe9..9ebf310 100644
--- a/mlir/lib/Transforms/OpStats.cpp
+++ b/mlir/lib/Transforms/OpStats.cpp
@@ -8,10 +8,8 @@
#include "mlir/Transforms/Passes.h"
-#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
-#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/raw_ostream.h"
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 608bdcb..4ccb83f 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -36,6 +36,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
@@ -51,6 +52,7 @@
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include <cassert>
#include <cstddef>
#include <memory>
@@ -58,8 +60,6 @@
#include <vector>
#define DEBUG_TYPE "remove-dead-values"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
namespace mlir {
#define GEN_PASS_DEF_REMOVEDEADVALUES
@@ -119,21 +119,21 @@ static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
RunLivenessAnalysis &la) {
for (Value value : values) {
if (nonLiveSet.contains(value)) {
- LDBG("Value " << value << " is already marked non-live (dead)");
+ LDBG() << "Value " << value << " is already marked non-live (dead)";
continue;
}
const Liveness *liveness = la.getLiveness(value);
if (!liveness) {
- LDBG("Value " << value
- << " has no liveness info, conservatively considered live");
+ LDBG() << "Value " << value
+ << " has no liveness info, conservatively considered live";
return true;
}
if (liveness->isLive) {
- LDBG("Value " << value << " is live according to liveness analysis");
+ LDBG() << "Value " << value << " is live according to liveness analysis";
return true;
} else {
- LDBG("Value " << value << " is dead according to liveness analysis");
+ LDBG() << "Value " << value << " is dead according to liveness analysis";
}
}
return false;
@@ -148,8 +148,8 @@ static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
for (auto [index, value] : llvm::enumerate(values)) {
if (nonLiveSet.contains(value)) {
lives.reset(index);
- LDBG("Value " << value << " is already marked non-live (dead) at index "
- << index);
+ LDBG() << "Value " << value
+ << " is already marked non-live (dead) at index " << index;
continue;
}
@@ -161,17 +161,17 @@ static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
// (because they weren't erased) and also their liveness is null because
// liveness analysis ran before their creation.
if (!liveness) {
- LDBG("Value " << value << " at index " << index
- << " has no liveness info, conservatively considered live");
+ LDBG() << "Value " << value << " at index " << index
+ << " has no liveness info, conservatively considered live";
continue;
}
if (!liveness->isLive) {
lives.reset(index);
- LDBG("Value " << value << " at index " << index
- << " is dead according to liveness analysis");
+ LDBG() << "Value " << value << " at index " << index
+ << " is dead according to liveness analysis";
} else {
- LDBG("Value " << value << " at index " << index
- << " is live according to liveness analysis");
+ LDBG() << "Value " << value << " at index " << index
+ << " is live according to liveness analysis";
}
}
@@ -187,8 +187,8 @@ static void collectNonLiveValues(DenseSet<Value> &nonLiveSet, ValueRange range,
if (!nonLive[index])
continue;
nonLiveSet.insert(result);
- LDBG("Marking value " << result << " as non-live (dead) at index "
- << index);
+ LDBG() << "Marking value " << result << " as non-live (dead) at index "
+ << index;
}
}
@@ -258,16 +258,18 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
- LDBG("Processing simple op: " << *op);
+ LDBG() << "Processing simple op: " << *op;
if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
- LDBG("Simple op is not memory effect free or has live results, skipping: "
- << *op);
+ LDBG()
+ << "Simple op is not memory effect free or has live results, skipping: "
+ << *op;
return;
}
- LDBG("Simple op has all dead results and is memory effect free, scheduling "
- "for removal: "
- << *op);
+ LDBG()
+ << "Simple op has all dead results and is memory effect free, scheduling "
+ "for removal: "
+ << *op;
cl.operations.push_back(op);
collectNonLiveValues(nonLiveSet, op->getResults(),
BitVector(op->getNumResults(), true));
@@ -286,10 +288,10 @@ static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
- LDBG("Processing function op: " << funcOp.getOperation()->getName());
+ LDBG() << "Processing function op: " << funcOp.getOperation()->getName();
if (funcOp.isPublic() || funcOp.isExternal()) {
- LDBG("Function is public or external, skipping: "
- << funcOp.getOperation()->getName());
+ LDBG() << "Function is public or external, skipping: "
+ << funcOp.getOperation()->getName();
return;
}
@@ -345,8 +347,6 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
// since it forwards only to non-live value(s) (%1#1).
Operation *lastReturnOp = funcOp.back().getTerminator();
size_t numReturns = lastReturnOp->getNumOperands();
- if (numReturns == 0)
- return;
BitVector nonLiveRets(numReturns, true);
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
@@ -368,6 +368,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets});
// Do (5) and (6).
+ if (numReturns == 0)
+ return;
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
@@ -409,9 +411,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
- LLVM_DEBUG(DBGS() << "Processing region branch op: "; regionBranchOp->print(
- llvm::dbgs(), OpPrintingFlags().skipRegions());
- llvm::dbgs() << "\n");
+ LDBG() << "Processing region branch op: "
+ << OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
// Mark live results of `regionBranchOp` in `liveResults`.
auto markLiveResults = [&](BitVector &liveResults) {
liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
@@ -697,7 +698,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
- LDBG("Processing branch op: " << *branchOp);
+ LDBG() << "Processing branch op: " << *branchOp;
unsigned numSuccessors = branchOp->getNumSuccessors();
for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index d224f73..08803e0 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -14,8 +14,10 @@
#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Iterators.h"
+#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
@@ -130,11 +132,6 @@ struct ConversionValueMapping {
/// value.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
- /// Lookup the given value within the map, or return an empty vector if the
- /// value is not mapped. If it is mapped, this follows the same behavior
- /// as `lookupOrDefault`.
- ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
-
template <typename T>
struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
@@ -237,15 +234,6 @@ ConversionValueMapping::lookupOrDefault(Value from,
return !desiredValue.empty() ? std::move(desiredValue) : std::move(current);
}
-ValueVector ConversionValueMapping::lookupOrNull(Value from,
- TypeRange desiredTypes) const {
- ValueVector result = lookupOrDefault(from, desiredTypes);
- if (result == ValueVector{from} ||
- (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
- return {};
- return result;
-}
-
//===----------------------------------------------------------------------===//
// Rewriter and Translation State
//===----------------------------------------------------------------------===//
@@ -521,9 +509,11 @@ private:
class MoveBlockRewrite : public BlockRewrite {
public:
MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
- Region *region, Block *insertBeforeBlock)
- : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
- insertBeforeBlock(insertBeforeBlock) {}
+ Region *previousRegion, Region::iterator previousIt)
+ : BlockRewrite(Kind::MoveBlock, rewriterImpl, block),
+ region(previousRegion),
+ insertBeforeBlock(previousIt == previousRegion->end() ? nullptr
+ : &*previousIt) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::MoveBlock;
@@ -630,9 +620,12 @@ protected:
class MoveOperationRewrite : public OperationRewrite {
public:
MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
- Operation *op, Block *block, Operation *insertBeforeOp)
- : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
- insertBeforeOp(insertBeforeOp) {}
+ Operation *op, OpBuilder::InsertPoint previous)
+ : OperationRewrite(Kind::MoveOperation, rewriterImpl, op),
+ block(previous.getBlock()),
+ insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
+ ? nullptr
+ : &*previous.getPoint()) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::MoveOperation;
@@ -926,6 +919,23 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Return "true" if the given operation was replaced or erased.
bool wasOpReplaced(Operation *op) const;
+ /// Lookup the most recently mapped values with the desired types in the
+ /// mapping.
+ ///
+ /// Special cases:
+ /// - If the desired type range is empty, simply return the most recently
+ /// mapped values.
+ /// - If there is no mapping to the desired types, also return the most
+ /// recently mapped values.
+ /// - If there is no mapping for the given values at all, return the given
+ /// value.
+ ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
+
+ /// Lookup the given value within the map, or return an empty vector if the
+ /// value is not mapped. If it is mapped, this follows the same behavior
+ /// as `lookupOrDefault`.
+ ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
+
//===--------------------------------------------------------------------===//
// IR Rewrites / Type Conversion
//===--------------------------------------------------------------------===//
@@ -1248,6 +1258,22 @@ void ConversionPatternRewriterImpl::applyRewrites() {
// State Management
//===----------------------------------------------------------------------===//
+ValueVector
+ConversionPatternRewriterImpl::lookupOrDefault(Value from,
+ TypeRange desiredTypes) const {
+ return mapping.lookupOrDefault(from, desiredTypes);
+}
+
+ValueVector
+ConversionPatternRewriterImpl::lookupOrNull(Value from,
+ TypeRange desiredTypes) const {
+ ValueVector result = lookupOrDefault(from, desiredTypes);
+ if (result == ValueVector{from} ||
+ (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
+ return {};
+ return result;
+}
+
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
}
@@ -1295,7 +1321,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
// The current pattern does not have a type converter. I.e., it does not
// distinguish between legal and illegal types. For each operand, simply
// pass through the most recently mapped values.
- remapped.push_back(mapping.lookupOrDefault(operand));
+ remapped.push_back(lookupOrDefault(operand));
continue;
}
@@ -1314,7 +1340,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
continue;
}
- ValueVector repl = mapping.lookupOrDefault(operand, legalTypes);
+ ValueVector repl = lookupOrDefault(operand, legalTypes);
if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) {
// Mapped values have the correct type or there is an existing
// materialization. Or the operand is not mapped at all and has the
@@ -1324,7 +1350,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
}
// Create a materialization for the most recently mapped values.
- repl = mapping.lookupOrDefault(operand);
+ repl = lookupOrDefault(operand);
ValueRange castValues = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
/*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
@@ -1519,7 +1545,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
// Try to find a replacement value with the same type in the conversion value
// mapping. This includes cached materializations. We try to reuse those
// instead of generating duplicate IR.
- ValueVector repl = mapping.lookupOrNull(value, value.getType());
+ ValueVector repl = lookupOrNull(value, value.getType());
if (!repl.empty())
return repl.front();
@@ -1535,7 +1561,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
// No replacement value was found. Get the latest replacement value
// (regardless of the type) and build a source materialization to the
// original type.
- repl = mapping.lookupOrNull(value);
+ repl = lookupOrNull(value);
if (repl.empty()) {
// No replacement value is registered in the mapping. This means that the
// value is dropped and no longer needed. (If the value were still needed,
@@ -1568,23 +1594,30 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
void ConversionPatternRewriterImpl::notifyOperationInserted(
Operation *op, OpBuilder::InsertPoint previous) {
+ // If no previous insertion point is provided, the op used to be detached.
+ bool wasDetached = !previous.isSet();
LLVM_DEBUG({
- logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
- << ")\n";
+ logger.startLine() << "** Insert : '" << op->getName() << "' (" << op
+ << ")";
+ if (wasDetached)
+ logger.getOStream() << " (was detached)";
+ logger.getOStream() << "\n";
});
assert(!wasOpReplaced(op->getParentOp()) &&
"attempting to insert into a block within a replaced/erased op");
- if (!previous.isSet()) {
- // This is a newly created op.
+ if (wasDetached) {
+ // If the op was detached, it is most likely a newly created op.
+ // TODO: If the same op is inserted multiple times from a detached state,
+ // the rollback mechanism may erase the same op multiple times. This is a
+ // bug in the rollback-based dialect conversion driver.
appendRewrite<CreateOperationRewrite>(op);
patternNewOps.insert(op);
return;
}
- Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
- ? nullptr
- : &*previous.getPoint();
- appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
+
+ // The op was moved from one place to another.
+ appendRewrite<MoveOperationRewrite>(op, previous);
}
void ConversionPatternRewriterImpl::replaceOp(
@@ -1649,29 +1682,40 @@ void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
void ConversionPatternRewriterImpl::notifyBlockInserted(
Block *block, Region *previous, Region::iterator previousIt) {
- assert(!wasOpReplaced(block->getParentOp()) &&
- "attempting to insert into a region within a replaced/erased op");
+ // If no previous insertion point is provided, the block used to be detached.
+ bool wasDetached = !previous;
+ Operation *newParentOp = block->getParentOp();
LLVM_DEBUG(
{
- Operation *parent = block->getParentOp();
+ Operation *parent = newParentOp;
if (parent) {
logger.startLine() << "** Insert Block into : '" << parent->getName()
- << "'(" << parent << ")\n";
+ << "' (" << parent << ")";
} else {
logger.startLine()
- << "** Insert Block into detached Region (nullptr parent op)'\n";
+ << "** Insert Block into detached Region (nullptr parent op)";
}
+ if (wasDetached)
+ logger.getOStream() << " (was detached)";
+ logger.getOStream() << "\n";
});
+ assert(!wasOpReplaced(newParentOp) &&
+ "attempting to insert into a region within a replaced/erased op");
+ (void)newParentOp;
patternInsertedBlocks.insert(block);
- if (!previous) {
- // This is a newly created block.
+ if (wasDetached) {
+ // If the block was detached, it is most likely a newly created block.
+ // TODO: If the same block is inserted multiple times from a detached state,
+ // the rollback mechanism may erase the same block multiple times. This is a
+ // bug in the rollback-based dialect conversion driver.
appendRewrite<CreateBlockRewrite>(block);
return;
}
- Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
- appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
+
+ // The block was moved from one place to another.
+ appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
}
void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source,
@@ -1716,6 +1760,12 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
impl->logger.startLine()
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
});
+
+ // If the current insertion point is before the erased operation, we adjust
+ // the insertion point to be after the operation.
+ if (getInsertionPoint() == op->getIterator())
+ setInsertionPointAfter(op);
+
SmallVector<SmallVector<Value>> newVals =
llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
return v ? SmallVector<Value>{v} : SmallVector<Value>();
@@ -1731,6 +1781,12 @@ void ConversionPatternRewriter::replaceOpWithMultiple(
impl->logger.startLine()
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
});
+
+ // If the current insertion point is before the erased operation, we adjust
+ // the insertion point to be after the operation.
+ if (getInsertionPoint() == op->getIterator())
+ setInsertionPointAfter(op);
+
impl->replaceOp(op, std::move(newValues));
}
@@ -1739,6 +1795,12 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
impl->logger.startLine()
<< "** Erase : '" << op->getName() << "'(" << op << ")\n";
});
+
+ // If the current insertion point is before the erased operation, we adjust
+ // the insertion point to be after the operation.
+ if (getInsertionPoint() == op->getIterator())
+ setInsertionPointAfter(op);
+
SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
impl->replaceOp(op, std::move(nullRepls));
}
@@ -1845,6 +1907,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
moveOpBefore(&source->front(), dest, before);
}
+ // If the current insertion point is within the source block, adjust the
+ // insertion point to the destination block.
+ if (getInsertionBlock() == source)
+ setInsertionPoint(dest, getInsertionPoint());
+
// Erase the source block.
eraseBlock(source);
}
@@ -1976,6 +2043,7 @@ private:
/// Legalize the resultant IR after successfully applying the given pattern.
LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
ConversionPatternRewriter &rewriter,
+ const RewriterState &curState,
const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
const SetVector<Block *> &insertedBlocks);
@@ -2092,8 +2160,9 @@ OperationLegalizer::legalize(Operation *op,
// If the operation has no regions, just print it here.
if (!isIgnored && op->getNumRegions() == 0) {
- op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
- logger.getOStream() << "\n\n";
+ logger.startLine() << OpWithFlags(op,
+ OpPrintingFlags().printGenericOpForm())
+ << "\n";
}
});
@@ -2172,23 +2241,39 @@ OperationLegalizer::legalizeWithFold(Operation *op,
rewriterImpl.logger.startLine() << "* Fold {\n";
rewriterImpl.logger.indent();
});
- (void)rewriterImpl;
+
+ // Clear pattern state, so that the next pattern application starts with a
+ // clean slate. (The op/block sets are populated by listener notifications.)
+ auto cleanup = llvm::make_scope_exit([&]() {
+ rewriterImpl.patternNewOps.clear();
+ rewriterImpl.patternModifiedOps.clear();
+ rewriterImpl.patternInsertedBlocks.clear();
+ });
+
+ // Upon failure, undo all changes made by the folder.
+ RewriterState curState = rewriterImpl.getCurrentState();
// Try to fold the operation.
StringRef opName = op->getName().getStringRef();
SmallVector<Value, 2> replacementValues;
SmallVector<Operation *, 2> newOps;
rewriter.setInsertionPoint(op);
+ rewriter.startOpModification(op);
if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
+ rewriter.cancelOpModification(op);
return failure();
}
+ rewriter.finalizeOpModification(op);
// An empty list of replacement values indicates that the fold was in-place.
// As the operation changed, a new legalization needs to be attempted.
if (replacementValues.empty())
return legalize(op, rewriter);
+ // Insert a replacement for 'op' with the folded replacement values.
+ rewriter.replaceOp(op, replacementValues);
+
// Recursively legalize any new constant operations.
for (Operation *newOp : newOps) {
if (failed(legalize(newOp, rewriter))) {
@@ -2201,16 +2286,12 @@ OperationLegalizer::legalizeWithFold(Operation *op,
"op '" + opName +
"' folder rollback of IR modifications requested");
}
- // Legalization failed: erase all materialized constants.
- for (Operation *op : newOps)
- rewriter.eraseOp(op);
+ rewriterImpl.resetState(
+ curState, std::string(op->getName().getStringRef()) + " folder");
return failure();
}
}
- // Insert a replacement for 'op' with the folded replacement values.
- rewriter.replaceOp(op, replacementValues);
-
LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
return success();
}
@@ -2220,6 +2301,32 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
ConversionPatternRewriter &rewriter) {
auto &rewriterImpl = rewriter.getImpl();
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ Operation *checkOp;
+ std::optional<OperationFingerPrint> topLevelFingerPrint;
+ if (!rewriterImpl.config.allowPatternRollback) {
+ // The op may be getting erased, so we have to check the parent op.
+ // (In rare cases, a pattern may even erase the parent op, which will cause
+ // a crash here. Expensive checks are "best effort".) Skip the check if the
+ // op does not have a parent op.
+ if ((checkOp = op->getParentOp())) {
+ if (!op->getContext()->isMultithreadingEnabled()) {
+ topLevelFingerPrint = OperationFingerPrint(checkOp);
+ } else {
+ // Another thread may be modifying a sibling operation. Therefore, the
+ // fingerprinting mechanism of the parent op works only in
+ // single-threaded mode.
+ LLVM_DEBUG({
+ rewriterImpl.logger.startLine()
+ << "WARNING: Multi-threadeding is enabled. Some dialect "
+ "conversion expensive checks are skipped in multithreading "
+ "mode!\n";
+ });
+ }
+ }
+ }
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
// Functor that returns if the given pattern may be applied.
auto canApply = [&](const Pattern &pattern) {
bool canApply = canApplyPattern(op, pattern, rewriter);
@@ -2232,6 +2339,17 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
RewriterState curState = rewriterImpl.getCurrentState();
auto onFailure = [&](const Pattern &pattern) {
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ if (!rewriterImpl.config.allowPatternRollback) {
+ // Returning "failure" after modifying IR is not allowed.
+ if (checkOp) {
+ OperationFingerPrint fingerPrintAfterPattern(checkOp);
+ if (fingerPrintAfterPattern != *topLevelFingerPrint)
+ llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
+ "' returned failure but IR did change");
+ }
+ }
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
rewriterImpl.patternNewOps.clear();
rewriterImpl.patternModifiedOps.clear();
rewriterImpl.patternInsertedBlocks.clear();
@@ -2260,7 +2378,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
moveAndReset(rewriterImpl.patternModifiedOps);
SetVector<Block *> insertedBlocks =
moveAndReset(rewriterImpl.patternInsertedBlocks);
- auto result = legalizePatternResult(op, pattern, rewriter, newOps,
+ auto result = legalizePatternResult(op, pattern, rewriter, curState, newOps,
modifiedOps, insertedBlocks);
appliedPatterns.erase(&pattern);
if (failed(result)) {
@@ -2303,7 +2421,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
LogicalResult OperationLegalizer::legalizePatternResult(
Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
- const SetVector<Operation *> &newOps,
+ const RewriterState &curState, const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
const SetVector<Block *> &insertedBlocks) {
auto &impl = rewriter.getImpl();
@@ -2319,7 +2437,8 @@ LogicalResult OperationLegalizer::legalizePatternResult(
return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
};
if (!replacedRoot() && !updatedRootInPlace())
- llvm::report_fatal_error("expected pattern to replace the root operation");
+ llvm::report_fatal_error(
+ "expected pattern to replace the root operation or modify it in place");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Legalize each of the actions registered during application.
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index b82d850..607b86c 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -22,6 +22,7 @@
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ScopedPrinter.h"
#include "llvm/Support/raw_ostream.h"
diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp
index b639e87f..26c965c 100644
--- a/mlir/lib/Transforms/Utils/Inliner.cpp
+++ b/mlir/lib/Transforms/Utils/Inliner.cpp
@@ -21,7 +21,7 @@
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/SCCIterator.h"
#include "llvm/ADT/STLExtras.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "inlining"
@@ -348,13 +348,11 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
// InlinerInterfaceImpl
//===----------------------------------------------------------------------===//
-#ifndef NDEBUG
static std::string getNodeName(CallOpInterface op) {
if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee()))
return debugString(op);
return "_unnamed_callee_";
}
-#endif
/// Return true if the specified `inlineHistoryID` indicates an inline history
/// that already includes `node`.
@@ -614,10 +612,10 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
LLVM_DEBUG({
- llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n";
+ LDBG() << "* Inliner: Initial calls in SCC are: {";
for (unsigned i = 0, e = calls.size(); i < e; ++i)
- llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n";
- llvm::dbgs() << "}\n";
+ LDBG() << " " << i << ". " << calls[i].call << ",";
+ LDBG() << "}";
});
// Try to inline each of the call operations. Don't cache the end iterator
@@ -635,9 +633,9 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
CallOpInterface call = it.call;
LLVM_DEBUG({
if (doInline)
- llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n";
+ LDBG() << "* Inlining call: " << i << ". " << call;
else
- llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n";
+ LDBG() << "* Not inlining call: " << i << ". " << call;
});
if (!doInline)
continue;
@@ -654,7 +652,7 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
cast<CallableOpInterface>(targetRegion->getParentOp()),
targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
if (failed(inlineResult)) {
- LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
+ LDBG() << "** Failed to inline";
continue;
}
inlinedAnyCalls = true;
@@ -667,19 +665,16 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
auto historyToString = [](InlineHistoryT h) {
return h.has_value() ? std::to_string(*h) : "root";
};
- (void)historyToString;
- LLVM_DEBUG(llvm::dbgs()
- << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
- << getNodeName(call) << ", " << historyToString(inlineHistoryID)
- << "]\n");
+ LDBG() << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
+ << getNodeName(call) << ", " << historyToString(inlineHistoryID)
+ << "]";
for (unsigned k = prevSize; k != calls.size(); ++k) {
callHistory.push_back(newInlineHistoryID);
- LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call
- << "}\n with historyID = " << newInlineHistoryID
- << ", added due to inlining of\n call {" << call
- << "}\n with historyID = "
- << historyToString(inlineHistoryID) << "\n");
+ LDBG() << "* new call " << k << " {" << calls[k].call
+ << "}\n with historyID = " << newInlineHistoryID
+ << ", added due to inlining of\n call {" << call
+ << "}\n with historyID = " << historyToString(inlineHistoryID);
}
// If the inlining was successful, Merge the new uses into the source node.