aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/CMakeLists.txt12
-rw-r--r--mlir/Maintainers.md70
-rw-r--r--mlir/cmake/modules/AddMLIR.cmake64
-rw-r--r--mlir/cmake/modules/AddMLIRPython.cmake7
-rw-r--r--mlir/cmake/modules/FindSyclRuntime.cmake2
-rw-r--r--mlir/docs/Bindings/Python.md36
-rw-r--r--mlir/docs/DialectConversion.md47
-rw-r--r--mlir/docs/Dialects/GPU.md17
-rw-r--r--mlir/docs/Remarks.md259
-rw-r--r--mlir/docs/Tutorials/Toy/Ch-4.md2
-rw-r--r--mlir/examples/standalone/CMakeLists.txt8
-rw-r--r--mlir/examples/standalone/python/CMakeLists.txt15
-rw-r--r--mlir/examples/toy/Ch1/parser/AST.cpp2
-rw-r--r--mlir/examples/toy/Ch1/toyc.cpp3
-rw-r--r--mlir/examples/toy/Ch2/parser/AST.cpp2
-rw-r--r--mlir/examples/toy/Ch2/toyc.cpp7
-rw-r--r--mlir/examples/toy/Ch3/parser/AST.cpp2
-rw-r--r--mlir/examples/toy/Ch3/toyc.cpp11
-rw-r--r--mlir/examples/toy/Ch4/parser/AST.cpp2
-rw-r--r--mlir/examples/toy/Ch4/toyc.cpp11
-rw-r--r--mlir/examples/toy/Ch5/parser/AST.cpp2
-rw-r--r--mlir/examples/toy/Ch5/toyc.cpp11
-rw-r--r--mlir/examples/toy/Ch6/parser/AST.cpp2
-rw-r--r--mlir/examples/toy/Ch6/toyc.cpp17
-rw-r--r--mlir/examples/toy/Ch7/parser/AST.cpp2
-rw-r--r--mlir/examples/toy/Ch7/toyc.cpp17
-rw-r--r--mlir/include/mlir-c/ExecutionEngine.h7
-rw-r--r--mlir/include/mlir-c/IR.h3
-rw-r--r--mlir/include/mlir/Analysis/Presburger/IntegerRelation.h31
-rw-r--r--mlir/include/mlir/Analysis/Presburger/Matrix.h2
-rw-r--r--mlir/include/mlir/CMakeLists.txt1
-rw-r--r--mlir/include/mlir/Conversion/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Conversion/ConvertToLLVM/CMakeLists.txt3
-rw-r--r--mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h14
-rw-r--r--mlir/include/mlir/Conversion/LLVMCommon/Pattern.h20
-rw-r--r--mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h31
-rw-r--r--mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h3
-rw-r--r--mlir/include/mlir/Conversion/Passes.h2
-rw-r--r--mlir/include/mlir/Conversion/Passes.td32
-rw-r--r--mlir/include/mlir/Conversion/PtrToLLVM/PtrToLLVM.h27
-rw-r--r--mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h26
-rw-r--r--mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h27
-rw-r--r--mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td45
-rw-r--r--mlir/include/mlir/Dialect/AMDGPU/IR/CMakeLists.txt4
-rw-r--r--mlir/include/mlir/Dialect/AMDGPU/Transforms/CMakeLists.txt3
-rw-r--r--mlir/include/mlir/Dialect/AMX/AMX.td15
-rw-r--r--mlir/include/mlir/Dialect/Affine/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Affine/TransformOps/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Arith/IR/ArithBase.td44
-rw-r--r--mlir/include/mlir/Dialect/Arith/IR/ArithOps.td140
-rw-r--r--mlir/include/mlir/Dialect/Arith/Transforms/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/ArmNeon/TransformOps/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt8
-rw-r--r--mlir/include/mlir/Dialect/ArmSME/Transforms/CMakeLists.txt3
-rw-r--r--mlir/include/mlir/Dialect/ArmSVE/IR/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/ArmSVE/Transforms/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Async/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt6
-rw-r--r--mlir/include/mlir/Dialect/Bufferization/TransformOps/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Bufferization/Transforms/CMakeLists.txt3
-rw-r--r--mlir/include/mlir/Dialect/CommonFolders.h147
-rw-r--r--mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/DLTI/CMakeLists.txt3
-rw-r--r--mlir/include/mlir/Dialect/DLTI/TransformOps/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/EmitC/IR/CMakeLists.txt5
-rw-r--r--mlir/include/mlir/Dialect/EmitC/IR/EmitC.h3
-rw-r--r--mlir/include/mlir/Dialect/EmitC/IR/EmitC.td37
-rw-r--r--mlir/include/mlir/Dialect/EmitC/Transforms/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Func/IR/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Func/TransformOps/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Func/Transforms/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/GPU/IR/CMakeLists.txt12
-rw-r--r--mlir/include/mlir/Dialect/GPU/IR/GPUOps.td73
-rw-r--r--mlir/include/mlir/Dialect/GPU/TransformOps/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td5
-rw-r--r--mlir/include/mlir/Dialect/GPU/Transforms/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/GPU/Transforms/Passes.h3
-rw-r--r--mlir/include/mlir/Dialect/IRDL/IR/CMakeLists.txt12
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h47
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td13
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt26
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td77
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h4
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td1
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td38
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td128
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td3
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td19
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h10
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td387
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td87
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/Transforms/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Linalg/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt15
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/Linalg.h194
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td2
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml286
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td29
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Passes.td97
-rw-r--r--mlir/include/mlir/Dialect/Linalg/TransformOps/CMakeLists.txt6
-rw-r--r--mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td5
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h18
-rw-r--r--mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt6
-rw-r--r--mlir/include/mlir/Dialect/MLProgram/Transforms/CMakeLists.txt3
-rw-r--r--mlir/include/mlir/Dialect/MPI/IR/CMakeLists.txt6
-rw-r--r--mlir/include/mlir/Dialect/Math/IR/MathOps.td31
-rw-r--r--mlir/include/mlir/Dialect/Math/Transforms/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Math/Transforms/Passes.h26
-rw-r--r--mlir/include/mlir/Dialect/Math/Transforms/Passes.td20
-rw-r--r--mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td4
-rw-r--r--mlir/include/mlir/Dialect/MemRef/TransformOps/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/MemRef/Transforms/CMakeLists.txt3
-rw-r--r--mlir/include/mlir/Dialect/NVGPU/IR/CMakeLists.txt8
-rw-r--r--mlir/include/mlir/Dialect/NVGPU/TransformOps/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/NVGPU/Transforms/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt11
-rw-r--r--mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td33
-rw-r--r--mlir/include/mlir/Dialect/OpenACC/Transforms/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt7
-rw-r--r--mlir/include/mlir/Dialect/OpenMP/OpenMPAttrDefs.td8
-rw-r--r--mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td117
-rw-r--r--mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td19
-rw-r--r--mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt8
-rw-r--r--mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h10
-rw-r--r--mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td12
-rw-r--r--mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h3
-rw-r--r--mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.h21
-rw-r--r--mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td54
-rw-r--r--mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td127
-rw-r--r--mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/SCF/IR/CMakeLists.txt3
-rw-r--r--mlir/include/mlir/Dialect/SCF/IR/SCFOps.td20
-rw-r--r--mlir/include/mlir/Dialect/SCF/TransformOps/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/SCF/Transforms/CMakeLists.txt3
-rw-r--r--mlir/include/mlir/Dialect/SCF/Utils/Utils.h8
-rw-r--r--mlir/include/mlir/Dialect/SMT/IR/CMakeLists.txt6
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/IR/CMakeLists.txt18
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td4
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/Interfaces/SPIRVImageInterfaces.h2
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/Transforms/CMakeLists.txt3
-rw-r--r--mlir/include/mlir/Dialect/Shape/Transforms/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Shard/IR/CMakeLists.txt3
-rw-r--r--mlir/include/mlir/Dialect/Shard/Interfaces/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Shard/Transforms/CMakeLists.txt3
-rw-r--r--mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt7
-rw-r--r--mlir/include/mlir/Dialect/SparseTensor/TransformOps/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/SparseTensor/Transforms/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Tensor/TransformOps/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Tensor/Transforms/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h26
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt8
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td45
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td139
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td23
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td4
-rw-r--r--mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt3
-rw-r--r--mlir/include/mlir/Dialect/Transform/DebugExtension/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt12
-rw-r--r--mlir/include/mlir/Dialect/Transform/IRDLExtension/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Transform/Interfaces/CMakeLists.txt3
-rw-r--r--mlir/include/mlir/Dialect/Transform/LoopExtension/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Transform/PDLExtension/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Transform/Transforms/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Transform/TuneExtension/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/UB/IR/CMakeLists.txt4
-rw-r--r--mlir/include/mlir/Dialect/Utils/CMakeLists.txt3
-rw-r--r--mlir/include/mlir/Dialect/Vector/IR/CMakeLists.txt6
-rw-r--r--mlir/include/mlir/Dialect/Vector/IR/VectorOps.h1
-rw-r--r--mlir/include/mlir/Dialect/Vector/IR/VectorOps.td221
-rw-r--r--mlir/include/mlir/Dialect/Vector/TransformOps/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td11
-rw-r--r--mlir/include/mlir/Dialect/Vector/Transforms/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h8
-rw-r--r--mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h17
-rw-r--r--mlir/include/mlir/Dialect/WasmSSA/IR/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td12
-rw-r--r--mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt11
-rw-r--r--mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h11
-rw-r--r--mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td302
-rw-r--r--mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td14
-rw-r--r--mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td621
-rw-r--r--mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td58
-rw-r--r--mlir/include/mlir/Dialect/XeGPU/Transforms/CMakeLists.txt3
-rw-r--r--mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h74
-rw-r--r--mlir/include/mlir/ExecutionEngine/ExecutionEngine.h9
-rw-r--r--mlir/include/mlir/ExecutionEngine/MemRefUtils.h2
-rw-r--r--mlir/include/mlir/IR/Block.h14
-rw-r--r--mlir/include/mlir/IR/CMakeLists.txt21
-rw-r--r--mlir/include/mlir/IR/CommonAttrConstraints.td4
-rw-r--r--mlir/include/mlir/IR/EnumAttr.td2
-rw-r--r--mlir/include/mlir/IR/MLIRContext.h10
-rw-r--r--mlir/include/mlir/IR/OpBase.td45
-rw-r--r--mlir/include/mlir/IR/Operation.h8
-rw-r--r--mlir/include/mlir/IR/PatternMatch.h8
-rw-r--r--mlir/include/mlir/IR/Properties.td7
-rw-r--r--mlir/include/mlir/IR/Remarks.h520
-rw-r--r--mlir/include/mlir/InitAllTranslations.h2
-rw-r--r--mlir/include/mlir/Interfaces/CMakeLists.txt6
-rw-r--r--mlir/include/mlir/Interfaces/SideEffectInterfaces.h56
-rw-r--r--mlir/include/mlir/Interfaces/ViewLikeInterface.td12
-rw-r--r--mlir/include/mlir/Pass/PassOptions.h2
-rw-r--r--mlir/include/mlir/Query/Matcher/SliceMatchers.h2
-rw-r--r--mlir/include/mlir/Reducer/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Remark/RemarkStreamer.h49
-rw-r--r--mlir/include/mlir/Target/CMakeLists.txt1
-rw-r--r--mlir/include/mlir/Target/LLVM/XeVM/Target.h30
-rw-r--r--mlir/include/mlir/Target/LLVM/XeVM/Utils.h63
-rw-r--r--mlir/include/mlir/Target/LLVMIR/CMakeLists.txt1
-rw-r--r--mlir/include/mlir/Target/LLVMIR/DataLayoutImporter.h (renamed from mlir/lib/Target/LLVMIR/DataLayoutImporter.h)46
-rw-r--r--mlir/include/mlir/Target/LLVMIR/Dialect/All.h2
-rw-r--r--mlir/include/mlir/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.h31
-rw-r--r--mlir/include/mlir/Target/LLVMIR/Transforms/CMakeLists.txt5
-rw-r--r--mlir/include/mlir/Target/LLVMIR/Transforms/Passes.h26
-rw-r--r--mlir/include/mlir/Target/LLVMIR/Transforms/Passes.td46
-rw-r--r--mlir/include/mlir/Target/LLVMIR/Transforms/TargetUtils.h35
-rw-r--r--mlir/include/mlir/Target/SPIRV/Serialization.h28
-rw-r--r--mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h143
-rw-r--r--mlir/include/mlir/Target/Wasm/WasmImporter.h31
-rw-r--r--mlir/include/mlir/Transforms/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Transforms/DialectConversion.h217
-rw-r--r--mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h2
-rw-r--r--mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp9
-rw-r--r--mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp111
-rw-r--r--mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp17
-rw-r--r--mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp29
-rw-r--r--mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp46
-rw-r--r--mlir/lib/Analysis/DataFlowFramework.cpp21
-rw-r--r--mlir/lib/Analysis/FlatLinearValueConstraints.cpp2
-rw-r--r--mlir/lib/Analysis/Presburger/Barvinok.cpp14
-rw-r--r--mlir/lib/Analysis/Presburger/IntegerRelation.cpp29
-rw-r--r--mlir/lib/Analysis/Presburger/Matrix.cpp6
-rw-r--r--mlir/lib/Analysis/Presburger/QuasiPolynomial.cpp8
-rw-r--r--mlir/lib/Analysis/Presburger/Simplex.cpp8
-rw-r--r--mlir/lib/Analysis/TopologicalSortUtils.cpp7
-rw-r--r--mlir/lib/Bindings/Python/DialectGPU.cpp9
-rw-r--r--mlir/lib/Bindings/Python/DialectLLVM.cpp16
-rw-r--r--mlir/lib/Bindings/Python/DialectNVGPU.cpp6
-rw-r--r--mlir/lib/Bindings/Python/DialectPDL.cpp14
-rw-r--r--mlir/lib/Bindings/Python/DialectQuant.cpp11
-rw-r--r--mlir/lib/Bindings/Python/DialectSMT.cpp2
-rw-r--r--mlir/lib/Bindings/Python/DialectSparseTensor.cpp7
-rw-r--r--mlir/lib/Bindings/Python/DialectTransform.cpp15
-rw-r--r--mlir/lib/Bindings/Python/ExecutionEngineModule.cpp17
-rw-r--r--mlir/lib/Bindings/Python/Globals.h39
-rw-r--r--mlir/lib/Bindings/Python/IRAffine.cpp25
-rw-r--r--mlir/lib/Bindings/Python/IRAttributes.cpp54
-rw-r--r--mlir/lib/Bindings/Python/IRCore.cpp396
-rw-r--r--mlir/lib/Bindings/Python/IRModule.cpp70
-rw-r--r--mlir/lib/Bindings/Python/IRModule.h85
-rw-r--r--mlir/lib/Bindings/Python/IRTypes.cpp2
-rw-r--r--mlir/lib/Bindings/Python/MainModule.cpp23
-rw-r--r--mlir/lib/Bindings/Python/Pass.cpp10
-rw-r--r--mlir/lib/Bindings/Python/RegisterEverything.cpp2
-rw-r--r--mlir/lib/Bindings/Python/TransformInterpreter.cpp1
-rw-r--r--mlir/lib/Bytecode/Writer/BytecodeWriter.cpp15
-rw-r--r--mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp9
-rw-r--r--mlir/lib/CAPI/IR/BuiltinTypes.cpp4
-rw-r--r--mlir/lib/CAPI/IR/IR.cpp4
-rw-r--r--mlir/lib/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp51
-rw-r--r--mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp17
-rw-r--r--mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp37
-rw-r--r--mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp2
-rw-r--r--mlir/lib/Conversion/CMakeLists.txt3
-rw-r--r--mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp36
-rw-r--r--mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp41
-rw-r--r--mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp33
-rw-r--r--mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp139
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp3
-rw-r--r--mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp6
-rw-r--r--mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp85
-rw-r--r--mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp62
-rw-r--r--mlir/lib/Conversion/LLVMCommon/Pattern.cpp121
-rw-r--r--mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp115
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp116
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp56
-rw-r--r--mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp15
-rw-r--r--mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp28
-rw-r--r--mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp43
-rw-r--r--mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp10
-rw-r--r--mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp54
-rw-r--r--mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt17
-rw-r--r--mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp440
-rw-r--r--mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp7
-rw-r--r--mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp4
-rw-r--r--mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp4
-rw-r--r--mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp17
-rw-r--r--mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp10
-rw-r--r--mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp4
-rw-r--r--mlir/lib/Conversion/TosaToArith/TosaToArith.cpp14
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp90
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp16
-rw-r--r--mlir/lib/Conversion/VectorToAMX/CMakeLists.txt19
-rw-r--r--mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp429
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp27
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp1
-rw-r--r--mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp40
-rw-r--r--mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp313
-rw-r--r--mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt27
-rw-r--r--mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp1026
-rw-r--r--mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp16
-rw-r--r--mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp109
-rw-r--r--mlir/lib/Dialect/AMX/IR/AMXDialect.cpp4
-rw-r--r--mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp4
-rw-r--r--mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp21
-rw-r--r--mlir/lib/Dialect/Affine/Analysis/Utils.cpp78
-rw-r--r--mlir/lib/Dialect/Affine/IR/AffineOps.cpp32
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp199
-rw-r--r--mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp21
-rw-r--r--mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp93
-rw-r--r--mlir/lib/Dialect/Arith/IR/ArithOps.cpp5
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp12
-rw-r--r--mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp60
-rw-r--r--mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp2
-rw-r--r--mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp6
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp6
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp8
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp15
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp82
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp14
-rw-r--r--mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/DLTI/Traits.cpp2
-rw-r--r--mlir/lib/Dialect/EmitC/IR/EmitC.cpp69
-rw-r--r--mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp150
-rw-r--r--mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp4
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp98
-rw-r--r--mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp13
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp57
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp12
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp48
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp1
-rw-r--r--mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp38
-rw-r--r--mlir/lib/Dialect/LLVMIR/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp487
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp120
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp71
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp6
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp4
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp31
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp440
-rw-r--r--mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp7
-rw-r--r--mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp8
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp508
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp43
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp12
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt4
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp256
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp4
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp275
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp22
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp65
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp14
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp62
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp98
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp28
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp (renamed from mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp)15
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp11
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp132
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp8
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp9
-rw-r--r--mlir/lib/Dialect/Math/Transforms/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp (renamed from mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp)139
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp1
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp6
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp2
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp70
-rw-r--r--mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp6
-rw-r--r--mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp6
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp78
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp153
-rw-r--r--mlir/lib/Dialect/Ptr/IR/CMakeLists.txt20
-rw-r--r--mlir/lib/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp15
-rw-r--r--mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp12
-rw-r--r--mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp122
-rw-r--r--mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp52
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp3
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp9
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp5
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp4
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp4
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp15
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp63
-rw-r--r--mlir/lib/Dialect/SCF/Utils/Utils.cpp46
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp2
-rw-r--r--mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp5
-rw-r--r--mlir/lib/Dialect/Shard/Transforms/Partition.cpp10
-rw-r--r--mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp4
-rw-r--r--mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp7
-rw-r--r--mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp11
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp5
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp10
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorOps.cpp16
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp568
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp23
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp238
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp12
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformOps.cpp40
-rw-r--r--mlir/lib/Dialect/Transform/IR/Utils.cpp33
-rw-r--r--mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp4
-rw-r--r--mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp7
-rw-r--r--mlir/lib/Dialect/Utils/StaticValueUtils.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp217
-rw-r--r--mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp5
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp65
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp45
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp63
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp39
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp10
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp4
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp4
-rw-r--r--mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp26
-rw-r--r--mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp42
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt5
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp355
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp388
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp39
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp4
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp20
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp630
-rw-r--r--mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp98
-rw-r--r--mlir/lib/ExecutionEngine/ExecutionEngine.cpp20
-rw-r--r--mlir/lib/ExecutionEngine/JitRunner.cpp2
-rw-r--r--mlir/lib/ExecutionEngine/VulkanRuntime.cpp2
-rw-r--r--mlir/lib/IR/Block.cpp10
-rw-r--r--mlir/lib/IR/BuiltinAttributes.cpp17
-rw-r--r--mlir/lib/IR/CMakeLists.txt1
-rw-r--r--mlir/lib/IR/Dialect.cpp12
-rw-r--r--mlir/lib/IR/Location.cpp4
-rw-r--r--mlir/lib/IR/MLIRContext.cpp32
-rw-r--r--mlir/lib/IR/ODSSupport.cpp4
-rw-r--r--mlir/lib/IR/Remarks.cpp279
-rw-r--r--mlir/lib/Interfaces/SideEffectInterfaces.cpp5
-rw-r--r--mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp3
-rw-r--r--mlir/lib/Interfaces/ValueBoundsOpInterface.cpp2
-rw-r--r--mlir/lib/Query/Matcher/VariantValue.cpp5
-rw-r--r--mlir/lib/Query/Query.cpp2
-rw-r--r--mlir/lib/RegisterAllDialects.cpp2
-rw-r--r--mlir/lib/RegisterAllExtensions.cpp3
-rw-r--r--mlir/lib/RegisterAllPasses.cpp2
-rw-r--r--mlir/lib/Remark/CMakeLists.txt14
-rw-r--r--mlir/lib/Remark/RemarkStreamer.cpp69
-rw-r--r--mlir/lib/Rewrite/ByteCode.cpp300
-rw-r--r--mlir/lib/Rewrite/PatternApplicator.cpp6
-rw-r--r--mlir/lib/Target/CMakeLists.txt1
-rw-r--r--mlir/lib/Target/Cpp/TranslateToCpp.cpp52
-rw-r--r--mlir/lib/Target/LLVM/CMakeLists.txt24
-rw-r--r--mlir/lib/Target/LLVM/NVVM/Target.cpp48
-rw-r--r--mlir/lib/Target/LLVM/XeVM/Target.cpp418
-rw-r--r--mlir/lib/Target/LLVMIR/CMakeLists.txt2
-rw-r--r--mlir/lib/Target/LLVMIR/DataLayoutImporter.cpp63
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt1
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp12
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp139
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp133
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/Ptr/CMakeLists.txt12
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp66
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleImport.cpp140
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleTranslation.cpp180
-rw-r--r--mlir/lib/Target/LLVMIR/Transforms/CMakeLists.txt23
-rw-r--r--mlir/lib/Target/LLVMIR/Transforms/TargetToDataLayout.cpp62
-rw-r--r--mlir/lib/Target/LLVMIR/Transforms/TargetToTargetFeatures.cpp78
-rw-r--r--mlir/lib/Target/LLVMIR/Transforms/TargetUtils.cpp71
-rw-r--r--mlir/lib/Target/LLVMIR/TypeToLLVM.cpp11
-rw-r--r--mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp16
-rw-r--r--mlir/lib/Target/SPIRV/Serialization/Serializer.cpp53
-rw-r--r--mlir/lib/Target/SPIRV/Serialization/Serializer.h2
-rw-r--r--mlir/lib/Target/SPIRV/TranslateRegistration.cpp59
-rw-r--r--mlir/lib/Target/Wasm/CMakeLists.txt13
-rw-r--r--mlir/lib/Target/Wasm/TranslateFromWasm.cpp1522
-rw-r--r--mlir/lib/Target/Wasm/TranslateRegistration.cpp28
-rw-r--r--mlir/lib/Tools/PDLL/Parser/Parser.cpp3
-rw-r--r--mlir/lib/Tools/mlir-query/MlirQueryMain.cpp11
-rw-r--r--mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp13
-rw-r--r--mlir/lib/Transforms/CSE.cpp12
-rw-r--r--mlir/lib/Transforms/InlinerPass.cpp7
-rw-r--r--mlir/lib/Transforms/Mem2Reg.cpp4
-rw-r--r--mlir/lib/Transforms/RemoveDeadValues.cpp49
-rw-r--r--mlir/lib/Transforms/SROA.cpp4
-rw-r--r--mlir/lib/Transforms/SymbolDCE.cpp30
-rw-r--r--mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp9
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp703
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp42
-rw-r--r--mlir/lib/Transforms/Utils/InliningUtils.cpp13
-rw-r--r--mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp12
-rw-r--r--mlir/lib/Transforms/Utils/RegionUtils.cpp33
-rw-r--r--mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp126
-rw-r--r--mlir/python/mlir/_mlir_libs/_mlir/ir.pyi7
-rw-r--r--mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi1
-rw-r--r--mlir/python/mlir/dialects/_ods_common.py8
-rw-r--r--mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py93
-rw-r--r--mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir20
-rw-r--r--mlir/test/CAPI/CMakeLists.txt7
-rw-r--r--mlir/test/CAPI/global_constructors.c113
-rw-r--r--mlir/test/CMakeLists.txt8
-rw-r--r--mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir163
-rw-r--r--mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir12
-rw-r--r--mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir8
-rw-r--r--mlir/test/Conversion/ArithToLLVM/type-conversion.mlir15
-rw-r--r--mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir7
-rw-r--r--mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir43
-rw-r--r--mlir/test/Conversion/ControlFlowToLLVM/assert.mlir1
-rw-r--r--mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir12
-rw-r--r--mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir1
-rw-r--r--mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir1
-rw-r--r--mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir58
-rw-r--r--mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir1
-rw-r--r--mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir1
-rw-r--r--mlir/test/Conversion/MathToLibm/convert-to-libm.mlir228
-rw-r--r--mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir50
-rw-r--r--mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir29
-rw-r--r--mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir8
-rw-r--r--mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir4
-rw-r--r--mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir114
-rw-r--r--mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir22
-rw-r--r--mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir8
-rw-r--r--mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir145
-rw-r--r--mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir1
-rw-r--r--mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir318
-rw-r--r--mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir18
-rw-r--r--mlir/test/Conversion/SCFToSPIRV/for.mlir14
-rw-r--r--mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir2
-rw-r--r--mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir8
-rw-r--r--mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir10
-rw-r--r--mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir8
-rw-r--r--mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir32
-rw-r--r--mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir95
-rw-r--r--mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir1
-rw-r--r--mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir310
-rw-r--r--mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir355
-rw-r--r--mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir44
-rw-r--r--mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir37
-rw-r--r--mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir17
-rw-r--r--mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir424
-rw-r--r--mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir290
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir80
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/dpas.mlir18
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/fence.mlir15
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir75
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir261
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir75
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir42
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir25
-rw-r--r--mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir91
-rw-r--r--mlir/test/Dialect/AMDGPU/invalid.mlir8
-rw-r--r--mlir/test/Dialect/AMDGPU/ops.mlir18
-rw-r--r--mlir/test/Dialect/AMX/invalid.mlir126
-rw-r--r--mlir/test/Dialect/AMX/side-effects.mlir32
-rw-r--r--mlir/test/Dialect/Affine/loop-permute.mlir32
-rw-r--r--mlir/test/Dialect/Arith/canonicalize.mlir29
-rw-r--r--mlir/test/Dialect/Arith/int-range-interface.mlir9
-rw-r--r--mlir/test/Dialect/EmitC/attrs.mlir4
-rw-r--r--mlir/test/Dialect/EmitC/form-expressions.mlir (renamed from mlir/test/Dialect/EmitC/transforms.mlir)56
-rw-r--r--mlir/test/Dialect/EmitC/invalid_ops.mlir59
-rw-r--r--mlir/test/Dialect/EmitC/ops.mlir23
-rw-r--r--mlir/test/Dialect/EmitC/types.mlir6
-rw-r--r--mlir/test/Dialect/EmitC/wrap-func-in-class.mlir57
-rw-r--r--mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir40
-rw-r--r--mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir17
-rw-r--r--mlir/test/Dialect/GPU/broadcast-speculatability.mlir24
-rw-r--r--mlir/test/Dialect/GPU/int-range-interface.mlir19
-rw-r--r--mlir/test/Dialect/GPU/ops.mlir28
-rw-r--r--mlir/test/Dialect/GPU/outlining.mlir56
-rw-r--r--mlir/test/Dialect/GPU/promote-shuffle-amdgpu.mlir14
-rw-r--r--mlir/test/Dialect/LLVMIR/call-intrin.mlir9
-rw-r--r--mlir/test/Dialect/LLVMIR/func.mlir6
-rw-r--r--mlir/test/Dialect/LLVMIR/inlining.mlir18
-rw-r--r--mlir/test/Dialect/LLVMIR/invalid.mlir36
-rw-r--r--mlir/test/Dialect/LLVMIR/mem2reg.mlir35
-rw-r--r--mlir/test/Dialect/LLVMIR/nvvm.mlir43
-rw-r--r--mlir/test/Dialect/LLVMIR/ptr.mlir9
-rw-r--r--mlir/test/Dialect/LLVMIR/rocdl.mlir23
-rw-r--r--mlir/test/Dialect/LLVMIR/roundtrip.mlir8
-rw-r--r--mlir/test/Dialect/LLVMIR/sroa.mlir2
-rw-r--r--mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir50
-rw-r--r--mlir/test/Dialect/Linalg/block-pack-matmul.mlir144
-rw-r--r--mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir126
-rw-r--r--mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir66
-rw-r--r--mlir/test/Dialect/Linalg/data-layout-propagation.mlir113
-rw-r--r--mlir/test/Dialect/Linalg/decompose-unpack.mlir17
-rw-r--r--mlir/test/Dialect/Linalg/elementwise/named-to-elementwise.mlir56
-rw-r--r--mlir/test/Dialect/Linalg/fold-add-into-dest.mlir30
-rw-r--r--mlir/test/Dialect/Linalg/linalg-morph-category-ops.mlir15
-rw-r--r--mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir14
-rw-r--r--mlir/test/Dialect/Linalg/named-ops.mlir44
-rw-r--r--mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir85
-rw-r--r--mlir/test/Dialect/Linalg/roundtrip.mlir49
-rw-r--r--mlir/test/Dialect/Linalg/simplify-depthwise-conv.mlir (renamed from mlir/test/Dialect/Linalg/namedop_conversion.mlir)2
-rw-r--r--mlir/test/Dialect/Linalg/tile-to-forall.mlir2
-rw-r--r--mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir12
-rw-r--r--mlir/test/Dialect/Linalg/transform-op-pad.mlir6
-rw-r--r--mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir89
-rw-r--r--mlir/test/Dialect/Linalg/transpose-matmul.mlir38
-rw-r--r--mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir665
-rw-r--r--mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir124
-rw-r--r--mlir/test/Dialect/Math/expand-math.mlir35
-rw-r--r--mlir/test/Dialect/Math/ops.mlir15
-rw-r--r--mlir/test/Dialect/NVGPU/invalid.mlir11
-rw-r--r--mlir/test/Dialect/OpenACC/ops.mlir64
-rw-r--r--mlir/test/Dialect/OpenMP/invalid.mlir107
-rw-r--r--mlir/test/Dialect/OpenMP/ops.mlir12
-rw-r--r--mlir/test/Dialect/Ptr/invalid.mlir40
-rw-r--r--mlir/test/Dialect/Ptr/ops.mlir47
-rw-r--r--mlir/test/Dialect/SCF/canonicalize.mlir15
-rw-r--r--mlir/test/Dialect/SCF/ops.mlir4
-rw-r--r--mlir/test/Dialect/SPIRV/IR/invalid.mlir43
-rw-r--r--mlir/test/Dialect/Tosa/availability.mlir4
-rw-r--r--mlir/test/Dialect/Tosa/canonicalize.mlir69
-rw-r--r--mlir/test/Dialect/Tosa/dynamic_extension.mlir8
-rw-r--r--mlir/test/Dialect/Tosa/error_if_check.mlir30
-rw-r--r--mlir/test/Dialect/Tosa/invalid.mlir65
-rw-r--r--mlir/test/Dialect/Tosa/invalid_extension.mlir14
-rw-r--r--mlir/test/Dialect/Tosa/level_check.mlir8
-rw-r--r--mlir/test/Dialect/Tosa/ops.mlir68
-rw-r--r--mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir4
-rw-r--r--mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir10
-rw-r--r--mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir20
-rw-r--r--mlir/test/Dialect/Tosa/tosa-validation-valid.mlir4
-rw-r--r--mlir/test/Dialect/Tosa/verifier.mlir2
-rw-r--r--mlir/test/Dialect/Vector/canonicalize.mlir244
-rw-r--r--mlir/test/Dialect/Vector/invalid.mlir128
-rw-r--r--mlir/test/Dialect/Vector/linearize.mlir14
-rw-r--r--mlir/test/Dialect/Vector/ops.mlir4
-rw-r--r--mlir/test/Dialect/Vector/transform-vector.mlir7
-rw-r--r--mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir45
-rw-r--r--mlir/test/Dialect/Vector/vector-gather-lowering.mlir4
-rw-r--r--mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir2
-rw-r--r--mlir/test/Dialect/Vector/vector-sink.mlir30
-rw-r--r--mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir8
-rw-r--r--mlir/test/Dialect/Vector/vector-warp-distribute.mlir55
-rw-r--r--mlir/test/Dialect/WasmSSA/custom_parser/if.mlir53
-rw-r--r--mlir/test/Dialect/WasmSSA/custom_parser/memory.mlir7
-rw-r--r--mlir/test/Dialect/WasmSSA/custom_parser/table.mlir7
-rw-r--r--mlir/test/Dialect/XeGPU/invalid.mlir196
-rw-r--r--mlir/test/Dialect/XeGPU/layout.mlir23
-rw-r--r--mlir/test/Dialect/XeGPU/ops.mlir164
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir37
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir27
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir85
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir368
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir55
-rw-r--r--mlir/test/Dialect/common_folders.mlir22
-rw-r--r--mlir/test/Examples/standalone/test.toy7
-rw-r--r--mlir/test/IR/test-walk-pattern-rewrite-driver.mlir24
-rw-r--r--mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir9
-rw-r--r--mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir9
-rw-r--r--mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir8
-rw-r--r--mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir11
-rw-r--r--mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir9
-rw-r--r--mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir9
-rw-r--r--mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir12
-rw-r--r--mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir8
-rw-r--r--mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir12
-rw-r--r--mlir/test/Integration/Dialect/Tensor/cast-runtime-verification.mlir11
-rw-r--r--mlir/test/Integration/Dialect/Tensor/dim-runtime-verification.mlir16
-rw-r--r--mlir/test/Integration/Dialect/Tensor/extract-runtime-verification.mlir11
-rw-r--r--mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir11
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/ArmSME/outerproduct-f32.mlir4
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/ArmSME/outerproduct-f64.mlir4
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/ArmSME/transfer-write-2d.mlir4
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction.mlir2
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/scalable-interleave.mlir4
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/interleave.mlir4
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/outerproduct-f32.mlir6
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/outerproduct-i64.mlir6
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir4
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/transfer-read-2d.mlir4
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/transfer-read-3d.mlir2
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/transfer-read.mlir2
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/transfer-write.mlir8
-rw-r--r--mlir/test/Integration/Dialect/XeVM/GPU/lit.local.cfg4
-rw-r--r--mlir/test/Integration/Dialect/XeVM/GPU/xevm_block_dpas.mlir146
-rw-r--r--mlir/test/Integration/Dialect/XeVM/GPU/xevm_block_load_store.mlir109
-rw-r--r--mlir/test/Integration/Dialect/XeVM/GPU/xevm_block_load_store_pack_register.mlir131
-rw-r--r--mlir/test/Integration/Dialect/XeVM/GPU/xevm_block_load_store_transpose.mlir133
-rw-r--r--mlir/test/Integration/Dialect/XeVM/GPU/xevm_store_cst.mlir75
-rw-r--r--mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir20
-rw-r--r--mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir16
-rw-r--r--mlir/test/Pass/pipeline-options-parsing.mlir10
-rw-r--r--mlir/test/Target/Cpp/class.mlir78
-rw-r--r--mlir/test/Target/Cpp/const.mlir8
-rw-r--r--mlir/test/Target/Cpp/control_flow.mlir2
-rw-r--r--mlir/test/Target/Cpp/expressions.mlir76
-rw-r--r--mlir/test/Target/Cpp/for.mlir6
-rw-r--r--mlir/test/Target/Cpp/member.mlir33
-rw-r--r--mlir/test/Target/Cpp/switch.mlir2
-rw-r--r--mlir/test/Target/LLVMIR/Import/function-attributes.ll17
-rw-r--r--mlir/test/Target/LLVMIR/Import/global-variables.ll35
-rw-r--r--mlir/test/Target/LLVMIR/Import/intrinsic-prefer-unregistered.ll8
-rw-r--r--mlir/test/Target/LLVMIR/Import/intrinsic.ll21
-rw-r--r--mlir/test/Target/LLVMIR/fp-math-function-attributes.mlir18
-rw-r--r--mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir19
-rw-r--r--mlir/test/Target/LLVMIR/llvmir-invalid.mlir20
-rw-r--r--mlir/test/Target/LLVMIR/llvmir.mlir12
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/prefetch.mlir18
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/tma_prefetch.mlir117
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/tma_prefetch_invalid.mlir56
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/tma_store.mlir94
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/tma_store_invalid.mlir46
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/tma_store_reduce_invalid.mlir25
-rw-r--r--mlir/test/Target/LLVMIR/nvvmir-invalid.mlir276
-rw-r--r--mlir/test/Target/LLVMIR/nvvmir.mlir181
-rw-r--r--mlir/test/Target/LLVMIR/ompenmp-target-allocmem-freemem.mlir42
-rw-r--r--mlir/test/Target/LLVMIR/omptarget-atomic-capture-control-options.mlir44
-rw-r--r--mlir/test/Target/LLVMIR/omptarget-atomic-update-control-options.mlir36
-rw-r--r--mlir/test/Target/LLVMIR/omptarget-debug-147063.mlir45
-rw-r--r--mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir2
-rw-r--r--mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir2
-rw-r--r--mlir/test/Target/LLVMIR/omptarget-wsloop.mlir4
-rw-r--r--mlir/test/Target/LLVMIR/openmp-simd-aligned.mlir22
-rw-r--r--mlir/test/Target/LLVMIR/ptr.mlir16
-rw-r--r--mlir/test/Target/LLVMIR/rocdl.mlir39
-rw-r--r--mlir/test/Target/LLVMIR/target-to-data-layout-and-target-features.mlir137
-rw-r--r--mlir/test/Target/LLVMIR/target-to-data-layout-invalid.mlir9
-rw-r--r--mlir/test/Target/LLVMIR/target-to-data-layout-no-init.mlir12
-rw-r--r--mlir/test/Target/LLVMIR/target-to-target-features-dlti-query.mlir75
-rw-r--r--mlir/test/Target/SPIRV/arm-tensor-constant.mlir56
-rw-r--r--mlir/test/Target/SPIRV/debug-negative.mlir5
-rw-r--r--mlir/test/Target/SPIRV/debug.mlir69
-rw-r--r--mlir/test/Target/SPIRV/mlir-translate.mlir29
-rw-r--r--mlir/test/Target/SPIRV/module.mlir18
-rw-r--r--mlir/test/Target/Wasm/abs.mlir23
-rw-r--r--mlir/test/Target/Wasm/and.mlir27
-rw-r--r--mlir/test/Target/Wasm/bad_wasm_version.yaml8
-rw-r--r--mlir/test/Target/Wasm/clz.mlir25
-rw-r--r--mlir/test/Target/Wasm/const.mlir37
-rw-r--r--mlir/test/Target/Wasm/copysign.mlir31
-rw-r--r--mlir/test/Target/Wasm/ctz.mlir25
-rw-r--r--mlir/test/Target/Wasm/div.mlir127
-rw-r--r--mlir/test/Target/Wasm/function_export_out_of_scope.yaml13
-rw-r--r--mlir/test/Target/Wasm/global.mlir66
-rw-r--r--mlir/test/Target/Wasm/import.mlir19
-rw-r--r--mlir/test/Target/Wasm/inputs/abs.yaml.wasm33
-rw-r--r--mlir/test/Target/Wasm/inputs/and.yaml.wasm33
-rw-r--r--mlir/test/Target/Wasm/inputs/clz.yaml.wasm33
-rw-r--r--mlir/test/Target/Wasm/inputs/const.yaml.wasm39
-rw-r--r--mlir/test/Target/Wasm/inputs/copysign.yaml.wasm33
-rw-r--r--mlir/test/Target/Wasm/inputs/ctz.yaml.wasm33
-rw-r--r--mlir/test/Target/Wasm/inputs/div.yaml.wasm89
-rw-r--r--mlir/test/Target/Wasm/inputs/global.yaml.wasm63
-rw-r--r--mlir/test/Target/Wasm/inputs/import.yaml.wasm44
-rw-r--r--mlir/test/Target/Wasm/inputs/local.yaml.wasm37
-rw-r--r--mlir/test/Target/Wasm/inputs/max.yaml.wasm33
-rw-r--r--mlir/test/Target/Wasm/inputs/memory_min_eq_max.yaml.wasm10
-rw-r--r--mlir/test/Target/Wasm/inputs/memory_min_max.yaml.wasm10
-rw-r--r--mlir/test/Target/Wasm/inputs/memory_min_no_max.yaml.wasm8
-rw-r--r--mlir/test/Target/Wasm/inputs/min.yaml.wasm33
-rw-r--r--mlir/test/Target/Wasm/inputs/neg.yaml.wasm33
-rw-r--r--mlir/test/Target/Wasm/inputs/or.yaml.wasm33
-rw-r--r--mlir/test/Target/Wasm/inputs/popcnt.yaml.wasm33
-rw-r--r--mlir/test/Target/Wasm/inputs/rem.yaml.wasm45
-rw-r--r--mlir/test/Target/Wasm/inputs/rotl.yaml.wasm33
-rw-r--r--mlir/test/Target/Wasm/inputs/rotr.yaml.wasm33
-rw-r--r--mlir/test/Target/Wasm/inputs/shl.yaml.wasm33
-rw-r--r--mlir/test/Target/Wasm/inputs/shr_s.yaml.wasm33
-rw-r--r--mlir/test/Target/Wasm/inputs/shr_u.yaml.wasm33
-rw-r--r--mlir/test/Target/Wasm/inputs/sqrt.yaml.wasm33
-rw-r--r--mlir/test/Target/Wasm/inputs/stats.yaml.wasm38
-rw-r--r--mlir/test/Target/Wasm/inputs/sub.yaml.wasm39
-rw-r--r--mlir/test/Target/Wasm/inputs/table.yaml.wasm23
-rw-r--r--mlir/test/Target/Wasm/inputs/xor.yaml.wasm33
-rw-r--r--mlir/test/Target/Wasm/invalid_function_type_index.yaml16
-rw-r--r--mlir/test/Target/Wasm/local.mlir59
-rw-r--r--mlir/test/Target/Wasm/max.mlir30
-rw-r--r--mlir/test/Target/Wasm/memory_min_eq_max.mlir7
-rw-r--r--mlir/test/Target/Wasm/memory_min_max.mlir7
-rw-r--r--mlir/test/Target/Wasm/memory_min_no_max.mlir7
-rw-r--r--mlir/test/Target/Wasm/min.mlir29
-rw-r--r--mlir/test/Target/Wasm/missing_header.yaml12
-rw-r--r--mlir/test/Target/Wasm/neg.mlir23
-rw-r--r--mlir/test/Target/Wasm/or.mlir27
-rw-r--r--mlir/test/Target/Wasm/popcnt.mlir25
-rw-r--r--mlir/test/Target/Wasm/rem.mlir53
-rw-r--r--mlir/test/Target/Wasm/rotl.mlir27
-rw-r--r--mlir/test/Target/Wasm/rotr.mlir27
-rw-r--r--mlir/test/Target/Wasm/shl.mlir27
-rw-r--r--mlir/test/Target/Wasm/shr_s.mlir27
-rw-r--r--mlir/test/Target/Wasm/shr_u.mlir27
-rw-r--r--mlir/test/Target/Wasm/sqrt.mlir23
-rw-r--r--mlir/test/Target/Wasm/sub.mlir52
-rw-r--r--mlir/test/Target/Wasm/xor.mlir27
-rw-r--r--mlir/test/Transforms/inlining.mlir10
-rw-r--r--mlir/test/Transforms/remove-dead-values.mlir44
-rw-r--r--mlir/test/Transforms/test-canonicalize.mlir16
-rw-r--r--mlir/test/Transforms/test-context-aware-type-converter.mlir40
-rw-r--r--mlir/test/Transforms/test-legalize-type-conversion.mlir1
-rw-r--r--mlir/test/Transforms/test-legalizer-fold-after.mlir9
-rw-r--r--mlir/test/Transforms/test-legalizer-fold-before.mlir9
-rw-r--r--mlir/test/Transforms/test-legalizer-no-fold.mlir12
-rw-r--r--mlir/test/Transforms/test-legalizer.mlir63
-rw-r--r--mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp1
-rw-r--r--mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp9
-rw-r--r--mlir/test/lib/Dialect/GPU/CMakeLists.txt1
-rw-r--r--mlir/test/lib/Dialect/LLVM/TestPatterns.cpp34
-rw-r--r--mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp2
-rw-r--r--mlir/test/lib/Dialect/Math/CMakeLists.txt1
-rw-r--r--mlir/test/lib/Dialect/Math/TestExpandMath.cpp62
-rw-r--r--mlir/test/lib/Dialect/Test/TestAttributes.cpp12
-rw-r--r--mlir/test/lib/Dialect/Test/TestDialect.cpp7
-rw-r--r--mlir/test/lib/Dialect/Test/TestEnumDefs.td10
-rw-r--r--mlir/test/lib/Dialect/Test/TestOps.td22
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp218
-rw-r--r--mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp4
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp24
-rw-r--r--mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp111
-rw-r--r--mlir/test/lib/Pass/TestPassManager.cpp59
-rw-r--r--mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp2
-rw-r--r--mlir/test/lit.cfg.py11
-rw-r--r--mlir/test/lit.site.cfg.py.in2
-rw-r--r--mlir/test/mlir-runner/test-expand-math-approx.mlir2
-rw-r--r--mlir/test/mlir-tblgen/enums-python-bindings.td18
-rw-r--r--mlir/test/mlir-translate/emitc_classops.mlir78
-rw-r--r--mlir/test/python/dialects/linalg/opdsl/test_core_named_ops.py2
-rw-r--r--mlir/test/python/dialects/nvvm.py41
-rw-r--r--mlir/test/python/dialects/transform_vector_ext.py2
-rw-r--r--mlir/test/python/global_constructors.py72
-rw-r--r--mlir/test/python/ir/auto_location.py101
-rw-r--r--mlir/test/python/ir/context_managers.py22
-rw-r--r--mlir/test/python/ir/module.py24
-rw-r--r--mlir/test/python/ir/operation.py10
-rw-r--r--mlir/test/python/ir/symbol_table.py1
-rw-r--r--mlir/test/python/pass_manager.py33
-rw-r--r--mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp4
-rw-r--r--mlir/tools/mlir-opt/mlir-opt.cpp2
-rw-r--r--mlir/tools/mlir-runner/mlir-runner.cpp6
-rw-r--r--mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp8
-rw-r--r--mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp6
-rw-r--r--mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp2
-rw-r--r--mlir/tools/mlir-tblgen/EnumsGen.cpp2
-rw-r--r--mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp1
-rw-r--r--mlir/tools/mlir-tblgen/mlir-tblgen.cpp11
-rw-r--r--mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp29
-rw-r--r--mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp24
-rw-r--r--mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp11
-rw-r--r--mlir/unittests/Dialect/SparseTensor/MergerTest.cpp2
-rw-r--r--mlir/unittests/ExecutionEngine/Invoke.cpp59
-rw-r--r--mlir/unittests/IR/AttrTypeReplacerTest.cpp8
-rw-r--r--mlir/unittests/IR/CMakeLists.txt2
-rw-r--r--mlir/unittests/IR/RemarkTest.cpp315
-rw-r--r--mlir/unittests/Rewrite/PatternBenefit.cpp7
-rwxr-xr-xmlir/utils/clang-tidy/apply-clang-tidy.sh10
-rw-r--r--mlir/utils/tree-sitter-mlir/dialect/linalg.js2
-rw-r--r--mlir/utils/tree-sitter-mlir/queries/highlights.scm2
856 files changed, 31737 insertions, 7850 deletions
diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index a9414eb..1a211f5 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -137,6 +137,14 @@ else()
set(MLIR_ENABLE_ROCM_CONVERSIONS 0)
endif()
+# Build the XeVM conversions and run according tests if the SPIRV backend
+# is available.
+if ("SPIRV" IN_LIST LLVM_TARGETS_TO_BUILD)
+ set(MLIR_ENABLE_XEVM_CONVERSIONS 1)
+else()
+ set(MLIR_ENABLE_XEVM_CONVERSIONS 0)
+endif()
+
set(MLIR_ENABLE_CUDA_RUNNER 0 CACHE BOOL "Enable building the MLIR CUDA runner")
set(MLIR_ENABLE_ROCM_RUNNER 0 CACHE BOOL "Enable building the MLIR ROCm runner")
set(MLIR_ENABLE_SYCL_RUNNER 0 CACHE BOOL "Enable building the MLIR SYCL runner")
@@ -183,6 +191,10 @@ configure_file(
set(MLIR_BINDINGS_PYTHON_NB_DOMAIN "mlir"
CACHE STRING "nanobind domain for MLIR python bindings.")
+set(MLIR_PYTHON_PACKAGE_PREFIX "mlir"
+ CACHE STRING "Specifies that all MLIR packages are co-located under the
+ `MLIR_PYTHON_PACKAGE_PREFIX` top level package (the API has been
+ embedded in a relocatable way).")
set(MLIR_ENABLE_BINDINGS_PYTHON 0 CACHE BOOL
"Enables building of Python bindings.")
set(MLIR_BINDINGS_PYTHON_INSTALL_PREFIX "python_packages/mlir_core/mlir" CACHE STRING
diff --git a/mlir/Maintainers.md b/mlir/Maintainers.md
index c31b5c6..02e93eb 100644
--- a/mlir/Maintainers.md
+++ b/mlir/Maintainers.md
@@ -27,6 +27,42 @@ dialects, build system and language bindings.
[@joker-eph](https://github.com/joker-eph) (GitHub),
mehdi_amini (Discourse)
+### Code
+
+#### Standalone subcategories
+* Core tooling (ODS, DRR, PDLL, LSP) (core)
+* CMake ([christopherbate](https://github.com/christopherbate))
+* Dialect Conversion ([matthias-springer](https://github.com/matthias-springer), [zero9178](https://github.com/zero9178))
+* Python Bindings ([makslevental](https://github.com/makslevental), [rolfmorel](https://github.com/rolfmorel))
+
+### Dialects
+
+#### Code Structure Dialects
+* Builtin Dialect (core)
+* ‘func’ Dialect (core)
+* ‘scf’ Dialect (core)
+* ‘cf’ Dialect (core)
+* ‘index’ Dialect (core)
+* ‘ptr’ Dialect ([fabianmcg](https://github.com/fabianmcg))
+
+#### Basic Compute Dialects
+* ‘arith’ Dialect (core)
+* ‘math’ Dialect (core)
+* Rewrite System Dialects (core)
+* Transform Dialect ([martin-luecke](https://github.com/martin-luecke), [ftynse](https://github.com/ftynse), [rolfmorel](https://github.com/rolfmorel))
+* ‘pdl_interp’ Dialect ([jpienaar](https://github.com/jpienaar))
+* ‘pdl’ Dialect ([jpienaar](https://github.com/jpienaar))
+
+#### Accessory Dialects
+* ‘affine’ Dialect ([ftynse](https://github.com/ftynse))
+* ‘dlti’ Dialect ([rolfmorel](https://github.com/rolfmorel))
+* ‘irdl’ Dialect ([math-fehr](https://github.com/math-fehr), [moxinilian](https://github.com/moxinilian))
+* ‘shape’ Dialect ([jpienaar](https://github.com/jpienaar))
+* ‘smt’ Dialect ([fabianschuiki](https://github.com/fabianschuiki), [maerhart](https://github.com/maerhart))
+* ‘ub’ Dialect ([Hardcode84](https://github.com/Hardcode84))
+* ‘complex’ Dialect (core)
+* ‘async’ Dialect (unmaintained)
+
## Egress
MLIR components pertaining to egress flows from MLIR, in particular to LLVM IR.
@@ -44,6 +80,40 @@ MLIR components pertaining to egress flows from MLIR, in particular to LLVM IR.
[@gysit](https://github.com/gysit) (GitHub),
gysit (Discourse)
+### Dialects
+
+The `egress` maintainer refers to the people working in the Egress category,
+with the point-of-contact being the maintainers above. Named maintainers, if
+available, should be contacted first, as they're more active in those areas.
+
+#### Lowering Dialects
+* ‘llvm’ Dialect (egress)
+* ‘SPIR-V’ Dialect ([@kuhar](https://github.com/kuhar), [@antiagainst](https://github.com/antiagainst))
+* ‘emitc’ Dialect ([@aniragil](https://github.com/aniragil), [@marbre](https://github.com/marbre))
+
+#### GPU Dialects
+* ‘gpu’ Dialect ([@fabianmcg](https://github.com/fabianmcg))
+* ‘amdgpu’ Dialect ([@krzysz00](https://github.com/krzysz00))
+* ‘rocdl’ Dialect ([@krzysz00](https://github.com/krzysz00))
+* ‘nvgpu’ Dialect ([@grypp](https://github.com/grypp))
+* ‘nvvm’ Dialect ([@grypp](https://github.com/grypp))
+* ‘xegpu’ Dialect ([@chencha3](https://github.com/chencha3), [@Jianhui-Li](https://github.com/Jianhui-Li))
+* 'xevm' Dialect ([@silee2](https://github.com/silee2))
+
+#### CPU Dialects
+* ‘arm_neon’ Dialect ([@banach-space](https://github.com/banach-space))
+* ‘arm_sve’ Dialect ([@banach-space](https://github.com/banach-space))
+* ‘ArmSME’ Dialect ([@banach-space](https://github.com/banach-space))
+* ‘amx’ Dialect ([@adam-smnk](https://github.com/adam-smnk))
+* ‘x86vector’ Dialect ([@adam-smnk](https://github.com/adam-smnk))
+* ‘vcix’ Dialect ([@mshockwave](https://github.com/mshockwave))
+
+#### Paradigm Dialects
+* ‘omp’ Dialect ([@tblah](https://github.com/tblah), [@skatrak](https://github.com/skatrak))
+* ‘acc’ Dialect ([@clementval](https://github.com/clementval), [@razvanlupusoru](https://github.com/razvanlupusoru))
+* ‘mpi’ Dialect ([@fschlimb](https://github.com/fschlimb))
+* ‘shard’ Dialect ([@fschlimb](https://github.com/fschlimb))
+
## Tensor Compiler
MLIR components specific to construction of compilers for tensor algebra, in
diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake
index ff4269e..6589458a 100644
--- a/mlir/cmake/modules/AddMLIR.cmake
+++ b/mlir/cmake/modules/AddMLIR.cmake
@@ -174,8 +174,7 @@ function(add_mlir_dialect dialect dialect_namespace)
mlir_tablegen(${dialect}Types.cpp.inc -gen-typedef-defs -typedefs-dialect=${dialect_namespace})
mlir_tablegen(${dialect}Dialect.h.inc -gen-dialect-decls -dialect=${dialect_namespace})
mlir_tablegen(${dialect}Dialect.cpp.inc -gen-dialect-defs -dialect=${dialect_namespace})
- add_public_tablegen_target(MLIR${dialect}IncGen)
- add_dependencies(mlir-headers MLIR${dialect}IncGen)
+ add_mlir_dialect_tablegen_target(MLIR${dialect}IncGen)
endfunction()
# Declare sharded dialect operation declarations and definitions
@@ -190,7 +189,7 @@ function(add_sharded_ops ops_target shard_count)
tablegen(MLIR_SRC_SHARDER ${SHARDED_SRC} -op-shard-index=${index})
set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${SHARDED_SRC})
endforeach()
- add_public_tablegen_target(MLIR${ops_target}ShardGen)
+ add_mlir_dialect_tablegen_target(MLIR${ops_target}ShardGen)
set(SHARDED_SRCS ${SHARDED_SRCS} PARENT_SCOPE)
endfunction()
@@ -199,10 +198,23 @@ function(add_mlir_interface interface)
set(LLVM_TARGET_DEFINITIONS ${interface}.td)
mlir_tablegen(${interface}.h.inc -gen-op-interface-decls)
mlir_tablegen(${interface}.cpp.inc -gen-op-interface-defs)
- add_public_tablegen_target(MLIR${interface}IncGen)
- add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
+ add_mlir_generic_tablegen_target(MLIR${interface}IncGen)
endfunction()
+# Add a dialect-specific tablegen target that generates headers in the include directory.
+# In most cases, this is what should be used after invoking `mlir_tablegen`.
+macro(add_mlir_dialect_tablegen_target target)
+ add_public_tablegen_target(${target})
+ add_dependencies(mlir-headers ${target})
+endmacro()
+
+# Add a dialect-independent tablegen target that generates headers in the include directory.
+# Generally this is used for files outside of the Dialects/ folder, and also for interfaces
+# that do not depend on dialect-specific headers.
+macro(add_mlir_generic_tablegen_target target)
+ add_public_tablegen_target(${target})
+ add_dependencies(mlir-generic-headers ${target})
+endmacro()
# Generate Documentation
function(add_mlir_doc doc_filename output_file output_directory command)
@@ -388,6 +400,9 @@ function(add_mlir_library name)
if(TARGET ${name})
target_link_libraries(${name} INTERFACE ${LLVM_COMMON_LIBS})
+ if(ARG_INSTALL_WITH_TOOLCHAIN)
+ set_target_properties(${name} PROPERTIES MLIR_INSTALL_WITH_TOOLCHAIN TRUE)
+ endif()
if(NOT ARG_DISABLE_INSTALL)
add_mlir_library_install(${name})
endif()
@@ -617,26 +632,27 @@ endfunction(add_mlir_aggregate)
# This is usually done as part of add_mlir_library but is broken out for cases
# where non-standard library builds can be installed.
function(add_mlir_library_install name)
- if (NOT LLVM_INSTALL_TOOLCHAIN_ONLY)
- get_target_export_arg(${name} MLIR export_to_mlirtargets UMBRELLA mlir-libraries)
- install(TARGETS ${name}
- COMPONENT ${name}
- ${export_to_mlirtargets}
- LIBRARY DESTINATION lib${LLVM_LIBDIR_SUFFIX}
- ARCHIVE DESTINATION lib${LLVM_LIBDIR_SUFFIX}
- RUNTIME DESTINATION "${CMAKE_INSTALL_BINDIR}"
- # Note that CMake will create a directory like:
- # objects-${CMAKE_BUILD_TYPE}/obj.LibName
- # and put object files there.
- OBJECTS DESTINATION lib${LLVM_LIBDIR_SUFFIX}
- )
+ get_target_property(_install_with_toolchain ${name} MLIR_INSTALL_WITH_TOOLCHAIN)
+ if (NOT LLVM_INSTALL_TOOLCHAIN_ONLY OR _install_with_toolchain)
+ get_target_export_arg(${name} MLIR export_to_mlirtargets UMBRELLA mlir-libraries)
+ install(TARGETS ${name}
+ COMPONENT ${name}
+ ${export_to_mlirtargets}
+ LIBRARY DESTINATION lib${LLVM_LIBDIR_SUFFIX}
+ ARCHIVE DESTINATION lib${LLVM_LIBDIR_SUFFIX}
+ RUNTIME DESTINATION "${CMAKE_INSTALL_BINDIR}"
+ # Note that CMake will create a directory like:
+ # objects-${CMAKE_BUILD_TYPE}/obj.LibName
+ # and put object files there.
+ OBJECTS DESTINATION lib${LLVM_LIBDIR_SUFFIX}
+ )
- if (NOT LLVM_ENABLE_IDE)
- add_llvm_install_targets(install-${name}
- DEPENDS ${name}
- COMPONENT ${name})
- endif()
- set_property(GLOBAL APPEND PROPERTY MLIR_ALL_LIBS ${name})
+ if (NOT LLVM_ENABLE_IDE)
+ add_llvm_install_targets(install-${name}
+ DEPENDS ${name}
+ COMPONENT ${name})
+ endif()
+ set_property(GLOBAL APPEND PROPERTY MLIR_ALL_LIBS ${name})
endif()
set_property(GLOBAL APPEND PROPERTY MLIR_EXPORTS ${name})
endfunction()
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index c14e614..2b88355 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -704,7 +704,12 @@ function(add_mlir_python_extension libname extname)
# NanobindAdaptors.h uses PyClassMethod_New to build `pure_subclass`es but nanobind
# doesn't declare this API as undefined in its linker flags. So we need to declare it as such
# for downstream users that do not do something like `-undefined dynamic_lookup`.
- target_link_options(${libname} PUBLIC "LINKER:-U,_PyClassMethod_New")
+ # Same for the rest.
+ target_link_options(${libname} PUBLIC
+ "LINKER:-U,_PyClassMethod_New"
+ "LINKER:-U,_PyCode_Addr2Location"
+ "LINKER:-U,_PyFrame_GetLasti"
+ )
endif()
endif()
diff --git a/mlir/cmake/modules/FindSyclRuntime.cmake b/mlir/cmake/modules/FindSyclRuntime.cmake
index 9e6ae04..5986895 100644
--- a/mlir/cmake/modules/FindSyclRuntime.cmake
+++ b/mlir/cmake/modules/FindSyclRuntime.cmake
@@ -19,7 +19,7 @@ if(NOT DEFINED ENV{CMPLR_ROOT})
else()
get_filename_component(ONEAPI_VER "$ENV{CMPLR_ROOT}" NAME)
if(ONEAPI_VER VERSION_LESS 2024.0)
- if(LINUX OR (${CMAKE_SYSTEM_NAME} MATCHES "Linux"))
+ if(LINUX OR ("${CMAKE_SYSTEM_NAME}" MATCHES "Linux"))
set(SyclRuntime_ROOT "$ENV{CMPLR_ROOT}/linux")
elseif(WIN32)
set(SyclRuntime_ROOT "$ENV{CMPLR_ROOT}/windows")
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index bef9e7f..98ac635 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -216,13 +216,28 @@ added to an attached operation, they need to be re-parented to the containing
module).
Due to the validity and parenting accounting needs, `PyOperation` is the owner
-for regions and blocks and needs to be a top-level type that we can count on not
-aliasing. This let's us do things like selectively invalidating instances when
-mutations occur without worrying that there is some alias to the same operation
-in the hierarchy. Operations are also the only entity that are allowed to be in
-a detached state, and they are interned at the context level so that there is
-never more than one Python `mlir.ir.Operation` object for a unique
-`MlirOperation`, regardless of how it is obtained.
+for regions and blocks. Operations are also the only entities which are allowed to be in
+a detached state.
+
+**Note**: Multiple `PyOperation` objects (i.e., the Python objects themselves) can alias a single `mlir::Operation`.
+This means, for example, if you have `py_op1` and `py_op2` which wrap the same `mlir::Operation op`
+and you somehow transform `op` (e.g., you run a pass on `op`) then walking the MLIR AST via either/or `py_op1`, `py_op2`
+will reflect the same MLIR AST. This is perfectly safe and supported. What is not supported is invalidating any
+operation while there exist multiple Python objects wrapping that operation **and then manipulating those wrappers**.
+For example if `py_op1` and `py_op2` wrap the same operation under a root `py_op3` and then `py_op3` is
+transformed such that the operation referenced (by `py_op1`, `py_op2`) is erased. Then `py_op1`, `py_op2`
+become "undefined" in a sense; manipulating them in any way is "formally forbidden". Note, this also applies to
+`SymbolTable` mutation, which is considered a transformation of the root `SymbolTable`-supporting operation for the
+purposes of the discussion here. Metaphorically, one can think of this similarly to how STL container iterators are invalidated once the container itself is changed. The "best practices" recommendation is to structure your code such that
+
+1. First, query/manipulate various Python wrapper objects `py_op1`, `py_op2`, `py_op3`, etc.;
+2. Second, transform the AST/erase operations/etc. via a single root object;
+3. Invalidate all queried nodes (e.g., using `op._set_invalid()`).
+
+Ideally this should be done in a function body so that step (3) corresponds to the end of the function and there are no
+risks of Python wrapper objects leaking/living longer than necessary. In summary, you should scope your changes based on
+nesting i.e., change leaf nodes first before going up in hierarchy, and only in very rare cases query nested ops post
+modifying a parent op.
The C/C++ API allows for Region/Block to also be detached, but it simplifies the
ownership model a lot to eliminate that possibility in this API, allowing the
@@ -238,11 +253,6 @@ blocks. We may end up needing an op-local one at some point TBD, depending on
how hard it is to guarantee how mutations interact with their Python peer
objects. We can cross that bridge easily when we get there.
-Module, when used purely from the Python API, can't alias anyway, so we can use
-it as a top-level ref type without a live-list for interning. If the API ever
-changes such that this cannot be guaranteed (i.e. by letting you marshal a
-native-defined Module in), then there would need to be a live table for it too.
-
## User-level API
### Context Management
@@ -1229,4 +1239,4 @@ The exceptions to the free-threading compatibility:
- Usage of `Location.emit_error` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
- Usage of `Module.dump` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
- Usage of `mlir.dialects.transform.interpreter` is unsafe.
-- Usage of `mlir.dialects.gpu` and `gpu-module-to-binary` is unsafe. \ No newline at end of file
+- Usage of `mlir.dialects.gpu` and `gpu-module-to-binary` is unsafe.
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index 556e73c2..7070351 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -280,6 +280,15 @@ target types. If the source type is converted to itself, we say it is a "legal"
type. Type conversions are specified via the `addConversion` method described
below.
+There are two kind of conversion functions: context-aware and context-unaware
+conversions. A context-unaware conversion function converts a `Type` into a
+`Type`. A context-aware conversion function converts a `Value` into a type. The
+latter allows users to customize type conversion rules based on the IR.
+
+Note: When there is at least one context-aware type conversion function, the
+result of type conversions can no longer be cached, which can increase
+compilation time. Use this feature with caution!
+
A `materialization` describes how a list of values should be converted to a
list of values with specific types. An important distinction from a
`conversion` is that a `materialization` can produce IR, whereas a `conversion`
@@ -332,29 +341,31 @@ Several of the available hooks are detailed below:
```c++
class TypeConverter {
public:
- /// Register a conversion function. A conversion function defines how a given
- /// source type should be converted. A conversion function must be convertible
- /// to any of the following forms(where `T` is a class derived from `Type`:
- /// * Optional<Type>(T)
+ /// Register a conversion function. A conversion function must be convertible
+ /// to any of the following forms (where `T` is `Value` or a class derived
+ /// from `Type`, including `Type` itself):
+ ///
+ /// * std::optional<Type>(T)
/// - This form represents a 1-1 type conversion. It should return nullptr
- /// or `std::nullopt` to signify failure. If `std::nullopt` is returned, the
- /// converter is allowed to try another conversion function to perform
- /// the conversion.
- /// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &)
+ /// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
+ /// the converter is allowed to try another conversion function to
+ /// perform the conversion.
+ /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
/// - This form represents a 1-N type conversion. It should return
- /// `failure` or `std::nullopt` to signify a failed conversion. If the new
- /// set of types is empty, the type is removed and any usages of the
+ /// `failure` or `std::nullopt` to signify a failed conversion. If the
+ /// new set of types is empty, the type is removed and any usages of the
/// existing value are expected to be removed during conversion. If
/// `std::nullopt` is returned, the converter is allowed to try another
/// conversion function to perform the conversion.
- /// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &, ArrayRef<Type>)
- /// - This form represents a 1-N type conversion supporting recursive
- /// types. The first two arguments and the return value are the same as
- /// for the regular 1-N form. The third argument is contains is the
- /// "call stack" of the recursive conversion: it contains the list of
- /// types currently being converted, with the current type being the
- /// last one. If it is present more than once in the list, the
- /// conversion concerns a recursive type.
+ ///
+ /// Conversion functions that accept `Value` as the first argument are
+ /// context-aware. I.e., they can take into account IR when converting the
+ /// type of the given value. Context-unaware conversion functions accept
+ /// `Type` or a derived class as the first argument.
+ ///
+ /// Note: Context-unaware conversions are cached, but context-aware
+ /// conversions are not.
+ ///
/// Note: When attempting to convert a type, e.g. via 'convertType', the
/// mostly recently added conversions will be invoked first.
template <typename FnT,
diff --git a/mlir/docs/Dialects/GPU.md b/mlir/docs/Dialects/GPU.md
index 94b053d..8d4d2ca 100644
--- a/mlir/docs/Dialects/GPU.md
+++ b/mlir/docs/Dialects/GPU.md
@@ -193,10 +193,25 @@ llvm.func @foo() {
// mlir-translate --mlir-to-llvmir:
@binary_bin_cst = internal constant [6 x i8] c"AMDGPU", align 8
@binary_func_kernel_name = private unnamed_addr constant [7 x i8] c"func\00", align 1
+@binary_module = internal global ptr null
+@llvm.global_ctors = appending global [1 x {i32, ptr, ptr}] [{i32 123, ptr @binary_load, ptr null}]
+@llvm.global_dtors = appending global [1 x {i32, ptr, ptr}] [{i32 123, ptr @binary_unload, ptr null}]
+define internal void @binary_load() section ".text.startup" {
+entry:
+ %0 = call ptr @mgpuModuleLoad(ptr @binary_bin_cst)
+ store ptr %0, ptr @binary_module
+ ...
+}
+define internal void @binary_unload() section ".text.startup" {
+entry:
+ %0 = load ptr, ptr @binary_module, align 8
+ call void @mgpuModuleUnload(ptr %0)
+ ...
+}
...
define void @foo() {
...
- %module = call ptr @mgpuModuleLoad(ptr @binary_bin_cst)
+ %module = load ptr, ptr @binary_module, align 8
%kernel = call ptr @mgpuModuleGetFunction(ptr %module, ptr @binary_func_kernel_name)
call void @mgpuLaunchKernel(ptr %kernel, ...) ; Launch the kernel
...
diff --git a/mlir/docs/Remarks.md b/mlir/docs/Remarks.md
new file mode 100644
index 0000000..ee2ea85
--- /dev/null
+++ b/mlir/docs/Remarks.md
@@ -0,0 +1,259 @@
+# Remark Infrastructure
+
+Remarks are **structured, human- and machine-readable notes** emitted by the
+compiler to explain:
+
+- What was transformed
+- What was missed
+- Why it happened
+
+The **`RemarkEngine`** collects finalized remarks during compilation and sends
+them to a pluggable **streamer**. By default, MLIR integrates with LLVM’s
+[`llvm::remarks`](https://llvm.org/docs/Remarks.html), allowing you to:
+
+- Stream remarks as passes run
+- Serialize them to **YAML** or **LLVM bitstream** for tooling
+
+***
+
+## Key Points
+
+- **Opt-in** – Disabled by default; zero overhead unless enabled.
+- **Per-context** – Configured on `MLIRContext`.
+- **Formats** – LLVM Remark engine (YAML / Bitstream) or custom streamers.
+- **Kinds** – `Passed`, `Missed`, `Failure`, `Analysis`.
+- **API** – Lightweight streaming interface using `<<` (like MLIR diagnostics).
+
+***
+
+## How It Works
+
+Two main components:
+
+- **`RemarkEngine`** (owned by `MLIRContext`): Receives finalized
+ `InFlightRemark`s, optionally mirrors them to the `DiagnosticEngine`, and
+ dispatches to the installed streamer.
+
+- **`MLIRRemarkStreamerBase`** (abstract): Backend interface with a single hook:
+
+ ```c++
+ virtual void streamOptimizationRemark(const Remark &remark) = 0;
+ ```
+
+**Default backend – `MLIRLLVMRemarkStreamer`** Adapts `mlir::Remark` to LLVM’s
+remark format and writes YAML/bitstream via `llvm::remarks::RemarkStreamer`.
+
+**Ownership flow:** `MLIRContext` → `RemarkEngine` → `MLIRRemarkStreamerBase`
+
+***
+
+## Categories
+
+MLIR provides four built-in remark categories (extendable if needed):
+
+#### 1. **Passed**
+
+Optimization/transformation succeeded.
+
+```
+[Passed] RemarkName | Category:Vectorizer:myPass1 | Function=foo | Remark="vectorized loop", tripCount=128
+```
+
+#### 2. **Missed**
+
+Optimization/transformation didn’t apply — ideally with actionable feedback.
+
+```
+[Missed] | Category:Unroll | Function=foo | Reason="tripCount=4 < threshold=256", Suggestion="increase unroll to 128"
+```
+
+#### 3. **Failure**
+
+Optimization/transformation attempted but failed. This is slightly different
+from the `Missed` category.
+
+For example, the user specifies `-use-max-register=100` when invoking the
+compiler, but the attempt fails for some reason:
+
+```bash
+$ your-compiler -use-max-register=100 mycode.xyz
+```
+
+```
+[Failed] Category:RegisterAllocator | Reason="Limiting to use-max-register=100 failed; it now uses 104 registers for better performance"
+```
+
+#### 4. **Analysis**
+
+Neutral analysis results.
+
+```
+[Analysis] Category:Register | Remark="Kernel uses 168 registers"
+[Analysis] Category:Register | Remark="Kernel uses 10kB local memory"
+```
+
+***
+
+## Emitting Remarks
+
+The `remark::*` helpers return an **in-flight remark**.
+You append strings or key–value metrics using `<<`.
+
+### Remark Options
+
+When constructing a remark, you typically provide four fields that are `StringRef`:
+
+1. **Remark name** – identifiable name
+2. **Category** – high-level classification
+3. **Sub-category** – more fine-grained classification
+4. **Function name** – the function where the remark originates
+
+
+### Example
+
+```c++
+#include "mlir/IR/Remarks.h"
+
+LogicalResult MyPass::runOnOperation() {
+ Location loc = getOperation()->getLoc();
+
+ remark::RemarkOpts opts = remark::RemarkOpts::name(MyRemarkName1)
+ .category(categoryVectorizer)
+ .function(fName)
+ .subCategory(myPassname1);
+
+ // PASSED
+ remark::passed(loc, opts)
+ << "vectorized loop"
+ << remark::metric("tripCount", 128);
+
+ // ANALYSIS
+ remark::analysis(loc, opts)
+ << "Kernel uses 168 registers";
+
+ // MISSED (with reason + suggestion)
+ int tripBad = 4, threshold = 256, target = 128;
+ remark::missed(loc, opts)
+ << remark::reason("tripCount={0} < threshold={1}", tripBad, threshold)
+ << remark::suggest("increase unroll to {0}", target);
+
+ // FAILURE
+ remark::failed(loc, opts)
+ << remark::reason("failed due to unsupported pattern");
+
+ return success();
+}
+```
+
+***
+
+### Metrics and Shortcuts
+
+Helper functions accept
+[LLVM format](https://llvm.org/docs/ProgrammersManual.html#formatting-strings-the-formatv-function)
+style strings. This format builds lazily, so remarks are zero-cost when
+disabled.
+
+#### Adding Remarks
+
+- **`remark::add(fmt, ...)`** – Shortcut for `metric("Remark", ...)`.
+
+#### Adding Reasons
+
+- **`remark::reason(fmt, ...)`** – Shortcut for `metric("Reason", ...)`. Used to
+ explain why a remark was missed or failed.
+
+#### Adding Suggestions
+
+- **`remark::suggest(fmt, ...)`** – Shortcut for `metric("Suggestion", ...)`.
+ Used to provide actionable feedback.
+
+#### Adding Custom Metrics
+
+- **`remark::metric(key, value)`** – Adds a structured key–value metric.
+
+Example: tracking `TripCount`. When exported to YAML, it appears under `args`
+for machine readability:
+
+```cpp
+remark::metric("TripCount", value)
+```
+
+#### String Metrics
+
+Passing a plain string (e.g. `<< "vectorized loop"`) is equivalent to:
+
+```cpp
+metric("Remark", "vectorized loop")
+```
+
+***
+
+## Enabling Remarks
+
+### 1. **With LLVMRemarkStreamer (YAML or Bitstream)**
+
+Persists remarks to a file in the chosen format.
+
+```c++
+mlir::remark::RemarkCategories cats{/*passed=*/categoryLoopunroll,
+ /*missed=*/std::nullopt,
+ /*analysis=*/std::nullopt,
+ /*failed=*/categoryLoopunroll};
+
+mlir::remark::enableOptimizationRemarksWithLLVMStreamer(
+ context, yamlFile, llvm::remarks::Format::YAML, cats);
+```
+
+**YAML format** – human-readable, easy to diff:
+
+```yaml
+--- !Passed
+pass: Category:SubCategory
+name: MyRemarkName1
+function: myFunc
+loc: myfile.mlir:12:3
+args:
+ - Remark: vectorized loop
+ - tripCount: 128
+```
+
+**Bitstream format** – compact binary for large runs.
+
+***
+
+### 2. **With `mlir::emitRemarks` (No Streamer)**
+
+If the streamer isn't passed, the remarks are mirrored to the `DiagnosticEngine`
+using `mlir::emitRemarks`
+
+```c++
+mlir::remark::RemarkCategories cats{/*passed=*/categoryLoopunroll,
+ /*missed=*/std::nullopt,
+ /*analysis=*/std::nullopt,
+ /*failed=*/categoryLoopunroll};
+remark::enableOptimizationRemarks(
+ /*streamer=*/nullptr, cats,
+ /*printAsEmitRemarks=*/true);
+```
+
+***
+
+### 3. **With a Custom Streamer**
+
+You can implement a custom streamer by inheriting `MLIRRemarkStreamerBase` to
+consume remarks in any format.
+
+```c++
+class MyStreamer : public MLIRRemarkStreamerBase {
+public:
+ void streamOptimizationRemark(const Remark &remark) override {
+ // Convert and write remark to your custom format
+ }
+};
+
+auto myStreamer = std::make_unique<MyStreamer>();
+remark::enableOptimizationRemarks(
+ /*streamer=*/myStreamer, cats,
+ /*printAsEmitRemarks=*/true);
+```
diff --git a/mlir/docs/Tutorials/Toy/Ch-4.md b/mlir/docs/Tutorials/Toy/Ch-4.md
index e9abe36..621f6a6 100644
--- a/mlir/docs/Tutorials/Toy/Ch-4.md
+++ b/mlir/docs/Tutorials/Toy/Ch-4.md
@@ -170,7 +170,7 @@ let arguments = (ins
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs
);
-
+```
We have already provided the definition in the `extraClassDeclaration`
field of the `FuncOp` class:
diff --git a/mlir/examples/standalone/CMakeLists.txt b/mlir/examples/standalone/CMakeLists.txt
index 038242b..88dfa3e 100644
--- a/mlir/examples/standalone/CMakeLists.txt
+++ b/mlir/examples/standalone/CMakeLists.txt
@@ -8,6 +8,10 @@ set(CMAKE_CXX_STANDARD 17 CACHE STRING "C++ standard to conform to")
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
find_package(MLIR REQUIRED CONFIG)
+
+ # Define the default argument to use by `lit` when testing.
+ set(LLVM_LIT_ARGS "-sv" CACHE STRING "Default options for lit")
+
message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}")
message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")
@@ -48,6 +52,10 @@ add_subdirectory(include)
add_subdirectory(lib)
if(MLIR_ENABLE_BINDINGS_PYTHON)
message(STATUS "Enabling Python API")
+ include(MLIRDetectPythonEnv)
+ mlir_configure_python_dev_packages()
+ set(MLIR_PYTHON_PACKAGE_PREFIX "mlir_standalone" CACHE STRING "" FORCE)
+ set(MLIR_BINDINGS_PYTHON_INSTALL_PREFIX "python_packages/standalone/mlir_standalone" CACHE STRING "" FORCE)
add_subdirectory(python)
endif()
add_subdirectory(test)
diff --git a/mlir/examples/standalone/python/CMakeLists.txt b/mlir/examples/standalone/python/CMakeLists.txt
index 69c82fd..a0eca9c 100644
--- a/mlir/examples/standalone/python/CMakeLists.txt
+++ b/mlir/examples/standalone/python/CMakeLists.txt
@@ -2,8 +2,7 @@ include(AddMLIRPython)
# Specifies that all MLIR packages are co-located under the `mlir_standalone`
# top level package (the API has been embedded in a relocatable way).
-# TODO: Add an upstream cmake param for this vs having a global here.
-add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=mlir_standalone.")
+add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=${MLIR_PYTHON_PACKAGE_PREFIX}.")
################################################################################
@@ -49,8 +48,8 @@ declare_mlir_python_extension(StandalonePythonSources.NanobindExtension
add_mlir_python_common_capi_library(StandalonePythonCAPI
INSTALL_COMPONENT StandalonePythonModules
- INSTALL_DESTINATION python_packages/standalone/mlir_standalone/_mlir_libs
- OUTPUT_DIRECTORY "${MLIR_BINARY_DIR}/python_packages/standalone/mlir_standalone/_mlir_libs"
+ INSTALL_DESTINATION "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}/_mlir_libs"
+ OUTPUT_DIRECTORY "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}/_mlir_libs"
RELATIVE_INSTALL_ROOT "../../../.."
DECLARED_SOURCES
StandalonePythonSources
@@ -58,6 +57,7 @@ add_mlir_python_common_capi_library(StandalonePythonCAPI
# available.
MLIRPythonExtension.RegisterEverything
MLIRPythonSources.Core
+ MLIRPythonSources.Dialects.builtin
)
################################################################################
@@ -65,14 +65,15 @@ add_mlir_python_common_capi_library(StandalonePythonCAPI
################################################################################
add_mlir_python_modules(StandalonePythonModules
- ROOT_PREFIX "${MLIR_BINARY_DIR}/python_packages/standalone/mlir_standalone"
- INSTALL_PREFIX "python_packages/standalone/mlir_standalone"
+ ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}"
+ INSTALL_PREFIX "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}"
DECLARED_SOURCES
StandalonePythonSources
# TODO: Remove this in favor of showing fine grained registration once
# available.
MLIRPythonExtension.RegisterEverything
- MLIRPythonSources
+ MLIRPythonSources.Core
+ MLIRPythonSources.Dialects.builtin
COMMON_CAPI_LINK_LIBS
StandalonePythonCAPI
)
diff --git a/mlir/examples/toy/Ch1/parser/AST.cpp b/mlir/examples/toy/Ch1/parser/AST.cpp
index 2546f2a..8416424 100644
--- a/mlir/examples/toy/Ch1/parser/AST.cpp
+++ b/mlir/examples/toy/Ch1/parser/AST.cpp
@@ -120,7 +120,7 @@ void ASTDumper::dump(NumberExprAST *num) {
/// [ [ 1, 2 ], [ 3, 4 ] ]
/// We print out such array with the dimensions spelled out at every level:
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
-void printLitHelper(ExprAST *litOrNum) {
+static void printLitHelper(ExprAST *litOrNum) {
// Inside a literal expression we can have either a number or another literal
if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
llvm::errs() << num->getValue();
diff --git a/mlir/examples/toy/Ch1/toyc.cpp b/mlir/examples/toy/Ch1/toyc.cpp
index fb7b484..b9f3a2d 100644
--- a/mlir/examples/toy/Ch1/toyc.cpp
+++ b/mlir/examples/toy/Ch1/toyc.cpp
@@ -39,7 +39,8 @@ static cl::opt<enum Action>
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")));
/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
-std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
+static std::unique_ptr<toy::ModuleAST>
+parseInputFile(llvm::StringRef filename) {
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(filename);
if (std::error_code ec = fileOrErr.getError()) {
diff --git a/mlir/examples/toy/Ch2/parser/AST.cpp b/mlir/examples/toy/Ch2/parser/AST.cpp
index 2546f2a..8416424 100644
--- a/mlir/examples/toy/Ch2/parser/AST.cpp
+++ b/mlir/examples/toy/Ch2/parser/AST.cpp
@@ -120,7 +120,7 @@ void ASTDumper::dump(NumberExprAST *num) {
/// [ [ 1, 2 ], [ 3, 4 ] ]
/// We print out such array with the dimensions spelled out at every level:
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
-void printLitHelper(ExprAST *litOrNum) {
+static void printLitHelper(ExprAST *litOrNum) {
// Inside a literal expression we can have either a number or another literal
if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
llvm::errs() << num->getValue();
diff --git a/mlir/examples/toy/Ch2/toyc.cpp b/mlir/examples/toy/Ch2/toyc.cpp
index e33b49b..a60738d 100644
--- a/mlir/examples/toy/Ch2/toyc.cpp
+++ b/mlir/examples/toy/Ch2/toyc.cpp
@@ -58,7 +58,8 @@ static cl::opt<enum Action> emitAction(
cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")));
/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
-std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
+static std::unique_ptr<toy::ModuleAST>
+parseInputFile(llvm::StringRef filename) {
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(filename);
if (std::error_code ec = fileOrErr.getError()) {
@@ -71,7 +72,7 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
return parser.parseModule();
}
-int dumpMLIR() {
+static int dumpMLIR() {
mlir::MLIRContext context;
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::toy::ToyDialect>();
@@ -112,7 +113,7 @@ int dumpMLIR() {
return 0;
}
-int dumpAST() {
+static int dumpAST() {
if (inputType == InputType::MLIR) {
llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n";
return 5;
diff --git a/mlir/examples/toy/Ch3/parser/AST.cpp b/mlir/examples/toy/Ch3/parser/AST.cpp
index 2546f2a..8416424 100644
--- a/mlir/examples/toy/Ch3/parser/AST.cpp
+++ b/mlir/examples/toy/Ch3/parser/AST.cpp
@@ -120,7 +120,7 @@ void ASTDumper::dump(NumberExprAST *num) {
/// [ [ 1, 2 ], [ 3, 4 ] ]
/// We print out such array with the dimensions spelled out at every level:
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
-void printLitHelper(ExprAST *litOrNum) {
+static void printLitHelper(ExprAST *litOrNum) {
// Inside a literal expression we can have either a number or another literal
if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
llvm::errs() << num->getValue();
diff --git a/mlir/examples/toy/Ch3/toyc.cpp b/mlir/examples/toy/Ch3/toyc.cpp
index f8aa846..3094935 100644
--- a/mlir/examples/toy/Ch3/toyc.cpp
+++ b/mlir/examples/toy/Ch3/toyc.cpp
@@ -64,7 +64,8 @@ static cl::opt<enum Action> emitAction(
static cl::opt<bool> enableOpt("opt", cl::desc("Enable optimizations"));
/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
-std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
+static std::unique_ptr<toy::ModuleAST>
+parseInputFile(llvm::StringRef filename) {
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(filename);
if (std::error_code ec = fileOrErr.getError()) {
@@ -77,8 +78,8 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
return parser.parseModule();
}
-int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
- mlir::OwningOpRef<mlir::ModuleOp> &module) {
+static int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
+ mlir::OwningOpRef<mlir::ModuleOp> &module) {
// Handle '.toy' input to the compiler.
if (inputType != InputType::MLIR &&
!llvm::StringRef(inputFilename).ends_with(".mlir")) {
@@ -107,7 +108,7 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
return 0;
}
-int dumpMLIR() {
+static int dumpMLIR() {
mlir::MLIRContext context;
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::toy::ToyDialect>();
@@ -134,7 +135,7 @@ int dumpMLIR() {
return 0;
}
-int dumpAST() {
+static int dumpAST() {
if (inputType == InputType::MLIR) {
llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n";
return 5;
diff --git a/mlir/examples/toy/Ch4/parser/AST.cpp b/mlir/examples/toy/Ch4/parser/AST.cpp
index 2546f2a..8416424 100644
--- a/mlir/examples/toy/Ch4/parser/AST.cpp
+++ b/mlir/examples/toy/Ch4/parser/AST.cpp
@@ -120,7 +120,7 @@ void ASTDumper::dump(NumberExprAST *num) {
/// [ [ 1, 2 ], [ 3, 4 ] ]
/// We print out such array with the dimensions spelled out at every level:
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
-void printLitHelper(ExprAST *litOrNum) {
+static void printLitHelper(ExprAST *litOrNum) {
// Inside a literal expression we can have either a number or another literal
if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
llvm::errs() << num->getValue();
diff --git a/mlir/examples/toy/Ch4/toyc.cpp b/mlir/examples/toy/Ch4/toyc.cpp
index ae02bc4..36816f0 100644
--- a/mlir/examples/toy/Ch4/toyc.cpp
+++ b/mlir/examples/toy/Ch4/toyc.cpp
@@ -65,7 +65,8 @@ static cl::opt<enum Action> emitAction(
static cl::opt<bool> enableOpt("opt", cl::desc("Enable optimizations"));
/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
-std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
+static std::unique_ptr<toy::ModuleAST>
+parseInputFile(llvm::StringRef filename) {
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(filename);
if (std::error_code ec = fileOrErr.getError()) {
@@ -78,8 +79,8 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
return parser.parseModule();
}
-int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
- mlir::OwningOpRef<mlir::ModuleOp> &module) {
+static int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
+ mlir::OwningOpRef<mlir::ModuleOp> &module) {
// Handle '.toy' input to the compiler.
if (inputType != InputType::MLIR &&
!llvm::StringRef(inputFilename).ends_with(".mlir")) {
@@ -108,7 +109,7 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
return 0;
}
-int dumpMLIR() {
+static int dumpMLIR() {
mlir::MLIRContext context;
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::toy::ToyDialect>();
@@ -143,7 +144,7 @@ int dumpMLIR() {
return 0;
}
-int dumpAST() {
+static int dumpAST() {
if (inputType == InputType::MLIR) {
llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n";
return 5;
diff --git a/mlir/examples/toy/Ch5/parser/AST.cpp b/mlir/examples/toy/Ch5/parser/AST.cpp
index 2546f2a..8416424 100644
--- a/mlir/examples/toy/Ch5/parser/AST.cpp
+++ b/mlir/examples/toy/Ch5/parser/AST.cpp
@@ -120,7 +120,7 @@ void ASTDumper::dump(NumberExprAST *num) {
/// [ [ 1, 2 ], [ 3, 4 ] ]
/// We print out such array with the dimensions spelled out at every level:
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
-void printLitHelper(ExprAST *litOrNum) {
+static void printLitHelper(ExprAST *litOrNum) {
// Inside a literal expression we can have either a number or another literal
if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
llvm::errs() << num->getValue();
diff --git a/mlir/examples/toy/Ch5/toyc.cpp b/mlir/examples/toy/Ch5/toyc.cpp
index afdf782..3760a88 100644
--- a/mlir/examples/toy/Ch5/toyc.cpp
+++ b/mlir/examples/toy/Ch5/toyc.cpp
@@ -71,7 +71,8 @@ static cl::opt<enum Action> emitAction(
static cl::opt<bool> enableOpt("opt", cl::desc("Enable optimizations"));
/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
-std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
+static std::unique_ptr<toy::ModuleAST>
+parseInputFile(llvm::StringRef filename) {
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(filename);
if (std::error_code ec = fileOrErr.getError()) {
@@ -84,8 +85,8 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
return parser.parseModule();
}
-int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
- mlir::OwningOpRef<mlir::ModuleOp> &module) {
+static int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
+ mlir::OwningOpRef<mlir::ModuleOp> &module) {
// Handle '.toy' input to the compiler.
if (inputType != InputType::MLIR &&
!llvm::StringRef(inputFilename).ends_with(".mlir")) {
@@ -114,7 +115,7 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
return 0;
}
-int dumpMLIR() {
+static int dumpMLIR() {
mlir::DialectRegistry registry;
mlir::func::registerAllExtensions(registry);
@@ -171,7 +172,7 @@ int dumpMLIR() {
return 0;
}
-int dumpAST() {
+static int dumpAST() {
if (inputType == InputType::MLIR) {
llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n";
return 5;
diff --git a/mlir/examples/toy/Ch6/parser/AST.cpp b/mlir/examples/toy/Ch6/parser/AST.cpp
index 2546f2a..8416424 100644
--- a/mlir/examples/toy/Ch6/parser/AST.cpp
+++ b/mlir/examples/toy/Ch6/parser/AST.cpp
@@ -120,7 +120,7 @@ void ASTDumper::dump(NumberExprAST *num) {
/// [ [ 1, 2 ], [ 3, 4 ] ]
/// We print out such array with the dimensions spelled out at every level:
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
-void printLitHelper(ExprAST *litOrNum) {
+static void printLitHelper(ExprAST *litOrNum) {
// Inside a literal expression we can have either a number or another literal
if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
llvm::errs() << num->getValue();
diff --git a/mlir/examples/toy/Ch6/toyc.cpp b/mlir/examples/toy/Ch6/toyc.cpp
index 4a5e109..c31c53a 100644
--- a/mlir/examples/toy/Ch6/toyc.cpp
+++ b/mlir/examples/toy/Ch6/toyc.cpp
@@ -96,7 +96,8 @@ static cl::opt<enum Action> emitAction(
static cl::opt<bool> enableOpt("opt", cl::desc("Enable optimizations"));
/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
-std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
+static std::unique_ptr<toy::ModuleAST>
+parseInputFile(llvm::StringRef filename) {
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(filename);
if (std::error_code ec = fileOrErr.getError()) {
@@ -109,8 +110,8 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
return parser.parseModule();
}
-int loadMLIR(mlir::MLIRContext &context,
- mlir::OwningOpRef<mlir::ModuleOp> &module) {
+static int loadMLIR(mlir::MLIRContext &context,
+ mlir::OwningOpRef<mlir::ModuleOp> &module) {
// Handle '.toy' input to the compiler.
if (inputType != InputType::MLIR &&
!llvm::StringRef(inputFilename).ends_with(".mlir")) {
@@ -140,8 +141,8 @@ int loadMLIR(mlir::MLIRContext &context,
return 0;
}
-int loadAndProcessMLIR(mlir::MLIRContext &context,
- mlir::OwningOpRef<mlir::ModuleOp> &module) {
+static int loadAndProcessMLIR(mlir::MLIRContext &context,
+ mlir::OwningOpRef<mlir::ModuleOp> &module) {
if (int error = loadMLIR(context, module))
return error;
@@ -196,7 +197,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
return 0;
}
-int dumpAST() {
+static int dumpAST() {
if (inputType == InputType::MLIR) {
llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n";
return 5;
@@ -210,7 +211,7 @@ int dumpAST() {
return 0;
}
-int dumpLLVMIR(mlir::ModuleOp module) {
+static int dumpLLVMIR(mlir::ModuleOp module) {
// Register the translation to LLVM IR with the MLIR context.
mlir::registerBuiltinDialectTranslation(*module->getContext());
mlir::registerLLVMDialectTranslation(*module->getContext());
@@ -254,7 +255,7 @@ int dumpLLVMIR(mlir::ModuleOp module) {
return 0;
}
-int runJit(mlir::ModuleOp module) {
+static int runJit(mlir::ModuleOp module) {
// Initialize LLVM targets.
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
diff --git a/mlir/examples/toy/Ch7/parser/AST.cpp b/mlir/examples/toy/Ch7/parser/AST.cpp
index e38a743a..aa2c784 100644
--- a/mlir/examples/toy/Ch7/parser/AST.cpp
+++ b/mlir/examples/toy/Ch7/parser/AST.cpp
@@ -123,7 +123,7 @@ void ASTDumper::dump(NumberExprAST *num) {
/// [ [ 1, 2 ], [ 3, 4 ] ]
/// We print out such array with the dimensions spelled out at every level:
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
-void printLitHelper(ExprAST *litOrNum) {
+static void printLitHelper(ExprAST *litOrNum) {
// Inside a literal expression we can have either a number or another literal
if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
llvm::errs() << num->getValue();
diff --git a/mlir/examples/toy/Ch7/toyc.cpp b/mlir/examples/toy/Ch7/toyc.cpp
index 32208ecca..553ca6b 100644
--- a/mlir/examples/toy/Ch7/toyc.cpp
+++ b/mlir/examples/toy/Ch7/toyc.cpp
@@ -96,7 +96,8 @@ static cl::opt<enum Action> emitAction(
static cl::opt<bool> enableOpt("opt", cl::desc("Enable optimizations"));
/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
-std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
+static std::unique_ptr<toy::ModuleAST>
+parseInputFile(llvm::StringRef filename) {
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(filename);
if (std::error_code ec = fileOrErr.getError()) {
@@ -109,8 +110,8 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
return parser.parseModule();
}
-int loadMLIR(mlir::MLIRContext &context,
- mlir::OwningOpRef<mlir::ModuleOp> &module) {
+static int loadMLIR(mlir::MLIRContext &context,
+ mlir::OwningOpRef<mlir::ModuleOp> &module) {
// Handle '.toy' input to the compiler.
if (inputType != InputType::MLIR &&
!llvm::StringRef(inputFilename).ends_with(".mlir")) {
@@ -140,8 +141,8 @@ int loadMLIR(mlir::MLIRContext &context,
return 0;
}
-int loadAndProcessMLIR(mlir::MLIRContext &context,
- mlir::OwningOpRef<mlir::ModuleOp> &module) {
+static int loadAndProcessMLIR(mlir::MLIRContext &context,
+ mlir::OwningOpRef<mlir::ModuleOp> &module) {
if (int error = loadMLIR(context, module))
return error;
@@ -197,7 +198,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
return 0;
}
-int dumpAST() {
+static int dumpAST() {
if (inputType == InputType::MLIR) {
llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n";
return 5;
@@ -211,7 +212,7 @@ int dumpAST() {
return 0;
}
-int dumpLLVMIR(mlir::ModuleOp module) {
+static int dumpLLVMIR(mlir::ModuleOp module) {
// Register the translation to LLVM IR with the MLIR context.
mlir::registerBuiltinDialectTranslation(*module->getContext());
mlir::registerLLVMDialectTranslation(*module->getContext());
@@ -255,7 +256,7 @@ int dumpLLVMIR(mlir::ModuleOp module) {
return 0;
}
-int runJit(mlir::ModuleOp module) {
+static int runJit(mlir::ModuleOp module) {
// Initialize LLVM targets.
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
diff --git a/mlir/include/mlir-c/ExecutionEngine.h b/mlir/include/mlir-c/ExecutionEngine.h
index 99cddc5..1a58d68 100644
--- a/mlir/include/mlir-c/ExecutionEngine.h
+++ b/mlir/include/mlir-c/ExecutionEngine.h
@@ -46,6 +46,13 @@ MLIR_CAPI_EXPORTED MlirExecutionEngine mlirExecutionEngineCreate(
MlirModule op, int optLevel, int numPaths,
const MlirStringRef *sharedLibPaths, bool enableObjectDump);
+/// Initialize the ExecutionEngine. Global constructors specified by
+/// `llvm.mlir.global_ctors` will be run. One common scenario is that kernel
+/// binary compiled from `gpu.module` gets loaded during initialization. Make
+/// sure all symbols are resolvable before initialization by calling
+/// `mlirExecutionEngineRegisterSymbol` or including shared libraries.
+MLIR_CAPI_EXPORTED void mlirExecutionEngineInitialize(MlirExecutionEngine jit);
+
/// Destroy an ExecutionEngine instance.
MLIR_CAPI_EXPORTED void mlirExecutionEngineDestroy(MlirExecutionEngine jit);
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 71c7d43..e973697 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -415,6 +415,9 @@ MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module);
/// The returned module is null when the input operation was not a ModuleOp.
MLIR_CAPI_EXPORTED MlirModule mlirModuleFromOperation(MlirOperation op);
+/// Checks if two modules are equal.
+MLIR_CAPI_EXPORTED bool mlirModuleEqual(MlirModule lhs, MlirModule rhs);
+
//===----------------------------------------------------------------------===//
// Operation state.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index 4b18024..f865357 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -479,10 +479,28 @@ public:
/// respect to a positive constant `divisor`. Two constraints are added to the
/// system to capture equivalence with the floordiv:
/// q = dividend floordiv c <=> c*q <= dividend <= c*q + c - 1.
- void addLocalFloorDiv(ArrayRef<DynamicAPInt> dividend,
- const DynamicAPInt &divisor);
- void addLocalFloorDiv(ArrayRef<int64_t> dividend, int64_t divisor) {
- addLocalFloorDiv(getDynamicAPIntVec(dividend), DynamicAPInt(divisor));
+ /// Returns the column position of the new local variable.
+ unsigned addLocalFloorDiv(ArrayRef<DynamicAPInt> dividend,
+ const DynamicAPInt &divisor);
+ unsigned addLocalFloorDiv(ArrayRef<int64_t> dividend, int64_t divisor) {
+ return addLocalFloorDiv(getDynamicAPIntVec(dividend),
+ DynamicAPInt(divisor));
+ }
+
+ /// Adds a new local variable as the modulus of an affine function of other
+ /// variables, the coefficients of which are provided in `exprs`. The modulus
+ /// is with respect to a positive constant `modulus`. The function returns the
+ /// absolute index of the new local variable representing the result of the
+ /// modulus operation. Two new local variables are added to the system, one
+ /// representing the floor div with respect to the modulus and one
+ /// representing the mod. Three constraints are added to the system to capture
+ /// the equivalance. The first two are required to compute the result of the
+ /// floor division `q`, and the third computes the equality relation:
+ /// result = exprs - modulus * q.
+ unsigned addLocalModulo(ArrayRef<DynamicAPInt> exprs,
+ const DynamicAPInt &modulus);
+ unsigned addLocalModulo(ArrayRef<int64_t> exprs, int64_t modulus) {
+ return addLocalModulo(getDynamicAPIntVec(exprs), DynamicAPInt(modulus));
}
/// Projects out (aka eliminates) `num` variables starting at position
@@ -905,6 +923,11 @@ protected:
IntMatrix inequalities;
};
+inline raw_ostream &operator<<(raw_ostream &os, const IntegerRelation &rel) {
+ rel.print(os);
+ return os;
+}
+
/// An IntegerPolyhedron represents the set of points from a PresburgerSpace
/// that satisfy a list of affine constraints. Affine constraints can be
/// inequalities or equalities in the form:
diff --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h
index 054eb7b..15069fa 100644
--- a/mlir/include/mlir/Analysis/Presburger/Matrix.h
+++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h
@@ -322,7 +322,7 @@ public:
// The parameter is what [the original
// paper](https://www.cs.cmu.edu/~avrim/451f11/lectures/lect1129_LLL.pdf)
// calls `y`, usually 3/4.
- void LLL(Fraction delta);
+ void LLL(const Fraction &delta);
// Multiply each row of the matrix by the LCM of the denominators, thereby
// converting it to an integer matrix.
diff --git a/mlir/include/mlir/CMakeLists.txt b/mlir/include/mlir/CMakeLists.txt
index 9cf3b44..f88a35b 100644
--- a/mlir/include/mlir/CMakeLists.txt
+++ b/mlir/include/mlir/CMakeLists.txt
@@ -4,4 +4,5 @@ add_subdirectory(Dialect)
add_subdirectory(IR)
add_subdirectory(Interfaces)
add_subdirectory(Reducer)
+add_subdirectory(Target)
add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Conversion/CMakeLists.txt b/mlir/include/mlir/Conversion/CMakeLists.txt
index 9f76ab6..602c8ff 100644
--- a/mlir/include/mlir/Conversion/CMakeLists.txt
+++ b/mlir/include/mlir/Conversion/CMakeLists.txt
@@ -3,7 +3,7 @@ set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Conversion)
mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix Conversion)
mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix Conversion)
-add_public_tablegen_target(MLIRConversionPassIncGen)
+add_mlir_generic_tablegen_target(MLIRConversionPassIncGen)
add_mlir_doc(Passes ConversionPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Conversion/ConvertToLLVM/CMakeLists.txt b/mlir/include/mlir/Conversion/ConvertToLLVM/CMakeLists.txt
index 54d7a03..114b60d 100644
--- a/mlir/include/mlir/Conversion/ConvertToLLVM/CMakeLists.txt
+++ b/mlir/include/mlir/Conversion/ConvertToLLVM/CMakeLists.txt
@@ -3,5 +3,4 @@ mlir_tablegen(ToLLVMAttrInterface.h.inc -gen-attr-interface-decls)
mlir_tablegen(ToLLVMAttrInterface.cpp.inc -gen-attr-interface-defs)
mlir_tablegen(ToLLVMOpInterface.h.inc -gen-op-interface-decls)
mlir_tablegen(ToLLVMOpInterface.cpp.inc -gen-op-interface-defs)
-add_public_tablegen_target(MLIRConvertToLLVMInterfaceIncGen)
-add_dependencies(mlir-generic-headers MLIRConvertToLLVMInterfaceIncGen)
+add_mlir_generic_tablegen_target(MLIRConvertToLLVMInterfaceIncGen)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
index d5055f0..8e86808 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
@@ -189,15 +189,13 @@ public:
/// `unpack`.
static unsigned getNumUnpackedValues() { return 2; }
- /// Builds IR computing the sizes in bytes (suitable for opaque allocation)
- /// and appends the corresponding values into `sizes`. `addressSpaces`
- /// which must have the same length as `values`, is needed to handle layouts
- /// where sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)).
- static void computeSizes(OpBuilder &builder, Location loc,
+ /// Builds and returns IR computing the size in bytes (suitable for opaque
+ /// allocation). `addressSpace` is needed to handle layouts where
+ /// sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)).
+ static Value computeSize(OpBuilder &builder, Location loc,
const LLVMTypeConverter &typeConverter,
- ArrayRef<UnrankedMemRefDescriptor> values,
- ArrayRef<unsigned> addressSpaces,
- SmallVectorImpl<Value> &sizes);
+ UnrankedMemRefDescriptor desc,
+ unsigned addressSpace);
/// TODO: The following accessors don't take alignment rules between elements
/// of the descriptor struct into account. For some architectures, it might be
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 969154a..c292e37 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -183,10 +183,20 @@ protected:
ArrayRef<Value> sizes, ArrayRef<Value> strides,
ConversionPatternRewriter &rewriter) const;
+ /// Copies the given unranked memory descriptor to heap-allocated memory (if
+ /// toDynamic is true) or to stack-allocated memory (otherwise) and returns
+ /// the new descriptor. Also frees the previously used memory (that is assumed
+ /// to be heap-allocated) if toDynamic is false. Returns a "null" SSA value
+ /// on failure.
+ Value copyUnrankedDescriptor(OpBuilder &builder, Location loc,
+ UnrankedMemRefType memRefType, Value operand,
+ bool toDynamic) const;
+
/// Copies the memory descriptor for any operands that were unranked
/// descriptors originally to heap-allocated memory (if toDynamic is true) or
- /// to stack-allocated memory (otherwise). Also frees the previously used
- /// memory (that is assumed to be heap-allocated) if toDynamic is false.
+ /// to stack-allocated memory (otherwise). The vector of descriptors is
+ /// updated in place. Also frees the previously used memory (that is assumed
+ /// to be heap-allocated) if toDynamic is false.
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
TypeRange origTypes,
SmallVectorImpl<Value> &operands,
@@ -233,9 +243,7 @@ public:
virtual LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- SmallVector<Value> oneToOneOperands =
- getOneToOneAdaptorOperands(adaptor.getOperands());
- return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+ return dispatchTo1To1(*this, op, adaptor, rewriter);
}
private:
@@ -276,7 +284,7 @@ public:
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const {
- return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+ return dispatchTo1To1(*this, op, operands, rewriter);
}
private:
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index 38b5e49..2096bcb 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -74,8 +74,14 @@ public:
/// LLVM-compatible type. In particular, if more than one value is returned,
/// create an LLVM dialect structure type with elements that correspond to
/// each of the types converted with `convertCallingConventionType`.
- Type packFunctionResults(TypeRange types,
- bool useBarePointerCallConv = false) const;
+ ///
+ /// Populate the converted (unpacked) types into `groupedTypes`, if provided.
+ /// `groupedType` contains one nested vector per input type. In case of a 1:N
+ /// conversion, a nested vector may contain 0 or more then 1 converted type.
+ Type
+ packFunctionResults(TypeRange types, bool useBarePointerCallConv = false,
+ SmallVector<SmallVector<Type>> *groupedTypes = nullptr,
+ int64_t *numConvertedTypes = nullptr) const;
/// Convert a non-empty list of types of values produced by an operation into
/// an LLVM-compatible type. In particular, if more than one value is
@@ -88,15 +94,9 @@ public:
/// UnrankedMemRefType, are converted following the specific rules for the
/// calling convention. Calling convention independent types are converted
/// following the default LLVM type conversions.
- Type convertCallingConventionType(Type type,
- bool useBarePointerCallConv = false) const;
-
- /// Promote the bare pointers in 'values' that resulted from memrefs to
- /// descriptors. 'stdTypes' holds the types of 'values' before the conversion
- /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
- void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter,
- Location loc, ArrayRef<Type> stdTypes,
- SmallVectorImpl<Value> &values) const;
+ LogicalResult
+ convertCallingConventionType(Type type, SmallVectorImpl<Type> &result,
+ bool useBarePointerCallConv = false) const;
/// Returns the MLIR context.
MLIRContext &getContext() const;
@@ -109,9 +109,14 @@ public:
/// Promote the LLVM representation of all operands including promoting MemRef
/// descriptors to stack and use pointers to struct to avoid the complexity
/// of the platform-specific C/C++ ABI lowering related to struct argument
- /// passing.
+ /// passing. (The ArrayRef variant is for 1:N.)
+ SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands,
+ ArrayRef<ValueRange> adaptorOperands,
+ OpBuilder &builder,
+ bool useBarePtrCallConv = false) const;
SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands,
- ValueRange operands, OpBuilder &builder,
+ ValueRange adaptorOperands,
+ OpBuilder &builder,
bool useBarePtrCallConv = false) const;
/// Promote the LLVM struct representation of one MemRef descriptor to stack
diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
index b595b6a3..5abfb3d 100644
--- a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
+++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
@@ -10,8 +10,11 @@
constexpr const char *alignedAllocFunctionName = "aligned_alloc";
constexpr const char *mallocFunctionName = "malloc";
+constexpr const char *memcpyFunctionName = "memcpy";
constexpr const char *cppStandardLibraryHeader = "cstdlib";
constexpr const char *cStandardLibraryHeader = "stdlib.h";
+constexpr const char *cppStringLibraryHeader = "cstring";
+constexpr const char *cStringLibraryHeader = "string.h";
namespace mlir {
class DialectRegistry;
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 3dc48b2..da061b2 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -75,12 +75,14 @@
#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
+#include "mlir/Conversion/VectorToAMX/VectorToAMX.h"
#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h"
#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
+#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
namespace mlir {
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 6e1baaf..5180b56 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -52,6 +52,8 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
"Test conversion patterns of only the specified dialects">,
Option<"useDynamic", "dynamic", "bool", "false",
"Use op conversion attributes to configure the conversion">,
+ Option<"allowPatternRollback", "allow-pattern-rollback", "bool", "true",
+ "Experimental performance flag to disallow pattern rollback">
];
}
@@ -968,7 +970,10 @@ def ConvertNVGPUToNVVMPass : Pass<"convert-nvgpu-to-nvvm"> {
}];
let dependentDialects = [
- "NVVM::NVVMDialect",
+ "arith::ArithDialect",
+ "LLVM::LLVMDialect",
+ "memref::MemRefDialect",
+ "NVVM::NVVMDialect"
];
}
@@ -1532,6 +1537,19 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
}
//===----------------------------------------------------------------------===//
+// VectorToAMX
+//===----------------------------------------------------------------------===//
+
+def ConvertVectorToAMX : Pass<"convert-vector-to-amx"> {
+ let summary = "Lower the operations from the vector dialect into the AMX "
+ "dialect";
+ let dependentDialects = [
+ "affine::AffineDialect", "amx::AMXDialect", "arith::ArithDialect",
+ "memref::MemRefDialect", "scf::SCFDialect", "vector::VectorDialect"
+ ];
+}
+
+//===----------------------------------------------------------------------===//
// XeVMToLLVM
//===----------------------------------------------------------------------===//
@@ -1540,4 +1558,16 @@ def ConvertXeVMToLLVMPass : Pass<"convert-xevm-to-llvm"> {
let dependentDialects = ["LLVM::LLVMDialect"];
}
+//===----------------------------------------------------------------------===//
+// XeGPUToXeVM
+//===----------------------------------------------------------------------===//
+
+def ConvertXeGPUToXeVMPass : Pass<"convert-xegpu-to-xevm"> {
+ let summary = "Convert XeGPU to XeVM dialect";
+ let dependentDialects = ["xevm::XeVMDialect", "vector::VectorDialect",
+ "memref::MemRefDialect", "arith::ArithDialect",
+ "LLVM::LLVMDialect", "index::IndexDialect",
+ "gpu::GPUDialect", "scf::SCFDialect"];
+}
+
#endif // MLIR_CONVERSION_PASSES
diff --git a/mlir/include/mlir/Conversion/PtrToLLVM/PtrToLLVM.h b/mlir/include/mlir/Conversion/PtrToLLVM/PtrToLLVM.h
new file mode 100644
index 0000000..0ff92bc
--- /dev/null
+++ b/mlir/include/mlir/Conversion/PtrToLLVM/PtrToLLVM.h
@@ -0,0 +1,27 @@
+//===- PtrToLLVM.h - Ptr to LLVM dialect conversion -------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_PTRTOLLVM_PTRTOLLVM_H
+#define MLIR_CONVERSION_PTRTOLLVM_PTRTOLLVM_H
+
+#include <memory>
+
+namespace mlir {
+class DialectRegistry;
+class LLVMTypeConverter;
+class RewritePatternSet;
+namespace ptr {
+/// Populate the convert to LLVM patterns for the `ptr` dialect.
+void populatePtrToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
+/// Register the convert to LLVM interface for the `ptr` dialect.
+void registerConvertPtrToLLVMInterface(DialectRegistry &registry);
+} // namespace ptr
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_PTRTOLLVM_PTRTOLLVM_H
diff --git a/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h b/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h
new file mode 100644
index 0000000..b075ac9
--- /dev/null
+++ b/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h
@@ -0,0 +1,26 @@
+//===- VectorToAMX.h - Convert vector to AMX dialect ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_VECTORTOAMX_VECTORTOAMX_H
+#define MLIR_CONVERSION_VECTORTOAMX_VECTORTOAMX_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+class Pass;
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTVECTORTOAMX
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Collect a set of patterns to convert from the vector to AMX ops.
+void populateVectorToAMXConversionPatterns(RewritePatternSet &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_VECTORTOAMX_VECTORTOAMX_H
diff --git a/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h b/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h
new file mode 100644
index 0000000..ddaaae8
--- /dev/null
+++ b/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h
@@ -0,0 +1,27 @@
+//===-- XeGPUToXeVM.h - Convert XeGPU to XeVM dialect ---------_--*- C++-*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVM_H_
+#define MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVM_H_
+
+#include <memory>
+
+namespace mlir {
+class DialectRegistry;
+class LLVMTypeConverter;
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTXEGPUTOXEVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+void populateXeGPUToXeVMConversionPatterns(
+ const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVM_H_
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 92aacda..2ccf350 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -656,6 +656,48 @@ def AMDGPU_SwizzleBitModeOp : AMDGPU_Op<"swizzle_bitmode",
}];
}
+def AMDGPU_PermlaneSwapOp : AMDGPU_Op<"permlane_swap", [Pure, AllTypesMatch<["result", "src"]>]> {
+ let summary = "AMDGPU permlane swap op";
+ let description = [{
+ High-level wrapper on `rocdl.permlane{16,32}.swap` variants for permutations
+ on rows of lanes in a subgroup.
+
+ Supports arbitrary int/float/vector types, which will be repacked to i32 and
+ one or more `rocdl.permlane_swap` ops during lowering.
+ Supported lane permutations:
+ - Swap the data between odd and even rows of 16 lanes
+ - Swap the data between the first 32 lanes and the last 32 lanes
+
+ Example:
+ ```mlir
+ %0 = amdgpu.permlane_swap %src 16 : f16
+ %1 = amdgpu.permlane_swap %src 32 { fetch_inactive = true, bound_ctrl = true } : f16
+ ```
+
+ Operands:
+ * `$src`: Vector register to permute across lanes of the subgroup.
+ * `$row_length`: The length of a row to permute in number of lanes (valid values are 16 and 32).
+ * `$fetch_inactive`: Optional. Used to dertermine behavior of a fetch from a disabled lane.
+ `fetch_inactive = false`: If the source lane is disabled, use `bound_ctrl` to determine the source value.
+ `fetch_inactive = true`: If the source lane is disabled, fetch the source value anyway (ignoring `bound_ctrl`).
+ * `$bound_ctrl`: Optional. Used to determine what a thread should do if its source operand is from
+ a disabled lane: use the value zero, or disable the write.
+ `bound_ctrl = false`: Do not write when source is from a disabled lane
+ `bound_ctrl = true`: Use zero as input if source is from a disabled lane
+
+ Note: Lowering is only supported on gfx950 and up.
+ }];
+ let arguments = (ins AnyIntegerOrFloatOr1DVector:$src,
+ I32Attr:$row_length,
+ DefaultValuedAttr<BoolAttr, "false">:$fetch_inactive,
+ DefaultValuedAttr<BoolAttr, "false">:$bound_ctrl);
+ let results = (outs AnyIntegerOrFloatOr1DVector:$result);
+ let assemblyFormat = [{
+ $src $row_length attr-dict `:` type($result)
+ }];
+ let hasVerifier = 1;
+}
+
def AMDGPU_LDSBarrierOp : AMDGPU_Op<"lds_barrier"> {
let summary = "Barrier that includes a wait for LDS memory operations.";
let description = [{
@@ -907,7 +949,8 @@ def AMDGPU_GatherToLDSOp :
The elements gathered by the subgroup will be written contiguously in order of lane ID
starting at `$dst[$dstIndices]`. Byte-sized (ex. i8) or short-sized (ex. i16)
types will be zero-padded/extended to 32 bits before being written. 96-bit types
- (ex. vector<3xf32>) will be zero-padded to 128 bits before being written.
+ (ex. vector<3xf32>) will be zero-padded to 128 bits before being written. Only the
+ offsets held by lane 0 are used.
* `$transferType`: type of the data to be transferred by each thread. This is used to determine
the size of the data to be transferred and the number of threads in the subgroup.
The transfer type must be a scalar type or a vector type with a single element type.
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/AMDGPU/IR/CMakeLists.txt
index ed074c2..cab3469 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/CMakeLists.txt
@@ -4,9 +4,9 @@ add_mlir_doc(AMDGPU AMDGPU Dialects/ -gen-dialect-doc)
set(LLVM_TARGET_DEFINITIONS AMDGPU.td)
mlir_tablegen(AMDGPUEnums.h.inc -gen-enum-decls)
mlir_tablegen(AMDGPUEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRAMDGPUEnumsGen)
+add_mlir_dialect_tablegen_target(MLIRAMDGPUEnumsGen)
set(LLVM_TARGET_DEFINITIONS AMDGPU.td)
mlir_tablegen(AMDGPUAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=amdgpu)
mlir_tablegen(AMDGPUAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=amdgpu)
-add_public_tablegen_target(MLIRAMDGPUAttributesIncGen)
+add_mlir_dialect_tablegen_target(MLIRAMDGPUAttributesIncGen)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/AMDGPU/Transforms/CMakeLists.txt
index 8880989..1371923d 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -1,6 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name AMDGPU)
-add_public_tablegen_target(MLIRAMDGPUTransformsIncGen)
-add_dependencies(mlir-headers MLIRAMDGPUTransformsIncGen)
+add_mlir_dialect_tablegen_target(MLIRAMDGPUTransformsIncGen)
add_mlir_doc(Passes AMDGPUPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index 6bbde43..1236fed 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -142,14 +142,17 @@ class AMX_Op<string mnemonic, list<Trait> traits = []> :
// Tile reset.
//
-def TileZeroOp : AMX_Op<"tile_zero", [Pure,
- AMXIntrinsicOpInterface
+def TileZeroOp : AMX_Op<"tile_zero", [
+ AMXIntrinsicOpInterface,
+ MemoryEffects<[MemWrite]>
]> {
let summary = "tile zero operation";
let description = [{
Zeroes the destination tile, with the shape defined by the 2-dim
vector type of the result. This is eventually lowered into the
"tilezero" instruction with the corresponding tile configuration.
+ With memory-effects, each "tilezero" operation serves as a compilation
+ hint to use a separate tile register.
Example:
@@ -179,15 +182,17 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure,
// Tile memory operations.
//
-def TileLoadOp : AMX_Op<"tile_load", [Pure,
- AMXIntrinsicOpInterface
+def TileLoadOp : AMX_Op<"tile_load", [
+ AMXIntrinsicOpInterface,
+ MemoryEffects<[MemWrite]>
]> {
let summary = "tile load operation";
let description = [{
Loads a tile from memory defined by a base and indices, with the
shape defined by the 2-dim vector type of the result. This is
eventually lowered into the "tileloadd" instruction with the
- corresponding tile configuration.
+ corresponding tile configuration. With memory-effects, each "tileload"
+ operation serves as a compilation hint to use a separate tile register.
Example:
diff --git a/mlir/include/mlir/Dialect/Affine/CMakeLists.txt b/mlir/include/mlir/Dialect/Affine/CMakeLists.txt
index fe1b372..b4b904ee 100644
--- a/mlir/include/mlir/Dialect/Affine/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Affine/CMakeLists.txt
@@ -3,6 +3,6 @@ add_subdirectory(TransformOps)
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Affine)
-add_public_tablegen_target(MLIRAffinePassIncGen)
+add_mlir_dialect_tablegen_target(MLIRAffinePassIncGen)
add_mlir_doc(Passes AffinePasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Affine/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/Affine/TransformOps/CMakeLists.txt
index c743f5c..3d16eaa 100644
--- a/mlir/include/mlir/Dialect/Affine/TransformOps/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Affine/TransformOps/CMakeLists.txt
@@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS AffineTransformOps.td)
mlir_tablegen(AffineTransformOps.h.inc -gen-op-decls)
mlir_tablegen(AffineTransformOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRAffineTransformOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRAffineTransformOpsIncGen)
add_mlir_doc(AffineTransformOps AffineLoopTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index 19a2ade..a3e9fc3 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -20,10 +20,10 @@ def Arith_Dialect : Dialect {
mathematical operations. This includes unary, binary, and ternary arithmetic
ops, bitwise and shift ops, cast ops, and compare ops. Operations in this
dialect also accept vectors and tensors of integers or floats. The dialect
- assumes integers are represented by bitvectors with a two's complement
- representation. Unless otherwise stated, the operations within this dialect
- propagate poison values, i.e., if any of its inputs are poison, then the
- output is poison. Unless otherwise stated, operations applied to `vector`
+ assumes integers are represented by bitvectors with a two's complement
+ representation. Unless otherwise stated, the operations within this dialect
+ propagate poison values, i.e., if any of its inputs are poison, then the
+ output is poison. Unless otherwise stated, operations applied to `vector`
and `tensor` values propagates poison elementwise.
}];
@@ -76,27 +76,29 @@ def Arith_CmpIPredicateAttr : I64EnumAttr<
def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>;
def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>;
-def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>;
-def ATOMIC_RMW_KIND_MAXIMUMF : I64EnumAttrCase<"maximumf", 3>;
-def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>;
-def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>;
-def ATOMIC_RMW_KIND_MINIMUMF : I64EnumAttrCase<"minimumf", 6>;
-def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>;
-def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>;
-def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>;
-def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>;
-def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>;
-def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>;
-def ATOMIC_RMW_KIND_MAXNUMF : I64EnumAttrCase<"maxnumf", 13>;
-def ATOMIC_RMW_KIND_MINNUMF : I64EnumAttrCase<"minnumf", 14>;
+def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 2>;
+def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 3>;
+def ATOMIC_RMW_KIND_MAXIMUMF : I64EnumAttrCase<"maximumf", 4>;
+def ATOMIC_RMW_KIND_MAXNUMF : I64EnumAttrCase<"maxnumf", 5>;
+def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 6>;
+def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 7>;
+def ATOMIC_RMW_KIND_MINIMUMF : I64EnumAttrCase<"minimumf", 8>;
+def ATOMIC_RMW_KIND_MINNUMF : I64EnumAttrCase<"minnumf", 9>;
+def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 10>;
+def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 11>;
+def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 12>;
+def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 13>;
+def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 14>;
+def ATOMIC_RMW_KIND_XORI : I64EnumAttrCase<"xori", 15>;
def AtomicRMWKindAttr : I64EnumAttr<
"AtomicRMWKind", "",
- [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN,
- ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
- ATOMIC_RMW_KIND_MINIMUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
+ [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ANDI,
+ ATOMIC_RMW_KIND_ASSIGN, ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXNUMF,
+ ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU, ATOMIC_RMW_KIND_MINIMUMF,
+ ATOMIC_RMW_KIND_MINNUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
- ATOMIC_RMW_KIND_ANDI, ATOMIC_RMW_KIND_MAXNUMF, ATOMIC_RMW_KIND_MINNUMF]> {
+ ATOMIC_RMW_KIND_XORI]> {
let cppNamespace = "::mlir::arith";
}
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index ef9ccb7..20c9097 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -216,14 +216,14 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
def Arith_AddIOp : Arith_IntBinaryOpWithOverflowFlags<"addi", [Commutative]> {
let summary = "integer addition operation";
let description = [{
- Performs N-bit addition on the operands. The operands are interpreted as
- unsigned bitvectors. The result is represented by a bitvector containing the
- mathematical value of the addition modulo 2^n, where `n` is the bitwidth.
- Because `arith` integers use a two's complement representation, this operation
+ Performs N-bit addition on the operands. The operands are interpreted as
+ unsigned bitvectors. The result is represented by a bitvector containing the
+ mathematical value of the addition modulo 2^n, where `n` is the bitwidth.
+ Because `arith` integers use a two's complement representation, this operation
is applicable on both signed and unsigned integer operands.
The `addi` operation takes two operands and returns one result, each of
- these is required to be the same type. This type may be an integer scalar type,
+ these is required to be the same type. This type may be an integer scalar type,
a vector whose element type is integer, or a tensor of integers.
This op supports `nuw`/`nsw` overflow flags which stands for
@@ -489,8 +489,8 @@ def Arith_DivUIOp : Arith_IntBinaryOp<"divui", [ConditionallySpeculatable]> {
the most significant, i.e. for `i16` given two's complement representation,
`6 / -2 = 6 / (2^16 - 2) = 0`.
- Division by zero is undefined behavior. When applied to `vector` and
- `tensor` values, the behavior is undefined if _any_ elements are divided by
+ Division by zero is undefined behavior. When applied to `vector` and
+ `tensor` values, the behavior is undefined if _any_ elements are divided by
zero.
Example:
@@ -525,9 +525,9 @@ def Arith_DivSIOp : Arith_IntBinaryOp<"divsi", [ConditionallySpeculatable]> {
Signed integer division. Rounds towards zero. Treats the leading bit as
sign, i.e. `6 / -2 = -3`.
- Divison by zero, or signed division overflow (minimum value divided by -1)
- is undefined behavior. When applied to `vector` and `tensor` values, the
- behavior is undefined if _any_ of its elements are divided by zero or has a
+ Divison by zero, or signed division overflow (minimum value divided by -1)
+ is undefined behavior. When applied to `vector` and `tensor` values, the
+ behavior is undefined if _any_ of its elements are divided by zero or has a
signed division overflow.
Example:
@@ -562,10 +562,10 @@ def Arith_CeilDivUIOp : Arith_IntBinaryOp<"ceildivui",
let description = [{
Unsigned integer division. Rounds towards positive infinity. Treats the
leading bit as the most significant, i.e. for `i16` given two's complement
- representation, `6 / -2 = 6 / (2^16 - 2) = 1`.
+ representation, `6 / -2 = 6 / (2^16 - 2) = 1`.
- Division by zero is undefined behavior. When applied to `vector` and
- `tensor` values, the behavior is undefined if _any_ elements are divided by
+ Division by zero is undefined behavior. When applied to `vector` and
+ `tensor` values, the behavior is undefined if _any_ elements are divided by
zero.
Example:
@@ -594,9 +594,9 @@ def Arith_CeilDivSIOp : Arith_IntBinaryOp<"ceildivsi",
let description = [{
Signed integer division. Rounds towards positive infinity, i.e. `7 / -2 = -3`.
- Divison by zero, or signed division overflow (minimum value divided by -1)
- is undefined behavior. When applied to `vector` and `tensor` values, the
- behavior is undefined if _any_ of its elements are divided by zero or has a
+ Divison by zero, or signed division overflow (minimum value divided by -1)
+ is undefined behavior. When applied to `vector` and `tensor` values, the
+ behavior is undefined if _any_ of its elements are divided by zero or has a
signed division overflow.
Example:
@@ -624,9 +624,9 @@ def Arith_FloorDivSIOp : Arith_TotalIntBinaryOp<"floordivsi"> {
let description = [{
Signed integer division. Rounds towards negative infinity, i.e. `5 / -2 = -3`.
- Divison by zero, or signed division overflow (minimum value divided by -1)
- is undefined behavior. When applied to `vector` and `tensor` values, the
- behavior is undefined if _any_ of its elements are divided by zero or has a
+ Divison by zero, or signed division overflow (minimum value divided by -1)
+ is undefined behavior. When applied to `vector` and `tensor` values, the
+ behavior is undefined if _any_ of its elements are divided by zero or has a
signed division overflow.
Example:
@@ -650,8 +650,8 @@ def Arith_RemUIOp : Arith_TotalIntBinaryOp<"remui"> {
Unsigned integer division remainder. Treats the leading bit as the most
significant, i.e. for `i16`, `6 % -2 = 6 % (2^16 - 2) = 6`.
- Division by zero is undefined behavior. When applied to `vector` and
- `tensor` values, the behavior is undefined if _any_ elements are divided by
+ Division by zero is undefined behavior. When applied to `vector` and
+ `tensor` values, the behavior is undefined if _any_ elements are divided by
zero.
Example:
@@ -680,8 +680,8 @@ def Arith_RemSIOp : Arith_TotalIntBinaryOp<"remsi"> {
Signed integer division remainder. Treats the leading bit as sign, i.e. `6 %
-2 = 0`.
- Division by zero is undefined behavior. When applied to `vector` and
- `tensor` values, the behavior is undefined if _any_ elements are divided by
+ Division by zero is undefined behavior. When applied to `vector` and
+ `tensor` values, the behavior is undefined if _any_ elements are divided by
zero.
Example:
@@ -794,9 +794,9 @@ def Arith_XOrIOp : Arith_TotalIntBinaryOp<"xori", [Commutative]> {
def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> {
let summary = "integer left-shift";
let description = [{
- The `shli` operation shifts the integer value of the first operand to the left
- by the integer value of the second operand. The second operand is interpreted as
- unsigned. The low order bits are filled with zeros. If the value of the second
+ The `shli` operation shifts the integer value of the first operand to the left
+ by the integer value of the second operand. The second operand is interpreted as
+ unsigned. The low order bits are filled with zeros. If the value of the second
operand is greater or equal than the bitwidth of the first operand, then the
operation returns poison.
@@ -811,7 +811,7 @@ def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> {
%1 = arith.constant 5 : i8 // %1 is 0b00000101
%2 = arith.constant 3 : i8
%3 = arith.shli %1, %2 : i8 // %3 is 0b00101000
- %4 = arith.shli %1, %2 overflow<nsw, nuw> : i8
+ %4 = arith.shli %1, %2 overflow<nsw, nuw> : i8
```
}];
let hasFolder = 1;
@@ -824,9 +824,9 @@ def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> {
def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> {
let summary = "unsigned integer right-shift";
let description = [{
- The `shrui` operation shifts an integer value of the first operand to the right
+ The `shrui` operation shifts an integer value of the first operand to the right
by the value of the second operand. The first operand is interpreted as unsigned,
- and the second operand is interpreted as unsigned. The high order bits are always
+ and the second operand is interpreted as unsigned. The high order bits are always
filled with zeros. If the value of the second operand is greater or equal than the
bitwidth of the first operand, then the operation returns poison.
@@ -848,11 +848,11 @@ def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> {
def Arith_ShRSIOp : Arith_TotalIntBinaryOp<"shrsi"> {
let summary = "signed integer right-shift";
let description = [{
- The `shrsi` operation shifts an integer value of the first operand to the right
- by the value of the second operand. The first operand is interpreted as signed,
- and the second operand is interpreter as unsigned. The high order bits in the
- output are filled with copies of the most-significant bit of the shifted value
- (which means that the sign of the value is preserved). If the value of the second
+ The `shrsi` operation shifts an integer value of the first operand to the right
+ by the value of the second operand. The first operand is interpreted as signed,
+ and the second operand is interpreter as unsigned. The high order bits in the
+ output are filled with copies of the most-significant bit of the shifted value
+ (which means that the sign of the value is preserved). If the value of the second
operand is greater or equal than bitwidth of the first operand, then the operation
returns poison.
@@ -1229,28 +1229,28 @@ def Arith_ScalingExtFOp
let summary = "Upcasts input floats using provided scales values following "
"OCP MXFP Spec";
let description = [{
- This operation upcasts input floating-point values using provided scale
- values. It expects both scales and the input operand to be of the same shape,
- making the operation elementwise. Scales are usually calculated per block
+ This operation upcasts input floating-point values using provided scale
+ values. It expects both scales and the input operand to be of the same shape,
+ making the operation elementwise. Scales are usually calculated per block
following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
- If scales are calculated per block where blockSize != 1, then scales may
- require broadcasting to make this operation elementwise. For example, let's
- say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
- assuming quantization happens on the last axis, the input can be reshaped to
- `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
- per block on the last axis. Therefore, scales will be of shape
- `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
- shape as long as it is broadcast compatible with the input, e.g.,
+ If scales are calculated per block where blockSize != 1, then scales may
+ require broadcasting to make this operation elementwise. For example, let's
+ say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
+ assuming quantization happens on the last axis, the input can be reshaped to
+ `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
+ per block on the last axis. Therefore, scales will be of shape
+ `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
+ shape as long as it is broadcast compatible with the input, e.g.,
`<1 x 1 x ... (dimN/blockSize) x 1>`.
- In this example, before calling into `arith.scaling_extf`, scales must be
- broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
- that there could be multiple quantization axes. Internally,
+ In this example, before calling into `arith.scaling_extf`, scales must be
+ broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
+ that there could be multiple quantization axes. Internally,
`arith.scaling_extf` would perform the following:
-
+
```
- resultTy = get_type(result)
+ resultTy = get_type(result)
scaleTy = get_type(scale)
inputTy = get_type(input)
scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
@@ -1258,12 +1258,12 @@ def Arith_ScalingExtFOp
input.extf = arith.extf(input) : inputTy to resultTy
result = arith.mulf(scale.extf, input.extf)
```
- It propagates NaN values. Therefore, if either scale or the input element
+ It propagates NaN values. Therefore, if either scale or the input element
contains NaN, then the output element value will also be a NaN.
}];
let hasVerifier = 1;
let assemblyFormat =
- [{ $in `,` $scale (`fastmath` `` $fastmath^)? attr-dict `:`
+ [{ $in `,` $scale (`fastmath` `` $fastmath^)? attr-dict `:`
type($in) `,` type($scale) `to` type($out)}];
}
@@ -1373,28 +1373,28 @@ def Arith_ScalingTruncFOp
let summary = "Downcasts input floating point values using provided scales "
"values following OCP MXFP Spec";
let description = [{
- This operation downcasts input using the provided scale values. It expects
- both scales and the input operand to be of the same shape and, therefore,
- makes the operation elementwise. Scales are usually calculated per block
+ This operation downcasts input using the provided scale values. It expects
+ both scales and the input operand to be of the same shape and, therefore,
+ makes the operation elementwise. Scales are usually calculated per block
following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
Users are required to normalize and clamp the scales as necessary before calling
passing them to this operation. OCP MXFP spec also does the flushing of denorms
- on the input operand, which should be handled during lowering by passing appropriate
- fastMath flag to this operation.
-
- If scales are calculated per block where blockSize != 1, scales may require
- broadcasting to make this operation elementwise. For example, let's say the
- input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
- assuming quantization happens on the last axis, the input can be reshaped to
- `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
- per block on the last axis. Therefore, scales will be of shape
- `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
- shape as long as it is broadcast compatible with the input, e.g.,
+ on the input operand, which should be handled during lowering by passing appropriate
+ fastMath flag to this operation.
+
+ If scales are calculated per block where blockSize != 1, scales may require
+ broadcasting to make this operation elementwise. For example, let's say the
+ input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
+ assuming quantization happens on the last axis, the input can be reshaped to
+ `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
+ per block on the last axis. Therefore, scales will be of shape
+ `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
+ shape as long as it is broadcast compatible with the input, e.g.,
`<1 x 1 x ... (dimN/blockSize) x 1>`.
- In this example, before calling into `arith.scaling_truncf`, scales must be
- broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
- that there could be multiple quantization axes. Internally,
+ In this example, before calling into `arith.scaling_truncf`, scales must be
+ broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
+ that there could be multiple quantization axes. Internally,
`arith.scaling_truncf` would perform the following:
```
@@ -1409,7 +1409,7 @@ def Arith_ScalingTruncFOp
}];
let hasVerifier = 1;
let assemblyFormat =
- [{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:`
+ [{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:`
type($in) `,` type($scale) `to` type($out)}];
}
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Arith/Transforms/CMakeLists.txt
index 3f39e40..7ae7165 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Arith)
-add_public_tablegen_target(MLIRArithTransformsIncGen)
+add_mlir_dialect_tablegen_target(MLIRArithTransformsIncGen)
add_mlir_doc(Passes ArithPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
index 3de3ec3..4d58c2e 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
@@ -3,6 +3,6 @@ add_mlir_doc(ArmNeon ArmNeon Dialects/ -gen-dialect-doc -dialect=arm_neon)
set(LLVM_TARGET_DEFINITIONS ArmNeon.td)
mlir_tablegen(ArmNeonConversions.inc -gen-llvmir-conversions)
-add_public_tablegen_target(MLIRArmNeonConversionsIncGen)
+add_mlir_dialect_tablegen_target(MLIRArmNeonConversionsIncGen)
add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/CMakeLists.txt
index b8bc72a..f652ba6 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/CMakeLists.txt
@@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS ArmNeonVectorTransformOps.td)
mlir_tablegen(ArmNeonVectorTransformOps.h.inc -gen-op-decls)
mlir_tablegen(ArmNeonVectorTransformOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRArmNeonVectorTransformOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRArmNeonVectorTransformOpsIncGen)
add_mlir_doc(ArmNeonVectorTransformOps ArmNeonVectorTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
index 9801d8b..2e4538f 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
@@ -8,25 +8,25 @@ mlir_tablegen(ArmSMEEnums.h.inc -gen-enum-decls)
mlir_tablegen(ArmSMEEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(ArmSMEAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=arm_sme)
mlir_tablegen(ArmSMEAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=arm_sme)
-add_public_tablegen_target(MLIRArmSMEOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRArmSMEOpsIncGen)
# Generate LLVM IR Conversions
set(LLVM_TARGET_DEFINITIONS ArmSMEOps.td)
mlir_tablegen(ArmSMEOpsConversions.inc -gen-llvmir-conversions)
-add_public_tablegen_target(MLIRArmSMEConversionsIncGen)
+add_mlir_dialect_tablegen_target(MLIRArmSMEConversionsIncGen)
# Generate op interface declarations and definitions
set(LLVM_TARGET_DEFINITIONS ArmSMEOps.td)
mlir_tablegen(ArmSMEOpInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(ArmSMEOpInterfaces.cpp.inc -gen-op-interface-defs)
-add_public_tablegen_target(MLIRArmSMEOpInterfaces)
+add_mlir_dialect_tablegen_target(MLIRArmSMEOpInterfaces)
# Generate declarations and definitions of ArmSME intrinsic Ops
set(LLVM_TARGET_DEFINITIONS ArmSMEIntrinsicOps.td)
mlir_tablegen(ArmSMEIntrinsicOps.h.inc -gen-op-decls)
mlir_tablegen(ArmSMEIntrinsicOps.cpp.inc -gen-op-defs)
mlir_tablegen(ArmSMEIntrinsicConversions.inc -gen-llvmir-conversions)
-add_public_tablegen_target(MLIRArmSMEIntrinsicOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRArmSMEIntrinsicOpsIncGen)
# Generate the docs
add_mlir_doc(ArmSMEOps ArmSMEOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/Transforms/CMakeLists.txt
index 509f3fc..a96d1d9 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -2,7 +2,6 @@ set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name ArmSME)
mlir_tablegen(PassesEnums.h.inc -gen-enum-decls)
mlir_tablegen(PassesEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRArmSMETransformsIncGen)
-add_dependencies(mlir-headers MLIRArmSMETransformsIncGen)
+add_mlir_dialect_tablegen_target(MLIRArmSMETransformsIncGen)
add_mlir_doc(Passes ArmSMEPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/IR/CMakeLists.txt
index 06595b7..b9e46b2 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/CMakeLists.txt
@@ -3,4 +3,4 @@ add_mlir_doc(ArmSVE ArmSVE Dialects/ -gen-dialect-doc -dialect=arm_sve)
set(LLVM_TARGET_DEFINITIONS ArmSVE.td)
mlir_tablegen(ArmSVEConversions.inc -gen-llvmir-conversions)
-add_public_tablegen_target(MLIRArmSVEConversionsIncGen)
+add_mlir_dialect_tablegen_target(MLIRArmSVEConversionsIncGen)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt
index ce8d8fe..699edab 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt
@@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS ArmSVEVectorTransformOps.td)
mlir_tablegen(ArmSVEVectorTransformOps.h.inc -gen-op-decls)
mlir_tablegen(ArmSVEVectorTransformOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRArmSVEVectorTransformOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRArmSVEVectorTransformOpsIncGen)
add_mlir_doc(ArmSVEVectorTransformOps ArmSVEVectorTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/Transforms/CMakeLists.txt
index 7226642..39601d6 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name ArmSVE)
-add_public_tablegen_target(MLIRArmSVEPassIncGen)
+add_mlir_dialect_tablegen_target(MLIRArmSVEPassIncGen)
add_mlir_doc(Passes ArmSVEPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Async/CMakeLists.txt b/mlir/include/mlir/Dialect/Async/CMakeLists.txt
index cabd5d3..6abdf5f 100644
--- a/mlir/include/mlir/Dialect/Async/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Async/CMakeLists.txt
@@ -4,6 +4,6 @@ set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Async)
mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix Async)
mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix Async)
-add_public_tablegen_target(MLIRAsyncPassIncGen)
+add_mlir_dialect_tablegen_target(MLIRAsyncPassIncGen)
add_mlir_doc(Passes AsyncPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
index 3ead521..4e79862 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
@@ -8,11 +8,9 @@ add_mlir_interface(BufferViewFlowOpInterface)
set(LLVM_TARGET_DEFINITIONS BufferizationEnums.td)
mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
mlir_tablegen(BufferizationEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRBufferizationEnumsIncGen)
-add_dependencies(mlir-headers MLIRBufferizationEnumsIncGen)
+add_mlir_dialect_tablegen_target(MLIRBufferizationEnumsIncGen)
set(LLVM_TARGET_DEFINITIONS BufferizationTypeInterfaces.td)
mlir_tablegen(BufferizationTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(BufferizationTypeInterfaces.cpp.inc -gen-type-interface-defs)
-add_public_tablegen_target(MLIRBufferizationTypeInterfacesIncGen)
-add_dependencies(mlir-headers MLIRBufferizationTypeInterfacesIncGen)
+add_mlir_dialect_tablegen_target(MLIRBufferizationTypeInterfacesIncGen)
diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/TransformOps/CMakeLists.txt
index 95276e3..dbccd7c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/CMakeLists.txt
@@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS BufferizationTransformOps.td)
mlir_tablegen(BufferizationTransformOps.h.inc -gen-op-decls)
mlir_tablegen(BufferizationTransformOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRBufferizationTransformOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRBufferizationTransformOpsIncGen)
add_mlir_doc(BufferizationTransformOps BufferizationTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/Transforms/CMakeLists.txt
index dcae4b8..dcdd49c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -1,6 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Bufferization)
-add_public_tablegen_target(MLIRBufferizationPassIncGen)
-add_dependencies(mlir-headers MLIRBufferizationPassIncGen)
+add_mlir_dialect_tablegen_target(MLIRBufferizationPassIncGen)
add_mlir_doc(Passes BufferizationPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index b5a1242..1137651 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -15,10 +15,16 @@
#ifndef MLIR_DIALECT_COMMONFOLDERS_H
#define MLIR_DIALECT_COMMONFOLDERS_H
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/Types.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
+
+#include <cassert>
+#include <cstddef>
#include <optional>
namespace mlir {
@@ -30,11 +36,13 @@ class PoisonAttr;
/// Uses `resultType` for the type of the returned attribute.
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
/// which will be directly propagated to result.
-template <class AttrElementT,
+template <class AttrElementT, //
class ElementValueT = typename AttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
+ class ResultAttrElementT = AttrElementT,
+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
class CalculationT = function_ref<
- std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
+ std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
Type resultType,
CalculationT &&calculate) {
@@ -65,7 +73,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
if (!calRes)
return {};
- return AttrElementT::get(resultType, *calRes);
+ return ResultAttrElementT::get(resultType, *calRes);
}
if (isa<SplatElementsAttr>(operands[0]) &&
@@ -99,7 +107,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
return {};
auto lhsIt = *maybeLhsIt;
auto rhsIt = *maybeRhsIt;
- SmallVector<ElementValueT, 4> elementResults;
+ SmallVector<ResultElementValueT, 4> elementResults;
elementResults.reserve(lhs.getNumElements());
for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) {
auto elementResult = calculate(*lhsIt, *rhsIt);
@@ -119,11 +127,13 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
/// attribute.
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
/// which will be directly propagated to result.
-template <class AttrElementT,
+template <class AttrElementT, //
class ElementValueT = typename AttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
+ class ResultAttrElementT = AttrElementT,
+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
class CalculationT = function_ref<
- std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
+ std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
CalculationT &&calculate) {
assert(operands.size() == 2 && "binary op takes two operands");
@@ -139,64 +149,73 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
return operands[1];
}
- auto getResultType = [](Attribute attr) -> Type {
+ auto getAttrType = [](Attribute attr) -> Type {
if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
return typed.getType();
return {};
};
- Type lhsType = getResultType(operands[0]);
- Type rhsType = getResultType(operands[1]);
+ Type lhsType = getAttrType(operands[0]);
+ Type rhsType = getAttrType(operands[1]);
if (!lhsType || !rhsType)
return {};
if (lhsType != rhsType)
return {};
return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
+ ResultAttrElementT, ResultElementValueT,
CalculationT>(
operands, lhsType, std::forward<CalculationT>(calculate));
}
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
- class PoisonAttr = void,
+ class PoisonAttr = void, //
+ class ResultAttrElementT = AttrElementT,
+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
class CalculationT =
- function_ref<ElementValueT(ElementValueT, ElementValueT)>>
+ function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
CalculationT &&calculate) {
- return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
+ return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
+ ResultAttrElementT>(
operands, resultType,
- [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
- return calculate(a, b);
- });
+ [&](ElementValueT a, ElementValueT b)
+ -> std::optional<ResultElementValueT> { return calculate(a, b); });
}
-template <class AttrElementT,
+template <class AttrElementT, //
class ElementValueT = typename AttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
+ class ResultAttrElementT = AttrElementT,
+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
class CalculationT =
- function_ref<ElementValueT(ElementValueT, ElementValueT)>>
+ function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
CalculationT &&calculate) {
- return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
+ return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
+ ResultAttrElementT>(
operands,
- [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
- return calculate(a, b);
- });
+ [&](ElementValueT a, ElementValueT b)
+ -> std::optional<ResultElementValueT> { return calculate(a, b); });
}
/// Performs constant folding `calculate` with element-wise behavior on the one
/// attributes in `operands` and returns the result if possible.
+/// Uses `resultType` for the type of the returned attribute.
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
/// which will be directly propagated to result.
-template <class AttrElementT,
+template <class AttrElementT, //
class ElementValueT = typename AttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
+ class ResultAttrElementT = AttrElementT,
+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
class CalculationT =
- function_ref<std::optional<ElementValueT>(ElementValueT)>>
+ function_ref<std::optional<ResultElementValueT>(ElementValueT)>>
Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
+ Type resultType,
CalculationT &&calculate) {
- if (!llvm::getSingleElement(operands))
+ if (!resultType || !llvm::getSingleElement(operands))
return {};
static_assert(
@@ -214,7 +233,7 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
auto res = calculate(op.getValue());
if (!res)
return {};
- return AttrElementT::get(op.getType(), *res);
+ return ResultAttrElementT::get(resultType, *res);
}
if (isa<SplatElementsAttr>(operands[0])) {
// Both operands are splats so we can avoid expanding the values out and
@@ -224,7 +243,7 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
auto elementResult = calculate(op.getSplatValue<ElementValueT>());
if (!elementResult)
return {};
- return DenseElementsAttr::get(op.getType(), *elementResult);
+ return DenseElementsAttr::get(cast<ShapedType>(resultType), *elementResult);
} else if (isa<ElementsAttr>(operands[0])) {
// Operands are ElementsAttr-derived; perform an element-wise fold by
// expanding the values.
@@ -234,7 +253,7 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
if (!maybeOpIt)
return {};
auto opIt = *maybeOpIt;
- SmallVector<ElementValueT> elementResults;
+ SmallVector<ResultElementValueT> elementResults;
elementResults.reserve(op.getNumElements());
for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
auto elementResult = calculate(*opIt);
@@ -242,19 +261,81 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
return {};
elementResults.push_back(*elementResult);
}
- return DenseElementsAttr::get(op.getShapedType(), elementResults);
+ return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
}
return {};
}
-template <class AttrElementT,
+/// Performs constant folding `calculate` with element-wise behavior on the one
+/// attributes in `operands` and returns the result if possible.
+/// Uses the operand element type for the element type of the returned
+/// attribute.
+/// Optional PoisonAttr template argument allows to specify 'poison' attribute
+/// which will be directly propagated to result.
+template <class AttrElementT, //
+ class ElementValueT = typename AttrElementT::ValueType,
+ class PoisonAttr = ub::PoisonAttr,
+ class ResultAttrElementT = AttrElementT,
+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
+ class CalculationT =
+ function_ref<std::optional<ResultElementValueT>(ElementValueT)>>
+Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
+ CalculationT &&calculate) {
+ if (!llvm::getSingleElement(operands))
+ return {};
+
+ static_assert(
+ std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
+ "PoisonAttr is undefined, either add a dependency on UB dialect or pass "
+ "void as template argument to opt-out from poison semantics.");
+ if constexpr (!std::is_void_v<PoisonAttr>) {
+ if (isa<PoisonAttr>(operands[0]))
+ return operands[0];
+ }
+
+ auto getAttrType = [](Attribute attr) -> Type {
+ if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
+ return typed.getType();
+ return {};
+ };
+
+ Type operandType = getAttrType(operands[0]);
+ if (!operandType)
+ return {};
+
+ return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
+ ResultAttrElementT, ResultElementValueT,
+ CalculationT>(
+ operands, operandType, std::forward<CalculationT>(calculate));
+}
+
+template <class AttrElementT, //
+ class ElementValueT = typename AttrElementT::ValueType,
+ class PoisonAttr = ub::PoisonAttr,
+ class ResultAttrElementT = AttrElementT,
+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
+ class CalculationT = function_ref<ResultElementValueT(ElementValueT)>>
+Attribute constFoldUnaryOp(ArrayRef<Attribute> operands, Type resultType,
+ CalculationT &&calculate) {
+ return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
+ ResultAttrElementT>(
+ operands, resultType,
+ [&](ElementValueT a) -> std::optional<ResultElementValueT> {
+ return calculate(a);
+ });
+}
+
+template <class AttrElementT, //
class ElementValueT = typename AttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
- class CalculationT = function_ref<ElementValueT(ElementValueT)>>
+ class ResultAttrElementT = AttrElementT,
+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
+ class CalculationT = function_ref<ResultElementValueT(ElementValueT)>>
Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
CalculationT &&calculate) {
- return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
- operands, [&](ElementValueT a) -> std::optional<ElementValueT> {
+ return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
+ ResultAttrElementT>(
+ operands, [&](ElementValueT a) -> std::optional<ResultElementValueT> {
return calculate(a);
});
}
diff --git a/mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt
index 837664e..c068941 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt
@@ -6,4 +6,4 @@ mlir_tablegen(ComplexEnums.h.inc -gen-enum-decls)
mlir_tablegen(ComplexEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(ComplexAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(ComplexAttributes.cpp.inc -gen-attrdef-defs)
-add_public_tablegen_target(MLIRComplexAttributesIncGen)
+add_mlir_dialect_tablegen_target(MLIRComplexAttributesIncGen)
diff --git a/mlir/include/mlir/Dialect/DLTI/CMakeLists.txt b/mlir/include/mlir/Dialect/DLTI/CMakeLists.txt
index 4f8382e..a7695e6 100644
--- a/mlir/include/mlir/Dialect/DLTI/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/DLTI/CMakeLists.txt
@@ -6,5 +6,4 @@ add_mlir_doc(DLTIAttrs DLTIDialect Dialects/ -gen-dialect-doc)
set(LLVM_TARGET_DEFINITIONS DLTIAttrs.td)
mlir_tablegen(DLTIAttrs.h.inc -gen-attrdef-decls -attrdefs-dialect=dlti)
mlir_tablegen(DLTIAttrs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=dlti)
-add_public_tablegen_target(MLIRDLTIAttrsIncGen)
-add_dependencies(mlir-headers MLIRDLTIAttrsIncGen)
+add_mlir_dialect_tablegen_target(MLIRDLTIAttrsIncGen)
diff --git a/mlir/include/mlir/Dialect/DLTI/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/DLTI/TransformOps/CMakeLists.txt
index 1188d1a..9c13682 100644
--- a/mlir/include/mlir/Dialect/DLTI/TransformOps/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/DLTI/TransformOps/CMakeLists.txt
@@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS DLTITransformOps.td)
mlir_tablegen(DLTITransformOps.h.inc -gen-op-decls)
mlir_tablegen(DLTITransformOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRDLTITransformOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRDLTITransformOpsIncGen)
add_mlir_doc(DLTITransformOps DLTITransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/EmitC/IR/CMakeLists.txt
index 299cee7..db106b2 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/EmitC/IR/CMakeLists.txt
@@ -4,12 +4,11 @@ add_mlir_doc(EmitC EmitC Dialects/ -gen-dialect-doc -dialect emitc)
set(LLVM_TARGET_DEFINITIONS EmitCInterfaces.td)
mlir_tablegen(EmitCInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(EmitCInterfaces.cpp.inc -gen-op-interface-defs)
-add_public_tablegen_target(MLIREmitCInterfacesIncGen)
-add_dependencies(mlir-generic-headers MLIREmitCInterfacesIncGen)
+add_mlir_generic_tablegen_target(MLIREmitCInterfacesIncGen)
set(LLVM_TARGET_DEFINITIONS EmitCAttributes.td)
mlir_tablegen(EmitCEnums.h.inc -gen-enum-decls)
mlir_tablegen(EmitCEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(EmitCAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(EmitCAttributes.cpp.inc -gen-attrdef-defs)
-add_public_tablegen_target(MLIREmitCAttributesIncGen)
+add_mlir_dialect_tablegen_target(MLIREmitCAttributesIncGen)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index 1984ed8..eb7ddeb 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -53,6 +53,9 @@ bool isPointerWideType(mlir::Type type);
struct Placeholder {};
using ReplacementItem = std::variant<StringRef, Placeholder>;
+/// Determines whether \p type is a valid fundamental C++ type in EmitC.
+bool isFundamentalType(mlir::Type type);
+
} // namespace emitc
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 937b34a6..fb7a108 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -43,7 +43,8 @@ class EmitC_UnaryOp<string mnemonic, list<Trait> traits = []> :
let extraClassDeclaration = [{
bool hasSideEffects() {
- return false;
+ // If operand is fundamental type, the operation is pure.
+ return !isFundamentalType(getOperand().getType());
}
}];
}
@@ -57,7 +58,9 @@ class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
let extraClassDeclaration = [{
bool hasSideEffects() {
- return false;
+ // If both operands are fundamental types, the operation is pure.
+ return !isFundamentalType(getOperand(0).getType()) ||
+ !isFundamentalType(getOperand(1).getType());
}
}];
}
@@ -452,9 +455,11 @@ def EmitC_DivOp : EmitC_BinaryOp<"div", []> {
let results = (outs FloatIntegerIndexOrOpaqueType);
}
-def EmitC_ExpressionOp : EmitC_Op<"expression",
- [HasOnlyGraphRegion, OpAsmOpInterface,
- SingleBlockImplicitTerminator<"emitc::YieldOp">, NoRegionArguments]> {
+def EmitC_ExpressionOp
+ : EmitC_Op<
+ "expression", [HasOnlyGraphRegion, OpAsmOpInterface,
+ IsolatedFromAbove,
+ SingleBlockImplicitTerminator<"emitc::YieldOp">]> {
let summary = "Expression operation";
let description = [{
The `emitc.expression` operation returns a single SSA value which is yielded by
@@ -491,12 +496,13 @@ def EmitC_ExpressionOp : EmitC_Op<"expression",
at its use.
}];
- let arguments = (ins UnitAttr:$do_not_inline);
+ let arguments = (ins Variadic<AnyTypeOf<[EmitCType, EmitC_LValueType]>>:$defs,
+ UnitAttr:$do_not_inline);
let results = (outs EmitCType:$result);
let regions = (region SizedRegion<1>:$region);
let hasVerifier = 1;
- let assemblyFormat = "attr-dict (`noinline` $do_not_inline^)? `:` type($result) $region";
+ let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = [{
bool hasSideEffects() {
@@ -507,6 +513,13 @@ def EmitC_ExpressionOp : EmitC_Op<"expression",
return llvm::any_of(getRegion().front().without_terminator(), predicate);
};
Operation *getRootOp();
+ Block &createBody() {
+ assert(getRegion().empty() && "expression already has a body");
+ Block &block = getRegion().emplaceBlock();
+ for (auto operand : getOperands())
+ block.addArgument(operand.getType(), operand.getLoc());
+ return block;
+ }
//===------------------------------------------------------------------===//
// OpAsmOpInterface Methods
@@ -1051,6 +1064,8 @@ def EmitC_MemberOp : EmitC_Op<"member"> {
```mlir
%0 = "emitc.member" (%arg0) {member = "a"}
: (!emitc.lvalue<!emitc.opaque<"mystruct">>) -> !emitc.lvalue<i32>
+ %1 = "emitc.member" (%arg0) {member = "b"}
+ : (!emitc.lvalue<!emitc.opaque<"mystruct">>) -> !emitc.array<2xi32>
```
}];
@@ -1058,7 +1073,7 @@ def EmitC_MemberOp : EmitC_Op<"member"> {
Arg<StrAttr, "the member to access">:$member,
EmitC_LValueOf<[EmitC_OpaqueType]>:$operand
);
- let results = (outs EmitC_LValueOf<[EmitCType]>);
+ let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>);
}
def EmitC_MemberOfPtrOp : EmitC_Op<"member_of_ptr"> {
@@ -1073,6 +1088,9 @@ def EmitC_MemberOfPtrOp : EmitC_Op<"member_of_ptr"> {
%0 = "emitc.member_of_ptr" (%arg0) {member = "a"}
: (!emitc.lvalue<!emitc.ptr<!emitc.opaque<"mystruct">>>)
-> !emitc.lvalue<i32>
+ %1 = "emitc.member_of_ptr" (%arg0) {member = "b"}
+ : (!emitc.lvalue<!emitc.ptr<!emitc.opaque<"mystruct">>>)
+ -> !emitc.array<2xi32>
```
}];
@@ -1080,7 +1098,7 @@ def EmitC_MemberOfPtrOp : EmitC_Op<"member_of_ptr"> {
Arg<StrAttr, "the member to access">:$member,
EmitC_LValueOf<[EmitC_OpaqueType,EmitC_PointerType]>:$operand
);
- let results = (outs EmitC_LValueOf<[EmitCType]>);
+ let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>);
}
def EmitC_ConditionalOp : EmitC_Op<"conditional",
@@ -1697,6 +1715,7 @@ def EmitC_GetFieldOp
let arguments = (ins FlatSymbolRefAttr:$field_name);
let results = (outs EmitCType:$result);
let assemblyFormat = "$field_name `:` type($result) attr-dict";
+ let hasVerifier = 1;
}
#endif // MLIR_DIALECT_EMITC_IR_EMITC
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/EmitC/Transforms/CMakeLists.txt
index 0b507d7..bf3b4b2 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name EmitC)
-add_public_tablegen_target(MLIREmitCTransformsIncGen)
+add_mlir_dialect_tablegen_target(MLIREmitCTransformsIncGen)
add_mlir_doc(Passes EmitCPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Func/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Func/IR/CMakeLists.txt
index 08a6123..b592978 100644
--- a/mlir/include/mlir/Dialect/Func/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Func/IR/CMakeLists.txt
@@ -3,6 +3,6 @@ mlir_tablegen(FuncOps.h.inc -gen-op-decls)
mlir_tablegen(FuncOps.cpp.inc -gen-op-defs)
mlir_tablegen(FuncOpsDialect.h.inc -gen-dialect-decls)
mlir_tablegen(FuncOpsDialect.cpp.inc -gen-dialect-defs)
-add_public_tablegen_target(MLIRFuncOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRFuncOpsIncGen)
add_mlir_doc(FuncOps FuncOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/Func/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/Func/TransformOps/CMakeLists.txt
index 7ac6504..b505489 100644
--- a/mlir/include/mlir/Dialect/Func/TransformOps/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Func/TransformOps/CMakeLists.txt
@@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS FuncTransformOps.td)
mlir_tablegen(FuncTransformOps.h.inc -gen-op-decls)
mlir_tablegen(FuncTransformOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRFuncTransformOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRFuncTransformOpsIncGen)
add_mlir_doc(FuncTransformOps FuncTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/Func/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Func/Transforms/CMakeLists.txt
index 33c72e3..4ca0228 100644
--- a/mlir/include/mlir/Dialect/Func/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Func/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Func)
-add_public_tablegen_target(MLIRFuncTransformsIncGen)
+add_mlir_dialect_tablegen_target(MLIRFuncTransformsIncGen)
add_mlir_doc(Passes FuncPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/GPU/IR/CMakeLists.txt
index 3c95b4f..f7c84fc 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/GPU/IR/CMakeLists.txt
@@ -4,29 +4,29 @@ add_mlir_doc(GPUOps GPUOps Dialects/ -gen-op-doc)
set(LLVM_TARGET_DEFINITIONS GPUBase.td)
mlir_tablegen(GPUOpInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(GPUOpInterfaces.cpp.inc -gen-op-interface-defs)
-add_public_tablegen_target(MLIRGPUOpInterfacesIncGen)
+add_mlir_dialect_tablegen_target(MLIRGPUOpInterfacesIncGen)
set(LLVM_TARGET_DEFINITIONS ParallelLoopMapperAttr.td)
mlir_tablegen(ParallelLoopMapperEnums.h.inc -gen-enum-decls)
mlir_tablegen(ParallelLoopMapperEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRParallelLoopMapperEnumsGen)
+add_mlir_dialect_tablegen_target(MLIRParallelLoopMapperEnumsGen)
set(LLVM_TARGET_DEFINITIONS GPUOps.td)
mlir_tablegen(GPUOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(GPUOpsEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRGPUOpsEnumsGen)
+add_mlir_dialect_tablegen_target(MLIRGPUOpsEnumsGen)
set(LLVM_TARGET_DEFINITIONS CompilationAttrInterfaces.td)
mlir_tablegen(CompilationAttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(CompilationAttrInterfaces.cpp.inc -gen-attr-interface-defs)
-add_public_tablegen_target(MLIRGPUCompilationAttrInterfacesIncGen)
+add_mlir_dialect_tablegen_target(MLIRGPUCompilationAttrInterfacesIncGen)
set(LLVM_TARGET_DEFINITIONS GPUDeviceMappingAttr.td)
mlir_tablegen(GPUDeviceMapperEnums.h.inc -gen-enum-decls)
mlir_tablegen(GPUDeviceMapperEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRGPUDeviceMapperEnumsGen)
+add_mlir_dialect_tablegen_target(MLIRGPUDeviceMapperEnumsGen)
set(LLVM_TARGET_DEFINITIONS GPUOps.td)
mlir_tablegen(GPUOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=gpu)
mlir_tablegen(GPUOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=gpu)
-add_public_tablegen_target(MLIRGPUOpsAttributesIncGen)
+add_mlir_dialect_tablegen_target(MLIRGPUOpsAttributesIncGen)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 2ed7d38..3fb0cfe 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -804,8 +804,8 @@ def GPU_LaunchOp : GPU_Op<"launch", [
Optional<Index>:$clusterSizeY,
Optional<Index>:$clusterSizeZ,
Optional<I32>:$dynamicSharedMemorySize,
- OptionalAttr<SymbolRefAttr>:$kernelFunc,
- OptionalAttr<SymbolRefAttr>:$kernelModule)>,
+ OptionalAttr<FlatSymbolRefAttr>:$module,
+ OptionalAttr<FlatSymbolRefAttr>:$function)>,
Results<(outs Optional<GPU_AsyncToken>:$asyncToken)> {
let summary = "GPU kernel launch operation";
@@ -839,7 +839,7 @@ def GPU_LaunchOp : GPU_Op<"launch", [
- a variadic number of Workgroup memory attributions.
- a variadic number of Private memory attributions.
- The `kernelFunc` and `kernelModule` attributes are optional and specifies
+ The `function` and `module` attributes are optional and specifies
the kernel name and a module in which the kernel should be outlined.
Syntax:
@@ -850,6 +850,8 @@ def GPU_LaunchOp : GPU_Op<"launch", [
`blocks` `(` ssa-id-list `)` `in` ssa-reassignment
`threads` `(` ssa-id-list `)` `in` ssa-reassignment
(dynamic_shared_memory_size ssa-use)?
+ (`module(` symbol-ref-id `)`)?
+ (`function(` symbol-ref-id `)`)?
memory-attribution
region attr-dict?
ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
@@ -907,6 +909,14 @@ def GPU_LaunchOp : GPU_Op<"launch", [
// sizes are immediately usable inside body region.
"some_op"(%cx, %bx, %tx) : (index, index, index) -> ()
}
+
+ // Launch with module and function attributes.
+ gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %0, %sz_by = %1, %sz_bz = %2)
+ threads(%tx, %ty, %tz) in (%sz_tx = %3, %sz_ty = %4, %sz_tz = %5)
+ module(@kernel_module) function(@kernel_func) {
+ "some_op"(%bx, %tx) : (index, index) -> ()
+ %42 = load %val1[%bx] : memref<?xf32, 1>
+ }
```
Rationale: using operation/block arguments gives analyses a clear way of
@@ -931,7 +941,9 @@ def GPU_LaunchOp : GPU_Op<"launch", [
CArg<"TypeRange", "{}">:$privateAttributions,
CArg<"Value", "nullptr">:$clusterSizeX,
CArg<"Value", "nullptr">:$clusterSizeY,
- CArg<"Value", "nullptr">:$clusterSizeZ)>
+ CArg<"Value", "nullptr">:$clusterSizeZ,
+ CArg<"FlatSymbolRefAttr", "nullptr">:$module,
+ CArg<"FlatSymbolRefAttr", "nullptr">:$function)>,
];
let extraClassDeclaration = [{
@@ -1505,7 +1517,7 @@ def GPU_GPUModuleOp : GPU_Op<"module", [
/// Sets the targets of the module.
void setTargets(ArrayRef<TargetAttrInterface> targets);
}];
-
+
let hasVerifier = 1;
}
@@ -3197,7 +3209,58 @@ def GPU_WarpExecuteOnLane0Op : GPU_Op<"warp_execute_on_lane_0",
bool isDefinedOutsideOfRegion(Value value) {
return !getRegion().isAncestor(value.getParentRegion());
}
+
+ /// Get the terminator of the warp region.
+ gpu::YieldOp getTerminator();
+ }];
+}
+
+def GPU_BroadcastType : I32EnumAttr<"BroadcastType",
+ "a lane to broadcast from",
+ [
+ I32EnumAttrCase<"first_active_lane", 0>,
+ I32EnumAttrCase<"any_lane", 1>,
+ I32EnumAttrCase<"specific_lane", 2>
+ ]>{
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::gpu";
+}
+def GPU_BroadcastTypeAttr : EnumAttr<GPU_Dialect, GPU_BroadcastType, "broadcast">;
+
+def GPU_SubgroupBroadcastOp : GPU_Op<"subgroup_broadcast",
+ [NoMemoryEffect, AllTypesMatch<["result", "src"]>,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
+ DeclareOpInterfaceMethods<ConditionallySpeculatable, ["getSpeculatability"]>] #
+ ElementwiseMappable.traits>,
+ Arguments<(ins AnyType:$src,
+ Optional<I32>:$lane,
+ GPU_BroadcastTypeAttr:$broadcast_type)> {
+ let summary = "Broadcasts a value from the specific lane across subgroup";
+ let description = [{
+ Broadcasts a value from one lane to all active lanes in a subgroup. The
+ result is guaranteed to be uniform across the active lanes in subgroup.
+
+ The possible broadcast types are:
+
+ * `first_active_lane` - broadcasts the value from the first active lane
+ in the subgroup.
+ * `specific_lane` - broadcasts from the specified lane. The lane index
+ must be uniform and within the subgroup size. The result is poison if the
+ lane index is invalid, non subgroup-uniform, or if the source lane is not
+ active.
+ * `any_lane` - broadcasts the value from any lane of the subgroup,
+ assuming the input is already subgroup uniform. The result is poison if
+ the input is not uniform. This is useful to convey uniformity to the
+ compiler to enable more optimizations. Also, it allows more speculation
+ opportunities than `first_active_lane` since `first_active_lane` results
+ can depend on active lanes which may change during speculation across
+ control flow.
+ }];
+ let results = (outs AnyType:$result);
+ let assemblyFormat = [{
+ $src `,` $broadcast_type ($lane^)? attr-dict `:` type($result)
}];
+ let hasVerifier = 1;
}
#endif // GPU_OPS
diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/GPU/TransformOps/CMakeLists.txt
index c99f3df..41c7499 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/CMakeLists.txt
@@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS GPUTransformOps.td)
mlir_tablegen(GPUTransformOps.h.inc -gen-op-decls)
mlir_tablegen(GPUTransformOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRGPUTransformOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRGPUTransformOpsIncGen)
add_mlir_doc(GPUTransformOps GPUTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
index 87423c6..3a8caf8 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
@@ -331,7 +331,10 @@ def ApplyGPUPromoteShuffleToAMDGPUPatternsOp : Op<Transform_Dialect,
Collects patterns that are tryin to promote `gpu.shuffle`s to specialized
AMDGPU intrinsics.
}];
- let assemblyFormat = "attr-dict";
+ let arguments = (ins OptionalAttr<StrAttr>:$chipset);
+ let assemblyFormat = [{
+ (`chipset` `=` $chipset^)? attr-dict
+ }];
}
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/GPU/Transforms/CMakeLists.txt
index 60daed4..7cc5eb0 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/CMakeLists.txt
@@ -2,6 +2,6 @@ set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name GPU)
mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix GPU)
mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix GPU)
-add_public_tablegen_target(MLIRGPUPassIncGen)
+add_mlir_dialect_tablegen_target(MLIRGPUPassIncGen)
add_mlir_doc(Passes GPUPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index b4fd55e6..d5c253d 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -114,7 +114,8 @@ void populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns);
void populateGpuEliminateBarriersPatterns(RewritePatternSet &patterns);
/// Tries to promote `gpu.shuffle`s to specialized AMDGPU intrinsics.
-void populateGpuPromoteShuffleToAMDGPUPatterns(RewritePatternSet &patterns);
+void populateGpuPromoteShuffleToAMDGPUPatterns(
+ RewritePatternSet &patterns, std::optional<amdgpu::Chipset> maybeChipset);
/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
diff --git a/mlir/include/mlir/Dialect/IRDL/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/IRDL/IR/CMakeLists.txt
index 861db0c..9dcadec 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/IRDL/IR/CMakeLists.txt
@@ -5,22 +5,19 @@ add_mlir_doc(IRDLOps IRDL Dialects/ -gen-dialect-doc -dialect=irdl)
set(LLVM_TARGET_DEFINITIONS IRDLInterfaces.td)
mlir_tablegen(IRDLInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(IRDLInterfaces.cpp.inc -gen-op-interface-defs)
-add_public_tablegen_target(MLIRIRDLInterfacesIncGen)
-add_dependencies(mlir-generic-headers MLIRIRDLInterfacesIncGen)
+add_mlir_generic_tablegen_target(MLIRIRDLInterfacesIncGen)
# Add IRDL operations
set(LLVM_TARGET_DEFINITIONS IRDLOps.td)
mlir_tablegen(IRDLOps.h.inc -gen-op-decls)
mlir_tablegen(IRDLOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRIRDLOpsIncGen)
-add_dependencies(mlir-generic-headers MLIRIRDLOpsIncGen)
+add_mlir_generic_tablegen_target(MLIRIRDLOpsIncGen)
# Add IRDL types
set(LLVM_TARGET_DEFINITIONS IRDLTypes.td)
mlir_tablegen(IRDLTypesGen.h.inc -gen-typedef-decls)
mlir_tablegen(IRDLTypesGen.cpp.inc -gen-typedef-defs)
-add_public_tablegen_target(MLIRIRDLTypesIncGen)
-add_dependencies(mlir-generic-headers MLIRIRDLTypesIncGen)
+add_mlir_generic_tablegen_target(MLIRIRDLTypesIncGen)
# Add IRDL attributes
set(LLVM_TARGET_DEFINITIONS IRDLAttributes.td)
@@ -28,5 +25,4 @@ mlir_tablegen(IRDLEnums.h.inc -gen-enum-decls)
mlir_tablegen(IRDLEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(IRDLAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(IRDLAttributes.cpp.inc -gen-attrdef-defs)
-add_public_tablegen_target(MLIRIRDLAttributesIncGen)
-add_dependencies(mlir-generic-headers MLIRIRDLAttributesIncGen)
+add_mlir_generic_tablegen_target(MLIRIRDLAttributesIncGen)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
index 5b7df69..cb24893 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
@@ -18,6 +18,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
+#include "llvm/Support/LogicalResult.h"
namespace mlir {
namespace NVVM {
@@ -26,13 +27,26 @@ namespace NVVM {
enum class PTXRegisterMod {
/// Read register with no modifier
Read = 0,
- /// Read register with '+' modifier
+ /// Write register with '=' modifier
Write = 2,
- /// Read register with '=' modifier.
- /// Note that, this is not natively supported by LLVM, but it is possible to
- /// set read and write for the same operand.
+ /// ReadWrite register with '+' modifier.
+ /// Note that, this is not natively supported by LLVM, the Interface does
+ /// mapping
ReadWrite = 1,
};
+
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ PTXRegisterMod mod) {
+ switch (mod) {
+ case PTXRegisterMod::Read:
+ return os << "Read";
+ case PTXRegisterMod::Write:
+ return os << "Write";
+ case PTXRegisterMod::ReadWrite:
+ return os << "ReadWrite";
+ }
+ llvm_unreachable("Unknown PTXRegisterMod value");
+}
} // namespace NVVM
} // namespace mlir
@@ -54,16 +68,23 @@ class PtxBuilder {
SmallVector<Value> ptxOperands;
// Register constraints (read, write, readwrite) and register data types
std::string registerConstraints;
-
+ // Modifiers
+ SmallVector<PTXRegisterMod> registerModifiers;
+ // Has return value as write-only or read-write
bool hasResult = false;
+ // Indicates if the Op will handle the register mapping manually.
+ bool needsManualRegisterMapping = false;
public:
/// Single constructor that only initializes members.
- PtxBuilder(Operation *op, PatternRewriter &rewriter)
- : interfaceOp(op), rewriter(rewriter) {}
+ PtxBuilder(Operation *op, PatternRewriter &rewriter,
+ bool needsManualRegisterMapping = false)
+ : interfaceOp(op), rewriter(rewriter),
+ needsManualRegisterMapping(needsManualRegisterMapping) {}
/// Add an operand with the read/write input type.
- void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read);
+ LogicalResult insertValue(Value v,
+ PTXRegisterMod itype = PTXRegisterMod::Read);
/// Builds the inline assembly Op and returns it. The `insertValue` needs to
/// be called to pass operands before building the PTX.
@@ -74,6 +95,16 @@ public:
void buildAndReplaceOp();
};
+/// Count the number of placeholder variables such as {$r}, {$w}, {$rw} in the
+/// PTX code.
+void countPlaceholderNumbers(StringRef ptxCode,
+ llvm::SmallDenseSet<unsigned> &seenRW,
+ llvm::SmallDenseSet<unsigned> &seenW,
+ llvm::SmallDenseSet<unsigned> &seenR,
+ llvm::SmallVectorImpl<unsigned> &rwNums,
+ llvm::SmallVectorImpl<unsigned> &wNums,
+ llvm::SmallVectorImpl<unsigned> &rNums);
+
} // namespace NVVM
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
index e98b94b..086cdcc 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
@@ -124,19 +124,21 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
following this order:
1) Adds results
2) Adds operands
- 3) Adds attributes
+ 3) Adds attributes
+ Returns true if the OP is going to do register mapping itself
}],
- /*retType=*/"void",
+ /*retType=*/"bool",
/*methodName=*/"getAsmValues",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
- "llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues),
+ "llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues
+ ),
/*methodBody=*/"",
/*defaultImpl=*/ [{
mlir::Operation* op = $_op;
// Step 1. Add results
- for (auto val : op->getResults())
- asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
+ for (auto val : op->getResults())
+ asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
// Step 2. Add operands
for (auto val : op->getOperands())
@@ -149,6 +151,7 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
}
}
+ return false; // No manual mapping needed
}]
>
];
diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index cfad07e5..8d9474b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -11,17 +11,17 @@ mlir_tablegen(LLVMOpsAttrDefs.h.inc -gen-attrdef-decls
-attrdefs-dialect=llvm)
mlir_tablegen(LLVMOpsAttrDefs.cpp.inc -gen-attrdef-defs
-attrdefs-dialect=llvm)
-add_public_tablegen_target(MLIRLLVMOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRLLVMOpsIncGen)
set(LLVM_TARGET_DEFINITIONS LLVMTypes.td)
mlir_tablegen(LLVMTypes.h.inc -gen-typedef-decls -typedefs-dialect=llvm)
mlir_tablegen(LLVMTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=llvm)
-add_public_tablegen_target(MLIRLLVMTypesIncGen)
+add_mlir_dialect_tablegen_target(MLIRLLVMTypesIncGen)
set(LLVM_TARGET_DEFINITIONS LLVMIntrinsicOps.td)
mlir_tablegen(LLVMIntrinsicOps.h.inc -gen-op-decls)
mlir_tablegen(LLVMIntrinsicOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRLLVMIntrinsicOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRLLVMIntrinsicOpsIncGen)
add_mlir_doc(LLVMOps LLVMOps Dialects/ -gen-op-doc)
add_mlir_doc(LLVMIntrinsicOps LLVMIntrinsicOps Dialects/ -gen-op-doc)
@@ -33,32 +33,30 @@ mlir_tablegen(LLVMAttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(LLVMAttrInterfaces.cpp.inc -gen-attr-interface-defs)
mlir_tablegen(LLVMTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(LLVMTypeInterfaces.cpp.inc -gen-type-interface-defs)
-add_public_tablegen_target(MLIRLLVMInterfacesIncGen)
+add_mlir_dialect_tablegen_target(MLIRLLVMInterfacesIncGen)
set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions)
mlir_tablegen(LLVMConversionEnumsToLLVM.inc -gen-enum-to-llvmir-conversions)
mlir_tablegen(LLVMConversionEnumsFromLLVM.inc -gen-enum-from-llvmir-conversions)
mlir_tablegen(LLVMOpFromLLVMIRConversions.inc -gen-op-from-llvmir-conversions)
-add_public_tablegen_target(MLIRLLVMConversionsIncGen)
+add_mlir_dialect_tablegen_target(MLIRLLVMConversionsIncGen)
set(LLVM_TARGET_DEFINITIONS LLVMIntrinsicOps.td)
mlir_tablegen(LLVMIntrinsicConversions.inc -gen-llvmir-conversions)
mlir_tablegen(LLVMIntrinsicFromLLVMIRConversions.inc -gen-intr-from-llvmir-conversions)
mlir_tablegen(LLVMConvertibleLLVMIRIntrinsics.inc -gen-convertible-llvmir-intrinsics)
-add_public_tablegen_target(MLIRLLVMIntrinsicConversionsIncGen)
+add_mlir_dialect_tablegen_target(MLIRLLVMIntrinsicConversionsIncGen)
set(LLVM_TARGET_DEFINITIONS BasicPtxBuilderInterface.td)
mlir_tablegen(BasicPtxBuilderInterface.h.inc -gen-op-interface-decls)
mlir_tablegen(BasicPtxBuilderInterface.cpp.inc -gen-op-interface-defs)
-add_public_tablegen_target(MLIRBasicPtxBuilderInterfaceIncGen)
-add_dependencies(mlir-headers MLIRBasicPtxBuilderInterfaceIncGen)
+add_mlir_dialect_tablegen_target(MLIRBasicPtxBuilderInterfaceIncGen)
set(LLVM_TARGET_DEFINITIONS NVVMRequiresSMTraits.td)
mlir_tablegen(NVVMRequiresSMTraits.h.inc -gen-op-interface-decls)
mlir_tablegen(NVVMRequiresSMTraits.cpp.inc -gen-op-interface-defs)
-add_public_tablegen_target(MLIRNVVMRequiresSMTraitsIncGen)
-add_dependencies(mlir-headers MLIRNVVMRequiresSMTraitsIncGen)
+add_mlir_dialect_tablegen_target(MLIRNVVMRequiresSMTraitsIncGen)
add_mlir_dialect(NVVMOps nvvm)
add_mlir_doc(NVVMOps NVVMDialect Dialects/ -gen-dialect-doc -dialect=nvvm)
@@ -70,7 +68,7 @@ mlir_tablegen(NVVMOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(NVVMOpsEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(NVVMOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=nvvm)
mlir_tablegen(NVVMOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=nvvm)
-add_public_tablegen_target(MLIRNVVMConversionsIncGen)
+add_mlir_dialect_tablegen_target(MLIRNVVMConversionsIncGen)
add_mlir_dialect(ROCDLOps rocdl)
add_mlir_doc(ROCDLOps ROCDLDialect Dialects/ -gen-dialect-doc -dialect=rocdl)
@@ -78,7 +76,7 @@ set(LLVM_TARGET_DEFINITIONS ROCDLOps.td)
mlir_tablegen(ROCDLConversions.inc -gen-llvmir-conversions)
mlir_tablegen(ROCDLOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=rocdl)
mlir_tablegen(ROCDLOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=rocdl)
-add_public_tablegen_target(MLIRROCDLConversionsIncGen)
+add_mlir_dialect_tablegen_target(MLIRROCDLConversionsIncGen)
add_mlir_dialect(VCIXOps vcix)
add_mlir_doc(VCIXOps VCIXDialect Dialects/ -gen-dialect-doc -dialect=vcix)
@@ -86,7 +84,7 @@ set(LLVM_TARGET_DEFINITIONS VCIXOps.td)
mlir_tablegen(VCIXConversions.inc -gen-llvmir-conversions)
mlir_tablegen(VCIXOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=vcix)
mlir_tablegen(VCIXOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=vcix)
-add_public_tablegen_target(MLIRVCIXConversionsIncGen)
+add_mlir_dialect_tablegen_target(MLIRVCIXConversionsIncGen)
add_mlir_dialect(XeVMOps xevm)
add_mlir_doc(XeVMOps XeVMDialect Dialects/ -gen-dialect-doc -dialect=xevm)
@@ -96,4 +94,4 @@ mlir_tablegen(XeVMOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(XeVMOpsEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(XeVMOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=xevm)
mlir_tablegen(XeVMOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=xevm)
-add_public_tablegen_target(MLIRXeVMConversionsIncGen)
+add_mlir_dialect_tablegen_target(MLIRXeVMConversionsIncGen)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 790d2e7..fc5c5f9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -11,8 +11,10 @@
include "mlir/Dialect/LLVMIR/LLVMDialect.td"
include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"
+include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/CommonAttrConstraints.td"
+include "mlir/Interfaces/DataLayoutInterfaces.td"
// All of the attributes will extend this class.
class LLVM_Attr<string name, string attrMnemonic,
@@ -23,6 +25,41 @@ class LLVM_Attr<string name, string attrMnemonic,
}
//===----------------------------------------------------------------------===//
+// AddressSpaceAttr
+//===----------------------------------------------------------------------===//
+
+def LLVM_AddressSpaceAttr :
+ LLVM_Attr<"AddressSpace", "address_space", [
+ DeclareAttrInterfaceMethods<MemorySpaceAttrInterface>
+ ]> {
+ let summary = "LLVM address space";
+ let description = [{
+ The `address_space` attribute represents an LLVM address space. It takes an
+ unsigned integer parameter that specifies the address space number.
+
+ Different address spaces in LLVM can have different properties:
+ - Address space 0 is the default/generic address space
+ - Other address spaces may have specific semantics (e.g., shared memory,
+ constant memory, etc.) depending on the target architecture
+
+ Example:
+
+ ```mlir
+ // Address space 0 (default)
+ #llvm.address_space<0>
+
+ // Address space 1 (e.g., global memory on some targets)
+ #llvm.address_space<1>
+
+ // Address space 3 (e.g., shared memory on some GPU targets)
+ #llvm.address_space<3>
+ ```
+ }];
+ let parameters = (ins "unsigned":$addressSpace);
+ let assemblyFormat = "`<` $addressSpace `>`";
+}
+
+//===----------------------------------------------------------------------===//
// CConvAttr
//===----------------------------------------------------------------------===//
@@ -553,7 +590,7 @@ def LLVM_DIGlobalVariable : LLVM_Attr<"DIGlobalVariable", "di_global_variable",
//===----------------------------------------------------------------------===//
def LLVM_DILexicalBlockAttr : LLVM_Attr<"DILexicalBlock", "di_lexical_block",
- /*traits=*/[], "DIScopeAttr"> {
+ /*traits=*/[], "DILocalScopeAttr"> {
let parameters = (ins
"DIScopeAttr":$scope,
OptionalParameter<"DIFileAttr">:$file,
@@ -579,7 +616,7 @@ def LLVM_DILexicalBlockAttr : LLVM_Attr<"DILexicalBlock", "di_lexical_block",
//===----------------------------------------------------------------------===//
def LLVM_DILexicalBlockFile : LLVM_Attr<"DILexicalBlockFile", "di_lexical_block_file",
- /*traits=*/[], "DIScopeAttr"> {
+ /*traits=*/[], "DILocalScopeAttr"> {
let parameters = (ins
"DIScopeAttr":$scope,
OptionalParameter<"DIFileAttr">:$file,
@@ -637,7 +674,7 @@ def LLVM_DILocalVariableAttr : LLVM_Attr<"DILocalVariable", "di_local_variable",
def LLVM_DISubprogramAttr : LLVM_Attr<"DISubprogram", "di_subprogram",
[LLVM_DIRecursiveTypeAttrInterface],
- "DIScopeAttr"> {
+ "DILocalScopeAttr"> {
let parameters = (ins
// DIRecursiveTypeAttrInterface specific parameters.
OptionalParameter<"DistinctAttr">:$recId,
@@ -1245,7 +1282,8 @@ def LLVM_VScaleRangeAttr : LLVM_Attr<"VScaleRange", "vscale_range"> {
// TargetFeaturesAttr
//===----------------------------------------------------------------------===//
-def LLVM_TargetFeaturesAttr : LLVM_Attr<"TargetFeatures", "target_features">
+def LLVM_TargetFeaturesAttr : LLVM_Attr<"TargetFeatures", "target_features",
+ [DLTIQueryInterface]>
{
let summary = "LLVM target features attribute";
@@ -1298,6 +1336,9 @@ def LLVM_TargetFeaturesAttr : LLVM_Attr<"TargetFeatures", "target_features">
static constexpr StringLiteral getAttributeName() {
return StringLiteral("target_features");
}
+
+ /// Returns the attribute associated with the key.
+ FailureOr<Attribute> query(DataLayoutEntryKey key);
}];
let assemblyFormat = "`<` `[` (`]`) : ($features^ `]`)? `>`";
@@ -1305,6 +1346,34 @@ def LLVM_TargetFeaturesAttr : LLVM_Attr<"TargetFeatures", "target_features">
}
//===----------------------------------------------------------------------===//
+// TargetAttr
+//===----------------------------------------------------------------------===//
+
+def LLVM_TargetAttr : LLVM_Attr<"Target", "target",
+ [LLVM_TargetAttrInterface]> {
+ let summary = "LLVM target info: triple, chip, features";
+ let description = [{
+ An attribute to hold LLVM target information, specifying LLVM's target
+ `triple` string, the target `chip` string (i.e. the `cpu` string), and
+ target `features` string as an attribute. The latter is optional.
+
+ Responds to DLTI-queries on the keys:
+ * A query for `"triple"` returns the `StringAttr` for the `triple`.
+ * A query for `"chip"` returns the `StringAttr` for the `chip`/`cpu`.
+ * A query for `"features"` returns the `StringAttr`, if provided.
+ }];
+ let parameters = (ins "StringAttr":$triple,
+ "StringAttr":$chip,
+ OptionalParameter<"TargetFeaturesAttr", "">:$features);
+
+ let assemblyFormat = [{`<` struct($triple, $chip, $features) `>`}];
+
+ let extraClassDeclaration = [{
+ FailureOr<Attribute> query(DataLayoutEntryKey key);
+ }];
+}
+
+//===----------------------------------------------------------------------===//
// UndefAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
index 3ede857..fafccf3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
@@ -15,7 +15,9 @@
#define MLIR_DIALECT_LLVMIR_LLVMATTRS_H_
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include <optional>
#include "mlir/Dialect/LLVMIR/LLVMOpsEnums.h.inc"
@@ -89,8 +91,8 @@ public:
// TODO: this shouldn't be needed after we unify the attribute generation, i.e.
// --gen-attr-* and --gen-attrdef-*.
using cconv::CConv;
-using tailcallkind::TailCallKind;
using linkage::Linkage;
+using tailcallkind::TailCallKind;
} // namespace LLVM
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
index 107bf3e..ab0462f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
@@ -27,6 +27,7 @@ def LLVM_Dialect : Dialect {
);
let extraClassDeclaration = [{
+ static StringRef getTargetAttrName() { return "llvm.target"; }
/// Name of the data layout attributes.
static StringRef getDataLayoutAttrName() { return "llvm.data_layout"; }
static StringRef getNoAliasScopesAttrName() { return "noalias_scopes"; }
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 138170f..60235bc 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -14,6 +14,7 @@
#define LLVMIR_INTERFACES
include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/DataLayoutInterfaces.td"
def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
let description = [{
@@ -532,4 +533,41 @@ def LLVM_DIRecursiveTypeAttrInterface
];
}
+def LLVM_TargetAttrInterface
+ : AttrInterface<"TargetAttrInterface", [DLTIQueryInterface]> {
+ let description = [{
+ Interface for attributes that describe LLVM targets.
+
+ These attributes should be able to return the specified target `triple`,
+ `chip` and `features`.
+
+ Implementing attributes should provide a `DLTIQueryInterface::query()`
+ implementation which responds to keys `"triple"`, `"chip"` and `"features"`
+ by returning appropriate `StringAttr`s.
+ }];
+ let cppNamespace = "::mlir::LLVM";
+ let methods = [
+ InterfaceMethod<
+ /*description=*/"Returns the target triple identifier.",
+ /*retTy=*/"StringAttr",
+ /*methodName=*/"getTriple",
+ /*args=*/(ins)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the target chip (i.e. \"cpu\") identifier.",
+ /*retTy=*/"StringAttr",
+ /*methodName=*/"getChip",
+ /*args=*/(ins)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the target features as a TargetFeaturesAttr.",
+ /*retTy=*/"Attribute", // NB: will be a LLVM::TargetFeaturesAttr, though
+ // need to work around a cyclic dependency on
+ // LLVMInterfaces.td and LLVMAttrDefs.td.
+ /*methodName=*/"getFeatures",
+ /*args=*/(ins)
+ >
+ ];
+}
+
#endif // LLVMIR_INTERFACES
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index d38298f..fa2e10c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -87,21 +87,21 @@ class LLVM_TernarySameArgsIntrOpF<string func, list<Trait> traits = []> :
class LLVM_CountZerosIntrOp<string func, list<Trait> traits = []> :
LLVM_OneResultIntrOp<func, [], [0],
!listconcat([Pure, SameOperandsAndResultType], traits),
- /*requiresFastmath=*/0,
+ /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0,
/*immArgPositions=*/[1], /*immArgAttrNames=*/["is_zero_poison"]> {
let arguments = (ins LLVM_ScalarOrVectorOf<AnySignlessInteger>:$in,
I1Attr:$is_zero_poison);
}
def LLVM_AbsOp : LLVM_OneResultIntrOp<"abs", [], [0], [Pure],
- /*requiresFastmath=*/0,
+ /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0,
/*immArgPositions=*/[1], /*immArgAttrNames=*/["is_int_min_poison"]> {
let arguments = (ins LLVM_ScalarOrVectorOf<AnySignlessInteger>:$in,
I1Attr:$is_int_min_poison);
}
def LLVM_IsFPClass : LLVM_OneResultIntrOp<"is.fpclass", [], [0], [Pure],
- /*requiresFastmath=*/0,
+ /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0,
/*immArgPositions=*/[1], /*immArgAttrNames=*/["bit"]> {
let arguments = (ins LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$in, I32Attr:$bit);
}
@@ -348,15 +348,11 @@ def LLVM_PtrMaskOp
// Memory marker intrinsics.
//
-/// Base operation for lifetime markers. The LLVM intrinsics require the size
-/// operand to be an immediate. In MLIR it is encoded as an attribute.
-class LLVM_LifetimeBaseOp<string opName> : LLVM_ZeroResultIntrOp<opName, [1],
- [DeclareOpInterfaceMethods<PromotableOpInterface>],
- /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
- /*requiresArgAndResultAttrs=*/0, /*requiresOpBundles=*/0,
- /*immArgPositions=*/[0], /*immArgAttrNames=*/["size"]> {
- let arguments = (ins I64Attr:$size, LLVM_AnyPointer:$ptr);
- let assemblyFormat = "$size `,` $ptr attr-dict `:` qualified(type($ptr))";
+/// Base operation for lifetime markers.
+class LLVM_LifetimeBaseOp<string opName> : LLVM_ZeroResultIntrOp<opName, [0],
+ [DeclareOpInterfaceMethods<PromotableOpInterface>]> {
+ let arguments = (ins LLVM_AnyPointer:$ptr);
+ let assemblyFormat = "$ptr attr-dict `:` qualified(type($ptr))";
}
def LLVM_LifetimeStartOp : LLVM_LifetimeBaseOp<"lifetime.start">;
@@ -364,8 +360,8 @@ def LLVM_LifetimeEndOp : LLVM_LifetimeBaseOp<"lifetime.end">;
def LLVM_InvariantStartOp : LLVM_OneResultIntrOp<"invariant.start", [], [1],
[DeclareOpInterfaceMethods<PromotableOpInterface>],
- /*requiresFastmath=*/0, /*immArgPositions=*/[0],
- /*immArgAttrNames=*/["size"]> {
+ /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0,
+ /*immArgPositions=*/[0], /*immArgAttrNames=*/["size"]> {
let arguments = (ins I64Attr:$size, LLVM_AnyPointer:$ptr);
let results = (outs LLVM_DefaultPointer:$res);
let assemblyFormat = "$size `,` $ptr attr-dict `:` qualified(type($ptr))";
@@ -416,6 +412,7 @@ class LLVM_ConstrainedIntr<string mnem, int numArgs,
!gt(hasRoundingMode, 0) : [DeclareOpInterfaceMethods<RoundingModeOpInterface>],
true : []),
/*requiresFastmath=*/0,
+ /*requiresArgAndResultAttrs=*/0,
/*immArgPositions=*/[],
/*immArgAttrNames=*/[]> {
dag regularArgs = !dag(ins, !listsplat(LLVM_Type, numArgs), !foreach(i, !range(numArgs), "arg_" #i));
@@ -593,7 +590,7 @@ def LLVM_ExpectOp
def LLVM_ExpectWithProbabilityOp
: LLVM_OneResultIntrOp<"expect.with.probability", [], [0],
[Pure, AllTypesMatch<["val", "expected", "res"]>],
- /*requiresFastmath=*/0,
+ /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0,
/*immArgPositions=*/[2], /*immArgAttrNames=*/["prob"]> {
let arguments = (ins AnySignlessInteger:$val,
AnySignlessInteger:$expected,
@@ -676,22 +673,6 @@ def LLVM_CoroPromiseOp : LLVM_IntrOp<"coro.promise", [], [], [], 1> {
class LLVM_DbgIntrOp<string name, string argName, list<Trait> traits = []>
: LLVM_IntrOp<name, [], [], traits, 0> {
- let llvmBuilder = [{
- // Debug intrinsics without debug locations are invalid.
- if(!builder.getCurrentDebugLocation())
- return success();
- llvm::Module *module = builder.GetInsertBlock()->getModule();
- llvm::LLVMContext &ctx = module->getContext();
- llvm::Function *fn =
- llvm::Intrinsic::getOrInsertDeclaration(module, llvm::Intrinsic::}]
- # !subst(".", "_", name) # [{);
- builder.CreateCall(fn, {
- llvm::MetadataAsValue::get(ctx,
- llvm::ValueAsMetadata::get(moduleTranslation.lookupValue(opInst.getOperand(0)))),
- llvm::MetadataAsValue::get(ctx, moduleTranslation.translateDebugInfo($varInfo)),
- llvm::MetadataAsValue::get(ctx, moduleTranslation.translateExpression($locationExpr)),
- });
- }];
let mlirBuilder = [{
// Add debug intrindic to the list of intrinsics that need to be converted once the
// full function was converted.
@@ -714,6 +695,22 @@ def LLVM_DbgDeclareOp : LLVM_DbgIntrOp<"dbg.declare", "addr", [
LLVM_DILocalVariableAttr:$varInfo,
DefaultValuedAttr<LLVM_DIExpressionAttr, "{}">:$locationExpr
);
+ let llvmBuilder = [{
+ // Debug records without debug locations are invalid.
+ if(!builder.getCurrentDebugLocation())
+ return success();
+ llvm::DILocalScope *scope = getLocalScopeFromLoc(builder, opInst.getLoc(),
+ moduleTranslation);
+
+ llvm::Module *module = builder.GetInsertBlock()->getModule();
+ llvm::DIBuilder debugInfoBuilder(*module);
+ debugInfoBuilder.insertDeclare(moduleTranslation.lookupValue(opInst.getOperand(0)),
+ llvm::cast<llvm::DILocalVariable>(
+ moduleTranslation.translateDebugInfo($varInfo)),
+ moduleTranslation.translateExpression($locationExpr),
+ moduleTranslation.translateLoc(opInst.getLoc(), scope),
+ builder.GetInsertPoint());
+ }];
}
def LLVM_DbgValueOp : LLVM_DbgIntrOp<"dbg.value", "value",
@@ -724,22 +721,41 @@ def LLVM_DbgValueOp : LLVM_DbgIntrOp<"dbg.value", "value",
LLVM_DILocalVariableAttr:$varInfo,
DefaultValuedAttr<LLVM_DIExpressionAttr, "{}">:$locationExpr
);
+ let llvmBuilder = [{
+ // Debug records without debug locations are invalid.
+ if(!builder.getCurrentDebugLocation())
+ return success();
+ llvm::DILocalScope *scope = getLocalScopeFromLoc(builder, opInst.getLoc(),
+ moduleTranslation);
+
+ llvm::Module *module = builder.GetInsertBlock()->getModule();
+ llvm::DIBuilder debugInfoBuilder(*module);
+ debugInfoBuilder.insertDbgValueIntrinsic(
+ moduleTranslation.lookupValue(opInst.getOperand(0)),
+ llvm::cast<llvm::DILocalVariable>(
+ moduleTranslation.translateDebugInfo($varInfo)),
+ moduleTranslation.translateExpression($locationExpr),
+ moduleTranslation.translateLoc(opInst.getLoc(), scope),
+ builder.GetInsertPoint());
+ }];
}
def LLVM_DbgLabelOp : LLVM_IntrOp<"dbg.label", [], [], [], 0> {
let summary = "Relates the program to a debug information label.";
let arguments = (ins LLVM_DILabelAttr:$label);
let llvmBuilder = [{
- // Debug intrinsics without debug locations are invalid.
+ // Debug records without debug locations are invalid.
if(!builder.getCurrentDebugLocation())
return success();
+ llvm::DILocalScope *scope = getLocalScopeFromLoc(builder, opInst.getLoc(),
+ moduleTranslation);
+
llvm::Module *module = builder.GetInsertBlock()->getModule();
- llvm::LLVMContext &ctx = module->getContext();
- llvm::Function *fn =
- llvm::Intrinsic::getOrInsertDeclaration(module, llvm::Intrinsic::dbg_label);
- builder.CreateCall(fn, {
- llvm::MetadataAsValue::get(ctx, moduleTranslation.translateDebugInfo($label))
- });
+ llvm::DIBuilder debugInfoBuilder(*module);
+ debugInfoBuilder.insertLabel(
+ llvm::cast<llvm::DILabel>(moduleTranslation.translateDebugInfo($label)),
+ moduleTranslation.translateLoc(opInst.getLoc(), scope),
+ builder.GetInsertPoint());
}];
let mlirBuilder = [{
DILabelAttr labelAttr = $_label_attr($label);
@@ -829,7 +845,7 @@ class LLVM_VecReductionAccBase<string mnem, Type element>
/*overloadedResults=*/[],
/*overloadedOperands=*/[1],
/*traits=*/[Pure, SameOperandsAndResultElementType],
- /*equiresFastmath=*/1>,
+ /*requiresFastmath=*/1>,
Arguments<(ins element:$start_value,
LLVM_VectorOf<element>:$input,
DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags)>;
@@ -1073,14 +1089,36 @@ def LLVM_masked_scatter : LLVM_ZeroResultIntrOp<"masked.scatter"> {
}
/// Create a call to Masked Expand Load intrinsic.
-def LLVM_masked_expandload : LLVM_IntrOp<"masked.expandload", [0], [], [], 1> {
- let arguments = (ins LLVM_AnyPointer, LLVM_VectorOf<I1>, LLVM_AnyVector);
+def LLVM_masked_expandload
+ : LLVM_OneResultIntrOp<"masked.expandload", [0], [],
+ /*traits=*/[], /*requiresFastMath=*/0, /*requiresArgAndResultAttrs=*/1,
+ /*immArgPositions=*/[], /*immArgAttrNames=*/[]> {
+ dag args = (ins LLVM_AnyPointer:$ptr,
+ LLVM_VectorOf<I1>:$mask,
+ LLVM_AnyVector:$passthru);
+
+ let arguments = !con(args, baseArgs);
+
+ let builders = [
+ OpBuilder<(ins "TypeRange":$resTy, "Value":$ptr, "Value":$mask, "Value":$passthru, CArg<"uint64_t", "1">:$align)>
+ ];
}
/// Create a call to Masked Compress Store intrinsic.
def LLVM_masked_compressstore
- : LLVM_IntrOp<"masked.compressstore", [], [0], [], 0> {
- let arguments = (ins LLVM_AnyVector, LLVM_AnyPointer, LLVM_VectorOf<I1>);
+ : LLVM_ZeroResultIntrOp<"masked.compressstore", [0],
+ /*traits=*/[], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
+ /*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0,
+ /*immArgPositions=*/[], /*immArgAttrNames=*/[]> {
+ dag args = (ins LLVM_AnyVector:$value,
+ LLVM_AnyPointer:$ptr,
+ LLVM_VectorOf<I1>:$mask);
+
+ let arguments = !con(args, baseArgs);
+
+ let builders = [
+ OpBuilder<(ins "Value":$value, "Value":$ptr, "Value":$mask, CArg<"uint64_t", "1">:$align)>
+ ];
}
//
@@ -1159,7 +1197,7 @@ def LLVM_vector_insert
PredOpTrait<"it is not inserting scalable into fixed-length vectors.",
CPred<"!isScalableVectorType($srcvec.getType()) || "
"isScalableVectorType($dstvec.getType())">>],
- /*requiresFastmath=*/0,
+ /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0,
/*immArgPositions=*/[2], /*immArgAttrNames=*/["pos"]> {
let arguments = (ins LLVM_AnyVector:$dstvec, LLVM_AnyVector:$srcvec,
I64Attr:$pos);
@@ -1193,7 +1231,7 @@ def LLVM_vector_extract
PredOpTrait<"it is not extracting scalable from fixed-length vectors.",
CPred<"!isScalableVectorType($res.getType()) || "
"isScalableVectorType($srcvec.getType())">>],
- /*requiresFastmath=*/0,
+ /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0,
/*immArgPositions=*/[1], /*immArgAttrNames=*/["pos"]> {
let arguments = (ins LLVM_AnyVector:$srcvec, I64Attr:$pos);
let results = (outs LLVM_AnyVector:$res);
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index a8d7cf2..d6aa958 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -475,11 +475,12 @@ class LLVM_OneResultIntrOp<string mnem, list<int> overloadedResults = [],
list<int> overloadedOperands = [],
list<Trait> traits = [],
bit requiresFastmath = 0,
+ bit requiresArgAndResultAttrs = 0,
list<int> immArgPositions = [],
list<string> immArgAttrNames = []>
: LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1,
/*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
- requiresFastmath, /*requiresArgAndResultAttrs=*/0,
+ requiresFastmath, requiresArgAndResultAttrs,
/*requiresOpBundles=*/0, immArgPositions,
immArgAttrNames>;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 3f27f6d..9753dca 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1313,7 +1313,8 @@ def LLVM_GlobalOp : LLVM_Op<"mlir.global",
OptionalAttr<StrAttr>:$section,
OptionalAttr<SymbolRefAttr>:$comdat,
OptionalAttr<DIGlobalVariableExpressionArrayAttr>:$dbg_exprs,
- DefaultValuedAttr<Visibility, "mlir::LLVM::Visibility::Default">:$visibility_
+ DefaultValuedAttr<Visibility, "mlir::LLVM::Visibility::Default">:$visibility_,
+ OptionalAttr<ArrayAttr>:$target_specific_attrs
);
let summary = "LLVM dialect global.";
let description = [{
@@ -1411,6 +1412,21 @@ def LLVM_GlobalOp : LLVM_Op<"mlir.global",
llvm.mlir.global private constant @y(dense<1.0> : tensor<8xf32>) { alignment = 32 : i64 } : !llvm.array<8 x f32>
```
+ The `target_specific_attrs` attribute provides a mechanism to preserve
+ target-specific LLVM IR attributes that are not explicitly modeled in the
+ LLVM dialect.
+
+ The attribute is an array containing either string attributes or
+ two-element array attributes of strings. The value of a standalone string
+ attribute is interpreted as the name of an LLVM IR attribute on the global.
+ A two-element array is interpreted as a key-value pair.
+
+ Example:
+
+ ```mlir
+ llvm.mlir.global external @example() {
+ target_specific_attrs = ["value-less-attr", ["int-attr", "4"], ["string-attr", "string"]]} : f64
+ ```
}];
let regions = (region AnyRegion:$initializer);
@@ -1961,7 +1977,6 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
OptionalAttr<BoolAttr>:$unsafe_fp_math,
OptionalAttr<BoolAttr>:$no_infs_fp_math,
OptionalAttr<BoolAttr>:$no_nans_fp_math,
- OptionalAttr<BoolAttr>:$approx_func_fp_math,
OptionalAttr<BoolAttr>:$no_signed_zeros_fp_math,
OptionalAttr<StrAttr>:$denormal_fp_math,
OptionalAttr<StrAttr>:$denormal_fp_math_f32,
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 17561f7..a150649 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -28,6 +28,7 @@ namespace mlir {
class AsmParser;
class AsmPrinter;
+class DataLayout;
namespace LLVM {
class LLVMDialect;
@@ -111,6 +112,15 @@ bool isCompatibleFloatingPointType(Type type);
/// dialect pointers and LLVM dialect scalable vector types.
bool isCompatibleVectorType(Type type);
+/// Returns `true` if the given type is a loadable type compatible with the LLVM
+/// dialect.
+bool isLoadableType(Type type);
+
+/// Returns true if the given type is supported by atomic operations. All
+/// integer, float, and pointer types with a power-of-two bitsize and a minimal
+/// size of 8 bits are supported.
+bool isTypeCompatibleWithAtomicOp(Type type, const DataLayout &dataLayout);
+
/// Returns the element count of any LLVM-compatible vector type.
llvm::ElementCount getVectorNumElements(Type type);
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 30df3b7..8537c70 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -25,6 +25,7 @@ include "mlir/Dialect/LLVMIR/LLVMTypes.td"
def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>;
+def LLVM_PointerConst : LLVM_PointerInAddressSpace<4>;
def LLVM_PointerLocal : LLVM_PointerInAddressSpace<5>;
def LLVM_PointerTensor : LLVM_PointerInAddressSpace<6>;
def LLVM_PointerSharedCluster : LLVM_PointerInAddressSpace<7>;
@@ -83,6 +84,15 @@ def NVVM_Dialect : Dialect {
/// are grid constants.
static StringRef getGridConstantAttrName() { return "nvvm.grid_constant"; }
+ /// Get the name of the attribute used to annotate the `.blocksareclusters`
+ /// PTX directive for kernel functions.
+ /// This attribute implies that the grid launch configuration for the
+ /// corresponding kernel function is specifying the number of clusters
+ /// instead of the number of thread blocks. This attribute is only
+ /// allowed for kernel functions and requires nvvm.reqntid and
+ /// nvvm.cluster_dim attributes.
+ static StringRef getBlocksAreClustersAttrName() { return "nvvm.blocksareclusters"; }
+
/// Verify an attribute from this dialect on the argument at 'argIndex' for
/// the region at 'regionIndex' on the given operation. Returns failure if
/// the verification failed, success otherwise. This hook may optionally be
@@ -258,6 +268,7 @@ def NVVM_ClusterDim : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster
def NVVM_ClockOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock">;
def NVVM_Clock64Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock64">;
def NVVM_GlobalTimerOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer">;
+def NVVM_GlobalTimerLoOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer.lo">;
//===----------------------------------------------------------------------===//
// envreg registers
@@ -315,16 +326,19 @@ def NVVM_InlinePtxOp : NVVM_Op<"inline_ptx",
}];
let arguments = (ins Variadic<AnyType>:$readOnlyArgs,
+ Variadic<AnyType>:$readWriteArgs,
StrAttr:$ptxCode,
PtxPredicate:$predicate);
let results = (outs Variadic<AnyType>:$writeOnlyArgs);
-
- let assemblyFormat = [{
- $ptxCode `(` $readOnlyArgs `)`
- (`,` `predicate` `=` $predicate^)? attr-dict
- `:` type(operands)
- (`->` type($writeOnlyArgs)^)?
+
+ let assemblyFormat = [{
+ $ptxCode
+ ( `ro` `(` $readOnlyArgs^ `:` type($readOnlyArgs) `)` )?
+ ( `rw` `(` $readWriteArgs^ `:` type($readWriteArgs) `)` )?
+ (`,` `predicate` `=` $predicate^)?
+ attr-dict
+ ( `->` type($writeOnlyArgs)^ )?
}];
let extraClassDefinition = [{
@@ -333,6 +347,10 @@ def NVVM_InlinePtxOp : NVVM_Op<"inline_ptx",
return std::string(ptxInstStr.data());
}
}];
+
+ let extraClassDeclaration = [{
+ bool getAsmValues(RewriterBase &, llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &);
+ }];
}
//===----------------------------------------------------------------------===//
@@ -402,6 +420,74 @@ def NVVM_ReduxOp :
}
//===----------------------------------------------------------------------===//
+// NVVM nanosleep
+//===----------------------------------------------------------------------===//
+
+def NVVM_NanosleepOp : NVVM_Op<"nanosleep">,
+ Arguments<(ins
+ ConfinedAttr<I32Attr, [IntMinValue<1>, IntMaxValue<1000000>]>:$duration)>
+{
+ let summary = "Suspends the thread for a specified duration.";
+
+ let description = [{
+ The op suspends the thread for a sleep duration approximately close to the
+ delay `$duration`, specified in nanoseconds.
+
+ The sleep duration is approximated, but guaranteed to be in the
+ interval [0, 2*t]. The maximum sleep duration is 1 millisecond.
+ The implementation may reduce the sleep duration for individual threads
+ within a warp such that all sleeping threads in the warp wake up together.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-nanosleep)
+ }];
+
+ string llvmBuilder = [{
+ createIntrinsicCall(builder,
+ llvm::Intrinsic::nvvm_nanosleep,
+ {builder.getInt32($duration)});
+ }];
+ let assemblyFormat = "attr-dict $duration";
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM Performance Monitor events
+//===----------------------------------------------------------------------===//
+
+def NVVM_PMEventOp : NVVM_PTXBuilder_Op<"pmevent">,
+ Arguments<(ins OptionalAttr<I16Attr>:$maskedEventId,
+ OptionalAttr<I32Attr>:$eventId)> {
+ let summary = "Trigger one or more Performance Monitor events.";
+
+ let description = [{
+ Triggers one or more of a fixed number of performance monitor events, with
+ event index or mask specified by immediate operand.
+
+ Without `mask` it triggers a single performance monitor event indexed by
+ immediate operand a, in the range 0..15.
+
+ With `mask` it triggers one or more of the performance monitor events. Each
+ bit in the 16-bit immediate operand controls an event.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-pmevent)
+ }];
+
+ string llvmBuilder = [{
+ llvm::Value *mId = builder.getInt16(* $maskedEventId);
+ createIntrinsicCall(builder, llvm::Intrinsic::nvvm_pm_event_mask, {mId});
+ }];
+
+ let assemblyFormat = "attr-dict (`id` `=` $eventId^)? (`mask` `=` $maskedEventId^)?";
+
+ let extraClassDeclaration = [{
+ bool hasIntrinsic() { return !getEventId(); }
+ }];
+ let extraClassDefinition = [{
+ std::string $cppClass::getPtx() { return std::string("pmevent %0;"); }
+ }];
+ let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
// NVVM Split arrive/wait barrier
//===----------------------------------------------------------------------===//
@@ -2032,13 +2118,16 @@ def NVVM_StMatrixOp: NVVM_Op<"stmatrix">,
def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
Results<(outs AnyType:$res)>,
- Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> {
+ Arguments<(ins LLVM_PointerShared:$ptr, I32Attr:$num,
+ MMALayoutAttr:$layout,
+ LdStMatrixShapeAttr:$shape,
+ LdStMatrixEltTypeAttr:$eltType)> {
let summary = "cooperative matrix load";
string llvmBuilder = [{
auto operands = moduleTranslation.lookupValues(opInst.getOperands());
- auto intId = getLdMatrixIntrinsicId($layout, $num);
+ auto intId = getLdMatrixIntrinsicId($layout, $num, $shape, $eltType);
$res = createIntrinsicCall(builder, intId, operands, {operands[0]->getType()});
}];
@@ -2215,6 +2304,70 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
// NVVM TMA Ops
//===----------------------------------------------------------------------===//
+// List of modes supported for TMA Load and Prefetch Ops
+def TMALoadModeTile : I32EnumAttrCase<"TILE", 0, "tile">;
+def TMALoadModeIm2Col : I32EnumAttrCase<"IM2COL", 1, "im2col">;
+def TMALoadModeIm2ColW : I32EnumAttrCase<"IM2COL_W", 2, "im2col_w">;
+def TMALoadModeIm2ColW128 : I32EnumAttrCase<"IM2COL_W_128", 3, "im2col_w_128">;
+def TMALoadModeTileGather4 : I32EnumAttrCase<"TILE_GATHER4", 4, "tile_gather4">;
+
+def TMALoadMode : I32EnumAttr<"TMALoadMode", "NVVM TMA Load Mode",
+ [TMALoadModeTile, TMALoadModeIm2Col,
+ TMALoadModeIm2ColW, TMALoadModeIm2ColW128,
+ TMALoadModeTileGather4]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def TMALoadModeAttr : EnumAttr<NVVM_Dialect, TMALoadMode, "tma_load_mode"> {
+ let summary = "List of Load-Modes supported for TMA Tensor Ops";
+ let description = [{
+ TMA Tensor Ops support the following modes, when copying data from
+ global memory to shared memory (i.e. load):
+
+ Tile Mode: It's the default mode. The source multi-dimensional tensor
+ layout is preserved at the destination.
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-tiled-mode)
+
+ Im2col Mode: This mode is used when `im2colOffsets` operands are present.
+ The elements in the Bounding Box of the source tensor are rearranged into
+ columns at the destination. In this mode, the tensor has to be at least
+ 3-dimensional. The number of `im2colOffsets` is `dims - 2` where `dims`
+ is the dimension of the tensor.
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-mode)
+
+ Im2col_W Mode: This mode is similar to Im2Col mode with the restriction that
+ elements are accessed across the W dimension only. The number of `im2colOffsets`
+ are always two, referred as `wHalo` and `wOffset`.
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-w-w128-modes)
+
+ Im2col_W_128 Mode: This mode is similar to Im2Col_W mode with the number of
+ elements accessed across the W dimension is always 128 only.
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-w-w128-modes)
+
+ Tile_Gather4 Mode: This mode is similar to Tile mode but works only on 2D tensor.
+ In gather4 mode, four rows in the source 2D tensor are combined to form a single
+ 2D tensor at the destination. This mode requires five co-ordinates. The first one
+ represents the column-index followed by four row indices.
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-tiled-scatter4-gather4-modes)
+ }];
+
+ let assemblyFormat = "`<` $value `>`";
+}
+
+// List of modes supported for TMA Store and Reduction Ops
+def TMAStoreModeTile : I32EnumAttrCase<"TILE", 0, "tile">;
+def TMAStoreModeIm2Col : I32EnumAttrCase<"IM2COL", 1, "im2col">;
+def TMAStoreModeTileScatter4 : I32EnumAttrCase<"TILE_SCATTER4", 2, "tile_scatter4">;
+
+def TMAStoreMode : I32EnumAttr<"TMAStoreMode", "NVVM TMA Store Mode",
+ [TMAStoreModeTile, TMAStoreModeIm2Col, TMAStoreModeTileScatter4]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def TMAStoreModeAttr : EnumAttr<NVVM_Dialect, TMAStoreMode, "tma_store_mode"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
Arguments<(ins )> {
let assemblyFormat = "attr-dict";
@@ -2341,20 +2494,43 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
}
def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp :
- NVVM_Op<"cp.async.bulk.tensor.global.shared.cta",
- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
+ NVVM_PTXBuilder_Op<"cp.async.bulk.tensor.global.shared.cta",
+ [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
AttrSizedOperandSegments]>,
- Arguments<(ins LLVM_AnyPointer:$tmaDescriptor,
- LLVM_PointerShared:$srcMem,
- Variadic<I32>:$coordinates,
- PtxPredicate:$predicate)> {
+ Arguments<(ins LLVM_PointerGeneric:$tmaDescriptor,
+ LLVM_PointerShared:$srcMem,
+ Variadic<I32>:$coordinates,
+ Optional<I64>:$l2CacheHint,
+ DefaultValuedAttr<TMAStoreModeAttr, "TMAStoreMode::TILE">:$mode,
+ PtxPredicate:$predicate)> {
+ let description = [{
+ Initiates an asynchronous copy of the tensor data from shared::cta
+ memory to global memory. This Op supports all the store modes specified in
+ `TMAStoreMode`.
+
+ The `l2CacheHint` operand is optional, and it is used to specify cache
+ eviction policy that may be used during the memory access.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor)
+ }];
+
let assemblyFormat = [{
$tmaDescriptor `,`
$srcMem `,`
`box` `[`$coordinates `]`
- (`,` `predicate` `=` $predicate^)?
- attr-dict `:` type(operands)
+ (`l2_cache_hint` `=` $l2CacheHint^ )?
+ (`,` `predicate` `=` $predicate^)?
+ attr-dict `:` type($tmaDescriptor) `,` type($srcMem)
}];
+
+ let extraClassDeclaration = [{
+ bool hasIntrinsic() { return !getPredicate(); }
+
+ static mlir::NVVM::IDArgPair
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase& builder);
+ }];
+
let extraClassDefinition = [{
std::string $cppClass::getPtx() {
int dim = getCoordinates().size();
@@ -2370,6 +2546,12 @@ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp :
}
}];
let hasVerifier = 1;
+
+ string llvmBuilder = [{
+ auto [id, args] = NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
+ *op, moduleTranslation, builder);
+ createIntrinsicCall(builder, id, args);
+ }];
}
//===----------------------------------------------------------------------===//
@@ -2389,15 +2571,25 @@ def PrefetchCacheLevelAttr : EnumAttr<NVVM_Dialect, PrefetchCacheLevel, "prefetc
let assemblyFormat = "$value";
}
-def NVVM_PrefetchOp : NVVM_Op<"prefetch"> {
+def NVVM_PrefetchOp : NVVM_Op<"prefetch",
+ [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> {
let summary = "Brings the cache line containing an address into the specified cache level";
let description = [{
- Operand `addr` can be a global, local or generic address pointer. No
- operation is performed if `addr` maps to a `shared` memory location.
+ Prefetches the cache line containing the address given by `addr`. The
+ operand may be a global, local, or generic pointer. When `tensormap` is
+ specified, the operand may instead be a constant or generic pointer. If the
+ address maps to shared memory, the operation has no effect.
+
+ At most one of `cacheLevel` or `tensormap` may be present. The `cacheLevel`
+ attribute selects the target cache level. When combined with `uniform`, the
+ prefetch is performed to the uniform cache, in which case `addr` must be a
+ generic pointer.
+
+ When `tensormap` is used, the line containing `addr` is brought from the
+ constant or parameter state space for later use by `cp.async.bulk.tensor`.
+ If `in_param_space` is specified, the generic pointer is interpreted as
+ referring to the parameter state space.
- The `cacheLevel` attribute specifies the cache level to which the cache line
- containing the specified address is brought.
-
`uniform` can be specified after the `cacheLevel` to indicate that the
prefetch is performed to the specified uniform cache level. If `uniform` is
specified, `addr` must be a generic address pointer and no operation is
@@ -2408,33 +2600,41 @@ def NVVM_PrefetchOp : NVVM_Op<"prefetch"> {
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu)
}];
- let arguments = (ins PrefetchCacheLevelAttr:$cacheLevel,
- UnitAttr:$uniform,
+ let arguments = (ins OptionalAttr<PrefetchCacheLevelAttr>:$cacheLevel,
+ OptionalAttr<CacheEvictionPriorityAttr>:$evictPriority,
AnyTypeOf<[LLVM_PointerGlobal,
LLVM_PointerLocal,
- LLVM_PointerGeneric]>:$addr,
- OptionalAttr<CacheEvictionPriorityAttr>:$evictPriority);
- let assemblyFormat = "`level` `=` $cacheLevel (`uniform` $uniform^)? `,` $addr (`,` `evict_priority` `=` $evictPriority^)? attr-dict `:` type($addr)";
+ LLVM_PointerGeneric,
+ LLVM_PointerConst]>:$addr,
+ PtxPredicate:$predicate,
+ UnitAttr:$tensormap,
+ UnitAttr:$uniform,
+ UnitAttr:$in_param_space);
+ let assemblyFormat = "(`level` `=` $cacheLevel^ (`uniform` $uniform^)? `,`)? (`tensormap` $tensormap^ (`in_param_space` $in_param_space^)? `,`)? (`evict_priority` `=` $evictPriority^ `,`)? $addr (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
let hasVerifier = 1;
let extraClassDeclaration = [{
- static llvm::Intrinsic::ID getIntrinsicID(NVVM::PrefetchOp &op);
- }];
- let llvmBuilder = [{
- auto intId = NVVM::PrefetchOp::getIntrinsicID(op);
- createIntrinsicCall(builder, intId, $addr);
+ static NVVM::IDArgPair
+ getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder);
+ bool hasIntrinsic() { return !getPredicate() || !getTensormap(); }
}];
-}
-
-def NVVM_PrefetchTensorMapOp : NVVM_Op<"prefetch.tensormap",
- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
- Arguments<(ins LLVM_AnyPointer:$tmaDescriptor, PtxPredicate:$predicate)> {
- let assemblyFormat = "$tmaDescriptor (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
let extraClassDefinition = [{
- std::string $cppClass::getPtx() {
+ std::string $cppClass::getPtx() {
+ // Inline PTX is only supported for prefetch tensormap
return std::string("prefetch.tensormap [%0];");
}
}];
+ let llvmBuilder = [{
+ auto [id, args] = NVVM::PrefetchOp::getIntrinsicIDAndArgs(op,
+ moduleTranslation, builder);
+
+ if(op.getTensormap())
+ // Overloaded intrinsic
+ createIntrinsicCall(builder, id, args, {args[0]->getType()});
+ else
+ createIntrinsicCall(builder, id, args);
+ }];
}
def NVVM_CpAsyncBulkPrefetchOp : NVVM_Op<"cp.async.bulk.prefetch"> {
@@ -2483,23 +2683,16 @@ def NVVM_CpAsyncBulkPrefetchOp : NVVM_Op<"cp.async.bulk.prefetch"> {
def NVVM_CpAsyncBulkTensorPrefetchOp :
NVVM_Op<"cp.async.bulk.tensor.prefetch", [AttrSizedOperandSegments]> {
let arguments = (ins
- LLVM_AnyPointer:$tmaDescriptor,
+ LLVM_PointerGeneric:$tmaDescriptor,
Variadic<I32>:$coordinates,
Variadic<I16>:$im2colOffsets,
+ DefaultValuedAttr<TMALoadModeAttr, "TMALoadMode::TILE">:$mode,
Optional<I64>:$l2CacheHint);
let description = [{
Initiates an asynchronous prefetch operation on the tensor data from global
- memory to L2 cache.
-
- The Op has two modes:
- 1) Tiled Mode: It's the default mode. The source multi-dimensional tensor
- layout is preserved at the destination.
-
- 2) Im2col Mode: This mode is used when `im2colOffsets` operands are present.
- the elements in the Bounding Box of the source tensor are rearranged into
- columns at the destination. In this mode, the tensor has to be at least
- 3-dimensional.
+ memory to L2 cache. This Op supports all the load modes specified in
+ `TMALoadMode`.
The `l2CacheHint` operand is optional, and it is used to specify cache
eviction policy that may be used during the memory access.
@@ -2516,50 +2709,20 @@ def NVVM_CpAsyncBulkTensorPrefetchOp :
}];
let extraClassDeclaration = [{
- static llvm::Intrinsic::ID getIntrinsicID(int tensorDims, bool isIm2Col);
+ static mlir::NVVM::IDArgPair
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase& builder);
}];
let hasVerifier = 1;
string llvmBuilder = [{
- // Arguments to the intrinsic:
- // tmaDesc, tensorDims, im2colOffsets
- // cache_hint(if applicable) and flag(boolean)
- llvm::SmallVector<llvm::Value *> translatedOperands;
- translatedOperands.push_back($tmaDescriptor);
-
- for (auto v : op.getCoordinates())
- translatedOperands.push_back(moduleTranslation.lookupValue(v));
-
- for (auto v : op.getIm2colOffsets())
- translatedOperands.push_back(moduleTranslation.lookupValue(v));
-
- llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
- auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
-
- bool isCacheHint = op.getL2CacheHint() ? true : false;
- translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);
- translatedOperands.push_back(builder.getInt1(isCacheHint));
-
- auto intId = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicID(
- op.getCoordinates().size(), op.getIm2colOffsets().size() > 0);
- createIntrinsicCall(builder, intId, translatedOperands);
+ auto [id, args] = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
+ *op, moduleTranslation, builder);
+ createIntrinsicCall(builder, id, args);
}];
}
-// List of modes supported for TMA Store and Reduction Ops
-def TMAStoreModeTile : I32EnumAttrCase<"TILE", 0, "tile">;
-def TMAStoreModeIm2Col : I32EnumAttrCase<"IM2COL", 1, "im2col">;
-
-def TMAStoreMode : I32EnumAttr<"TMAStoreMode", "NVVM TMA Store Mode",
- [TMAStoreModeTile, TMAStoreModeIm2Col]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "::mlir::NVVM";
-}
-def TMAStoreModeAttr : EnumAttr<NVVM_Dialect, TMAStoreMode, "tma_store_mode"> {
- let assemblyFormat = "`<` $value `>`";
-}
-
// List of Reduction Ops supported with TMA Store
def TMAReduxKindAdd : I32EnumAttrCase<"ADD", 0, "add">;
def TMAReduxKindMin : I32EnumAttrCase<"MIN", 1, "min">;
@@ -2986,8 +3149,7 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
let hasVerifier = 1;
let extraClassDeclaration = [{
- void getAsmValues(RewriterBase &rewriter,
- llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues);
+ bool getAsmValues(RewriterBase &, llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &);
}];
}
@@ -2995,30 +3157,46 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
// NVVM Griddepcontrol Ops
//===----------------------------------------------------------------------===//
-def NVVM_GriddepcontrolWaitOp : NVVM_IntrOp<"griddepcontrol.wait", [], 0> {
- let assemblyFormat = "attr-dict";
+def GridDepActionWait : I32EnumCase<"wait", 0>;
+def GridDepActionLaunchDependent : I32EnumCase<"launch_dependents", 1>;
+
+def GridDepActionKind : I32Enum<"GridDepActionKind", "Action kind for grid dependency control",
+ [GridDepActionWait, GridDepActionLaunchDependent]> {
+ let cppNamespace = "::mlir::NVVM";
+}
+
+def GridDepActionAttr : EnumAttr<NVVM_Dialect, GridDepActionKind, "grid_dep_action">;
+def NVVM_GriddepcontrolOp : NVVM_Op<"griddepcontrol", []> {
let description = [{
- Causes the executing thread to wait until all prerequisite grids in flight
+ If the $kind attribute is set to `wait`, it causes the
+ executing thread to wait until all prerequisite grids in flight
have completed and all the memory operations from the prerequisite grids
are performed and made visible to the current grid.
+ When the $kind is launch_dependents, it signals that specific dependents
+ the runtime system designated to react to this instruction can be scheduled
+ as soon as all other CTAs in the grid issue the same instruction or have
+ completed.
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-griddepcontrol)
}];
-}
-
-def NVVM_GriddepcontrolLaunchDependentsOp
- : NVVM_IntrOp<"griddepcontrol.launch.dependents", [], 0> {
- let assemblyFormat = "attr-dict";
- let description = [{
- Signals that specific dependents the runtime system designated to react to
- this instruction can be scheduled as soon as all other CTAs in the grid
- issue the same instruction or have completed.
+ let arguments = (ins GridDepActionAttr:$kind);
+ let assemblyFormat = "$kind attr-dict";
- [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-griddepcontrol)
+ string llvmBuilder = [{
+ llvm::Intrinsic::ID id;
+ switch ($kind) {
+ case NVVM::GridDepActionKind::wait:
+ id = llvm::Intrinsic::nvvm_griddepcontrol_wait;
+ break;
+ case NVVM::GridDepActionKind::launch_dependents:
+ id = llvm::Intrinsic::nvvm_griddepcontrol_launch_dependents;
+ break;
+ }
+ createIntrinsicCall(builder, id);
}];
}
@@ -3027,9 +3205,10 @@ def NVVM_GriddepcontrolLaunchDependentsOp
//===----------------------------------------------------------------------===//
def NVVM_MapaOp: NVVM_Op<"mapa",
- [TypesMatchWith<"`res` and `a` should have the same type",
- "a", "res", "$_self">, NVVMRequiresSM<90>]> {
- let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$res);
+ [InputAddressIsCombinationOf<["a", "res"],
+ [[LLVM_PointerShared, LLVM_PointerSharedCluster], [LLVM_PointerGeneric, LLVM_PointerGeneric]],
+ "Valid address-space check(or mapping) for mapa Op">, NVVMRequiresSM<90>]> {
+ let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerSharedCluster]>:$res);
let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b);
string llvmBuilder = [{
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index a2354e2..9fa3ec1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -93,19 +93,22 @@ class ROCDL_IntrPure1Op<string mnemonic> :
class ROCDL_IntrOp<string mnemonic, list<int> overloadedResults,
list<int> overloadedOperands, list<Trait> traits, int numResults,
- int requiresAccessGroup = 0, int requiresAliasAnalysis = 0, list<int> immArgPositions = [],
+ int requiresAccessGroup = 0, int requiresAliasAnalysis = 0,
+ int requiresArgAndResultAttrs = 0,
+ list<int> immArgPositions = [],
list<string> immArgAttrNames = []> :
LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
"amdgcn_" # !subst(".", "_", mnemonic), overloadedResults,
overloadedOperands, traits, numResults, requiresAccessGroup,
- requiresAliasAnalysis, 0, 0, 0, immArgPositions, immArgAttrNames>;
+ requiresAliasAnalysis, 0, requiresArgAndResultAttrs, 0,
+ immArgPositions, immArgAttrNames>;
// Subclass to save typing and ease readibility when there aren't overloaded
// operands or memory accesses.
class ROCDL_ConcreteNonMemIntrOp<string mnemonic, list<Trait> traits,
int numResults, list<int> immArgPositions = [],
list<string> immArgNames = []>
- : ROCDL_IntrOp<mnemonic, [], [], traits, numResults, 0, 0,
+ : ROCDL_IntrOp<mnemonic, [], [], traits, numResults, 0, 0, 0,
immArgPositions, immArgNames>;
//===----------------------------------------------------------------------===//
// ROCDL special register op definitions
@@ -148,8 +151,11 @@ class ROCDL_DimGetterFunctionOp<string mnemonic, string device_function,
//===----------------------------------------------------------------------===//
class ROCDL_MbcntOp<string mnemonic> :
- ROCDL_IntrPure1Op<"mbcnt." # mnemonic>,
- Arguments<(ins I32:$in0, I32:$in1)> {
+ ROCDL_IntrOp<"mbcnt." # mnemonic, [], [], [Pure], 1,
+ 0, 0, /*requiresArgAndResultAttrs=*/1> {
+ dag args = (ins I32:$in0, I32:$in1);
+ let arguments = !con(args, baseArgs);
+ let results = (outs I32:$res);
let assemblyFormat = [{
$in0 `,` $in1 attr-dict `:` `(` type($in0) `,` type($in1) `)` `->` type($res)
}];
@@ -189,6 +195,20 @@ def ROCDL_BallotOp :
let assemblyFormat = "$pred attr-dict `:` type($res)";
}
+def ROCDL_ReadfirstlaneOp : ROCDL_IntrOp<"readfirstlane", [], [0], [AllTypesMatch<["res", "src"]>], 1>,
+ Arguments<(ins LLVM_Type:$src)> {
+ let results = (outs LLVM_Type:$res);
+ let summary = "Get the value in first active lane.";
+
+ let description = [{
+ Returns the value in the lowest active lane of the input operand.
+ }];
+
+ let assemblyFormat = [{
+ $src attr-dict `:` type($res)
+ }];
+}
+
def ROCDL_ReadlaneOp : ROCDL_IntrOp<"readlane", [], [0], [AllTypesMatch<["res", "src0"]>], 1>,
Arguments<(ins LLVM_Type:$src0,
I32:$src1)> {
@@ -201,7 +221,7 @@ def ROCDL_ReadlaneOp : ROCDL_IntrOp<"readlane", [], [0], [AllTypesMatch<["res",
let assemblyFormat = [{
$src0 `,` $src1 attr-dict `:` `(` type($src0) `,` type($src1) `)` `->` type($res)
- }];
+ }];
}
//===----------------------------------------------------------------------===//
@@ -501,7 +521,7 @@ def ROCDL_ds_read_tr16_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr16.b64">;
//===---------------------------------------------------------------------===//
def ROCDL_LoadToLDSOp :
- ROCDL_IntrOp<"load.to.lds", [], [0], [], 0, 0, 1, [2, 3, 4], ["size", "offset", "aux"]> {
+ ROCDL_IntrOp<"load.to.lds", [], [0], [], 0, 0, 1, 0, [2, 3, 4], ["size", "offset", "aux"]> {
dag args = (ins Arg<LLVM_AnyPointer, "", [MemRead]>:$globalPtr,
Arg<ROCDLBufferLDS, "", [MemWrite]>:$ldsPtr,
I32Attr:$size,
@@ -520,7 +540,7 @@ def ROCDL_LoadToLDSOp :
}
def ROCDL_GlobalLoadLDSOp :
- ROCDL_IntrOp<"global.load.lds", [], [], [], 0, 0, 1, [2, 3, 4], ["size", "offset", "aux"]> {
+ ROCDL_IntrOp<"global.load.lds", [], [], [], 0, 0, 1, 0, [2, 3, 4], ["size", "offset", "aux"]> {
dag args = (ins Arg<ROCDLGlobalBuffer, "", [MemRead]>:$globalPtr,
Arg<ROCDLBufferLDS, "", [MemWrite]>:$ldsPtr,
I32Attr:$size,
@@ -734,7 +754,7 @@ def ROCDL_RawBufferAtomicUMinOp :
// DPP Update intrinsic
def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
- [AllTypesMatch<["res", "src", "old"]>], 1, 0, 0,
+ [AllTypesMatch<["res", "src", "old"]>], 1, 0, 0, 0,
[2, 3, 4, 5], ["dppCtrl", "rowMask", "bankMask", "boundCtrl"]>,
Arguments<(ins LLVM_Type:$old, LLVM_Type:$src, I32Attr:$dppCtrl, I32Attr:$rowMask,
I32Attr:$bankMask, I1Attr:$boundCtrl)> {
@@ -746,7 +766,7 @@ def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
// PermLaneX16 intrinsic operation
def ROCDL_PermlaneX16Op : ROCDL_IntrOp<"permlanex16", [], [0],
- [AllTypesMatch<["res", "old", "src0"]>, AllTypesMatch<["src1", "src2"]>], 1, 0, 0,
+ [AllTypesMatch<["res", "old", "src0"]>, AllTypesMatch<["src1", "src2"]>], 1, 0, 0, 0,
[4, 5], ["fi", "boundControl"]>,
Arguments<(ins LLVM_Type:$old, LLVM_Type:$src0, LLVM_Type:$src1, LLVM_Type:$src2,
I1Attr:$fi, I1Attr:$boundControl)> {
@@ -760,6 +780,53 @@ def ROCDL_PermlaneX16Op : ROCDL_IntrOp<"permlanex16", [], [0],
}];
}
+class ROCDL_ConcretePair<Type elem0, Type elem1> :
+ Type<And<[
+ LLVM_AnyStruct.predicate,
+ SubstLeaves<
+ "$_self",
+ "::llvm::cast<::mlir::LLVM::LLVMStructType>($_self).getBody()[0]",
+ elem0.predicate>,
+ SubstLeaves<
+ "$_self",
+ "::llvm::cast<::mlir::LLVM::LLVMStructType>($_self).getBody()[1]",
+ elem1.predicate>
+ ]>,
+ "LLVM dialect-compatible struct of " # elem0.summary # "and" # elem1.summary,
+ "::mlir::LLVM::LLVMStructType">,
+ BuildableType<"::mlir::LLVM::LLVMStructType::getLiteral($_builder.getContext(), "
+ "{" # elem0.builderCall # ", " # elem1.builderCall # "})">;
+
+// Permlane16 swap intrinsic operation
+def ROCDL_Permlane16SwapOp : ROCDL_IntrOp<"permlane16.swap", [], [],
+ [], 1, 0, 0, 0,
+ [2, 3], ["fi", "boundControl"]>,
+ Arguments<(ins I32:$old, I32:$src, I1Attr:$fi, I1Attr:$boundControl)> {
+ let results = (outs ROCDL_ConcretePair<I32, I32>:$res);
+ let assemblyFormat = [{
+ attr-dict $old `,` $src `,` $fi `,` $boundControl `:` `(` type($old) `,` type($src) `)` `->` type($res)
+ }];
+ let description = [{
+ Performs a `permlane16.swap` operation with the given operands, applying the
+ permutation specified by $fi to the provided inputs.
+ }];
+}
+
+// Permlane32 swap intrinsic operation
+def ROCDL_Permlane32SwapOp : ROCDL_IntrOp<"permlane32.swap", [], [],
+ [], 1, 0, 0, 0,
+ [2, 3], ["fi", "boundControl"]>,
+ Arguments<(ins I32:$old, I32:$src, I1Attr:$fi, I1Attr:$boundControl)> {
+ let results = (outs ROCDL_ConcretePair<I32, I32>:$res);
+ let assemblyFormat = [{
+ attr-dict $old `,` $src `,` $fi `,` $boundControl `:` `(` type($old) `,` type($src) `)` `->` type($res)
+ }];
+ let description = [{
+ Performs a `permlane32.swap` operation with the given operands, applying the
+ permutation specified by $fi to the provided inputs.
+ }];
+}
+
class ROCDL_ConcreteVector<Type elem, int length> :
FixedVectorOfLengthAndType<[length], [elem]>,
BuildableType<
diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/Transforms/CMakeLists.txt
index becae6a..04e5a0f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name LLVM)
-add_public_tablegen_target(MLIRLLVMPassIncGen)
+add_mlir_dialect_tablegen_target(MLIRLLVMPassIncGen)
add_mlir_doc(Passes LLVMPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt
index f0a486d..18b8174 100644
--- a/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt
@@ -5,6 +5,6 @@ set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Linalg)
mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix Linalg)
mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix Linalg)
-add_public_tablegen_target(MLIRLinalgPassIncGen)
+add_mlir_dialect_tablegen_target(MLIRLinalgPassIncGen)
add_mlir_doc(Passes LinalgPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
index 386d2f3..d62aced 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
@@ -62,14 +62,12 @@ add_mlir_dialect(LinalgOps linalg)
set(LLVM_TARGET_DEFINITIONS LinalgEnums.td)
mlir_tablegen(LinalgOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(LinalgOpsEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRLinalgOpsEnumsIncGen)
-add_dependencies(mlir-headers MLIRLinalgOpsEnumsIncGen)
+add_mlir_dialect_tablegen_target(MLIRLinalgOpsEnumsIncGen)
set(LLVM_TARGET_DEFINITIONS LinalgOps.td)
mlir_tablegen(LinalgOpsAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(LinalgOpsAttrDefs.cpp.inc -gen-attrdef-defs)
-add_public_tablegen_target(MLIRLinalgOpsAttributesIncGen)
-add_dependencies(mlir-headers MLIRLinalgOpsAttributesIncGen)
+add_mlir_dialect_tablegen_target(MLIRLinalgOpsAttributesIncGen)
add_mlir_doc(LinalgDoc LinalgOps Dialects/ -gen-op-doc)
add_dependencies(LinalgOpsDocGen LinalgOdsGen)
@@ -77,20 +75,17 @@ add_dependencies(LinalgOpsDocGen LinalgOdsGen)
set(LLVM_TARGET_DEFINITIONS LinalgStructuredOps.td)
mlir_tablegen(LinalgStructuredOps.h.inc -gen-op-decls)
mlir_tablegen(LinalgStructuredOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRLinalgStructuredOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRLinalgStructuredOpsIncGen)
add_dependencies(MLIRLinalgStructuredOpsIncGen LinalgOdsGen)
-add_dependencies(mlir-headers MLIRLinalgStructuredOpsIncGen)
set(LLVM_TARGET_DEFINITIONS LinalgRelayoutOps.td)
mlir_tablegen(LinalgRelayoutOps.h.inc -gen-op-decls)
mlir_tablegen(LinalgRelayoutOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRLinalgRelayoutOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRLinalgRelayoutOpsIncGen)
add_dependencies(MLIRLinalgRelayoutOpsIncGen LinalgOdsGen)
-add_dependencies(mlir-headers MLIRLinalgRelayoutOpsIncGen)
set(LLVM_TARGET_DEFINITIONS LinalgInterfaces.td)
mlir_tablegen(LinalgInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(LinalgInterfaces.cpp.inc -gen-op-interface-defs)
-add_public_tablegen_target(MLIRLinalgInterfacesIncGen)
-add_dependencies(mlir-headers MLIRLinalgInterfacesIncGen)
+add_mlir_dialect_tablegen_target(MLIRLinalgInterfacesIncGen)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index 62c04bb..eb4e381 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -145,8 +145,7 @@ std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr);
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.h.inc"
-namespace mlir {
-namespace linalg {
+namespace mlir::linalg {
/// Returns the outer shape in the packed domain before applying the
/// transposition.
@@ -155,7 +154,194 @@ template <typename OpTy,
std::is_same_v<OpTy, linalg::UnPackOp>>>
SmallVector<int64_t> getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack);
-} // namespace linalg
-} // namespace mlir
+/// Specialization of `linalg.matmul` op that has a transpose map on A
+class MatmulTransposeAOp : public MatmulOp {
+ /// Create an affine map for a transpose-A matmul. Used only in the builders.
+ static SmallVector<AffineMap> getDefaultIndexingMaps(OpBuilder &builder);
+
+public:
+ using MatmulOp::MatmulOp;
+ static ::mlir::TypeID resolveTypeID() { return TypeID::get<MatmulOp>(); }
+
+ /// Build a transpose A matmul.
+ static void build(OpBuilder &builder, OperationState &result,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ static MatmulTransposeAOp create(OpBuilder &builder, Location location,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ /// Build a transpose A matmul with a specific result type.
+ static void build(OpBuilder &builder, OperationState &result,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ static MatmulTransposeAOp create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ /// Build a transpose A matmul with a specific result type and a cast type.
+ static void build(OpBuilder &builder, OperationState &result,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ static MatmulTransposeAOp create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ Attribute cast,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ /// Checks if the affine map is the expected one for this operation
+ static bool isDefaultIndexingMaps(Attribute attr);
+
+ static bool classof(Operation *op);
+};
+
+/// Specialization of `linalg.matmul` op that has a transpose map on B
+class MatmulTransposeBOp : public MatmulOp {
+ /// Create an affine map for a transpose-B matmul. Used only in the builders.
+ static SmallVector<AffineMap> getDefaultIndexingMaps(OpBuilder &builder);
+
+public:
+ using MatmulOp::MatmulOp;
+ static ::mlir::TypeID resolveTypeID() { return TypeID::get<MatmulOp>(); }
+
+ /// Build a transpose B matmul.
+ static void build(OpBuilder &builder, OperationState &result,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ static MatmulTransposeBOp create(OpBuilder &builder, Location location,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ /// Build a transpose B matmul with a specific result type.
+ static void build(OpBuilder &builder, OperationState &result,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ static MatmulTransposeBOp create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ /// Build a transpose B matmul with a specific result type and a cast type.
+ static void build(OpBuilder &builder, OperationState &result,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ static MatmulTransposeBOp create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ Attribute cast,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ /// Checks if the affine map is the expected one for this operation
+ static bool isDefaultIndexingMaps(Attribute attr);
+
+ static bool classof(Operation *op);
+};
+
+/// Specialization of `linalg.batch_matmul` op that has a transpose map on A
+class BatchMatmulTransposeAOp : public BatchMatmulOp {
+ /// Create an affine map for a transpose-A batch_matmul. Used only in the
+ /// builders.
+ static SmallVector<AffineMap> getDefaultIndexingMaps(OpBuilder &builder);
+
+public:
+ using BatchMatmulOp::BatchMatmulOp;
+ static ::mlir::TypeID resolveTypeID() { return TypeID::get<BatchMatmulOp>(); }
+
+ /// Build a transpose A matmul.
+ static void build(OpBuilder &builder, OperationState &result,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ static BatchMatmulTransposeAOp
+ create(OpBuilder &builder, Location location, ValueRange inputs,
+ ValueRange outputs, ArrayRef<NamedAttribute> attributes = {});
+
+ /// Build a transpose A matmul with a specific result type.
+ static void build(OpBuilder &builder, OperationState &result,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ static BatchMatmulTransposeAOp
+ create(OpBuilder &builder, Location location, TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ /// Build a transpose A matmul with a specific result type and a cast type.
+ static void build(OpBuilder &builder, OperationState &result,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ static BatchMatmulTransposeAOp
+ create(OpBuilder &builder, Location location, TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ /// Checks if the affine map is the expected one for this operation
+ static bool isDefaultIndexingMaps(Attribute attr);
+
+ static bool classof(Operation *op);
+};
+
+/// Specialization of `linalg.batch_matmul` op that has a transpose map on B
+class BatchMatmulTransposeBOp : public BatchMatmulOp {
+ /// Create an affine map for a transpose-B batch_matmul. Used only in the
+ /// builders.
+ static SmallVector<AffineMap> getDefaultIndexingMaps(OpBuilder &builder);
+
+public:
+ using BatchMatmulOp::BatchMatmulOp;
+ static ::mlir::TypeID resolveTypeID() { return TypeID::get<BatchMatmulOp>(); }
+
+ /// Build a transpose B matmul.
+ static void build(OpBuilder &builder, OperationState &result,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ static BatchMatmulTransposeBOp
+ create(OpBuilder &builder, Location location, ValueRange inputs,
+ ValueRange outputs, ArrayRef<NamedAttribute> attributes = {});
+
+ /// Build a transpose B matmul with a specific result type.
+ static void build(OpBuilder &builder, OperationState &result,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ static BatchMatmulTransposeBOp
+ create(OpBuilder &builder, Location location, TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ /// Build a transpose B matmul with a specific result type and a cast type.
+ static void build(OpBuilder &builder, OperationState &result,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ static BatchMatmulTransposeBOp
+ create(OpBuilder &builder, Location location, TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes = {});
+
+ /// Checks if the affine map is the expected one for this operation
+ static bool isDefaultIndexingMaps(Attribute attr);
+
+ static bool classof(Operation *op);
+};
+
+} // namespace mlir::linalg
#endif // MLIR_DIALECT_LINALG_IR_LINALG_H
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index ba73cfb..9f1e88a0 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -474,7 +474,7 @@ def LinalgStructuredInterface
int64_t resultIndex =
opOperand->getOperandNumber() - $_op.getNumDpsInputs();
assert(resultIndex >= 0 &&
- resultIndex < this->getOperation()->getNumResults());
+ resultIndex < $_op.getNumDpsInits());
Operation *yieldOp = getBlock()->getTerminator();
return &yieldOp->getOpOperand(resultIndex);
}]
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 3637147..9aae1b8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1056,152 +1056,6 @@ structured_op: !LinalgStructuredOpConfig
scalar_arg: BZp
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
- name: matmul_transpose_a
- cpp_class_name: MatmulTransposeAOp
- doc: |-
- Performs a matrix multiplication of two 2D inputs with lhs operand
- transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
- - !LinalgOperandDefConfig
- name: cast
- kind: type_fn_attr
- default_fn: cast_signed
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d0)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
- iterator_types:
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: matmul_transpose_b
- cpp_class_name: MatmulTransposeBOp
- doc: |-
- Performs a matrix multiplication of two 2D inputs with rhs operand
- transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
- - !LinalgOperandDefConfig
- name: cast
- kind: type_fn_attr
- default_fn: cast_signed
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d1, d2)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
- iterator_types:
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
-metadata: !LinalgOpMetadata
name: mmt4d
cpp_class_name: Mmt4DOp
doc: |-
@@ -1359,146 +1213,6 @@ structured_op: !LinalgStructuredOpConfig
scalar_arg: rhs
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
- name: batch_matmul_transpose_a
- cpp_class_name: BatchMatmulTransposeAOp
- doc: |-
- Performs a batched matrix multiplication of two 3D inputs where lhs operand
- has its non-batch dimensions transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d1)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
- iterator_types:
- - parallel
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: batch_matmul_transpose_b
- cpp_class_name: BatchMatmulTransposeBOp
- doc: |-
- Performs a batched matrix multiplication of two 3D inputs where rhs operand
- has its non-batch dimensions transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d2, d3)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
- iterator_types:
- - parallel
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
-metadata: !LinalgOpMetadata
name: quantized_batch_matmul
cpp_class_name: QuantizedBatchMatmulOp
doc: |-
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index ca0cc03..f3674c3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -253,10 +253,12 @@ def MapOp : LinalgStructuredBase_Op<"map", [
}
```
- Shortened print form is available. Applies to simple maps with one
- non-yield operation inside the body.
+ Shortened print form is available for simple maps where the body contains exactly
+ two operations (the payload operation and a yield), the payload operation has
+ the same number of operands as block arguments with operands matching block
+ arguments in order, and the yield operand is the result of the payload operation.
- The example above will be printed as:
+ The example above will be printed using the shortened form as:
```mlir
%add = linalg.map { arith.addf }
ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
@@ -340,13 +342,15 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
}
```
- Shortened print form is available. Applies to simple (not variadic) reduces
- with one non-yield operation inside the body. Applies only if the operation
- takes `%out` as the first argument.
+ Shortened print form is available for simple reduces where the body contains exactly
+ two operations (the payload operation and a yield), the payload operation has the
+ same number of operands as block arguments, the first block argument (init) is the
+ last operand of the payload operation with remaining operands matching remaining
+ block arguments in order, and the yield operand is the result of the payload operation.
- The example above will be printed as:
+ The example above will be printed using the shortened form as:
```mlir
- %reduce = linalg.reduce { arith.addf }
+ %reduce = linalg.reduce { arith.addf }
ins(%input:tensor<16x32x64xf32>)
outs(%init:tensor<16x64xf32>)
dimensions = [1]
@@ -785,6 +789,9 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
/// Returns a list of AffineMap with the default matmul indexing charactristic.
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
+ /// Returns true if the AffineMap is the default matmul indexing charactristic.
+ static bool isDefaultIndexingMaps(Attribute attr);
+
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
@@ -1057,6 +1064,9 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
/// Returns a list with default AffineMap(s), i.e. without broadcasts and transpositions.
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
+ /// Returns true if the AffineMap is the default batch matmul indexing charactristic.
+ static bool isDefaultIndexingMaps(Attribute attr);
+
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
@@ -1181,6 +1191,9 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
/// Returns a list of AffineMap with the default batch_reduce_matmul indexing charactristic.
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
+ /// Returns true if the AffineMap is the default batch reduce matmul indexing charactristic.
+ static bool isDefaultIndexingMaps(Attribute attr);
+
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 373842c..44da2965 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -10,6 +10,88 @@
#define MLIR_DIALECT_LINALG_PASSES
include "mlir/Pass/PassBase.td"
+include "mlir/IR/Constraints.td"
+
+// ------------------ Begin of "form" conversions
+//
+// These conversions allow for the transformation of linalg ops between
+// different forms. Structured ops can be represented in different forms,
+// such as generic ops, category ops, and named ops.
+//
+// The operation tree is as follows:
+// generic category named
+// ---------|-------------|----------
+// generic ---> contract ----> matmul
+// | \-> batch_matmul
+// | \-> batch_reduce_matmul
+// | \-> ...
+// \-> elementwise -> add
+// \-> sub
+// \-> ...
+//
+// Morphisms between representations can happen in the following 6 ways:
+// generic <---> category <---> named
+// \-------------------------/
+//
+// generic subsumes category which subsumes structured named (not softmax,
+// convolutions, etc). The generalization path is guaranteed, the
+// specialization path is not.
+
+def LinalgMorphOpsPass : Pass<"linalg-morph-ops"> {
+ let summary = "Convert linalg ops between forms";
+
+ let description = [{
+ Convert a linalg op from one representation to another equivalent.
+ For example, a linalg named op `linalg.add` can also be written as an
+ category op `linalg.elementwise`, and can also be re-written as
+ a `linalg.generic`, giving the morphism:
+
+ named-op <--> category_op (elementwise, contraction, ..) <--> generic
+
+ Note that the set of `linalg.generic` subsumes named and category ops
+ and therefore not all `linalg.genric` can be converted to named or
+ category op. Similarly, catgory ops subsume named ops.
+
+ Note:
+ Legacy converters:
+ `--linalg-generalize-named-ops` is the path `named-op --> generic-op`
+ `--linalg-specialize-generic-ops` is the path `named-op <-- generic-op`
+ }];
+ let dependentDialects = ["linalg::LinalgDialect"];
+
+ let options = [
+ // Generalization path is guaranteed.
+ Option<"namedToCategory", "named-to-category", "bool", /*default=*/"false",
+ "convert named ops to category op e.g. `linalg.elementwise`">,
+ Option<"categoryToGeneric", "category-to-generic", "bool", /*default=*/"false",
+ "convert category ops e.g. `linalg.elementwise` to `linalg.generic`">,
+ Option<"namedToGeneric", "named-to-generic", "bool", /*default=*/"false",
+ "convert named ops e.g. `linalg.add` to `linalg.generic`">,
+
+ // Specialization path is not guaranteed.
+ Option<"genericToNamed", "generic-to-named", "bool", /*default=*/"false",
+ "convert linalg.generic to equivalent named ops"> ];
+ // TODOs: `generic-to-category`, `category-to-named`
+}
+
+def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops">,
+ Deprecated<"Use 'linalg-morph-ops' instead."> {
+ let summary = "Convert named ops into generic ops";
+ let dependentDialects = ["linalg::LinalgDialect"];
+}
+
+def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops">,
+ Deprecated<"Use 'linalg-morph-ops' instead."> {
+ let summary = "Convert generic ops back to named ops";
+ let dependentDialects = ["linalg::LinalgDialect"];
+}
+
+// ------------------ End of "form" conversions
+
+def SimplifyDepthwiseConvPass: Pass<"simplify-depthwise-conv"> {
+ let summary = "Simplify depthwise convolution.";
+ let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
+}
def ConvertElementwiseToLinalgPass : Pass<"convert-elementwise-to-linalg", ""> {
let summary = "Convert ElementwiseMappable ops to linalg";
@@ -77,11 +159,6 @@ def LinalgElementwiseOpFusionPass : Pass<"linalg-fuse-elementwise-ops"> {
];
}
-def LinalgNamedOpConversionPass: Pass<"linalg-named-op-conversion"> {
- let summary = "Convert from one named linalg op to another.";
- let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
-}
-
def LinalgInlineScalarOperandsPass : Pass<"linalg-inline-scalar-operands"> {
let summary = "Inline scalar operands into linalg generic ops";
let dependentDialects = [
@@ -89,16 +166,6 @@ def LinalgInlineScalarOperandsPass : Pass<"linalg-inline-scalar-operands"> {
];
}
-def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops"> {
- let summary = "Convert named ops into generic ops";
- let dependentDialects = ["linalg::LinalgDialect"];
-}
-
-def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
- let summary = "Convert generic ops back to named ops";
- let dependentDialects = ["linalg::LinalgDialect"];
-}
-
def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
let summary = "Fold transform, broadcast and other ops into elementwise";
let dependentDialects = ["linalg::LinalgDialect"];
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/TransformOps/CMakeLists.txt
index 4f6b251..71a2689 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/CMakeLists.txt
@@ -1,17 +1,17 @@
set(LLVM_TARGET_DEFINITIONS LinalgMatchOps.td)
mlir_tablegen(LinalgMatchOps.h.inc -gen-op-decls)
mlir_tablegen(LinalgMatchOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRLinalgMatchOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRLinalgMatchOpsIncGen)
set(LLVM_TARGET_DEFINITIONS LinalgTransformOps.td)
mlir_tablegen(LinalgTransformOps.h.inc -gen-op-decls)
mlir_tablegen(LinalgTransformOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRLinalgTransformOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRLinalgTransformOpsIncGen)
set(LLVM_TARGET_DEFINITIONS LinalgTransformEnums.td)
mlir_tablegen(LinalgTransformOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(LinalgTransformOpsEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRLinalgTransformEnumsIncGen)
+add_mlir_dialect_tablegen_target(MLIRLinalgTransformEnumsIncGen)
add_mlir_doc(LinalgMatchOps LinalgStructuredMatchOps Dialects/ -gen-op-doc)
add_mlir_doc(LinalgTransformOps LinalgStructuredTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 61ce23f..a19cce4 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2348,6 +2348,9 @@ def VectorizeChildrenAndApplyPatternsOp :
operation that is contained inside the vectorization target.
This transformation supports the following attributes:
+ - `fold_type_extensions_into_contract`: a `UnitAttr` to enable the folding of
+ type extension operations into `vector.contract` to create a mixed precision
+ operation.
- `vectorize_padding`: a `UnitAttr` to activate the vectorization of
`tensor.pad` ops. Different pipelines may prefer to lower such ops to
loops.
@@ -2368,6 +2371,7 @@ def VectorizeChildrenAndApplyPatternsOp :
}];
let arguments = (ins TransformHandleTypeInterface:$target,
+ UnitAttr:$fold_type_extensions_into_contract,
UnitAttr:$vectorize_padding,
UnitAttr:$vectorize_nd_extract,
UnitAttr:$flatten_1d_depthwise_conv,
@@ -2381,6 +2385,7 @@ def VectorizeChildrenAndApplyPatternsOp :
let builders = [
OpBuilder<(ins "Value":$target,
+ CArg<"bool", "false">:$foldTypeExtensionsIntoContract,
CArg<"bool", "false">:$vectorizePadding,
CArg<"bool", "false">:$vectorizeNDExtract,
CArg<"bool", "false">:$flatten1DDepthwise)>
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d4ffe0a..64d3a24 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -649,7 +649,7 @@ FailureOr<TilingInterface>
rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
const PadTilingInterfaceOptions &constOptions,
SmallVector<tensor::PadOp> &padOps,
- PadSizeComputationFunction computePaddingSizeFun =
+ const PadSizeComputationFunction &computePaddingSizeFun =
&computeIndexingMapOpInterfacePaddedShape);
namespace detail {
@@ -1690,7 +1690,7 @@ struct DecomposeOuterUnitDimsPackOpPattern
/// Rewrites a linalg::UnPackOp into a sequence of rank-reduced
/// * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp
///
-/// Requires that all the outer dims of the input linalg::PackOp are 1.
+/// Requires that all the tiled outer dims of the input linalg::PackOp are 1.
///
/// Before:
/// ```
@@ -1831,6 +1831,10 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
void populateLinalgGenericOpsSpecializationPatterns(
RewritePatternSet &patterns);
+/// Populates `patterns` that convert linalg named ops e.g. `linalg.add`
+/// to equivalent `linalg.elementwise`.
+void populateLinalgNamedToElementwisePatterns(RewritePatternSet &patterns);
+
/// Populates `patterns` with patterns that fold operations like
/// `linalg.transform` into elementwise op map.
void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns);
@@ -1914,6 +1918,11 @@ void populateDataLayoutPropagationPatterns(
RewritePatternSet &patterns,
const ControlPropagationFn &controlPackUnPackPropagation);
+/// Patterns to sink extract slice across other operations.
+void populateExtractSliceSinkingPatterns(
+ RewritePatternSet &patterns,
+ const ControlPropagationFn &controlPackUnPackPropagation);
+
/// Pattern to remove dead operands and results of `linalg.generic` operations.
/// This is a pattern wrapper for `deduplicateOperandsAndRemoveDeadResults`.
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);
@@ -1958,9 +1967,8 @@ void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns);
void populateFuseTensorPadWithProducerLinalgOpPatterns(
RewritePatternSet &patterns);
-/// Patterns to convert from one named op to another. These can be seen as
-/// canonicalizations of named ops into another named op.
-void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);
+/// Patterns to simplify depthwise convolutions.
+void populateSimplifyDepthwiseConvPatterns(RewritePatternSet &patterns);
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
/// tensors via reassociative reshape ops.
diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt
index 6e89471..2013f34 100644
--- a/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt
@@ -5,11 +5,9 @@ add_mlir_doc(MLProgramOps MLProgramOps Dialects/ -gen-dialect-doc)
set(LLVM_TARGET_DEFINITIONS MLProgramAttributes.td)
mlir_tablegen(MLProgramAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(MLProgramAttributes.cpp.inc -gen-attrdef-defs)
-add_public_tablegen_target(MLIRMLProgramAttributesIncGen)
-add_dependencies(mlir-headers MLIRMLProgramAttributesIncGen)
+add_mlir_dialect_tablegen_target(MLIRMLProgramAttributesIncGen)
set(LLVM_TARGET_DEFINITIONS MLProgramTypes.td)
mlir_tablegen(MLProgramTypes.h.inc -gen-typedef-decls)
mlir_tablegen(MLProgramTypes.cpp.inc -gen-typedef-defs)
-add_public_tablegen_target(MLIRMLProgramTypesIncGen)
-add_dependencies(mlir-headers MLIRMLProgramTypesIncGen)
+add_mlir_dialect_tablegen_target(MLIRMLProgramTypesIncGen)
diff --git a/mlir/include/mlir/Dialect/MLProgram/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/MLProgram/Transforms/CMakeLists.txt
index c5c11f1..5fa6bdc 100644
--- a/mlir/include/mlir/Dialect/MLProgram/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/MLProgram/Transforms/CMakeLists.txt
@@ -1,6 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name MLProgram)
-add_public_tablegen_target(MLIRMLProgramPassIncGen)
-add_dependencies(mlir-headers MLIRMLProgramPassIncGen)
+add_mlir_dialect_tablegen_target(MLIRMLProgramPassIncGen)
add_mlir_doc(Passes MLProgramPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/MPI/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/MPI/IR/CMakeLists.txt
index dfec2ea..eeb3ca5 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/MPI/IR/CMakeLists.txt
@@ -5,13 +5,13 @@ add_mlir_doc(MPIOps MPI Dialects/ -gen-dialect-doc)
set(LLVM_TARGET_DEFINITIONS MPIOps.td)
mlir_tablegen(MPIOps.h.inc -gen-op-decls)
mlir_tablegen(MPIOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRMPIOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRMPIOpsIncGen)
# Add MPI types
set(LLVM_TARGET_DEFINITIONS MPITypes.td)
mlir_tablegen(MPITypesGen.h.inc -gen-typedef-decls)
mlir_tablegen(MPITypesGen.cpp.inc -gen-typedef-defs)
-add_public_tablegen_target(MLIRMPITypesIncGen)
+add_mlir_dialect_tablegen_target(MLIRMPITypesIncGen)
# Add MPI attributes
set(LLVM_TARGET_DEFINITIONS MPI.td)
@@ -19,4 +19,4 @@ mlir_tablegen(MPIEnums.h.inc -gen-enum-decls)
mlir_tablegen(MPIEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(MPIAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(MPIAttrDefs.cpp.inc -gen-attrdef-defs)
-add_public_tablegen_target(MLIRMPIAttrsIncGen)
+add_mlir_dialect_tablegen_target(MLIRMPIAttrsIncGen)
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 5637038..cfd8c4b 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -353,6 +353,37 @@ def Math_CeilOp : Math_FloatUnaryOp<"ceil"> {
}
//===----------------------------------------------------------------------===//
+// ClampFOp
+//===----------------------------------------------------------------------===//
+
+def Math_ClampFOp : Math_FloatTernaryOp<"clampf"> {
+ let summary = "floating point clamping operation";
+ let description = [{
+ The `clampf` operation takes three operands and returns one result, each of
+ these is required to be the same type. Operands must be of floating point type
+ (i.e., scalar, tensor or vector).
+
+ The semantics of the operation are described by:
+ ```
+ clampf(value, min, max) = maxf(minf(value, min), max)
+ ```
+
+ Example:
+
+ ```mlir
+ %d = math.clampf %value to [%min, %max] : f64
+ ```
+ }];
+ let arguments = (ins FloatLike:$value, FloatLike:$min, FloatLike:$max,
+ DefaultValuedAttr<Arith_FastMathAttr,
+ "::mlir::arith::FastMathFlags::none">:$fastmath);
+ let assemblyFormat = [{
+ $value `to` ` ` `[` $min `,` $max `]` (`fastmath` `` $fastmath^)?
+ attr-dict `:` type($result)
+ }];
+}
+
+//===----------------------------------------------------------------------===//
// CopySignOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Math/Transforms/CMakeLists.txt
index a37f069..c452372 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Math/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Math)
-add_public_tablegen_target(MLIRMathTransformsIncGen)
+add_mlir_dialect_tablegen_target(MLIRMathTransformsIncGen)
add_mlir_doc(Passes MathPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index c0fe5d3..b3abbf7 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -23,22 +23,16 @@ class ConversionTarget;
class RewritePatternSet;
class TypeConverter;
-void populateExpandCtlzPattern(RewritePatternSet &patterns);
-void populateExpandTanPattern(RewritePatternSet &patterns);
-void populateExpandSinhPattern(RewritePatternSet &patterns);
-void populateExpandCoshPattern(RewritePatternSet &patterns);
-void populateExpandTanhPattern(RewritePatternSet &patterns);
-void populateExpandAsinhPattern(RewritePatternSet &patterns);
-void populateExpandAcoshPattern(RewritePatternSet &patterns);
-void populateExpandAtanhPattern(RewritePatternSet &patterns);
-void populateExpandFmaFPattern(RewritePatternSet &patterns);
-void populateExpandCeilFPattern(RewritePatternSet &patterns);
-void populateExpandExp2FPattern(RewritePatternSet &patterns);
-void populateExpandPowFPattern(RewritePatternSet &patterns);
-void populateExpandFPowIPattern(RewritePatternSet &patterns);
-void populateExpandRoundFPattern(RewritePatternSet &patterns);
-void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
-void populateExpandRsqrtPattern(RewritePatternSet &patterns);
+namespace math {
+/// Adds patterns to expand math operations into other more fundamental
+/// operations. For example, hyperbolic functions are expanded into expressions
+/// using `exp`. If `opMnemonics` is empty then all available patterns will be
+/// added, otherwise only the patterns corresponding to ops in `opMnemonics`
+/// will be added to the set.
+void populateExpansionPatterns(RewritePatternSet &patterns,
+ ArrayRef<StringRef> opMnemonics = {});
+} // namespace math
+
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
struct MathPolynomialApproximationOptions {
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index a84c890..4d415ae 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -44,4 +44,24 @@ def MathExtendToSupportedTypes : Pass<"math-extend-to-supported-types"> {
let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
}
+def MathExpandOpsPass : Pass<"math-expand-ops"> {
+ let summary = "Expand math operations.";
+ let description = [{
+ Expands some math operations into more fundamental operations, allowing them
+ to be subsequently lowered through these. For example, hyperbolic functions
+ are transformed into their expanded form containing only `exp` functions.
+
+ The `ops` parameter can be used to apply only a subset of all the
+ available expansions, these must correspond to the operation mnemonic.
+ For example, `ops=sinh,acosh` will expand only `math.sinh` and
+ `math.acosh` operations. If the list is empty, then all expansions are
+ applied.
+ }];
+ let dependentDialects = ["arith::ArithDialect"];
+ let options = [
+ ListOption<"opMnemonics", "ops", "std::string",
+ "Operations to expand.">
+ ];
+}
+
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 9321089a..d6b7a97 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1214,7 +1214,7 @@ def LoadOp : MemRef_Op<"load",
A set `nontemporal` attribute indicates that this load is not expected to
be reused in the cache. For details, refer to the
- [https://llvm.org/docs/LangRef.html#load-instruction](LLVM load instruction).
+ [LLVM load instruction](https://llvm.org/docs/LangRef.html#load-instruction).
An optional `alignment` attribute allows to specify the byte alignment of the
load operation. It must be a positive power of 2. The operation must access
@@ -1947,7 +1947,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
A set `nontemporal` attribute indicates that this store is not expected to
be reused in the cache. For details, refer to the
- [https://llvm.org/docs/LangRef.html#store-instruction](LLVM store instruction).
+ [LLVM store instruction](https://llvm.org/docs/LangRef.html#store-instruction).
An optional `alignment` attribute allows to specify the byte alignment of the
store operation. It must be a positive power of 2. The operation must access
diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/MemRef/TransformOps/CMakeLists.txt
index 8dbe988..3170faa 100644
--- a/mlir/include/mlir/Dialect/MemRef/TransformOps/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/CMakeLists.txt
@@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS MemRefTransformOps.td)
mlir_tablegen(MemRefTransformOps.h.inc -gen-op-decls)
mlir_tablegen(MemRefTransformOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRMemRefTransformOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRMemRefTransformOpsIncGen)
add_mlir_doc(MemRefTransformOps MemRefTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/MemRef/Transforms/CMakeLists.txt
index 6f868f7..e46eae5 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -1,6 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name MemRef)
-add_public_tablegen_target(MLIRMemRefPassIncGen)
-add_dependencies(mlir-headers MLIRMemRefPassIncGen)
+add_mlir_dialect_tablegen_target(MLIRMemRefPassIncGen)
add_mlir_doc(Passes MemRefPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/NVGPU/IR/CMakeLists.txt
index ecdaae7..7beebae 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/CMakeLists.txt
@@ -4,19 +4,19 @@ add_mlir_doc(NVGPUOps NVGPU Dialects/ -gen-dialect-doc)
set(LLVM_TARGET_DEFINITIONS NVGPUOps.td)
mlir_tablegen(NVGPUOps.h.inc -gen-op-decls)
mlir_tablegen(NVGPUOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRNVGPUOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRNVGPUOpsIncGen)
set(LLVM_TARGET_DEFINITIONS NVGPU.td)
mlir_tablegen(NVGPUEnums.h.inc -gen-enum-decls)
mlir_tablegen(NVGPUEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRNVGPUEnumsIncGen)
+add_mlir_dialect_tablegen_target(MLIRNVGPUEnumsIncGen)
set(LLVM_TARGET_DEFINITIONS NVGPU.td)
mlir_tablegen(NVGPUAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(NVGPUAttrDefs.cpp.inc -gen-attrdef-defs)
-add_public_tablegen_target(MLIRNVGPUAttributesIncGen)
+add_mlir_dialect_tablegen_target(MLIRNVGPUAttributesIncGen)
set(LLVM_TARGET_DEFINITIONS NVGPUTypes.td)
mlir_tablegen(NVGPUTypeDefs.h.inc -gen-typedef-decls)
mlir_tablegen(NVGPUTypeDefs.cpp.inc -gen-typedef-defs)
-add_public_tablegen_target(MLIRNVGPUTypesIncGen)
+add_mlir_dialect_tablegen_target(MLIRNVGPUTypesIncGen)
diff --git a/mlir/include/mlir/Dialect/NVGPU/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/NVGPU/TransformOps/CMakeLists.txt
index d75ae3d..f8e7acd 100644
--- a/mlir/include/mlir/Dialect/NVGPU/TransformOps/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/NVGPU/TransformOps/CMakeLists.txt
@@ -1,4 +1,4 @@
set(LLVM_TARGET_DEFINITIONS NVGPUTransformOps.td)
mlir_tablegen(NVGPUTransformOps.h.inc -gen-op-decls)
mlir_tablegen(NVGPUTransformOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRNVGPUTransformOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRNVGPUTransformOpsIncGen)
diff --git a/mlir/include/mlir/Dialect/NVGPU/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/NVGPU/Transforms/CMakeLists.txt
index 706b66e..12245fe 100644
--- a/mlir/include/mlir/Dialect/NVGPU/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/NVGPU/Transforms/CMakeLists.txt
@@ -2,6 +2,6 @@ set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name NVGPU)
mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix NVGPU)
mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix NVGPU)
-add_public_tablegen_target(MLIRNVGPUPassIncGen)
+add_mlir_dialect_tablegen_target(MLIRNVGPUPassIncGen)
add_mlir_doc(Passes NVGPUPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt
index 66b1e89..789cd70 100644
--- a/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt
@@ -2,7 +2,7 @@ add_subdirectory(Transforms)
set(LLVM_TARGET_DEFINITIONS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Frontend/OpenACC/ACC.td)
mlir_tablegen(AccCommon.td --gen-directive-decl --directives-dialect=OpenACC)
-add_public_tablegen_target(acc_common_td)
+add_mlir_dialect_tablegen_target(acc_common_td)
add_mlir_dialect(OpenACCOps acc)
@@ -12,19 +12,16 @@ add_dependencies(OpenACCDialectOpsDocGen acc_common_td)
set(LLVM_TARGET_DEFINITIONS OpenACCOps.td)
mlir_tablegen(OpenACCOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpenACCOpsEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIROpenACCEnumsIncGen)
-add_dependencies(mlir-headers MLIROpenACCEnumsIncGen)
+add_mlir_dialect_tablegen_target(MLIROpenACCEnumsIncGen)
set(LLVM_TARGET_DEFINITIONS OpenACCOps.td)
mlir_tablegen(OpenACCOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=acc)
mlir_tablegen(OpenACCOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=acc)
-add_public_tablegen_target(MLIROpenACCAttributesIncGen)
-add_dependencies(mlir-headers MLIROpenACCAttributesIncGen)
+add_mlir_dialect_tablegen_target(MLIROpenACCAttributesIncGen)
add_mlir_interface(OpenACCOpsInterfaces)
set(LLVM_TARGET_DEFINITIONS OpenACCTypeInterfaces.td)
mlir_tablegen(OpenACCTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(OpenACCTypeInterfaces.cpp.inc -gen-type-interface-defs)
-add_public_tablegen_target(MLIROpenACCTypeInterfacesIncGen)
-add_dependencies(mlir-headers MLIROpenACCTypeInterfacesIncGen)
+add_mlir_dialect_tablegen_target(MLIROpenACCTypeInterfacesIncGen)
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 18d5f2d..cfe73d8 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -1302,7 +1302,7 @@ def OpenACC_ReductionRecipeOp
let description = [{
Declares an OpenACC reduction recipe. The operation requires two
- mandatory regions.
+ mandatory regions and one optional region.
1. The initializer region specifies how to initialize the local reduction
value. The region has a first argument that contains the value of the
@@ -1313,6 +1313,8 @@ def OpenACC_ReductionRecipeOp
values of the reduction type into one. It has at least two arguments
and it is expected to `acc.yield` the combined value. Extra arguments
can be added to deal with dynamic arrays.
+ 3. The optional destroy region specifies how to destruct the value when it
+ reaches its end of life. It takes the reduction value as argument.
Example:
@@ -1329,6 +1331,10 @@ def OpenACC_ReductionRecipeOp
// two values into one.
%2 = arith.addi %0, %1 : i64
acc.yield %2 : i64
+ } destroy {
+ ^bb0(%0: i64)
+ // destroy region contains a sequence of operations to destruct the
+ // created copy.
}
// The reduction symbol is then used in the corresponding operation.
@@ -1362,12 +1368,14 @@ def OpenACC_ReductionRecipeOp
OpenACC_ReductionOperatorAttr:$reductionOperator);
let regions = (region AnyRegion:$initRegion,
- AnyRegion:$combinerRegion);
+ AnyRegion:$combinerRegion,
+ AnyRegion:$destroyRegion);
let assemblyFormat = [{
$sym_name `:` $type attr-dict-with-keyword
`reduction_operator` $reductionOperator
`init` $initRegion `combiner` $combinerRegion
+ (`destroy` $destroyRegion^)?
}];
let hasRegionVerifier = 1;
@@ -1536,6 +1544,15 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
/// Adds a private clause variable to this operation, including its recipe.
void addPrivatization(MLIRContext *, mlir::acc::PrivateOp op,
mlir::acc::PrivateRecipeOp recipe);
+ /// Adds a firstprivate clause variable to this operation, including its
+ /// recipe.
+ void addFirstPrivatization(MLIRContext *, mlir::acc::FirstprivateOp op,
+ mlir::acc::FirstprivateRecipeOp recipe);
+
+ /// Adds a reduction clause variable to this operation, including its
+ /// recipe.
+ void addReduction(MLIRContext *, mlir::acc::ReductionOp op,
+ mlir::acc::ReductionRecipeOp recipe);
}];
let assemblyFormat = [{
@@ -1681,6 +1698,14 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
/// Adds a private clause variable to this operation, including its recipe.
void addPrivatization(MLIRContext *, mlir::acc::PrivateOp op,
mlir::acc::PrivateRecipeOp recipe);
+ /// Adds a firstprivate clause variable to this operation, including its
+ /// recipe.
+ void addFirstPrivatization(MLIRContext *, mlir::acc::FirstprivateOp op,
+ mlir::acc::FirstprivateRecipeOp recipe);
+ /// Adds a reduction clause variable to this operation, including its
+ /// recipe.
+ void addReduction(MLIRContext *, mlir::acc::ReductionOp op,
+ mlir::acc::ReductionRecipeOp recipe);
}];
let assemblyFormat = [{
@@ -2407,6 +2432,10 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
/// Adds a private clause variable to this operation, including its recipe.
void addPrivatization(MLIRContext *, mlir::acc::PrivateOp op,
mlir::acc::PrivateRecipeOp recipe);
+ /// Adds a reduction clause variable to this operation, including its
+ /// recipe.
+ void addReduction(MLIRContext *, mlir::acc::ReductionOp op,
+ mlir::acc::ReductionRecipeOp recipe);
}];
let hasCustomAssemblyFormat = 1;
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenACC/Transforms/CMakeLists.txt
index ddbd583..9459f8b 100644
--- a/mlir/include/mlir/Dialect/OpenACC/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name OpenACC)
-add_public_tablegen_target(MLIROpenACCPassIncGen)
+add_mlir_dialect_tablegen_target(MLIROpenACCPassIncGen)
add_mlir_doc(Passes OpenACCPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
index a65c6b1..b6c8dba 100644
--- a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
@@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Frontend/OpenMP/OMP.td)
mlir_tablegen(OmpCommon.td --gen-directive-decl --directives-dialect=OpenMP)
-add_public_tablegen_target(omp_common_td)
+add_mlir_dialect_tablegen_target(omp_common_td)
set(LLVM_TARGET_DEFINITIONS OpenMPOps.td)
@@ -25,12 +25,11 @@ mlir_tablegen(OpenMPOpsEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(OpenMPOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=omp)
mlir_tablegen(OpenMPOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=omp)
add_mlir_doc(OpenMPOps OpenMPDialect Dialects/ -gen-dialect-doc -dialect=omp)
-add_public_tablegen_target(MLIROpenMPOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIROpenMPOpsIncGen)
add_dependencies(OpenMPDialectDocGen omp_common_td)
add_mlir_interface(OpenMPOpsInterfaces)
set(LLVM_TARGET_DEFINITIONS OpenMPTypeInterfaces.td)
mlir_tablegen(OpenMPTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(OpenMPTypeInterfaces.cpp.inc -gen-type-interface-defs)
-add_public_tablegen_target(MLIROpenMPTypeInterfacesIncGen)
-add_dependencies(mlir-generic-headers MLIROpenMPTypeInterfacesIncGen)
+add_mlir_generic_tablegen_target(MLIROpenMPTypeInterfacesIncGen)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPAttrDefs.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPAttrDefs.td
index 72ce4c6..c9e6764 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPAttrDefs.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPAttrDefs.td
@@ -42,10 +42,10 @@ def AtomicControlAttr : OpenMP_Attr<"AtomicControl", "atomic_control"> {
//===----------------------------------------------------------------------===//
def DeclareTargetAttr : OpenMP_Attr<"DeclareTarget", "declaretarget"> {
- let parameters = (ins
- OptionalParameter<"DeclareTargetDeviceTypeAttr">:$device_type,
- OptionalParameter<"DeclareTargetCaptureClauseAttr">:$capture_clause
- );
+ let parameters =
+ (ins OptionalParameter<"DeclareTargetDeviceTypeAttr">:$device_type,
+ OptionalParameter<"DeclareTargetCaptureClauseAttr">:$capture_clause,
+ OptionalParameter<"BoolAttr">:$automap);
let assemblyFormat = "`<` struct(params) `>`";
}
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index be114ea..2548a8a 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -2115,4 +2115,121 @@ def AllocateDirOp : OpenMP_Op<"allocate_dir", clauses = [
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// TargetAllocMemOp
+//===----------------------------------------------------------------------===//
+
+def TargetAllocMemOp : OpenMP_Op<"target_allocmem",
+ [MemoryEffects<[MemAlloc<DefaultResource>]>, AttrSizedOperandSegments]> {
+ let summary = "allocate storage on an openmp device for an object of a given type";
+
+ let description = [{
+ Allocates memory on the specified OpenMP device for an object of the given type.
+ Returns an integer value representing the device pointer to the allocated memory.
+ The memory is uninitialized after allocation. Operations must be paired with
+ `omp.target_freemem` to avoid memory leaks.
+
+ * `$device`: The integer ID of the OpenMP device where the memory will be allocated.
+ * `$in_type`: The type of the object for which memory is being allocated.
+ For arrays, this can be a static or dynamic array type.
+ * `$uniq_name`: An optional unique name for the allocated memory.
+ * `$bindc_name`: An optional name used for C interoperability.
+ * `$typeparams`: Runtime type parameters for polymorphic or parameterized types.
+ These are typically integer values that define aspects of a type not fixed at compile time.
+ * `$shape`: Runtime shape operands for dynamic arrays.
+ Each operand is an integer value representing the extent of a specific dimension.
+
+ ```mlir
+ // Allocate a static 3x3 integer vector on device 0
+ %device_0 = arith.constant 0 : i32
+ %ptr_static = omp.target_allocmem %device_0 : i32, vector<3x3xi32>
+ // ... use %ptr_static ...
+ omp.target_freemem %device_0, %ptr_static : i32, i64
+
+ // Allocate a dynamic 2D Fortran array (fir.array) on device 1
+ %device_1 = arith.constant 1 : i32
+ %rows = arith.constant 10 : index
+ %cols = arith.constant 20 : index
+ %ptr_dynamic = omp.target_allocmem %device_1 : i32, !fir.array<?x?xf32>, %rows, %cols : index, index
+ // ... use %ptr_dynamic ...
+ omp.target_freemem %device_1, %ptr_dynamic : i32, i64
+ ```
+ }];
+
+ let arguments = (ins
+ Arg<AnyInteger>:$device,
+ TypeAttr:$in_type,
+ OptionalAttr<StrAttr>:$uniq_name,
+ OptionalAttr<StrAttr>:$bindc_name,
+ Variadic<IntLikeType>:$typeparams,
+ Variadic<IntLikeType>:$shape
+ );
+ let results = (outs I64);
+
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ mlir::Type getAllocatedType();
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// TargetFreeMemOp
+//===----------------------------------------------------------------------===//
+
+def TargetFreeMemOp : OpenMP_Op<"target_freemem",
+ [MemoryEffects<[MemFree]>]> {
+ let summary = "free memory on an openmp device";
+
+ let description = [{
+ Deallocates memory on the specified OpenMP device that was previously
+ allocated by an `omp.target_allocmem` operation. After this operation, the
+ deallocated memory is in an undefined state and should not be accessed.
+ It is crucial to ensure that all accesses to the memory region are completed
+ before `omp.target_freemem` is called to avoid undefined behavior.
+
+ * `$device`: The integer ID of the OpenMP device from which the memory will be freed.
+ * `$heapref`: The integer value representing the device pointer to the memory
+ to be deallocated, which was previously returned by `omp.target_allocmem`.
+
+ ```mlir
+ // Example of allocating and freeing memory on an OpenMP device
+ %device_id = arith.constant 0 : i32
+ %allocated_ptr = omp.target_allocmem %device_id : i32, vector<3x3xi32>
+ // ... operations using %allocated_ptr on the device ...
+ omp.target_freemem %device_id, %allocated_ptr : i32, i64
+ ```
+ }];
+
+ let arguments = (ins
+ Arg<AnyInteger, "", [MemFree]>:$device,
+ Arg<I64, "", [MemFree]>:$heapref
+ );
+ let assemblyFormat = "$device `,` $heapref attr-dict `:` type($device) `,` qualified(type($heapref))";
+}
+
+//===----------------------------------------------------------------------===//
+// workdistribute Construct
+//===----------------------------------------------------------------------===//
+
+def WorkdistributeOp : OpenMP_Op<"workdistribute"> {
+ let summary = "workdistribute directive";
+ let description = [{
+ workdistribute divides execution of the enclosed structured block into
+ separate units of work, each executed only once by each
+ initial thread in the league.
+ ```
+ !$omp target teams
+ !$omp workdistribute
+ y = a * x + y
+ !$omp end workdistribute
+ !$omp end target teams
+ ```
+ }];
+ let regions = (region AnyRegion:$region);
+ let hasVerifier = 1;
+ let assemblyFormat = "$region attr-dict";
+}
+
#endif // OPENMP_OPS
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 0dc385b..d471e6c 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -329,14 +329,16 @@ def DeclareTargetInterface : OpInterface<"DeclareTargetInterface"> {
/*retTy=*/"void",
/*methodName=*/"setDeclareTarget",
(ins "mlir::omp::DeclareTargetDeviceType":$deviceType,
- "mlir::omp::DeclareTargetCaptureClause":$captureClause), [{}], [{
+ "mlir::omp::DeclareTargetCaptureClause":$captureClause,
+ "bool":$automap), [{}], [{
$_op->setAttr("omp.declare_target",
mlir::omp::DeclareTargetAttr::get(
$_op->getContext(),
mlir::omp::DeclareTargetDeviceTypeAttr::get(
$_op->getContext(), deviceType),
mlir::omp::DeclareTargetCaptureClauseAttr::get(
- $_op->getContext(), captureClause)));
+ $_op->getContext(), captureClause),
+ mlir::BoolAttr::get($_op->getContext(), automap)));
}]>,
InterfaceMethod<
/*description=*/[{
@@ -374,6 +376,19 @@ def DeclareTargetInterface : OpInterface<"DeclareTargetInterface"> {
if (auto dAttr = llvm::dyn_cast_or_null<mlir::omp::DeclareTargetAttr>(dTar))
return dAttr.getCaptureClause().getValue();
return {};
+ }]>,
+ InterfaceMethod<
+ /*description=*/[{
+ Return true if the DeclareTarget attribute has the AUTOMAP modifier.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"getDeclareTargetAutomap",
+ (ins), [{}], [{
+ if (mlir::Attribute dTar = $_op->getAttr("omp.declare_target"))
+ if (auto dAttr = llvm::dyn_cast_or_null<mlir::omp::DeclareTargetAttr>(dTar))
+ if (auto autoVal = dAttr.getAutomap())
+ return autoVal.getValue();
+ return false;
}]>
];
}
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt
index 7c94d52..388d735 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt
@@ -4,16 +4,14 @@ add_mlir_doc(PtrOps PtrOps Dialects/ -gen-dialect-doc -dialect=ptr)
set(LLVM_TARGET_DEFINITIONS PtrOps.td)
mlir_tablegen(PtrOpsAttrs.h.inc -gen-attrdef-decls -attrdefs-dialect=ptr)
mlir_tablegen(PtrOpsAttrs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=ptr)
-add_public_tablegen_target(MLIRPtrOpsAttributesIncGen)
+add_mlir_dialect_tablegen_target(MLIRPtrOpsAttributesIncGen)
set(LLVM_TARGET_DEFINITIONS MemorySpaceInterfaces.td)
-mlir_tablegen(MemorySpaceInterfaces.h.inc -gen-op-interface-decls)
-mlir_tablegen(MemorySpaceInterfaces.cpp.inc -gen-op-interface-defs)
mlir_tablegen(MemorySpaceAttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(MemorySpaceAttrInterfaces.cpp.inc -gen-attr-interface-defs)
-add_public_tablegen_target(MLIRPtrMemorySpaceInterfacesIncGen)
+add_mlir_dialect_tablegen_target(MLIRPtrMemorySpaceInterfacesIncGen)
set(LLVM_TARGET_DEFINITIONS PtrOps.td)
mlir_tablegen(PtrOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(PtrOpsEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRPtrOpsEnumsGen)
+add_mlir_dialect_tablegen_target(MLIRPtrOpsEnumsGen)
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h
index a046755..4d65c8d 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h
+++ b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h
@@ -17,16 +17,18 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/OpDefinition.h"
+#include <functional>
+#include <optional>
+
namespace mlir {
class Operation;
+class DataLayout;
namespace ptr {
-enum class AtomicBinOp : uint64_t;
-enum class AtomicOrdering : uint64_t;
+enum class AtomicBinOp : uint32_t;
+enum class AtomicOrdering : uint32_t;
} // namespace ptr
} // namespace mlir
#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.h.inc"
-#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h.inc"
-
#endif // MLIR_DIALECT_PTR_IR_MEMORYSPACEINTERFACES_H
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td
index 54efeb0..5231231 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td
@@ -42,7 +42,8 @@ def MemorySpaceAttrInterface : AttrInterface<"MemorySpaceAttrInterface"> {
/*methodName=*/ "isValidLoad",
/*args=*/ (ins "::mlir::Type":$type,
"::mlir::ptr::AtomicOrdering":$ordering,
- "::mlir::IntegerAttr":$alignment,
+ "std::optional<int64_t>":$alignment,
+ "const ::mlir::DataLayout *":$dataLayout,
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
>,
InterfaceMethod<
@@ -57,7 +58,8 @@ def MemorySpaceAttrInterface : AttrInterface<"MemorySpaceAttrInterface"> {
/*methodName=*/ "isValidStore",
/*args=*/ (ins "::mlir::Type":$type,
"::mlir::ptr::AtomicOrdering":$ordering,
- "::mlir::IntegerAttr":$alignment,
+ "std::optional<int64_t>":$alignment,
+ "const ::mlir::DataLayout *":$dataLayout,
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
>,
InterfaceMethod<
@@ -73,7 +75,8 @@ def MemorySpaceAttrInterface : AttrInterface<"MemorySpaceAttrInterface"> {
/*args=*/ (ins "::mlir::ptr::AtomicBinOp":$op,
"::mlir::Type":$type,
"::mlir::ptr::AtomicOrdering":$ordering,
- "::mlir::IntegerAttr":$alignment,
+ "std::optional<int64_t>":$alignment,
+ "const ::mlir::DataLayout *":$dataLayout,
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
>,
InterfaceMethod<
@@ -90,7 +93,8 @@ def MemorySpaceAttrInterface : AttrInterface<"MemorySpaceAttrInterface"> {
/*args=*/ (ins "::mlir::Type":$type,
"::mlir::ptr::AtomicOrdering":$successOrdering,
"::mlir::ptr::AtomicOrdering":$failureOrdering,
- "::mlir::IntegerAttr":$alignment,
+ "std::optional<int64_t>":$alignment,
+ "const ::mlir::DataLayout *":$dataLayout,
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
>,
InterfaceMethod<
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h
index dc0a3ff..bb01cea 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h
@@ -19,10 +19,9 @@
#include "llvm/Support/TypeSize.h"
#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
+#include "mlir/Dialect/Ptr/IR/PtrEnums.h"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.h.inc"
-#include "mlir/Dialect/Ptr/IR/PtrOpsEnums.h.inc"
-
#endif // MLIR_DIALECT_PTR_IR_PTRATTRS_H
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.h
new file mode 100644
index 0000000..2e98df8
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.h
@@ -0,0 +1,21 @@
+//===- PtrEnums.h - `ptr` dialect enums -------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the `ptr` dialect enums.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_PTR_IR_PTRENUMS_H
+#define MLIR_DIALECT_PTR_IR_PTRENUMS_H
+
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+#include "mlir/Dialect/Ptr/IR/PtrOpsEnums.h.inc"
+
+#endif // MLIR_DIALECT_PTR_IR_PTRENUMS_H
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td
index cc556c6..c169f48 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td
@@ -15,25 +15,25 @@ include "mlir/IR/EnumAttr.td"
// Atomic binary op enum attribute.
//===----------------------------------------------------------------------===//
-def AtomicBinOpXchg : I64EnumAttrCase<"xchg", 0, "xchg">;
-def AtomicBinOpAdd : I64EnumAttrCase<"add", 1, "add">;
-def AtomicBinOpSub : I64EnumAttrCase<"sub", 2, "sub">;
-def AtomicBinOpAnd : I64EnumAttrCase<"_and", 3, "_and">;
-def AtomicBinOpNand : I64EnumAttrCase<"nand", 4, "nand">;
-def AtomicBinOpOr : I64EnumAttrCase<"_or", 5, "_or">;
-def AtomicBinOpXor : I64EnumAttrCase<"_xor", 6, "_xor">;
-def AtomicBinOpMax : I64EnumAttrCase<"max", 7, "max">;
-def AtomicBinOpMin : I64EnumAttrCase<"min", 8, "min">;
-def AtomicBinOpUMax : I64EnumAttrCase<"umax", 9, "umax">;
-def AtomicBinOpUMin : I64EnumAttrCase<"umin", 10, "umin">;
-def AtomicBinOpFAdd : I64EnumAttrCase<"fadd", 11, "fadd">;
-def AtomicBinOpFSub : I64EnumAttrCase<"fsub", 12, "fsub">;
-def AtomicBinOpFMax : I64EnumAttrCase<"fmax", 13, "fmax">;
-def AtomicBinOpFMin : I64EnumAttrCase<"fmin", 14, "fmin">;
-def AtomicBinOpUIncWrap : I64EnumAttrCase<"uinc_wrap", 15, "uinc_wrap">;
-def AtomicBinOpUDecWrap : I64EnumAttrCase<"udec_wrap", 16, "udec_wrap">;
+def AtomicBinOpXchg : I32EnumCase<"xchg", 0, "xchg">;
+def AtomicBinOpAdd : I32EnumCase<"add", 1, "add">;
+def AtomicBinOpSub : I32EnumCase<"sub", 2, "sub">;
+def AtomicBinOpAnd : I32EnumCase<"_and", 3, "_and">;
+def AtomicBinOpNand : I32EnumCase<"nand", 4, "nand">;
+def AtomicBinOpOr : I32EnumCase<"_or", 5, "_or">;
+def AtomicBinOpXor : I32EnumCase<"_xor", 6, "_xor">;
+def AtomicBinOpMax : I32EnumCase<"max", 7, "max">;
+def AtomicBinOpMin : I32EnumCase<"min", 8, "min">;
+def AtomicBinOpUMax : I32EnumCase<"umax", 9, "umax">;
+def AtomicBinOpUMin : I32EnumCase<"umin", 10, "umin">;
+def AtomicBinOpFAdd : I32EnumCase<"fadd", 11, "fadd">;
+def AtomicBinOpFSub : I32EnumCase<"fsub", 12, "fsub">;
+def AtomicBinOpFMax : I32EnumCase<"fmax", 13, "fmax">;
+def AtomicBinOpFMin : I32EnumCase<"fmin", 14, "fmin">;
+def AtomicBinOpUIncWrap : I32EnumCase<"uinc_wrap", 15, "uinc_wrap">;
+def AtomicBinOpUDecWrap : I32EnumCase<"udec_wrap", 16, "udec_wrap">;
-def AtomicBinOp : I64EnumAttr<
+def AtomicBinOp : I32Enum<
"AtomicBinOp",
"ptr.atomicrmw binary operations",
[AtomicBinOpXchg, AtomicBinOpAdd, AtomicBinOpSub, AtomicBinOpAnd,
@@ -48,15 +48,15 @@ def AtomicBinOp : I64EnumAttr<
// Atomic ordering enum attribute.
//===----------------------------------------------------------------------===//
-def AtomicOrderingNotAtomic : I64EnumAttrCase<"not_atomic", 0, "not_atomic">;
-def AtomicOrderingUnordered : I64EnumAttrCase<"unordered", 1, "unordered">;
-def AtomicOrderingMonotonic : I64EnumAttrCase<"monotonic", 2, "monotonic">;
-def AtomicOrderingAcquire : I64EnumAttrCase<"acquire", 3, "acquire">;
-def AtomicOrderingRelease : I64EnumAttrCase<"release", 4, "release">;
-def AtomicOrderingAcqRel : I64EnumAttrCase<"acq_rel", 5, "acq_rel">;
-def AtomicOrderingSeqCst : I64EnumAttrCase<"seq_cst", 6, "seq_cst">;
+def AtomicOrderingNotAtomic : I32EnumCase<"not_atomic", 0, "not_atomic">;
+def AtomicOrderingUnordered : I32EnumCase<"unordered", 1, "unordered">;
+def AtomicOrderingMonotonic : I32EnumCase<"monotonic", 2, "monotonic">;
+def AtomicOrderingAcquire : I32EnumCase<"acquire", 3, "acquire">;
+def AtomicOrderingRelease : I32EnumCase<"release", 4, "release">;
+def AtomicOrderingAcqRel : I32EnumCase<"acq_rel", 5, "acq_rel">;
+def AtomicOrderingSeqCst : I32EnumCase<"seq_cst", 6, "seq_cst">;
-def AtomicOrdering : I64EnumAttr<
+def AtomicOrdering : I32Enum<
"AtomicOrdering",
"Atomic ordering for LLVM's memory model",
[AtomicOrderingNotAtomic, AtomicOrderingUnordered, AtomicOrderingMonotonic,
@@ -66,6 +66,8 @@ def AtomicOrdering : I64EnumAttr<
let cppNamespace = "::mlir::ptr";
}
+def AtomicOrderingProp : EnumProp<AtomicOrdering>;
+
//===----------------------------------------------------------------------===//
// Ptr add flags enum properties.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 440f6e5..1c88efc 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -119,6 +119,133 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
}
//===----------------------------------------------------------------------===//
+// LoadOp
+//===----------------------------------------------------------------------===//
+
+def AlignmentProp : OptionalProp<I64Prop>;
+
+def Ptr_LoadOp : Pointer_Op<"load", [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
+ ]> {
+ let description = [{
+ The `load` operation is used to read from memory. A load may be marked as
+ atomic, volatile, and/or nontemporal.
+
+ An atomic load only supports a limited set of value types, and requires
+ an explicit alignment.
+
+ Examples:
+ ```mlir
+ // A volatile load of a float variable.
+ %0 = ptr.load volatile %ptr : !ptr.ptr -> f32
+
+ // A nontemporal load of a float variable.
+ %0 = ptr.load %ptr nontemporal : !ptr.ptr -> f32
+
+ // An atomic load of an integer variable.
+ %0 = ptr.load %ptr atomic monotonic alignment = 8 : !ptr.ptr -> i64
+ ```
+
+ See the following link for more details on the meaning of `alignment`,
+ `volatile_`, `nontemporal`, `invariant`, `invariant_group`, `ordering`,
+ and `syncscope`:
+ https://llvm.org/docs/LangRef.html#load-instruction
+ }];
+ let arguments = (ins Ptr_PtrType:$ptr,
+ AlignmentProp:$alignment,
+ UnitProp:$volatile_,
+ UnitProp:$nontemporal,
+ UnitProp:$invariant,
+ UnitProp:$invariantGroup,
+ DefaultValuedProp<
+ AtomicOrderingProp,
+ "AtomicOrdering::not_atomic">:$ordering,
+ OptionalAttr<StrAttr>:$syncscope);
+ let results = (outs AnyType:$value);
+ let assemblyFormat = [{
+ (`volatile` $volatile_^)? $ptr
+ (`atomic` (`syncscope` `(` $syncscope^ `)`)? $ordering^)?
+ oilist(
+ `nontemporal` $nontemporal |
+ `invariant` $invariant |
+ `invariant_group` $invariantGroup |
+ `alignment` `=` $alignment
+ )
+ attr-dict `:` qualified(type($ptr)) `->` type($value)
+ }];
+ let builders = [
+ OpBuilder<(ins "Type":$type, "Value":$ptr,
+ CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile,
+ CArg<"bool", "false">:$isNonTemporal, CArg<"bool", "false">:$isInvariant,
+ CArg<"bool", "false">:$isInvariantGroup,
+ CArg<"AtomicOrdering", "AtomicOrdering::not_atomic">:$ordering,
+ CArg<"StringRef", "StringRef()">:$syncscope)>
+ ];
+ let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// StoreOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_StoreOp : Pointer_Op<"store", [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
+ ]> {
+ let description = [{
+ The `store` operation is used to write to memory. A store may be marked as
+ atomic, volatile, and/or nontemporal.
+
+ An atomic store only supports a limited set of value types, and requires
+ an explicit alignment.
+
+ Examples:
+ ```mlir
+ // A volatile store of a float variable.
+ ptr.store volatile %val, %ptr : f32, !ptr.ptr
+
+ // A nontemporal store of a float variable.
+ ptr.store %val, %ptr nontemporal : f32, !ptr.ptr
+
+ // An atomic store of an integer variable.
+ ptr.store %val, %ptr atomic monotonic alignment = 8: i64, !ptr.ptr
+ ```
+
+ See the following link for more details on the meaning of `alignment`,
+ `volatile_`, `nontemporal`, `invariant_group`, `ordering`, and `syncscope`:
+ https://llvm.org/docs/LangRef.html#store-instruction
+ }];
+ let arguments = (ins AnyType:$value,
+ Ptr_PtrType:$ptr,
+ AlignmentProp:$alignment,
+ UnitProp:$volatile_,
+ UnitProp:$nontemporal,
+ UnitProp:$invariantGroup,
+ DefaultValuedProp<
+ AtomicOrderingProp,
+ "AtomicOrdering::not_atomic">:$ordering,
+ OptionalAttr<StrAttr>:$syncscope);
+ let assemblyFormat = [{
+ (`volatile` $volatile_^)? $value `,` $ptr
+ (`atomic` (`syncscope` `(` $syncscope^ `)`)? $ordering^)?
+ oilist(
+ `nontemporal` $nontemporal |
+ `invariant_group` $invariantGroup |
+ `alignment` `=` $alignment
+ )
+ attr-dict `:` type($value) `,` qualified(type($ptr))
+ }];
+ let builders = [
+ OpBuilder<(ins "Value":$value, "Value":$ptr,
+ CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile,
+ CArg<"bool", "false">:$isNonTemporal,
+ CArg<"bool", "false">:$isInvariantGroup,
+ CArg<"AtomicOrdering", "AtomicOrdering::not_atomic">:$ordering,
+ CArg<"StringRef", "StringRef()">:$syncscope)>
+ ];
+ let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
// ToPtrOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt
index c08f399..120ddaf 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt
@@ -3,4 +3,4 @@ add_mlir_doc(QuantOps QuantDialect Dialects/ -gen-dialect-doc)
set(LLVM_TARGET_DEFINITIONS QuantDialectBytecode.td)
mlir_tablegen(QuantDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Quant")
-add_public_tablegen_target(MLIRQuantDialectBytecodeIncGen)
+add_mlir_dialect_tablegen_target(MLIRQuantDialectBytecodeIncGen)
diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt
index 30f7c16..a920f84 100644
--- a/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Quant)
-add_public_tablegen_target(MLIRQuantTransformsIncGen)
+add_mlir_dialect_tablegen_target(MLIRQuantTransformsIncGen)
add_mlir_doc(Passes QuantPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/SCF/IR/CMakeLists.txt
index 1be5f91..f6dfc0e 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/SCF/IR/CMakeLists.txt
@@ -6,5 +6,4 @@ mlir_tablegen(DeviceMappingAttrInterface.h.inc -gen-attr-interface-decls)
mlir_tablegen(DeviceMappingAttrInterface.cpp.inc -gen-attr-interface-defs)
mlir_tablegen(DeviceMappingAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(DeviceMappingAttributes.cpp.inc -gen-attrdef-defs)
-add_public_tablegen_target(MLIRDeviceMappingInterfacesIncGen)
-add_dependencies(mlir-generic-headers MLIRDeviceMappingInterfacesIncGen)
+add_mlir_generic_tablegen_target(MLIRDeviceMappingInterfacesIncGen)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 0c1c15b..88df541 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -169,9 +169,13 @@ def ForOp : SCF_Op<"for",
region capturing the loop body. The induction variable is represented as an
argument of this region. This SSA value is a signless integer or index.
The step is a value of same type but required to be positive, the lower and
- upper bounds can be also negative or zero. The lower and upper bounds specify
- a half-open range: the iteration is executed iff the signed comparison of induction
- variable value is less than the upper bound and bigger or equal to the lower bound.
+ upper bounds can be also negative or zero. The lower and upper bounds
+ specify a half-open range: the iteration is executed iff the comparison of
+ induction variable value is less than the upper bound and bigger or equal
+ to the lower bound.
+
+ By default, the integer comparison is signed. If the `unsignedCmp` unit
+ attribute is specified, the integer comparison is unsigned.
The body region must contain exactly one block that terminates with
`scf.yield`. Calling ForOp::build will create such a region and insert
@@ -184,8 +188,8 @@ def ForOp : SCF_Op<"for",
... // body
}
...
- // Integer case.
- scf.for %iv_32 = %lb_32 to %ub_32 step %step_32 : i32 {
+ // Unsigned integer case.
+ scf.for unsigned %iv_32 = %lb_32 to %ub_32 step %step_32 : i32 {
... // body
}
```
@@ -258,7 +262,8 @@ def ForOp : SCF_Op<"for",
let arguments = (ins AnySignlessIntegerOrIndex:$lowerBound,
AnySignlessIntegerOrIndex:$upperBound,
AnySignlessIntegerOrIndex:$step,
- Variadic<AnyType>:$initArgs);
+ Variadic<AnyType>:$initArgs,
+ UnitAttr:$unsignedCmp);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);
@@ -266,7 +271,8 @@ def ForOp : SCF_Op<"for",
let builders = [OpBuilder<(ins "Value":$lowerBound, "Value":$upperBound,
"Value":$step, CArg<"ValueRange", "{}">:$initArgs,
CArg<"function_ref<void(OpBuilder &, Location, Value, ValueRange)>",
- "nullptr">)>];
+ "nullptr">,
+ CArg<"bool", "false">:$unsignedCmp)>];
let extraClassDeclaration = [{
using BodyBuilderFn =
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/SCF/TransformOps/CMakeLists.txt
index 9095b1f..97d7b04 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/CMakeLists.txt
@@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS SCFTransformOps.td)
mlir_tablegen(SCFTransformOps.h.inc -gen-op-decls)
mlir_tablegen(SCFTransformOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRSCFTransformOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRSCFTransformOpsIncGen)
add_mlir_doc(SCFTransformOps SCFLoopTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/SCF/Transforms/CMakeLists.txt
index 1192bad..c1fda9e 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/CMakeLists.txt
@@ -1,7 +1,6 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name SCF)
-add_public_tablegen_target(MLIRSCFPassIncGen)
-add_dependencies(mlir-headers MLIRSCFPassIncGen)
+add_mlir_dialect_tablegen_target(MLIRSCFPassIncGen)
add_mlir_doc(Passes SCFPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index e620067..ecd829e 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -213,6 +213,14 @@ scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
FailureOr<scf::ForallOp> normalizeForallOp(RewriterBase &rewriter,
scf::ForallOp forallOp);
+/// Check if the provided loops are perfectly nested for-loops. Perfect nesting
+/// means:
+/// 1. All loops are scf.for operations
+/// 2. Each outer loop's region iter args match the inner loop's init args
+/// 3. Each outer loop's yields match the inner loop's results
+/// 4. Each region iter arg and result has exactly one use
+bool isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops);
+
} // namespace mlir
#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
diff --git a/mlir/include/mlir/Dialect/SMT/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/SMT/IR/CMakeLists.txt
index 1455551..9d50bc7 100644
--- a/mlir/include/mlir/Dialect/SMT/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/SMT/IR/CMakeLists.txt
@@ -4,12 +4,10 @@ set(LLVM_TARGET_DEFINITIONS SMT.td)
mlir_tablegen(SMTAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(SMTAttributes.cpp.inc -gen-attrdef-defs)
-add_public_tablegen_target(MLIRSMTAttrIncGen)
-add_dependencies(mlir-headers MLIRSMTAttrIncGen)
+add_mlir_dialect_tablegen_target(MLIRSMTAttrIncGen)
mlir_tablegen(SMTEnums.h.inc -gen-enum-decls)
mlir_tablegen(SMTEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRSMTEnumsIncGen)
-add_dependencies(mlir-headers MLIRSMTEnumsIncGen)
+add_mlir_dialect_tablegen_target(MLIRSMTEnumsIncGen)
add_mlir_doc(SMT SMT Dialects/ -gen-dialect-doc)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/SPIRV/IR/CMakeLists.txt
index de3148d..cbdc0b2c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/CMakeLists.txt
@@ -4,35 +4,29 @@ add_mlir_doc(SPIRVOps SPIRVOps Dialects/ -gen-op-doc)
set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls)
mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRSPIRVEnumsIncGen)
-add_dependencies(mlir-headers MLIRSPIRVEnumsIncGen)
+add_mlir_dialect_tablegen_target(MLIRSPIRVEnumsIncGen)
set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
mlir_tablegen(SPIRVEnumAvailability.h.inc -gen-spirv-enum-avail-decls)
mlir_tablegen(SPIRVEnumAvailability.cpp.inc -gen-spirv-enum-avail-defs)
mlir_tablegen(SPIRVCapabilityImplication.inc -gen-spirv-capability-implication)
-add_public_tablegen_target(MLIRSPIRVEnumAvailabilityIncGen)
-add_dependencies(mlir-headers MLIRSPIRVEnumAvailabilityIncGen)
+add_mlir_dialect_tablegen_target(MLIRSPIRVEnumAvailabilityIncGen)
set(LLVM_TARGET_DEFINITIONS SPIRVOps.td)
mlir_tablegen(SPIRVAvailability.h.inc -gen-avail-interface-decls)
mlir_tablegen(SPIRVAvailability.cpp.inc -gen-avail-interface-defs)
mlir_tablegen(SPIRVOpAvailabilityImpl.inc -gen-spirv-avail-impls)
-add_public_tablegen_target(MLIRSPIRVAvailabilityIncGen)
-add_dependencies(mlir-headers MLIRSPIRVAvailabilityIncGen)
+add_mlir_dialect_tablegen_target(MLIRSPIRVAvailabilityIncGen)
set(LLVM_TARGET_DEFINITIONS SPIRVOps.td)
mlir_tablegen(SPIRVSerialization.inc -gen-spirv-serialization)
-add_public_tablegen_target(MLIRSPIRVSerializationGen)
-add_dependencies(mlir-headers MLIRSPIRVSerializationGen)
+add_mlir_dialect_tablegen_target(MLIRSPIRVSerializationGen)
set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
mlir_tablegen(SPIRVAttrUtils.inc -gen-spirv-attr-utils)
-add_public_tablegen_target(MLIRSPIRVAttrUtilsGen)
-add_dependencies(mlir-headers MLIRSPIRVAttrUtilsGen)
+add_mlir_dialect_tablegen_target(MLIRSPIRVAttrUtilsGen)
set(LLVM_TARGET_DEFINITIONS SPIRVAttributes.td)
mlir_tablegen(SPIRVAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(SPIRVAttributes.cpp.inc -gen-attrdef-defs)
-add_public_tablegen_target(MLIRSPIRVAttributeIncGen)
-add_dependencies(mlir-headers MLIRSPIRVAttributeIncGen)
+add_mlir_dialect_tablegen_target(MLIRSPIRVAttributeIncGen)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
index aad5017..6253601 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
@@ -220,7 +220,7 @@ def SPIRV_LoadOp : SPIRV_Op<"Load", []> {
let arguments = (ins
SPIRV_AnyPtr:$ptr,
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access,
- OptionalAttr<I32Attr>:$alignment
+ OptionalAttr<IntValidAlignment<I32Attr>>:$alignment
);
let results = (outs
@@ -345,7 +345,7 @@ def SPIRV_StoreOp : SPIRV_Op<"Store", []> {
SPIRV_AnyPtr:$ptr,
SPIRV_Type:$value,
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access,
- OptionalAttr<I32Attr>:$alignment
+ OptionalAttr<IntValidAlignment<I32Attr>>:$alignment
);
let results = (outs);
diff --git a/mlir/include/mlir/Dialect/SPIRV/Interfaces/SPIRVImageInterfaces.h b/mlir/include/mlir/Dialect/SPIRV/Interfaces/SPIRVImageInterfaces.h
index ca5a1dc..0c70d4f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Interfaces/SPIRVImageInterfaces.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Interfaces/SPIRVImageInterfaces.h
@@ -9,6 +9,8 @@
#ifndef MLIR_DIALECT_SPIRV_IMAGE_INTERFACES_H_
#define MLIR_DIALECT_SPIRV_IMAGE_INTERFACES_H_
+#include "mlir/IR/OpDefinition.h"
+
#include "mlir/Dialect/SPIRV/Interfaces/SPIRVImageInterfaces.h.inc"
#endif // MLIR_DIALECT_SPIRV_IMAGE_INTERFACES_H_
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/SPIRV/Transforms/CMakeLists.txt
index 0459420..d0990c7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/CMakeLists.txt
@@ -1,7 +1,6 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name SPIRV)
-add_public_tablegen_target(MLIRSPIRVPassIncGen)
-add_dependencies(mlir-headers MLIRSPIRVPassIncGen)
+add_mlir_dialect_tablegen_target(MLIRSPIRVPassIncGen)
add_mlir_doc(Passes SPIRVPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Shape/Transforms/CMakeLists.txt
index 83a500d..991bd68 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Shape)
-add_public_tablegen_target(MLIRShapeTransformsIncGen)
+add_mlir_dialect_tablegen_target(MLIRShapeTransformsIncGen)
add_mlir_doc(Passes ShapePasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Shard/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Shard/IR/CMakeLists.txt
index a2495af..4432e01 100644
--- a/mlir/include/mlir/Dialect/Shard/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Shard/IR/CMakeLists.txt
@@ -21,5 +21,4 @@ set(LLVM_TARGET_DEFINITIONS ShardOps.td)
mlir_tablegen(ShardOps.h.inc -gen-op-decls)
mlir_tablegen(ShardOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRShardIncGen)
-add_dependencies(mlir-headers MLIRShardIncGen)
+add_mlir_dialect_tablegen_target(MLIRShardIncGen)
diff --git a/mlir/include/mlir/Dialect/Shard/Interfaces/CMakeLists.txt b/mlir/include/mlir/Dialect/Shard/Interfaces/CMakeLists.txt
index b3a44f3..29ce6e4 100644
--- a/mlir/include/mlir/Dialect/Shard/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Shard/Interfaces/CMakeLists.txt
@@ -1,4 +1,4 @@
set(LLVM_TARGET_DEFINITIONS ShardingInterface.td)
mlir_tablegen(ShardingInterface.h.inc -gen-op-interface-decls)
mlir_tablegen(ShardingInterface.cpp.inc -gen-op-interface-defs)
-add_public_tablegen_target(MLIRShardingInterfaceIncGen)
+add_mlir_dialect_tablegen_target(MLIRShardingInterfaceIncGen)
diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Shard/Transforms/CMakeLists.txt
index 9e2c8d0..840ff9e 100644
--- a/mlir/include/mlir/Dialect/Shard/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/CMakeLists.txt
@@ -1,6 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Shard)
-add_public_tablegen_target(MLIRShardPassIncGen)
-add_dependencies(mlir-headers MLIRShardPassIncGen)
+add_mlir_dialect_tablegen_target(MLIRShardPassIncGen)
add_mlir_doc(Passes ShardPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt
index 54ad949..12048b4 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt
@@ -6,15 +6,14 @@ mlir_tablegen(SparseTensorAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(SparseTensorAttrDefs.cpp.inc -gen-attrdef-defs)
mlir_tablegen(SparseTensorAttrEnums.h.inc -gen-enum-decls)
mlir_tablegen(SparseTensorAttrEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRSparseTensorAttrDefsIncGen)
+add_mlir_dialect_tablegen_target(MLIRSparseTensorAttrDefsIncGen)
set(LLVM_TARGET_DEFINITIONS SparseTensorTypes.td)
mlir_tablegen(SparseTensorTypes.h.inc -gen-typedef-decls)
mlir_tablegen(SparseTensorTypes.cpp.inc -gen-typedef-defs)
-add_public_tablegen_target(MLIRSparseTensorTypesIncGen)
+add_mlir_dialect_tablegen_target(MLIRSparseTensorTypesIncGen)
set(LLVM_TARGET_DEFINITIONS SparseTensorInterfaces.td)
mlir_tablegen(SparseTensorInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(SparseTensorInterfaces.cpp.inc -gen-op-interface-defs)
-add_public_tablegen_target(MLIRSparseTensorInterfacesIncGen)
-add_dependencies(mlir-headers MLIRSparseTensorInterfacesIncGen)
+add_mlir_dialect_tablegen_target(MLIRSparseTensorInterfacesIncGen)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/SparseTensor/TransformOps/CMakeLists.txt
index dc29a0b..9871d7b03 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/TransformOps/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/SparseTensor/TransformOps/CMakeLists.txt
@@ -1,4 +1,4 @@
set(LLVM_TARGET_DEFINITIONS SparseTensorTransformOps.td)
mlir_tablegen(SparseTensorTransformOps.h.inc -gen-op-decls)
mlir_tablegen(SparseTensorTransformOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRSparseTensorTransformOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRSparseTensorTransformOpsIncGen)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/SparseTensor/Transforms/CMakeLists.txt
index a24a947..b7dc667 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -2,5 +2,5 @@ set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name SparseTensor)
mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix SparseTensor)
mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix SparseTensor)
-add_public_tablegen_target(MLIRSparseTensorPassIncGen)
+add_mlir_dialect_tablegen_target(MLIRSparseTensorPassIncGen)
add_mlir_doc(Passes SparseTensorPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/Tensor/TransformOps/CMakeLists.txt
index bb9f703..eccf47f 100644
--- a/mlir/include/mlir/Dialect/Tensor/TransformOps/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/CMakeLists.txt
@@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS TensorTransformOps.td)
mlir_tablegen(TensorTransformOps.h.inc -gen-op-decls)
mlir_tablegen(TensorTransformOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRTensorTransformOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRTensorTransformOpsIncGen)
add_mlir_doc(TensorTransformOps TensorTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Tensor/Transforms/CMakeLists.txt
index b312cee..5c9c9b2 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Tensor)
-add_public_tablegen_target(MLIRTensorTransformsIncGen)
+add_mlir_dialect_tablegen_target(MLIRTensorTransformsIncGen)
add_mlir_doc(Passes TensorPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 87deef9..3e4da94 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -142,6 +142,32 @@ FailureOr<Value> buildIndependentOp(OpBuilder &b, tensor::PadOp padOp,
FailureOr<Value> buildIndependentOp(OpBuilder &b, tensor::EmptyOp emptyOp,
ValueRange independencies);
+/// Computes the offsets, sizes, and strides needed to build a collapsed
+/// `sliceOp`. The dimensions to collapse are specified by `reassociation`.
+///
+/// This fails when the specified collapse cannot be represented by a valid
+/// ExtractSliceOp.
+LogicalResult
+getCollapsedExtractSliceInfo(OpBuilder &b, tensor::ExtractSliceOp sliceOp,
+ ArrayRef<ReassociationIndices> reassociation,
+ SmallVectorImpl<OpFoldResult> &collapsedOffsets,
+ SmallVectorImpl<OpFoldResult> &collapsedSizes,
+ SmallVectorImpl<OpFoldResult> &collapsedStrides);
+
+/// Computes the offsets, sizes, and strides needed to build an expanded
+/// `sliceOp`. The dimensions to expand are specified by `reassociation` and
+/// `expandedShape`.
+///
+/// This fails when the specified expansion cannot be represented by a valid
+/// ExtractSliceOp.
+LogicalResult
+getExpandedExtractSliceInfo(OpBuilder &b, tensor::ExtractSliceOp sliceOp,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<int64_t> expandedShape,
+ SmallVectorImpl<OpFoldResult> &expandedOffsets,
+ SmallVectorImpl<OpFoldResult> &expandedSizes,
+ SmallVectorImpl<OpFoldResult> &expandedStrides);
+
} // namespace tensor
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
index 0a855d7..f533b29 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
@@ -7,19 +7,19 @@ mlir_tablegen(TosaOpsTypesBase.h.inc -gen-typedef-decls -typedefs-dialect=tosa)
mlir_tablegen(TosaOpsTypesBase.cpp.inc -gen-typedef-defs -typedefs-dialect=tosa)
mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=tosa)
mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=tosa)
-add_public_tablegen_target(MLIRTosaAttributesIncGen)
+add_mlir_dialect_tablegen_target(MLIRTosaAttributesIncGen)
set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td)
mlir_tablegen(TosaDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Tosa")
-add_public_tablegen_target(MLIRTosaDialectBytecodeIncGen)
+add_mlir_dialect_tablegen_target(MLIRTosaDialectBytecodeIncGen)
set(LLVM_TARGET_DEFINITIONS TosaOpBase.td)
mlir_tablegen(TosaEnums.h.inc -gen-enum-decls)
mlir_tablegen(TosaEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRTosaEnumsIncGen)
+add_mlir_dialect_tablegen_target(MLIRTosaEnumsIncGen)
set(LLVM_TARGET_DEFINITIONS TosaOps.td)
mlir_tablegen(TosaAvailability.h.inc -gen-avail-interface-decls)
mlir_tablegen(TosaAvailability.cpp.inc -gen-avail-interface-defs)
mlir_tablegen(TosaOpAvailabilityImpl.inc -gen-tosa-avail-impls)
-add_public_tablegen_target(MLIRTosaAvailabilityIncGen)
+add_mlir_dialect_tablegen_target(MLIRTosaAvailabilityIncGen)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index e048f8a..115a11b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -382,6 +382,34 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
}
//===----------------------------------------------------------------------===//
+// Iterable attributes.
+//===----------------------------------------------------------------------===//
+// Defined in `section 3. Enumerations` of the TOSA specification.
+
+def Tosa_RESIZE_NEAREST_NEIGHBOR : I32EnumAttrCase<"NEAREST_NEIGHBOR", 1>;
+def Tosa_RESIZE_BILINEAR : I32EnumAttrCase<"BILINEAR", 2>;
+
+def Tosa_ResizeModeAttr
+ : Tosa_I32EnumAttr<"ResizeMode", "Supported resize/upsampling strategies", "resize_mode",
+ [Tosa_RESIZE_NEAREST_NEIGHBOR, Tosa_RESIZE_BILINEAR]>;
+
+def Tosa_NANPROPAGATION_PROPAGATE : I32EnumAttrCase<"PROPAGATE", 1>;
+def Tosa_NANPROPAGATION_IGNORE : I32EnumAttrCase<"IGNORE", 2>;
+
+def Tosa_NanPropagationModeAttr
+ : Tosa_I32EnumAttr<"NanPropagationMode", "Supported NaN propagation strategies", "nan_mode",
+ [Tosa_NANPROPAGATION_PROPAGATE, Tosa_NANPROPAGATION_IGNORE]>;
+
+def Tosa_ROUNDING_SINGLE_ROUND : I32EnumAttrCase<"SINGLE_ROUND", 1>;
+def Tosa_ROUNDING_INEXACT_ROUND : I32EnumAttrCase<"INEXACT_ROUND", 2>;
+def Tosa_ROUNDING_DOUBLE_ROUND : I32EnumAttrCase<"DOUBLE_ROUND", 3>;
+
+def Tosa_RoundingModeAttr
+ : Tosa_I32EnumAttr<"RoundingMode", "Supported rounding modes", "rounding_mode",
+ [Tosa_ROUNDING_SINGLE_ROUND, Tosa_ROUNDING_INEXACT_ROUND, Tosa_ROUNDING_DOUBLE_ROUND]>;
+
+
+//===----------------------------------------------------------------------===//
// TOSA Interfaces.
//===----------------------------------------------------------------------===//
@@ -444,10 +472,7 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
ResultsBroadcastableShape,
TosaElementwiseOperator,
SameOperandsAndResultRank,
- Pure])> {
- let assemblyFormat =
- "operands attr-dict `:` functional-type(operands, results)";
-}
+ Pure])> {}
class Tosa_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
Tosa_ElementwiseOp<mnemonic, !listconcat(traits, [
@@ -455,22 +480,18 @@ class Tosa_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
SameOperandsAndResultElementType])> {}
class Tosa_InferTensorTypeOp<string mnemonic, list<Trait> traits = []>
- : Tosa_Op<mnemonic, !listconcat(traits, [InferTensorTypeAdaptor, Pure])> {
- let assemblyFormat =
- "operands attr-dict `:` functional-type(operands, results)";
-}
+ : Tosa_Op<mnemonic, !listconcat(traits, [InferTensorTypeAdaptor, Pure])> {}
class Tosa_InferShapedTypeOp<string mnemonic, list<Trait> traits = []>
- : Tosa_Op<mnemonic, !listconcat(traits, [InferShapedTypeOpAdaptor, Pure])> {
- let assemblyFormat =
- "operands attr-dict `:` functional-type(operands, results)";
-}
+ : Tosa_Op<mnemonic, !listconcat(traits, [InferShapedTypeOpAdaptor, Pure])> {}
// The "SameVariadicOperandSize" trait allows us to pass optional arguments
// for multiple zero points in convolution ops.
class Tosa_ConvOp<string mnemonic, list<Trait> traits = []>
: Tosa_InferShapedTypeOp<mnemonic, !listconcat(traits,
[SameVariadicOperandSize])> {
+ let assemblyFormat =
+ "operands attr-dict `:` functional-type(operands, results)";
}
#endif // TOSA_OP_BASE
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 2088955..953e7c3 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -43,7 +43,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
let arguments = (ins
Tosa_TensorAtLeast1D: $input,
I32Attr: $axis,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationModeAttr, "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -57,6 +57,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
let hasFolder = 1;
let hasVerifier = 1;
+ let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
@@ -109,6 +110,9 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
let hasCanonicalizer = 1;
let hasVerifier = 1;
+
+ let assemblyFormat =
+ "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -299,6 +303,9 @@ def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d", [
}];
let hasVerifier = 1;
+
+ let assemblyFormat =
+ "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -336,6 +343,9 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
let builders = [Tosa_MatMulOpQuantInfoBuilder];
let hasVerifier = 1;
+
+ let assemblyFormat =
+ "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -357,7 +367,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
Tosa_IntArrayAttr2:$kernel,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$pad,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationModeAttr, "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -371,6 +381,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
let hasCanonicalizer = 1;
let hasVerifier = 1;
+ let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
@@ -418,6 +429,9 @@ def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d", [
}];
let hasVerifier = 1;
+
+ let assemblyFormat =
+ "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -487,7 +501,7 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {
Tosa_Tensor:$input,
Tosa_IntOrFloatAttr:$min_val,
Tosa_IntOrFloatAttr:$max_val,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationModeAttr, "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -501,6 +515,7 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {
let hasCanonicalizer = 1;
let hasVerifier = 1;
+ let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
@@ -560,6 +575,8 @@ def Tosa_SigmoidOp : Tosa_ElementwiseUnaryOp<"sigmoid"> {
Profile<[Tosa_PRO_FP]>,
Extension<[Tosa_EXT_BF16]>,
];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -589,6 +606,8 @@ def Tosa_TanhOp : Tosa_ElementwiseUnaryOp<"tanh"> {
Profile<[Tosa_PRO_FP]>,
Extension<[Tosa_EXT_BF16]>,
];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -633,6 +652,8 @@ def Tosa_AddOp : Tosa_ElementwiseOp<"add", [
];
let hasFolder = 1;
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -662,6 +683,8 @@ def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift",
Profile<[Tosa_PRO_INT]>,
Extension<[]>,
];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -690,6 +713,8 @@ def Tosa_BitwiseAndOp : Tosa_ElementwiseOp<"bitwise_and", [
Profile<[Tosa_PRO_INT]>,
Extension<[]>,
];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -718,6 +743,8 @@ def Tosa_BitwiseOrOp : Tosa_ElementwiseOp<"bitwise_or", [
Profile<[Tosa_PRO_INT]>,
Extension<[]>,
];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -746,6 +773,8 @@ def Tosa_BitwiseXorOp : Tosa_ElementwiseOp<"bitwise_xor", [
Profile<[Tosa_PRO_INT]>,
Extension<[]>,
];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -777,6 +806,8 @@ def Tosa_IntDivOp : Tosa_ElementwiseOp<"intdiv", [SameOperandsAndResultElementTy
];
let hasFolder = 1;
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -805,6 +836,8 @@ def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
Extension<[]>,
];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -833,6 +866,8 @@ def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift",
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
Extension<[]>,
];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -861,6 +896,8 @@ def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift",
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
Extension<[]>,
];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -889,6 +926,8 @@ def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
Extension<[]>,
];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -917,6 +956,8 @@ def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
Extension<[]>,
];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -935,7 +976,7 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
let arguments = (ins
Tosa_Tensor:$input1,
Tosa_Tensor:$input2,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationModeAttr, "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -946,6 +987,7 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
Extension<[Tosa_EXT_BF16]>,
];
+ let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
@@ -964,7 +1006,7 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
let arguments = (ins
Tosa_Tensor:$input1,
Tosa_Tensor:$input2,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationModeAttr, "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -975,6 +1017,8 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
Extension<[Tosa_EXT_BF16]>,
];
+
+ let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
@@ -1041,6 +1085,8 @@ def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [SameOperandsAndResultElementType]> {
Profile<[Tosa_PRO_FP]>,
Extension<[Tosa_EXT_BF16]>,
];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -1069,6 +1115,8 @@ def Tosa_SubOp : Tosa_ElementwiseOp<"sub", [SameOperandsAndResultElementType]> {
];
let hasFolder = 1;
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -1113,6 +1161,9 @@ def Tosa_TableOp : Tosa_InferShapedTypeOp<"table"> {
}];
let hasVerifier = 1;
+
+ let assemblyFormat =
+ "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -1149,6 +1200,8 @@ def Tosa_AbsOp : Tosa_ElementwiseUnaryOp<"abs"> {
];
let hasFolder = 1;
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -1173,6 +1226,8 @@ def Tosa_BitwiseNotOp : Tosa_ElementwiseUnaryOp<"bitwise_not"> {
Profile<[Tosa_PRO_INT]>,
Extension<[]>,
];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -1197,6 +1252,8 @@ def Tosa_CeilOp : Tosa_ElementwiseUnaryOp<"ceil"> {
Profile<[Tosa_PRO_FP]>,
Extension<[Tosa_EXT_BF16]>,
];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -1221,6 +1278,8 @@ def Tosa_ClzOp : Tosa_ElementwiseUnaryOp<"clz"> {
Profile<[Tosa_PRO_INT]>,
Extension<[]>,
];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -1245,6 +1304,8 @@ def Tosa_CosOp : Tosa_ElementwiseUnaryOp<"cos"> {
Profile<[Tosa_PRO_FP]>,
Extension<[Tosa_EXT_BF16]>,
];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -1271,6 +1332,8 @@ def Tosa_ExpOp : Tosa_ElementwiseUnaryOp<"exp"> {
];
let hasFolder = 1;
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -1295,6 +1358,8 @@ def Tosa_FloorOp : Tosa_ElementwiseUnaryOp<"floor"> {
Profile<[Tosa_PRO_FP]>,
Extension<[Tosa_EXT_BF16]>,
];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -1321,6 +1386,8 @@ def Tosa_LogOp : Tosa_ElementwiseUnaryOp<"log"> {
];
let hasFolder = 1;
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -1345,6 +1412,8 @@ def Tosa_LogicalNotOp : Tosa_ElementwiseUnaryOp<"logical_not"> {
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
Extension<[]>,
];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -1424,6 +1493,8 @@ def Tosa_ReciprocalOp : Tosa_ElementwiseUnaryOp<"reciprocal"> {
}];
let hasFolder = 1;
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -1449,6 +1520,8 @@ def Tosa_RsqrtOp : Tosa_ElementwiseUnaryOp<"rsqrt"> {
Profile<[Tosa_PRO_FP]>,
Extension<[Tosa_EXT_BF16]>,
];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -1473,6 +1546,8 @@ def Tosa_SinOp : Tosa_ElementwiseUnaryOp<"sin"> {
Profile<[Tosa_PRO_FP]>,
Extension<[Tosa_EXT_BF16]>,
];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -1519,6 +1594,8 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
::mlir::TypedValue<::mlir::TensorType> getOnTrue() { return getInput2(); }
::mlir::TypedValue<::mlir::TensorType> getOnFalse() { return getInput3(); }
}];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -1559,6 +1636,8 @@ def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
}];
let hasFolder = 1;
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -1586,6 +1665,8 @@ def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> {
];
let hasFolder = 1;
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -1614,6 +1695,8 @@ def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal",
];
let hasFolder = 1;
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -1657,6 +1740,8 @@ def Tosa_ReduceAllOp : Tosa_InferTensorTypeOp<"reduce_all"> {
return leftOperand & rightOperand;
}
}];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -1696,6 +1781,8 @@ def Tosa_ReduceAnyOp : Tosa_InferTensorTypeOp<"reduce_any"> {
return leftOperand | rightOperand;
}
}];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -1711,7 +1798,7 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
let arguments = (ins
Tosa_TensorAtLeast1D:$input,
I32Attr:$axis,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationModeAttr, "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -1736,6 +1823,8 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
return (leftOperand.sge(rightOperand)) ? leftOperand : rightOperand;
}
}];
+
+ let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
@@ -1751,7 +1840,7 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
let arguments = (ins
Tosa_TensorAtLeast1D:$input,
I32Attr:$axis,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationModeAttr, "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -1776,6 +1865,8 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
return (leftOperand.sle(rightOperand)) ? leftOperand : rightOperand;
}
}];
+
+ let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
@@ -1815,6 +1906,8 @@ def Tosa_ReduceProductOp : Tosa_InferTensorTypeOp<"reduce_product"> {
return leftOperand * rightOperand;
}
}];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -1854,6 +1947,8 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
return leftOperand + rightOperand;
}
}];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -1894,6 +1989,8 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
/// Method used by InferTypeOpInterface.
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
+
+ let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -1943,6 +2040,9 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
let hasFolder = 1;
let hasVerifier = 1;
+
+ let assemblyFormat =
+ "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -2047,6 +2147,9 @@ def Tosa_SliceOp : Tosa_InferShapedTypeOp<"slice"> {
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
+
+ let assemblyFormat =
+ "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -2078,6 +2181,9 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
let hasFolder = 1;
let hasVerifier = 1;
+
+ let assemblyFormat =
+ "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -2111,6 +2217,9 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
+
+ let assemblyFormat =
+ "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -2145,6 +2254,9 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
];
let hasVerifier = 1;
+
+ let assemblyFormat =
+ "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -2180,6 +2292,9 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
];
let hasVerifier = 1;
+
+ let assemblyFormat =
+ "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
@@ -2224,7 +2339,7 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> {
Rank4TosaShape:$scale,
Rank2TosaShape:$offset,
Rank2TosaShape:$border,
- Tosa_ResizeTypeAttr:$mode
+ Tosa_ResizeModeAttr:$mode
);
let results = (outs
@@ -2238,6 +2353,7 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> {
let hasFolder = 1;
let hasVerifier = 1;
+ let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
@@ -2247,7 +2363,7 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> {
//===----------------------------------------------------------------------===//
// Operator: cast
//===----------------------------------------------------------------------===//
-def Tosa_CastOp: Tosa_Op<"cast", [Pure,
+def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>]> {
@@ -2374,7 +2490,7 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
Tosa_ScalarIntOrFloatTensor:$input_zp,
Tosa_ScalarIntOrFloatTensor:$output_zp,
BoolAttr:$scale32,
- Tosa_RoundingTypeAttr:$rounding_mode,
+ Tosa_RoundingModeAttr:$rounding_mode,
BoolAttr:$per_channel,
BoolAttr: $input_unsigned,
BoolAttr: $output_unsigned
@@ -2397,8 +2513,7 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
}];
let hasVerifier = 1;
-
- let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
+ let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 754640d..553d69cc 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -234,29 +234,6 @@ def Tosa_IntegerAttr : Attr<CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
def Tosa_IntOrFloatAttr : AnyAttrOf<[Tosa_IntegerAttr, Tosa_FloatAttr]>;
-//===----------------------------------------------------------------------===//
-// Iterable attributes.
-//===----------------------------------------------------------------------===//
-// Defined in `section 3. Enumerations` of the TOSA specification.
-
-// Supported regimes for tosa.resize.
-def Tosa_ResizeTypeAttr : StringBasedAttr<
- CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"BILINEAR\" || " #
- "::llvm::cast<StringAttr>($_self).getValue() == \"NEAREST_NEIGHBOR\"">,
- "Supported resize/upsampling strategies">;
-
-// Supported NaN propagation strategies.
-def Tosa_NanPropagationAttr : StringBasedAttr<
- CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"PROPAGATE\" || " #
- "::llvm::cast<StringAttr>($_self).getValue() == \"IGNORE\"">,
- "Supported NaN propagation strategies">;
-
-// Rounding mode for tosa.rescale
-def Tosa_RoundingTypeAttr : StringBasedAttr<
- CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"SINGLE_ROUND\" || " #
- "::llvm::cast<StringAttr>($_self).getValue() == \"INEXACT_ROUND\" || " #
- "::llvm::cast<StringAttr>($_self).getValue() == \"DOUBLE_ROUND\"">,
- "Supported rounding modes">;
def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index c8f2907..d819cc1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -44,7 +44,7 @@ def Tosa_ApplyScaleOp :
Tosa_IntLike:$value,
Tosa_IntLike:$multiplier,
Tosa_Int8Like:$shift,
- Tosa_RoundingTypeAttr:$rounding_mode
+ Tosa_RoundingModeAttr:$rounding_mode
);
let results = (outs
@@ -55,7 +55,7 @@ def Tosa_ApplyScaleOp :
std::optional<SmallVector<int64_t, 4>> getShapeForUnroll();
}];
- let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
+ let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
index d4e2661..7484473 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -2,7 +2,6 @@ set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TosaOpt)
mlir_tablegen(PassesEnums.h.inc -gen-enum-decls)
mlir_tablegen(PassesEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRTosaPassIncGen)
-add_dependencies(mlir-headers MLIRTosaPassIncGen)
+add_mlir_dialect_tablegen_target(MLIRTosaPassIncGen)
add_mlir_doc(Passes TosaPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Transform/DebugExtension/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/DebugExtension/CMakeLists.txt
index 2b49fae..a3ebfbf 100644
--- a/mlir/include/mlir/Dialect/Transform/DebugExtension/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Transform/DebugExtension/CMakeLists.txt
@@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS DebugExtensionOps.td)
mlir_tablegen(DebugExtensionOps.h.inc -gen-op-decls)
mlir_tablegen(DebugExtensionOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRTransformDialectDebugExtensionOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRTransformDialectDebugExtensionOpsIncGen)
add_mlir_doc(DebugExtensionOps DebugExtensionOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt
index 9acab92..34e0801 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt
@@ -5,25 +5,21 @@
set(LLVM_TARGET_DEFINITIONS TransformDialect.td)
mlir_tablegen(TransformDialect.h.inc -gen-dialect-decls -dialect=transform)
mlir_tablegen(TransformDialect.cpp.inc -gen-dialect-defs -dialect=transform)
-add_public_tablegen_target(MLIRTransformDialectIncGen)
-add_dependencies(mlir-headers MLIRTransformDialectIncGen)
+add_mlir_dialect_tablegen_target(MLIRTransformDialectIncGen)
set(LLVM_TARGET_DEFINITIONS TransformTypes.td)
mlir_tablegen(TransformTypes.h.inc -gen-typedef-decls)
mlir_tablegen(TransformTypes.cpp.inc -gen-typedef-defs)
-add_public_tablegen_target(MLIRTransformTypesIncGen)
-add_dependencies(mlir-headers MLIRTransformTypesIncGen)
+add_mlir_dialect_tablegen_target(MLIRTransformTypesIncGen)
add_mlir_doc(TransformTypes TransformTypes Dialects/ -gen-typedef-doc)
set(LLVM_TARGET_DEFINITIONS TransformAttrs.td)
mlir_tablegen(TransformDialectEnums.h.inc -gen-enum-decls)
mlir_tablegen(TransformDialectEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRTransformDialectEnumIncGen)
-add_dependencies(mlir-headers MLIRTransformDialectEnumIncGen)
+add_mlir_dialect_tablegen_target(MLIRTransformDialectEnumIncGen)
mlir_tablegen(TransformAttrs.h.inc -gen-attrdef-decls)
mlir_tablegen(TransformAttrs.cpp.inc -gen-attrdef-defs)
-add_public_tablegen_target(MLIRTransformDialectAttributesIncGen)
-add_dependencies(mlir-headers MLIRTransformDialectAttributesIncGen)
+add_mlir_dialect_tablegen_target(MLIRTransformDialectAttributesIncGen)
add_mlir_dialect(TransformOps transform)
add_mlir_doc(TransformOps TransformOps Dialects/ -gen-op-doc -dialect=transform)
diff --git a/mlir/include/mlir/Dialect/Transform/IRDLExtension/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/IRDLExtension/CMakeLists.txt
index dfcd906..3bcaab0 100644
--- a/mlir/include/mlir/Dialect/Transform/IRDLExtension/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Transform/IRDLExtension/CMakeLists.txt
@@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS IRDLExtensionOps.td)
mlir_tablegen(IRDLExtensionOps.h.inc -gen-op-decls)
mlir_tablegen(IRDLExtensionOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRTransformDialectIRDLExtensionOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRTransformDialectIRDLExtensionOpsIncGen)
add_mlir_doc(IRDLExtensionOps IRDLExtensionOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/Transform/Interfaces/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/Interfaces/CMakeLists.txt
index 14ce5b8..195f4d9 100644
--- a/mlir/include/mlir/Dialect/Transform/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Transform/Interfaces/CMakeLists.txt
@@ -6,8 +6,7 @@ add_mlir_doc(TransformInterfaces TransformOpInterfaces Dialects/ -gen-op-interfa
set(LLVM_TARGET_DEFINITIONS TransformInterfaces.td)
mlir_tablegen(TransformTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(TransformTypeInterfaces.cpp.inc -gen-type-interface-defs)
-add_public_tablegen_target(MLIRTransformDialectTypeInterfacesIncGen)
-add_dependencies(mlir-headers MLIRTransformDialectTypeInterfacesIncGen)
+add_mlir_dialect_tablegen_target(MLIRTransformDialectTypeInterfacesIncGen)
add_mlir_doc(TransformInterfaces TransformTypeInterfaces Dialects/ -gen-type-interface-docs)
add_mlir_interface(MatchInterfaces)
diff --git a/mlir/include/mlir/Dialect/Transform/LoopExtension/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/LoopExtension/CMakeLists.txt
index 8f5e510..afb2c77 100644
--- a/mlir/include/mlir/Dialect/Transform/LoopExtension/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Transform/LoopExtension/CMakeLists.txt
@@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS LoopExtensionOps.td)
mlir_tablegen(LoopExtensionOps.h.inc -gen-op-decls)
mlir_tablegen(LoopExtensionOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRTransformDialectLoopExtensionOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRTransformDialectLoopExtensionOpsIncGen)
add_mlir_doc(LoopExtensionOps LoopExtensionOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/Transform/PDLExtension/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/PDLExtension/CMakeLists.txt
index 6af6b83..fa3c003 100644
--- a/mlir/include/mlir/Dialect/Transform/PDLExtension/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Transform/PDLExtension/CMakeLists.txt
@@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS PDLExtensionOps.td)
mlir_tablegen(PDLExtensionOps.h.inc -gen-op-decls)
mlir_tablegen(PDLExtensionOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRTransformDialectPDLExtensionOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRTransformDialectPDLExtensionOpsIncGen)
add_mlir_doc(PDLExtensionOps PDLExtensionOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/Transforms/CMakeLists.txt
index 3a399e6..b804802 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Transform)
-add_public_tablegen_target(MLIRTransformDialectTransformsIncGen)
+add_mlir_dialect_tablegen_target(MLIRTransformDialectTransformsIncGen)
add_mlir_doc(Passes TransformPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/TuneExtension/CMakeLists.txt
index 9afca81..81893af 100644
--- a/mlir/include/mlir/Dialect/Transform/TuneExtension/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/CMakeLists.txt
@@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS TuneExtensionOps.td)
mlir_tablegen(TuneExtensionOps.h.inc -gen-op-decls)
mlir_tablegen(TuneExtensionOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRTransformDialectTuneExtensionOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRTransformDialectTuneExtensionOpsIncGen)
add_mlir_doc(TuneExtensionOps TuneExtensionOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/UB/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/UB/IR/CMakeLists.txt
index 0449cb2..bb63a78 100644
--- a/mlir/include/mlir/Dialect/UB/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/UB/IR/CMakeLists.txt
@@ -5,11 +5,11 @@ mlir_tablegen(UBOpsDialect.h.inc -gen-dialect-decls)
mlir_tablegen(UBOpsDialect.cpp.inc -gen-dialect-defs)
mlir_tablegen(UBOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=ub)
mlir_tablegen(UBOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=ub)
-add_public_tablegen_target(MLIRUBOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRUBOpsIncGen)
add_mlir_doc(UBOps UBOps Dialects/ -gen-dialect-doc)
set(LLVM_TARGET_DEFINITIONS UBOpsInterfaces.td)
mlir_tablegen(UBOpsInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(UBOpsInterfaces.cpp.inc -gen-attr-interface-defs)
-add_public_tablegen_target(MLIRUBOpsInterfacesIncGen)
+add_mlir_dialect_tablegen_target(MLIRUBOpsInterfacesIncGen)
diff --git a/mlir/include/mlir/Dialect/Utils/CMakeLists.txt b/mlir/include/mlir/Dialect/Utils/CMakeLists.txt
index edfb1ca..0945d7d 100644
--- a/mlir/include/mlir/Dialect/Utils/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Utils/CMakeLists.txt
@@ -1,5 +1,4 @@
set(LLVM_TARGET_DEFINITIONS StructuredOpsUtils.td)
mlir_tablegen(DialectUtilsEnums.h.inc -gen-enum-decls)
mlir_tablegen(DialectUtilsEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRDialectUtilsIncGen)
-add_dependencies(mlir-headers MLIRDialectUtilsIncGen)
+add_mlir_dialect_tablegen_target(MLIRDialectUtilsIncGen)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Vector/IR/CMakeLists.txt
index 5bbc4c6..0dc45db 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Vector/IR/CMakeLists.txt
@@ -5,8 +5,7 @@ add_mlir_doc(VectorOps VectorOps Dialects/ -gen-op-doc)
set(LLVM_TARGET_DEFINITIONS VectorOps.td)
mlir_tablegen(VectorOps.h.inc -gen-op-decls)
mlir_tablegen(VectorOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRVectorOpsIncGen)
-add_dependencies(mlir-generic-headers MLIRVectorOpsIncGen)
+add_mlir_generic_tablegen_target(MLIRVectorOpsIncGen)
# Add Vector attributes
set(LLVM_TARGET_DEFINITIONS VectorAttributes.td)
@@ -14,5 +13,4 @@ mlir_tablegen(VectorEnums.h.inc -gen-enum-decls)
mlir_tablegen(VectorEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(VectorAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(VectorAttributes.cpp.inc -gen-attrdef-defs)
-add_public_tablegen_target(MLIRVectorAttributesIncGen)
-add_dependencies(mlir-generic-headers MLIRVectorAttributesIncGen)
+add_mlir_generic_tablegen_target(MLIRVectorAttributesIncGen)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 364c172..63410b8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -32,6 +32,7 @@
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/Alignment.h"
// Pull in all enum type definitions and utility function declarations.
#include "mlir/Dialect/Vector/IR/VectorEnums.h.inc"
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index dc55704..77e26cc 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -754,7 +754,7 @@ def Vector_FMAOp :
}
def Vector_ToElementsOp : Vector_Op<"to_elements", [
- Pure,
+ InferTypeOpAdaptor, Pure,
ShapedTypeMatchesElementCountAndTypes<"source", "elements">]> {
let summary = "operation that decomposes a vector into all its scalar elements";
let description = [{
@@ -1712,35 +1712,33 @@ def Vector_LoadOp : Vector_Op<"load", [
An optional `alignment` attribute allows to specify the byte alignment of the
load operation. It must be a positive power of 2. The operation must access
- memory at an address aligned to this boundary. Violations may lead to
- architecture-specific faults or performance penalties.
- A value of 0 indicates no specific alignment requirement.
+ memory at an address aligned to this boundary. Violating this requirement
+ triggers immediate undefined behavior.
}];
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
[MemRead]>:$base,
Variadic<Index>:$indices,
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
- ConfinedAttr<OptionalAttr<I64Attr>,
- [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
+ OptionalAttr<IntValidAlignment<I64Attr>>: $alignment);
let builders = [
OpBuilder<(ins "VectorType":$resultType,
"Value":$base,
"ValueRange":$indices,
CArg<"bool", "false">:$nontemporal,
- CArg<"uint64_t", "0">:$alignment), [{
+ CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
return build($_builder, $_state, resultType, base, indices, nontemporal,
- alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+ alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
nullptr);
}]>,
OpBuilder<(ins "TypeRange":$resultTypes,
"Value":$base,
"ValueRange":$indices,
CArg<"bool", "false">:$nontemporal,
- CArg<"uint64_t", "0">:$alignment), [{
+ CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
return build($_builder, $_state, resultTypes, base, indices, nontemporal,
- alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+ alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
nullptr);
}]>
];
@@ -1828,9 +1826,8 @@ def Vector_StoreOp : Vector_Op<"store", [
An optional `alignment` attribute allows to specify the byte alignment of the
store operation. It must be a positive power of 2. The operation must access
- memory at an address aligned to this boundary. Violations may lead to
- architecture-specific faults or performance penalties.
- A value of 0 indicates no specific alignment requirement.
+ memory at an address aligned to this boundary. Violating this requirement
+ triggers immediate undefined behavior.
}];
let arguments = (ins
@@ -1839,17 +1836,16 @@ def Vector_StoreOp : Vector_Op<"store", [
[MemWrite]>:$base,
Variadic<Index>:$indices,
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
- ConfinedAttr<OptionalAttr<I64Attr>,
- [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
+ OptionalAttr<IntValidAlignment<I64Attr>>: $alignment);
let builders = [
OpBuilder<(ins "Value":$valueToStore,
"Value":$base,
"ValueRange":$indices,
CArg<"bool", "false">:$nontemporal,
- CArg<"uint64_t", "0">:$alignment), [{
+ CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
return build($_builder, $_state, valueToStore, base, indices, nontemporal,
- alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+ alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
nullptr);
}]>
];
@@ -1876,7 +1872,8 @@ def Vector_MaskedLoadOp :
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
- AnyVectorOfNonZeroRank:$pass_thru)>,
+ AnyVectorOfNonZeroRank:$pass_thru,
+ OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)>,
Results<(outs AnyVectorOfNonZeroRank:$result)> {
let summary = "loads elements from memory into a vector as defined by a mask vector";
@@ -1912,6 +1909,11 @@ def Vector_MaskedLoadOp :
%1 = vector.maskedload %base[%i, %j], %mask, %pass_thru
: memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
```
+
+ An optional `alignment` attribute allows to specify the byte alignment of the
+ load operation. It must be a positive power of 2. The operation must access
+ memory at an address aligned to this boundary. Violating this requirement
+ triggers immediate undefined behavior.
}];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
@@ -1932,6 +1934,29 @@ def Vector_MaskedLoadOp :
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
+
+ let builders = [
+ OpBuilder<(ins "VectorType":$resultType,
+ "Value":$base,
+ "ValueRange":$indices,
+ "Value":$mask,
+ "Value":$passthrough,
+ CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
+ return build($_builder, $_state, resultType, base, indices, mask, passthrough,
+ alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
+ nullptr);
+ }]>,
+ OpBuilder<(ins "TypeRange":$resultTypes,
+ "Value":$base,
+ "ValueRange":$indices,
+ "Value":$mask,
+ "Value":$passthrough,
+ CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
+ return build($_builder, $_state, resultTypes, base, indices, mask, passthrough,
+ alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
+ nullptr);
+ }]>
+ ];
}
def Vector_MaskedStoreOp :
@@ -1939,7 +1964,8 @@ def Vector_MaskedStoreOp :
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
- AnyVectorOfNonZeroRank:$valueToStore)> {
+ AnyVectorOfNonZeroRank:$valueToStore,
+ OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)> {
let summary = "stores elements from a vector into memory as defined by a mask vector";
@@ -1974,6 +2000,11 @@ def Vector_MaskedStoreOp :
vector.maskedstore %base[%i, %j], %mask, %value
: memref<?x?xf32>, vector<16xi1>, vector<16xf32>
```
+
+ An optional `alignment` attribute allows to specify the byte alignment of the
+ store operation. It must be a positive power of 2. The operation must access
+ memory at an address aligned to this boundary. Violating this requirement
+ triggers immediate undefined behavior.
}];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
@@ -1992,6 +2023,18 @@ def Vector_MaskedStoreOp :
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
+
+ let builders = [
+ OpBuilder<(ins "Value":$base,
+ "ValueRange":$indices,
+ "Value":$mask,
+ "Value":$valueToStore,
+ CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
+ return build($_builder, $_state, base, indices, mask, valueToStore,
+ alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
+ nullptr);
+ }]>
+ ];
}
def Vector_GatherOp :
@@ -2000,30 +2043,46 @@ def Vector_GatherOp :
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
- Variadic<Index>:$indices,
- VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
+ Variadic<Index>:$offsets,
+ VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
- AnyVectorOfNonZeroRank:$pass_thru)>,
+ AnyVectorOfNonZeroRank:$pass_thru,
+ OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)>,
Results<(outs AnyVectorOfNonZeroRank:$result)> {
let summary = [{
- gathers elements from memory or ranked tensor into a vector as defined by an
- index vector and a mask vector
+ Gathers elements from memory or ranked tensor into a vector as defined by an
+ index vector and a mask vector.
}];
let description = [{
The gather operation returns an n-D vector whose elements are either loaded
- from memory or ranked tensor, or taken from a pass-through vector, depending
+ from a k-D memref or tensor, or taken from an n-D pass-through vector, depending
on the values of an n-D mask vector.
- If a mask bit is set, the corresponding result element is defined by the base
- with indices and the n-D index vector (each index is a 1-D offset on the base).
- Otherwise, the corresponding element is taken from the n-D pass-through vector.
- Informally the semantics are:
+
+ If a mask bit is set, the corresponding result element is taken from `base`
+ at an index defined by k indices and n-D `index_vec`. Otherwise, the element
+ is taken from the pass-through vector. As an example, suppose that `base` is
+ 3-D and the result is 2-D:
+
+ ```mlir
+ func.func @gather_3D_to_2D(
+ %base: memref<?x10x?xf32>, %ofs_0: index, %ofs_1: index, %ofs_2: index,
+ %indices: vector<2x3xi32>, %mask: vector<2x3xi1>,
+ %fall_thru: vector<2x3xf32>) -> vector<2x3xf32> {
+ %result = vector.gather %base[%ofs_0, %ofs_1, %ofs_2]
+ [%indices], %mask, %fall_thru : [...]
+ return %result : vector<2x3xf32>
+ }
```
- result[0] := if mask[0] then base[index[0]] else pass_thru[0]
- result[1] := if mask[1] then base[index[1]] else pass_thru[1]
- etc.
+
+ The indexing semantics are then,
+
+ ```
+ result[i,j] := if mask[i,j] then base[i0, i1, i2 + indices[i,j]]
+ else pass_thru[i,j]
```
+ The index into `base` only varies in the innermost ((k-1)-th) dimension.
If a mask bit is set and the corresponding index is out-of-bounds for the
given base, the behavior is undefined. If a mask bit is not set, the value
@@ -2034,12 +2093,19 @@ def Vector_GatherOp :
during progressively lowering to bring other memory operations closer to
hardware ISA support for a gather.
+ An optional `alignment` attribute allows to specify the byte alignment of the
+ gather operation. It must be a positive power of 2. The operation must access
+ memory at an address aligned to this boundary. Violating this requirement
+ triggers immediate undefined behavior.
+
Examples:
```mlir
+ // 1-D memref gathered to 2-D vector.
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru
: memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+ // 2-D memref gathered to 1-D vector.
%1 = vector.gather %base[%i, %j][%v], %mask, %pass_thru
: memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
```
@@ -2047,28 +2113,43 @@ def Vector_GatherOp :
let extraClassDeclaration = [{
ShapedType getBaseType() { return getBase().getType(); }
- VectorType getIndexVectorType() { return getIndexVec().getType(); }
+ VectorType getIndexVectorType() { return getIndices().getType(); }
VectorType getMaskVectorType() { return getMask().getType(); }
VectorType getPassThruVectorType() { return getPassThru().getType(); }
VectorType getVectorType() { return getResult().getType(); }
}];
let assemblyFormat =
- "$base `[` $indices `]` `[` $index_vec `]` `,` "
+ "$base `[` $offsets `]` `[` $indices `]` `,` "
"$mask `,` $pass_thru attr-dict `:` type($base) `,` "
- "type($index_vec) `,` type($mask) `,` type($pass_thru) "
+ "type($indices) `,` type($mask) `,` type($pass_thru) "
"`into` type($result)";
let hasCanonicalizer = 1;
let hasVerifier = 1;
+
+ let builders = [
+ OpBuilder<(ins "VectorType":$resultType,
+ "Value":$base,
+ "ValueRange":$indices,
+ "Value":$index_vec,
+ "Value":$mask,
+ "Value":$passthrough,
+ CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
+ return build($_builder, $_state, resultType, base, indices, index_vec, mask, passthrough,
+ alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
+ nullptr);
+ }]>
+ ];
}
def Vector_ScatterOp :
Vector_Op<"scatter">,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
- Variadic<Index>:$indices,
- VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
+ Variadic<Index>:$offsets,
+ VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
- AnyVectorOfNonZeroRank:$valueToStore)> {
+ AnyVectorOfNonZeroRank:$valueToStore,
+ OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)> {
let summary = [{
scatters elements from a vector into memory as defined by an index vector
@@ -2102,6 +2183,11 @@ def Vector_ScatterOp :
correspond to those of the `llvm.masked.scatter`
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-scatter-intrinsics).
+ An optional `alignment` attribute allows to specify the byte alignment of the
+ scatter operation. It must be a positive power of 2. The operation must access
+ memory at an address aligned to this boundary. Violating this requirement
+ triggers immediate undefined behavior.
+
Examples:
```mlir
@@ -2115,17 +2201,30 @@ def Vector_ScatterOp :
let extraClassDeclaration = [{
MemRefType getMemRefType() { return getBase().getType(); }
- VectorType getIndexVectorType() { return getIndexVec().getType(); }
+ VectorType getIndexVectorType() { return getIndices().getType(); }
VectorType getMaskVectorType() { return getMask().getType(); }
VectorType getVectorType() { return getValueToStore().getType(); }
}];
let assemblyFormat =
- "$base `[` $indices `]` `[` $index_vec `]` `,` "
+ "$base `[` $offsets `]` `[` $indices `]` `,` "
"$mask `,` $valueToStore attr-dict `:` type($base) `,` "
- "type($index_vec) `,` type($mask) `,` type($valueToStore)";
+ "type($indices) `,` type($mask) `,` type($valueToStore)";
let hasCanonicalizer = 1;
let hasVerifier = 1;
+
+ let builders = [
+ OpBuilder<(ins "Value":$base,
+ "ValueRange":$indices,
+ "Value":$index_vec,
+ "Value":$mask,
+ "Value":$valueToStore,
+ CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">: $alignment), [{
+ return build($_builder, $_state, base, indices, index_vec, mask, valueToStore,
+ alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
+ nullptr);
+ }]>
+ ];
}
def Vector_ExpandLoadOp :
@@ -2133,7 +2232,8 @@ def Vector_ExpandLoadOp :
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
- AnyVectorOfNonZeroRank:$pass_thru)>,
+ AnyVectorOfNonZeroRank:$pass_thru,
+ OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)>,
Results<(outs AnyVectorOfNonZeroRank:$result)> {
let summary = "reads elements from memory and spreads them into a vector as defined by a mask";
@@ -2165,6 +2265,11 @@ def Vector_ExpandLoadOp :
correspond to those of the `llvm.masked.expandload`
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics).
+ An optional `alignment` attribute allows to specify the byte alignment of the
+ load operation. It must be a positive power of 2. The operation must access
+ memory at an address aligned to this boundary. Violating this requirement
+ triggers immediate undefined behavior.
+
Note, at the moment this Op is only available for fixed-width vectors.
Examples:
@@ -2195,6 +2300,19 @@ def Vector_ExpandLoadOp :
"type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
let hasCanonicalizer = 1;
let hasVerifier = 1;
+
+ let builders = [
+ OpBuilder<(ins "VectorType":$resultType,
+ "Value":$base,
+ "ValueRange":$indices,
+ "Value":$mask,
+ "Value":$passthrough,
+ CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
+ return build($_builder, $_state, resultType, base, indices, mask, passthrough,
+ alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
+ nullptr);
+ }]>
+ ];
}
def Vector_CompressStoreOp :
@@ -2202,7 +2320,8 @@ def Vector_CompressStoreOp :
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
- AnyVectorOfNonZeroRank:$valueToStore)> {
+ AnyVectorOfNonZeroRank:$valueToStore,
+ OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)> {
let summary = "writes elements selectively from a vector as defined by a mask";
@@ -2233,6 +2352,11 @@ def Vector_CompressStoreOp :
correspond to those of the `llvm.masked.compressstore`
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics).
+ An optional `alignment` attribute allows to specify the byte alignment of the
+ store operation. It must be a positive power of 2. The operation must access
+ memory at an address aligned to this boundary. Violating this requirement
+ triggers immediate undefined behavior.
+
Note, at the moment this Op is only available for fixed-width vectors.
Examples:
@@ -2261,6 +2385,17 @@ def Vector_CompressStoreOp :
"type($base) `,` type($mask) `,` type($valueToStore)";
let hasCanonicalizer = 1;
let hasVerifier = 1;
+ let builders = [
+ OpBuilder<(ins "Value":$base,
+ "ValueRange":$indices,
+ "Value":$mask,
+ "Value":$valueToStore,
+ CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
+ return build($_builder, $_state, base, indices, valueToStore, mask,
+ alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
+ nullptr);
+ }]>
+ ];
}
def Vector_ShapeCastOp :
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/Vector/TransformOps/CMakeLists.txt
index ec8b0f4..8a90d27 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/CMakeLists.txt
@@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS VectorTransformOps.td)
mlir_tablegen(VectorTransformOps.h.inc -gen-op-decls)
mlir_tablegen(VectorTransformOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRVectorTransformOpsIncGen)
+add_mlir_dialect_tablegen_target(MLIRVectorTransformOpsIncGen)
add_mlir_doc(VectorTransformOps VectorTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 299f198..07a4117 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -254,6 +254,17 @@ def ApplyLowerGatherPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyUnrollFromElementsPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.unroll_from_elements",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector from_elements operations should be unrolled
+ along the outermost dimension.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyLowerScanPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.lower_scan",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Vector/Transforms/CMakeLists.txt
index 2c288fe..c3c7ec9 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/CMakeLists.txt
@@ -4,6 +4,6 @@ mlir_tablegen(VectorTransformsEnums.cpp.inc -gen-enum-defs)
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Vector)
-add_public_tablegen_target(MLIRVectorTransformsIncGen)
+add_mlir_dialect_tablegen_target(MLIRVectorTransformsIncGen)
add_mlir_doc(Passes VectorPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index e03f0da..47f9611 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -305,6 +305,14 @@ void populateVectorToFromElementsToShuffleTreePatterns(
/// Populate the pattern set with the following patterns:
///
+/// [UnrollFromElements]
+/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
+/// outermost dimension.
+void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
/// [ContractionOpToMatmulOpLowering]
/// Lowers `vector.contract` to `llvm.intr.matrix.multiply`.
///
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 8bd54cf..ace2699 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -238,6 +239,22 @@ Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
/// static sizes in `shape`.
LogicalResult isValidMaskedInputVector(ArrayRef<int64_t> shape,
ArrayRef<int64_t> inputVectorSizes);
+
+/// Generic utility for unrolling n-D vector operations to (n-1)-D operations.
+/// This handles the common pattern of:
+/// 1. Check if already 1-D. If so, return failure.
+/// 2. Check for scalable dimensions. If so, return failure.
+/// 3. Create poison initialized result.
+/// 4. Loop through the outermost dimension, execute the UnrollVectorOpFn to
+/// create sub vectors.
+/// 5. Insert the sub vectors back into the final vector.
+/// 6. Replace the original op with the new result.
+using UnrollVectorOpFn =
+ function_ref<Value(PatternRewriter &, Location, VectorType, int64_t)>;
+
+LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
+ UnrollVectorOpFn unrollFn);
+
} // namespace vector
/// Constructs a permutation map of invariant memref indices to vector
diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/WasmSSA/IR/CMakeLists.txt
index 0d305ba..98f8966 100644
--- a/mlir/include/mlir/Dialect/WasmSSA/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/CMakeLists.txt
@@ -5,7 +5,7 @@ mlir_tablegen(WasmSSATypeConstraints.cpp.inc -gen-type-constraint-defs)
set (LLVM_TARGET_DEFINITIONS WasmSSAInterfaces.td)
mlir_tablegen(WasmSSAInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(WasmSSAInterfaces.cpp.inc -gen-op-interface-defs)
-add_public_tablegen_target(MLIRWasmSSAInterfacesIncGen)
+add_mlir_dialect_tablegen_target(MLIRWasmSSAInterfacesIncGen)
set(LLVM_TARGET_DEFINITIONS WasmSSAOps.td)
diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
index 676621b..b80ee2c 100644
--- a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
@@ -329,11 +329,11 @@ def WasmSSA_IfOp : WasmSSA_Op<"if", [Terminator,
```mlir
// Runs the if clause is %a is non-zero
- "wasmssa.if"(%a)[^bb1] ({
+ wasmssa.if %a {
// Execute if %a is non-zero
- },{
+ } else {
// else clause
- }) : (i32) -> ()
+ }
```
}];
let arguments = (ins I32:$condition, Variadic<WasmSSA_ValType>: $inputs);
@@ -359,6 +359,7 @@ def WasmSSA_IfOp : WasmSSA_Op<"if", [Terminator,
return createBlock(getElse());
}
}];
+ let assemblyFormat = "$condition (`(`$inputs^`)` `:` type($inputs))? attr-dict `:` $if custom<ElseRegion>($else) `>` $target";
}
def WasmSSA_LocalOp : WasmSSA_Op<"local", [
@@ -445,7 +446,7 @@ def WasmSSA_MemOp : WasmSSA_Op<"memory", [Symbol]> {
```mlir
// Define the `mem_0` memory with defined bounds of 0 -> 65536
- "wasmssa.memory"() <{limits = !wasmssa<limit[0:65536]>, sym_name = "mem_0"}> : () -> ()
+ wasmssa.memory @mem_0 !wasmssa<limit[0:65536]>
```
}];
let arguments = (ins SymbolNameAttr: $sym_name,
@@ -456,6 +457,8 @@ def WasmSSA_MemOp : WasmSSA_Op<"memory", [Symbol]> {
"::llvm::StringRef":$symbol,
"wasmssa::LimitType":$limit)>
];
+
+ let assemblyFormat = "$sym_name custom<WasmVisibility>($sym_visibility) $limits attr-dict";
}
def WasmSSA_MemImportOp : WasmSSA_Op<"import_mem", [Symbol, ImportOpInterface]> {
@@ -494,6 +497,7 @@ def WasmSSA_TableOp : WasmSSA_Op<"table", [Symbol]> {
let builders = [OpBuilder<(ins
"::llvm::StringRef":$symbol,
"wasmssa::TableType":$type)>];
+ let assemblyFormat = "$sym_name custom<WasmVisibility>($sym_visibility) $type attr-dict";
}
def WasmSSA_TableImportOp : WasmSSA_Op<"import_table", [Symbol, ImportOpInterface]> {
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
index 3f8cac4..efca3cf 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
@@ -4,11 +4,14 @@ add_mlir_doc(XeGPU XeGPU Dialects/ -gen-dialect-doc -dialect=xegpu)
set(LLVM_TARGET_DEFINITIONS XeGPU.td)
mlir_tablegen(XeGPUAttrs.h.inc -gen-attrdef-decls -attrdefs-dialect=xegpu)
mlir_tablegen(XeGPUAttrs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=xegpu)
-add_public_tablegen_target(MLIRXeGPUAttrsIncGen)
-add_dependencies(mlir-headers MLIRXeGPUAttrsIncGen)
+add_mlir_dialect_tablegen_target(MLIRXeGPUAttrsIncGen)
set(LLVM_TARGET_DEFINITIONS XeGPUAttrs.td)
mlir_tablegen(XeGPUEnums.h.inc -gen-enum-decls)
mlir_tablegen(XeGPUEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRXeGPUEnumsIncGen)
-add_dependencies(mlir-headers MLIRXeGPUEnumsIncGen)
+add_mlir_dialect_tablegen_target(MLIRXeGPUEnumsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS XeGPUAttrs.td)
+mlir_tablegen(XeGPUAttrInterface.h.inc -gen-attr-interface-decls)
+mlir_tablegen(XeGPUAttrInterface.cpp.inc -gen-attr-interface-defs)
+add_mlir_dialect_tablegen_target(MLIRXeGPUAttrInterfaceIncGen)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index 8e2784f..1481859 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -11,10 +11,12 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
@@ -22,17 +24,20 @@
namespace mlir {
namespace xegpu {
class TensorDescType;
+class DistributeLayoutAttr;
+class LayoutAttr;
+class SliceAttr;
} // namespace xegpu
} // namespace mlir
+#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.h.inc>
+#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
+
#define GET_ATTRDEF_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.h.inc>
#define GET_TYPEDEF_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.h.inc>
-
-#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
-
#define GET_OP_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPU.h.inc>
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 64eb21c..cfe3e80 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -175,7 +175,67 @@ def XeGPU_FenceScopeAttr:
let assemblyFormat = "$value";
}
-def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
+def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
+ let cppNamespace = "::mlir::xegpu";
+ let description = [{
+ Common trait for all XeGPU layouts.
+ }];
+
+ let methods = [
+ InterfaceMethod<"Check the availability of workgroup level layouts",
+ "bool",
+ "isForWorkgroup">,
+ InterfaceMethod<"Check the availability of subgroup level layouts",
+ "bool",
+ "isForSubgroup">,
+ InterfaceMethod<"Get the rank of attribute",
+ "int64_t",
+ "getRank">,
+ InterfaceMethod<"Get the num of effective subgroups",
+ "int64_t",
+ "getNumSubgroups", (ins), [{
+ std::optional<SmallVector<int64_t>> sgLayout = llvm::cast<ConcreteAttr>(tablegen_opaque_val).getSgLayoutAsInt();
+ if (sgLayout.has_value())
+ return computeProduct(*sgLayout);
+ return 0;
+ }], [{}]>,
+ InterfaceMethod<"Get the SgLayout field of the attribute as integer array",
+ "SmallVector<int64_t>",
+ "getSgLayoutAsInt">,
+ InterfaceMethod<"Get the SgData field of the attribute as integer array",
+ "SmallVector<int64_t>",
+ "getSgDataAsInt">,
+ InterfaceMethod<"Get the InstData field of the attribute as integer array",
+ "SmallVector<int64_t>",
+ "getInstDataAsInt">,
+ InterfaceMethod<"Get the LaneLayout field of the attribute as integer array",
+ "SmallVector<int64_t>",
+ "getLaneLayoutAsInt">,
+ InterfaceMethod<"Get the LaneData field of the attribute as integer array",
+ "SmallVector<int64_t>",
+ "getLaneDataAsInt">,
+ InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData",
+ "xegpu::DistributeLayoutAttr",
+ "dropSgLayoutAndData">,
+ InterfaceMethod<"Derive a new layout by dropping InstData",
+ "xegpu::DistributeLayoutAttr",
+ "dropInstData">,
+ InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
+ indices based on the effective subgroup layout.}],
+ "FailureOr<SmallVector<Value>>",
+ "delinearizeSubgroupId",
+ (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
+ InterfaceMethod<[{Generates instructions to compute multidimensional offsets for blocks
+ assigned to a subgroup identified by linearId. The shape parameter
+ represents the workgroup-level problem size. Each subgroup may access
+ multiple blocks according to round-robin distribution rules.}],
+ "FailureOr<SmallVector<SmallVector<Value>>>",
+ "getOffsets",
+ (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>
+ ];
+}
+
+def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
let summary = [{
Describes the data distribution to subgroups and work-items for a tensor
specified by the tensor descriptor.
@@ -297,12 +357,12 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
];
let extraClassDeclaration = [{
- bool isWgLayout() {
+ bool isForWorkgroup() {
return getSgLayout() != nullptr;
}
- bool isSgLayout() {
- return !isWgLayout();
+ bool isForSubgroup() {
+ return !isForWorkgroup();
}
int64_t getRank() {
@@ -330,12 +390,216 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
getLaneLayout(), getLaneData(), getOrder());
}
+
+ SmallVector<int64_t> getSgLayoutAsInt() const {
+ if (DenseI32ArrayAttr layout = getSgLayout())
+ return llvm::to_vector_of<int64_t>(layout.asArrayRef());
+ return {};
+ }
+
+ SmallVector<int64_t> getSgDataAsInt() const {
+ if (DenseI32ArrayAttr data = getSgData())
+ return llvm::to_vector_of<int64_t>(data.asArrayRef());
+ return {};
+ }
+
+ SmallVector<int64_t> getInstDataAsInt() const {
+ if (DenseI32ArrayAttr inst = getInstData())
+ return llvm::to_vector_of<int64_t>(inst.asArrayRef());
+ return {};
+ }
+
+ SmallVector<int64_t> getLaneLayoutAsInt() const {
+ if (DenseI32ArrayAttr layout = getLaneLayout())
+ return llvm::to_vector_of<int64_t>(layout.asArrayRef());
+ return {};
+ }
+
+ SmallVector<int64_t> getLaneDataAsInt() const {
+ if (DenseI32ArrayAttr data = getLaneData())
+ return llvm::to_vector_of<int64_t>(data.asArrayRef());
+ return {};
+ }
+
+ /// Delinearizes a linear subgroup ID into its multidimensional indices
+ /// based on the effective subgroup layout.
+ FailureOr<SmallVector<Value>>
+ delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
+
+ /// Generates instructions to compute multidimensional offsets for blocks
+ /// assigned to a subgroup identified by linearId. The shape parameter
+ /// represents the workgroup-level problem size. Each subgroup may access
+ /// multiple blocks according to round-robin distribution rules.
+ FailureOr<SmallVector<SmallVector<Value>>>
+ getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
+
}];
let assemblyFormat = "`<` struct(params) `>`";
let genVerifyDecl = 1;
}
+
+def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
+ let summary = [{Describes the data distribution and sharing among subgroups or work-items.}];
+
+ let description = [{
+ Like LayoutAttr, SliceAttr describes data distribution among subgroups or work-items.
+ However, whereas LayoutAttr requires the data to have the same rank as the attribute,
+ SliceAttr permits the data to have a lower rank. In this case, compute units in the
+ specified dimensions (given by `$dims`) share the data, provided that the remaining
+ ranks match the data rank. SliceAttr is commonly used by operations such as
+ vector.multi_reduction and vector.broadcast.
+
+ Example:
+ ```
+ #l = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>
+ #r = #xegpu.slice<#l, dim = [0]>
+
+ %exp = math.exp %input {layout_result_0 = #l}: vector<256x128xf32>
+ %red = vector.multi_reduction<add>, %exp, %acc [0] {layout_result_0 = #r}: vector<256x128xf32> to vector<128xf32>
+ %bcast = vector.broadcast %red {layout_result_0 = #l} : vector<128xf32> to vector<256x128xf32>
+ ```
+ In this example, %red is conceptually divided into 4 vectors of type vector<32xf32>, each assigned to
+ a group of subgroups. Each group consists of 8 subgroups from the same column of sg_layout, sharing a
+ single reduction result of type vector<32xf32>.
+
+ }];
+
+ let parameters = (ins
+ "xegpu::DistributeLayoutAttr": $parent,
+ "DenseI64ArrayAttr": $dims
+ );
+
+ let extraClassDeclaration = [{
+
+ int64_t getRank() const {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ return parent.getRank() - attr.getDims().size();
+ }
+
+ DenseI32ArrayAttr getOrder() const {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ return parent.getOrder();
+ }
+
+ bool isForWorkgroup() const {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ return parent.isForWorkgroup();
+ }
+
+ bool isForSubgroup() const {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ return parent.isForSubgroup();
+ }
+
+ /// Returns the SgLayout of the attribute, computed by applying
+ /// the slice dimensions to the underlying LayoutAttr.
+ SmallVector<int64_t> getSgLayoutAsInt() const {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ auto layout = parent.getSgLayoutAsInt();
+ if (layout.size()) {
+ ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
+ return XeGPUDialect::slice(ArrayRef<int64_t>(layout), dims);
+ }
+ return {};
+ }
+
+ /// Returns the SgData of the attribute, computed by applying
+ /// the slice dimensions to the underlying LayoutAttr.
+ SmallVector<int64_t> getSgDataAsInt() const {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ auto data = parent.getSgDataAsInt();
+ if (data.size()) {
+ ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
+ return XeGPUDialect::slice(ArrayRef<int64_t>(data), dims);
+ }
+ return {};
+ }
+
+ /// Returns the InstData of the attribute, computed by applying
+ /// the slice dimensions to the underlying LayoutAttr.
+ SmallVector<int64_t> getInstDataAsInt() const {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ auto inst = parent.getInstDataAsInt();
+ if (inst.size()) {
+ ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
+ return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(inst), dims);
+ }
+ return {};
+ }
+
+ /// Returns the LaneLayout of the attribute, computed by applying
+ /// the slice dimensions to the underlying LayoutAttr.
+ SmallVector<int64_t> getLaneLayoutAsInt() const {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ auto layout = parent.getLaneLayoutAsInt();
+ if (layout.size()) {
+ ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
+ return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(layout), dims);
+ }
+ return {};
+ }
+
+ /// Returns the LaneData of the attribute, computed by applying
+ /// the slice dimensions to the underlying LayoutAttr.
+ SmallVector<int64_t> getLaneDataAsInt() const {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ auto data = parent.getLaneDataAsInt();
+ if (data.size()) {
+ ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
+ return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(data), dims);
+ }
+ return {};
+ }
+
+ SliceAttr dropSgLayoutAndData() {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ parent = parent.dropSgLayoutAndData();
+ return SliceAttr::get(getContext(), parent, attr.getDims());
+ }
+
+ SliceAttr dropInstData() {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ parent = parent.dropInstData();
+ return SliceAttr::get(getContext(), parent, attr.getDims());
+ }
+
+ /// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
+ /// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
+ /// it will coalese two slice operations and return a simplified SliceAttr
+ /// #xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0, 1]>
+ SliceAttr flatten() const;
+
+ /// Delinearizes a linear subgroup ID into its multidimensional indices
+ /// based on the effective subgroup layout.
+ FailureOr<SmallVector<Value>>
+ delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
+
+ /// Generates instructions to compute multidimensional offsets for blocks
+ /// assigned to a subgroup identified by linearId. The shape parameter
+ /// represents the workgroup-level problem size. Each subgroup may access
+ /// multiple blocks according to round-robin distribution rules.
+ FailureOr<SmallVector<SmallVector<Value>>>
+ getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
+
+ }];
+
+ let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`";
+ let genVerifyDecl = 1;
+}
+
def XeGPU_RangeAttr : XeGPUAttr<"Range", "range"> {
let summary = [{Specifies a half-open range}];
let description = [{
@@ -365,4 +629,34 @@ def XeGPU_RangeAttr : XeGPUAttr<"Range", "range"> {
let genVerifyDecl = 1;
}
+def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
+ let summary = [{Specifies memory layouts with named attributes.}];
+
+ let description = [{
+ This attribute stores a collection of named attributes that describe
+ memory layout properties such as stride, block, etc.
+ }];
+
+ let parameters = (ins "DictionaryAttr": $attrs);
+ let hasCustomAssemblyFormat = 1;
+
+ let extraClassDeclaration = [{
+ /// Get a specific attribute by name
+ Attribute getAttr(StringRef name) const {
+ return getAttrs().get(name);
+ }
+
+ /// Check if a specific attribute exists
+ bool hasAttr(StringRef name) const {
+ return getAttrs().contains(name);
+ }
+
+ ArrayAttr getStrides() {
+ return getAttrs().getAs<ArrayAttr>("stride");
+ }
+
+ }];
+
+}
+
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
index 549018b..c173b93 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
@@ -40,7 +40,19 @@ def XeGPU_Dialect : Dialect {
let extraClassDeclaration = [{
/// Checks if the given shape can be evenly distributed based on the layout
/// and data factors provided by the LayoutAttr.
- static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::LayoutAttr attr);
+ static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::DistributeLayoutAttr attr);
+
+ /// drops/slices the shape in the specified dims, and return the rest. e.g.,
+ /// for shape = [32, 64, 8], dims = [0, 2], it will return [64]
+ template<typename T, typename U>
+ static llvm::SmallVector<T> slice(llvm::ArrayRef<T> shape, llvm::ArrayRef<U> dims) {
+ llvm::SmallVector<T> result;
+ for (auto [i, v]: llvm::enumerate(shape)) {
+ if (!llvm::is_contained(dims, i))
+ result.push_back(v);
+ }
+ return result;
+ }
}];
}
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 75b16a87..73f9061 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -29,7 +29,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
void printProperties(::mlir::MLIRContext *ctx,
::mlir::OpAsmPrinter &p, const Properties &prop,
::mlir::ArrayRef<::llvm::StringRef> elidedProps) {
-
+
DictionaryAttr propAttr = dyn_cast_if_present<mlir::DictionaryAttr>(getPropertiesAsAttr(ctx, prop));
// filter out the elidedProps from propAttr, and get the resultAttr
@@ -43,7 +43,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
}
if (!filteredAttrs.empty()) {
- p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
+ p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
}
}
@@ -60,8 +60,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
}
-def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface,
- AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface]> {
+def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface, AttrSizedOperandSegments]> {
let summary = "Create nd-tensor descriptor operation";
let description = [{
@@ -71,28 +70,32 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
future). Elements in the subview continuous in each dimension. It encodes the
following important information for supporting Intel hardware features:
- * source: an object representing (starting address/pointer of) a memory region.
+ Arguments:
+ - `source`: an object representing (starting address/pointer of) a memory region.
It can be either a memref object, or simply a pointer represented by uint64_t type.
For the case of dynamic memrefs or pointer, the shape and layout information of the
memory region should be explicitly passed via `shape` and `strides` parameters.
- * offsets: index values represents offsets from the "source" at the each dimension
+ - `offsets`: index values represents offsets from the "source" at the each dimension
at which the subview of the target memory will be created. It is encoded via
"offsets" and "const_offsets", such that it can accept various forms, such as,
operands (e.g., [%c0, %c]) and attributes (e.g., [2, 4]).
- * shape: the shape information of the memory region pointed by the "source". It is
+ - `shape`: the shape information of the memory region pointed by the "source". It is
typically encoded via the MemRefType of the source, e.g., memref<4096x4096xf16>.
But if "source" is simply a pointer represented as uint64_t type, or a memref
type without shape information e.g., memref<?x?xf16>, the shape information has
to be explicitly passed via the "shape" and "const_shape" arguments.
- * strides: the strides of the memory region pointed by the "source". Similar to shape,
+ - `strides`: the strides of the memory region pointed by the "source". Similar to shape,
it is typically encoded via the MemRefType of the source too. But if "source" is
simply a pointer represented as uint64_t type, or a memref type without shape
information e.g., memref<?x?xf16>, the strides information has to be explicitly
passed via the "strides" and "const_strides" argument.
+ Results:
+ - `res`: nd tensor descriptor
+
Example 1 (suppose the tensor shape inferred by the compiler is 8x16):
```mlir
%0 = memref.alloc() : memref<1024x1024xf32>
@@ -143,11 +146,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
let builders = [
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source)>,
- OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
- "llvm::ArrayRef<OpFoldResult>": $shape,
- "llvm::ArrayRef<OpFoldResult>": $strides)>,
-
- OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
+ OpBuilder<(ins "Type": $tdesc, "Value ": $source,
"llvm::ArrayRef<OpFoldResult>": $shape,
"llvm::ArrayRef<OpFoldResult>": $strides)>,
@@ -181,82 +180,38 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
return getType().getShape();
}
- /// wrapper for matching with OffsetSizeAndStrideOpInterface
- OperandRange getSizes() {
- return getShape();
+ SmallVector<OpFoldResult> getMixedOffsets() {
+ auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
+ auto dynamics = getOffsets();
+ if (statics.size() == 0 && dynamics.size() == 0)
+ return {};
+ return getMixedValues(statics, dynamics, getContext());
}
- ArrayRef<int64_t> getStaticOffsets(){
- auto attr = getConstOffsetsAttr();
+ SmallVector<OpFoldResult> getMixedSizes() {
+ SmallVector<int64_t> statics;
- if (attr)
- return attr;
+ /// Get the static sizes/shape, the value passed to const_shape
+ /// will overide the value in memref shape.
+ if (auto memrefTy = llvm::dyn_cast<MemRefType>(getSourceType()))
+ statics = llvm::to_vector(memrefTy.getShape());
+ if (auto attr = getConstShapeAttr())
+ statics = llvm::to_vector(attr.asArrayRef());
- int64_t rank = getMixedSizes().size();
-
- setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, 0));
-
- attr = getConstOffsetsAttr();
- return attr;
+ return getMixedValues(statics, getShape(), getContext());
}
- /// wrapper for matching with OffsetSizeAndStrideOpInterface
- /// If source is IntegerType or `const_shape` is filled,
- /// it will return `const_shape`, such that mixes of `shape`
- /// and `const_shape` will be used to represent the shape of
- /// source operand. They overide static shape from source memref type.
- ArrayRef<int64_t> getStaticSizes() {
- /// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
- static llvm::SmallVector<int64_t, 4> emptyShape;
-
- auto attr = getConstShapeAttr();
- if (attr)
- return attr;
-
- if (llvm::isa<IntegerType>(getSourceType()))
- return emptyShape;
-
- auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
- assert(memrefType && "Incorrect use of getStaticSizes");
- return memrefType.getShape();
- }
+ SmallVector<OpFoldResult> getMixedStrides() {
+ SmallVector<int64_t> statics;
- /// wrapper for matching with OffsetSizeAndStrideOpInterface
- /// If source is IntegerType or `const_strides` is filled, it
- /// will return `const_strides`, such that mixes of `strides`
- /// and `const_strides` will be used to represent the strides of
- /// source operand. They overide static strides from source memref type.
- ArrayRef<int64_t> getStaticStrides() {
- /// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
- static llvm::SmallVector<int64_t, 4> emptyStrides;
-
- auto attr = getConstStridesAttr();
- if (attr)
- return attr;
-
- if (llvm::isa<IntegerType>(getSourceType()))
- return emptyStrides;
-
- auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
- assert(memrefType && "Incorrect use of getStaticStrides");
- auto [strides, _] = memrefType.getStridesAndOffset();
- // reuse the storage of ConstStridesAttr since strides from
- // memref is not persistant
- setConstStrides(strides);
- attr = getConstStridesAttr();
- return attr;
- }
+ /// Get the static strides, the value passed to const_strides
+ /// will overide the value in memref.
+ if (auto memrefTy = llvm::dyn_cast<MemRefType>(getSourceType()))
+ statics = memrefTy.getStridesAndOffset().first;
+ if (auto attr = getConstStridesAttr())
+ statics = llvm::to_vector(attr.asArrayRef());
- /// Return the expected rank of each of the`static_offsets`,
- /// `static_shape` and `static_strides` attributes.
- std::array<unsigned, 3> getArrayAttrMaxRanks() {
- unsigned rank;
- if (auto ty = llvm::dyn_cast<MemRefType>(getSourceType())) {
- rank = ty.getRank();
- } else {
- rank = (unsigned)getMixedOffsets().size();
- }
- return {rank, rank, rank};
+ return getMixedValues(statics, getStrides(), getContext());
}
/// Return the number of leading operands before the `offsets`,
@@ -281,6 +236,14 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
return static_cast<unsigned>(MemorySpace::Global);
}
+ xegpu::DistributeLayoutAttr getLayoutAttr() {
+ return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getType().getLayout());
+ }
+
+ ArrayRef<int64_t> getDataShape() {
+ return getTensorDescShape();
+ }
+
}];
}
@@ -311,18 +274,40 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
xegpu::TensorDescType getTensorDescType() {
return getTensorDesc().getType();
}
+
+ SmallVector<OpFoldResult> getMixedOffsets() {
+ auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
+ auto dynamics = getOffsets();
+ if (statics.size() == 0 && dynamics.size() == 0)
+ return {};
+ return getMixedValues(statics, dynamics, getContext());
+ }
+
+ xegpu::DistributeLayoutAttr getLayoutAttr() {
+ return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getTensorDescType().getLayout());
+ }
+
+ ArrayRef<int64_t> getDataShape() {
+ return getTensorDescType().getShape();
+ }
+
}];
let assemblyFormat = [{
- $TensorDesc ``
- custom<OptionalDynamicIndexList>($offsets, $const_offsets)
+ $TensorDesc ``
+ custom<OptionalDynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `:` qualified(type($TensorDesc))
}];
let builders = [
- OpBuilder<(ins "Value": $TensorDesc,
- "xegpu::CachePolicyAttr": $l1_hint,
- "xegpu::CachePolicyAttr": $l2_hint,
+ OpBuilder<(ins "Value": $TensorDesc,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l3_hint)>,
+ OpBuilder<(ins "Value": $TensorDesc,
+ "ArrayRef<OpFoldResult>": $offsets,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];
@@ -370,7 +355,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
Variadic<Index>: $offsets,
- OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
+ OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<UnitAttr>: $packed,
OptionalAttr<DenseI64ArrayAttr>: $transpose,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
@@ -387,19 +372,43 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
xegpu::TensorDescType getTensorDescType() {
return getTensorDesc().getType();
}
+
+ SmallVector<OpFoldResult> getMixedOffsets() {
+ auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
+ auto dynamics = getOffsets();
+ if (statics.size() == 0 && dynamics.size() == 0)
+ return {};
+ return getMixedValues(statics, dynamics, getContext());
+ }
+
+ xegpu::DistributeLayoutAttr getLayoutAttr() {
+ return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getTensorDescType().getLayout());
+ }
+
+ ArrayRef<int64_t> getDataShape() {
+ return getTensorDescType().getShape();
+ }
+
+
}];
let assemblyFormat = [{
- $TensorDesc ``
- custom<OptionalDynamicIndexList>($offsets, $const_offsets)
+ $TensorDesc ``
+ custom<OptionalDynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)
}];
let builders = [
- OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
+ OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
+ "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l3_hint)>,
+ OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
+ "ArrayRef<OpFoldResult>": $offsets,
"UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
- "xegpu::CachePolicyAttr": $l1_hint,
- "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];
@@ -442,7 +451,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
let arguments = (ins XeGPU_ValueType: $value,
XeGPU_TensorDesc: $TensorDesc,
Variadic<Index>: $offsets,
- OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
+ OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -455,20 +464,42 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
xegpu::TensorDescType getTensorDescType() {
return getTensorDesc().getType();
}
+
+ SmallVector<OpFoldResult> getMixedOffsets() {
+ auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
+ auto dynamics = getOffsets();
+ if (statics.size() == 0 && dynamics.size() == 0)
+ return {};
+ return getMixedValues(statics, dynamics, getContext());
+ }
+
+ xegpu::DistributeLayoutAttr getLayoutAttr() {
+ return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getTensorDescType().getLayout());
+ }
+
+ ArrayRef<int64_t> getDataShape() {
+ return getTensorDescType().getShape();
+ }
+
}];
let assemblyFormat = [{
- $value `,`
- $TensorDesc ``
- custom<OptionalDynamicIndexList>($offsets, $const_offsets)
+ $value `,`
+ $TensorDesc ``
+ custom<OptionalDynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `:` type($value) `,` qualified(type($TensorDesc))
}];
let builders = [
- OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
- "xegpu::CachePolicyAttr": $l1_hint,
- "xegpu::CachePolicyAttr": $l2_hint,
- "xegpu::CachePolicyAttr": $l3_hint)>
+ OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l3_hint)>,
+ OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
+ "ArrayRef<OpFoldResult>": $offsets,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l3_hint)>
];
@@ -533,12 +564,17 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
(scattered) subviews, allowing each work-item in a subgroup specifying their own offset.
It accepts the following parameters:
- * source: a 1D memref or pointer (uint64_t) represents the flattened memory object.
- * offsets: a vector containing offsets of each access point. Its size
+ Arguments:
+ - `source`: a 1D memref or pointer (i64, i32, ui64, ui32) represents the flattened
+ memory object.
+ - `offsets`: a vector containing offsets of each access point. Its size
is fixed to the hardware supportted subgroup size, e.g., 16 on PVC,
implying each element in the vector corresponds to a work-item (SIMT lane)
in the subgroup.
+ Results:
+ - `res`: scattered tensor descriptor
+
The first dimension of the result TensorDesc corresponds to work-items, so it should
match the dimension of offsets. It may also has a second dimension corresponding to
the chunk_size if the chunk size is larger than 1.
@@ -569,8 +605,8 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
```
}];
- let arguments = (ins XeGPU_BaseAddrType: $source,
- XeGPU_OffsetType: $offsets);
+ let arguments = (ins XeGPU_GatherScatterBaseAddrType:$source,
+ XeGPU_OffsetType:$offsets);
let results = (outs XeGPU_TensorDesc:$TensorDesc);
let builders = [
@@ -628,6 +664,18 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
As compared to prefetch_nd, which works on non-scattered TensorDesc,
it works on scattered TensorDesc instead.
+ Arguments:
+ - `source`: represents the memory region to be loaded from, which can be either a
+ tensor_desc or a 1D memref or pointer (ui64, ui32, i64 or i32).
+ In case of tensor_desc, offsets come from the producer create_tdesc op.
+ tensor_desc cannot be used in SIMT mode.
+ - `offsets`: represents offsets from source. required if `source` in not a TensorDescType.
+ offsets is a vector of `index` type and vector length is either the subgroup size
+ or 1 in SIMT mode. scalar offset is also valid for SIMT mode.
+ - `l1_hint`, `l2_hint`, `l3_hint`: are optional cache hints for each level of cache.
+ - `offset_align_byte`: required if `source` is a pointer. If `source` is not a pointer,
+ it is not allowed. Represents the alignment in bytes of each offset in offsets.
+
Example 1:
```mlir
xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>,
@@ -635,12 +683,12 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
l3_hint = #xegpu.cache_hint<cached>}
: !xegpu.tensor_desc<16xf16>
```
-
+
Example 2:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc".
- The source operand could be a raw pointer (uint64_t).
- Please refer to create_tdesc for the restriction of memref.
+ The source operand could be a raw pointer (ui64, ui32, i64, i32).
+ Please refer to create_tdesc for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
@@ -650,13 +698,33 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
: memref<1024xf32>, vector<4xindex>
```
+ Example 3 (SIMT mode):
+ SIMT mode only accepts the offsets variant.
+ ```mlir
+ xegpu.prefetch %0[%1] {l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<cached>,
+ l3_hint = #xegpu.cache_hint<cached>}
+ : memref<256xf32>, vector<1xindex>
+ ```
+
+ Example 4 (SIMT mode):
+ SIMT mode only accepts the offsets variant.
+ ```mlir
+ xegpu.prefetch %0[%1] {l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<cached>,
+ l3_hint = #xegpu.cache_hint<cached>,
+ offset_align_byte = 2}
+ : i64, vector<1xindex>
+ ```
+
}];
- let arguments = (ins XeGPU_GatherScatterSourceType: $source,
- Optional<XeGPU_OffsetType>: $offsets,
- OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
- OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
- OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
+ let arguments = (ins XeGPU_GatherScatterSourceType:$source,
+ Optional<AnyTypeOf<[XeGPU_OffsetType, Index]>>:$offsets,
+ OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
+ OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
+ OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
+ OptionalAttr<I64Attr>:$offset_align_byte);
let extraClassDeclaration = extraBaseClassDeclaration # [{
Type getSourceType() {
@@ -673,19 +741,20 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
xegpu::TensorDescType getTensorDescType() {
return dyn_cast<xegpu::TensorDescType>(getSourceType());
}
+
}];
let assemblyFormat = [{
- $source
+ $source
(`[` $offsets^ `]`)?
prop-dict
- attr-dict `:` type(operands)
+ attr-dict `:` type(operands)
}];
-
+
let builders = [
OpBuilder<(ins "Value": $source,
- "xegpu::CachePolicyAttr": $l1_hint,
- "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];
@@ -703,8 +772,26 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
The mask operand masks out memory access so that it is safe to pass out-of-boundary
addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.
- In SIMT mode, the result vector represents the data to be loaded by each work-item.
- Each work-item recieves a `chunk_size` number of elements.
+ In SIMT mode, the result is a 1D vector that represents the data to be loaded by
+ each work-item. If size is not 1, size should be equal to the chunk size,
+
+ Arguments:
+ - `source`: represents the memory region to be loaded from, which can be either a
+ tensor_desc or a 1D memref or pointer (ui64, ui32, i64 or i32).
+ In case of tensor_desc, offsets come from the producer create_tdesc op.
+ tensor_desc cannot be used in SIMT mode.
+ - `offsets`: represents offsets from source. required if `source` in not a TensorDescType.
+ offsets is a vector of `index` type and vector length is either the subgroup size
+ or 1 in SIMT mode. scalar offset is also valid for SIMT mode.
+ - `mask`: is a vector of `i1` type, which is used to mask out the memory access.
+ mask is a vector of size equal to the subgroup size, or 1 in SIMT mode.
+ scalar mask is also valid for SIMT mode.
+ - `chunk_size`: (optional) represents contiguous number of elements to load from per work item.
+ - `l1_hint`, `l2_hint`, `l3_hint`: are optional cache hints for each level of cache.
+
+ Results:
+ - `res`: represents loaded data
+
Example 1:
```mlir
@@ -723,21 +810,12 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
vector<16xi1> -> vector<16x8xf32>
```
-
- Example 3 (SIMT mode):
- ```mlir
- %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
- l2_hint = #xegpu.cache_hint<uncached>,
- l3_hint = #xegpu.cache_hint<uncached>}>
- : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
- vector<16xi1> -> vector<8xf32>
- ```
-
- Example 4:
+
+ Example 3:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
- The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
- for the restriction of memref.
+ The source operand could be a raw pointer (ui64, ui32, i64, i32). Please refer to create_tdesc
+ for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%offsets = vector.step : vector<16xindex>
@@ -748,16 +826,25 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
: memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
```
+ Example 4 (SIMT mode):
+ SIMT mode only accepts the offsets variant. chunk_size can be inferred from result
+ type. In this example, chunk_size is 8.
+ ```mlir
+ %2 = xegpu.load %1[%2], %0 <{l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<uncached>,
+ l3_hint = #xegpu.cache_hint<uncached>}>
+ : memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<8xf32>
+ ```
+
}];
- let arguments = (ins XeGPU_GatherScatterSourceType: $source,
- Optional<XeGPU_OffsetType>: $offsets,
- XeGPU_MaskType: $mask,
- OptionalAttr<I64Attr>: $chunk_size,
- OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
- OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
- OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
- let results = (outs XeGPU_ValueType: $value);
+ let arguments = (ins XeGPU_GatherScatterSourceType:$source,
+ Optional<AnyTypeOf<[XeGPU_OffsetType, Index]>>:$offsets,
+ AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr<I64Attr>:$chunk_size,
+ OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
+ OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
+ OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);
+ let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$value);
let extraClassDeclaration = extraBaseClassDeclaration # [{
@@ -794,14 +881,20 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
let assemblyFormat = [{
$source
(`[` $offsets^ `]`)? `,`
- $mask prop-dict
+ $mask prop-dict
attr-dict `:` type(operands) `->` type($value)
}];
let builders = [
OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
- "xegpu::CachePolicyAttr": $l1_hint,
- "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l3_hint)>,
+ OpBuilder<(ins "Type": $value, "Value": $source,
+ "ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
+ "IntegerAttr": $chunk_size,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];
@@ -810,15 +903,31 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
let summary = "store data to scattered memory locations.";
- let description = [{ It (aka. store) stores data to scattered memory locations. The value is
+ let description =
+ [{ It (aka. store) stores data to scattered memory locations. The value is
typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
a 2D vector instead. For the later case, dim-1 of the value correspods to the simd lanes
and the dim-0 of the value corresponds to the chunk size stored per lane. So `store_scatter`
has transpose effect, which is similar to `load_gather`. Therefore, a transpose attribute is
introduced on purpose, making sure users are aware of this implicit transformation.
- In SIMT mode, the input vector represents the data to be stored by each work-item.
- Each work-item stores a `chunk_size` number of elements.
+ In SIMT mode, the result is a 1D vector that represents the data to be stored by
+ each work-item. If size is not 1, size should be equal to the chunk size.
+
+ Arguments:
+ - `value`: represents the data to be stored.
+ - `dest`: represents the memory region to be stored to, which can be either a
+ tensor_desc or a 1D memref or pointer (ui64, ui32, i64 or i32).
+ In case of tensor_desc, offsets come from the producer create_tdesc op.
+ tensor_desc cannot be used in SIMT mode.
+ - `offsets`: represents offsets from dest. required if `source` in not a TensorDescType.
+ offsets is a vector of `index` type and vector length is either the subgroup size
+ or 1 in SIMT mode. scalar offset is also valid for SIMT mode.
+ - `mask`: is a vector of `i1` type, which is used to mask out the memory access.
+ mask is a vector of size equal to the subgroup size, or 1 in SIMT mode.
+ scalar mask is also valid for SIMT mode.
+ - `chunk_size`: (optional) represents contiguous number of elements to store to per work item.
+ - `l1_hint`, `l2_hint`, `l3_hint`: are optional cache hints for each level of cache.
Example 1:
```mlir
@@ -836,19 +945,11 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
: vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>>, vector<16xi1>
```
- Example 3 (SIMT mode):
- ```mlir
- xegpu.store %0, %1, %2 <{l1_hint = #xegpu.cache_hint<uncached>,
- l2_hint = #xegpu.cache_hint<write_back>,
- l3_hint = #xegpu.cache_hint<write_through>}>
- : vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>> vector<16xi1>
- ```
-
- Example 4:
+ Example 3:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
The dest operand could be a raw pointer (uint64_t).
- Please refer to create_tdesc for the restriction of memref.
+ Please refer to create_tdesc for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%val = arith.constant dense<0.0> : vector<16xf32>
@@ -860,19 +961,27 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
: memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
```
+ Example 4 (SIMT mode):
+ SIMT mode only accepts the offsets variant. chunk_size can be inferred from value
+ type. In this example, chunk_size is 8.
+ ```mlir
+ xegpu.store %0, %1[%2], %3 <{l1_hint = #xegpu.cache_hint<uncached>,
+ l2_hint = #xegpu.cache_hint<write_back>,
+ l3_hint = #xegpu.cache_hint<write_through>}>
+ : vector<8xf32>, memref<256xf32>, vector<1xindex>, vector<1xi1>
+ ```
+
}];
- let arguments = (ins
- XeGPU_ValueType: $value,
- XeGPU_GatherScatterSourceType: $dest,
- Optional<XeGPU_OffsetType>: $offsets,
- XeGPU_MaskType: $mask,
- OptionalAttr<I64Attr>: $chunk_size,
- OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
- OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
- OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
+ let arguments = (ins AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$value,
+ XeGPU_GatherScatterSourceType:$dest,
+ Optional<AnyTypeOf<[XeGPU_OffsetType, Index]>>:$offsets,
+ AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr<I64Attr>:$chunk_size,
+ OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
+ OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
+ OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);
- let extraClassDeclaration = extraBaseClassDeclaration # [{
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
Type getDestType() {
return getDest().getType();
}
@@ -888,6 +997,11 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
return dyn_cast<xegpu::TensorDescType>(getDestType());
}
+ mlir::Type getElementType() {
+ auto type = getValue().getType();
+ return getElementTypeOrSelf(type);
+ }
+
VectorType getValueType() {
return llvm::dyn_cast<VectorType>(getValue().getType());
}
@@ -901,15 +1015,21 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
$value `,`
$dest
(`[` $offsets^ `]`)? `,`
- $mask
- prop-dict
+ $mask
+ prop-dict
attr-dict `:` type(operands)
}];
let builders = [
OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
- "xegpu::CachePolicyAttr": $l1_hint,
- "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l3_hint)>,
+ OpBuilder<(ins "Value": $value, "Value": $dest,
+ "ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
+ "IntegerAttr": $chunk_size,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];
@@ -1134,8 +1254,8 @@ def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [Pure, AllTypesMatch<["sou
the IR is lowered to WI level because that is the end result of all distributions.
}];
let arguments = (ins XeGPU_VectorType: $source,
- XeGPU_LayoutAttr: $input_layout,
- XeGPU_LayoutAttr: $target_layout);
+ DistributeLayoutAttr: $input_layout,
+ DistributeLayoutAttr: $target_layout);
let results = (outs XeGPU_VectorType: $result);
let assemblyFormat = [{
$source prop-dict attr-dict `:` type($source)
@@ -1146,4 +1266,161 @@ def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [Pure, AllTypesMatch<["sou
let hasCanonicalizer = 1;
}
+def isSharedPred : CPred<"isSharedMemory(llvm::cast<mlir::MemRefType>($_self))">;
+class StaticShared1DMemRefOf<list<Type> allowedTypes> :
+ ConfinedType<MemRefRankOf<allowedTypes, [1]>, [HasStaticShapePred, isSharedPred],
+ "statically shaped " # MemRefOf<allowedTypes>.summary # " for shared memory",
+ "mlir::MemRefType">;
+
+class SizeInBits<string name> :
+ StrFunc<"llvm::cast<mlir::ShapedType>($" # name # ".getType()).getNumElements()"
+ "*llvm::cast<mlir::ShapedType>($" # name # ".getType()).getElementTypeBitWidth()">;
+class AllMemSizesMatch<list<string> names> :
+ AllMatchSameOperatorTrait<names, SizeInBits<"_self">.result,
+ "size in bits">;
+
+def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
+ AllMemSizesMatch<["source", "mem_desc"]>]> {
+ let summary = "Create a memory descriptor.";
+ let description = [{
+ Creates a memory descriptor from a shared local memory (SLM) buffer, and xegpu
+ specific memory layout. The resulting memory descriptor has to have the same size
+ as the underlying shared local memory.
+
+ Arguments:
+ - `source` : a 1D statically shaped memref with element type i8, representing the raw SLM buffer.
+ Results:
+ - `mem_desc` : the memory descriptor.
+ }];
+ let arguments = (ins StaticShared1DMemRefOf<[I8]>:$source);
+ let results = (outs XeGPU_MemDesc:$mem_desc);
+ let assemblyFormat = "$source prop-dict attr-dict `` `:` type($source) `->` qualified(type($mem_desc))";
+}
+
+def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
+ AllElementTypesMatch<["mem_desc", "res"]>,
+ AllRanksMatch<["mem_desc", "res"]>]> {
+ let arguments = (ins XeGPU_MemDesc:$mem_desc,
+ Variadic<Index>: $offsets,
+ DenseI64ArrayAttr: $const_offsets,
+ OptionalAttr<DistributeLayoutAttr>:$layout
+ );
+ let results = (outs XeGPU_ValueType:$res);
+ let assemblyFormat = [{
+ $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
+ prop-dict attr-dict `` `:` type(operands) `->` type(results)
+ }];
+
+ let description = [{
+ This operation loads a 2D block of data from shared local memory (SLM) as specified
+ by the provided 2D `mem_desc`. Only 2D memory descriptors are supported; use the
+ subview operation to obtain a compatible 2D `mem_desc` from a higher-rank descriptor if needed.
+
+ Arguments:
+ - `mem_desc`: the memory descriptor identifying the SLM region.
+ - `offsets`: the coordinates within the matrix to read from.
+ - `layout`: [optional] An attribute for guiding distributions among
+ subgroups and/or work-items. It currently can accept either
+ LayoutAttr or SliceAttr.
+ Results:
+ - `res`: the matrix elements loaded from SLM.
+ }];
+
+ let builders = [
+ OpBuilder<(ins "Type":$res, "TypedValue<MemDescType>": $mem_desc,
+ "llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttr": $layout)>,
+ ];
+ let extraClassDeclaration = [{
+ SmallVector<OpFoldResult> getMixedOffsets() {
+ return getMixedValues(getConstOffsets(), getOffsets(), getContext());
+ }
+
+ ArrayRef<int64_t> getDataShape() {
+ return getRes().getType().getShape();
+ }
+ }];
+
+ let hasVerifier = 1;
+}
+
+def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
+ AllElementTypesMatch<["mem_desc", "data"]>,
+ AllRanksMatch<["mem_desc", "data"]>]> {
+ let arguments = (ins
+ XeGPU_ValueType:$data,
+ XeGPU_MemDesc:$mem_desc,
+ Variadic<Index>: $offsets,
+ DenseI64ArrayAttr: $const_offsets,
+ OptionalAttr<DistributeLayoutAttr>:$layout
+ );
+ let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
+ prop-dict attr-dict `` `:` type(operands)}];
+ let description = [{
+ This operation stores a 2D `data` fragment into the shared local memory region
+ specified by a 2D `mem_desc`. Only 2D memory descriptors are supported; use the
+ subview operation to obtain a 2D `mem_desc` from a higher-rank descriptor if needed.
+
+ Arguments:
+ - `mem_desc`: the memory descriptor specifying the SLM region.
+ - `offsets`: the coordinates within the matrix where the data will be written.
+ - `data`: the values to be stored in the matrix.
+ - `layout`: [optional] An attribute for guiding distributions among
+ subgroups and/or work-items. It currently can accept either
+ LayoutAttr or SliceAttr.
+ }];
+ let builders = [
+ OpBuilder<(ins "Value" : $data, "TypedValue<MemDescType>": $mem_desc,
+ "llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttr": $layout)>,
+ ];
+ let extraClassDeclaration = [{
+ SmallVector<OpFoldResult> getMixedOffsets() {
+ return getMixedValues(getConstOffsets(), getOffsets(), getContext());
+ }
+
+ ArrayRef<int64_t> getDataShape() {
+ return getData().getType().getShape();
+ }
+
+ }];
+
+ let hasVerifier = 1;
+}
+
+def XeGPU_MemDescSubviewOp: XeGPU_Op<"mem_desc_subview",
+ [Pure, ViewLikeOpInterface, AllElementTypesMatch<["src", "res"]>]> {
+ let description = [{
+ Creates a subview of a memory descriptor. The resulting memory descriptor can have
+ a lower rank than the source; in this case, the result dimensions correspond to the
+ higher-order dimensions of the source memory descriptor.
+
+ Arguments:
+ - `src` : a memory descriptor.
+ - `offsets` : the coordinates within the matrix the subview will be created from.
+
+ Results:
+ - `res` : a memory descriptor with smaller size.
+
+ }];
+ let arguments = (ins XeGPU_MemDesc:$src,
+ Variadic<Index>:$offsets,
+ DenseI64ArrayAttr:$const_offsets);
+ let results = (outs XeGPU_MemDesc:$res);
+ let assemblyFormat = [{$src `` custom<DynamicIndexList>($offsets, $const_offsets) prop-dict
+ attr-dict `` `:` qualified(type($src)) `->` qualified(type($res))}];
+ let builders = [
+ OpBuilder<(ins "Type": $res, "Value":$src, "llvm::ArrayRef<OpFoldResult>": $offsets)>
+ ];
+
+ let extraClassDeclaration = [{
+ mlir::Value getViewSource() { return getSrc(); }
+
+ SmallVector<OpFoldResult> getMixedOffsets() {
+ return getMixedValues(getConstOffsets(), getOffsets(), getContext());
+ }
+ }];
+
+ let hasVerifier = 1;
+}
+
+
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index b268cab..84902b2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -16,13 +16,17 @@ include "mlir/IR/BuiltinTypes.td"
def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64, UI1, UI8, UI16, UI32, UI64]>;
def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
-def XeGPU_BaseAddrType: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64, UI32, I64, I32]>;
+def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>;
+def XeGPU_BaseAddrType
+ : AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, XeGPU_PointerType]>;
def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;
def XeGPU_VectorType: VectorOfRankAndType<[1,2,3,4,5,6], [XeGPU_ScalarType]>;
+def XeGPU_GatherScatterBaseAddrType
+ : AnyTypeOf<[MemRefRankOf<[XeGPU_ScalarType], [1]>, XeGPU_PointerType]>;
// common base class for types in XeGPU dialect
class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
@@ -189,7 +193,8 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
let genVerifyDecl = 1;
}
-def XeGPU_GatherScatterSourceType : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>;
+def XeGPU_GatherScatterSourceType
+ : AnyTypeOf<[XeGPU_TensorDesc, XeGPU_GatherScatterBaseAddrType]>;
def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";
@@ -201,4 +206,53 @@ def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
}];
}
+def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "mlir::Type"> {
+ let summary = "MemDesc describing the data in SLM";
+ let description = [{
+ MemDesc represents a block of data stored in shared local memory.
+ By default, unless a layout attribute is provided, the data is stored
+ contiguously in row-major order within the region.
+
+ Examples:
+ ```mlir
+ // A multi-dimensional array stored in column-major order.
+ !xegpu.mem_desc<128x128xf16, #xegpu.mem_layout<stride = [1, 128]>>
+
+ // A multi-dimensional array stored in a blocked layout. Elements within the same block
+ // are stored contiguously in memory. Blocks are stored in row-major order.
+ !xegpu.mem_desc<128x128xf16, #xegpu.mem_layout<block = [8, 8]>>
+
+ // A multi-dimensional array stored in column-major order with blocked layout.
+ !xegpu.mem_desc<128x128xf16, #xegpu.mem_layout<stride = [1, 128], block = [8, 8]>>
+ ```
+ }];
+ let parameters = (ins ArrayRefParameter<"int64_t">: $shape,
+ "mlir::Type": $elementType,
+ OptionalParameter<"MemLayoutAttr">: $mem_layout);
+
+ let extraClassDeclaration = [{
+ bool hasRank() const { return true; }
+
+ MemDescType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape, Type elementType) const {
+ return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout());
+ }
+
+ ArrayAttr getStrides() {
+ auto layout = getMemLayout();
+ if (layout && layout.hasAttr("stride")) {
+ return layout.getStrides();
+ }
+
+ // derive and return default strides
+ SmallVector<int64_t> defaultStrides;
+ llvm::append_range(defaultStrides, getShape().drop_front());
+ llvm::append_values(defaultStrides, 1);
+ Builder builder(getContext());
+ return builder.getI64ArrayAttr(defaultStrides);
+ }
+ }];
+
+ let hasCustomAssemblyFormat = true;
+}
+
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUTYPES_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/Transforms/CMakeLists.txt
index 9de7e87..77ca255 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -1,6 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name XeGPU)
-add_public_tablegen_target(MLIRXeGPUPassIncGen)
-add_dependencies(mlir-headers MLIRXeGPUPassIncGen)
+add_mlir_dialect_tablegen_target(MLIRXeGPUPassIncGen)
add_mlir_doc(Passes XeGPUPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 488f358f..bad734d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -9,7 +9,9 @@
#ifndef MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_
#define MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
namespace mlir {
class VectorType;
@@ -20,6 +22,7 @@ class ValueRange;
class TypeConverter;
namespace xegpu {
+class DistributeLayoutAttr;
class LayoutAttr;
class TensorDescType;
} // namespace xegpu
@@ -59,22 +62,33 @@ FailureOr<VectorType> getDistributedVectorType(xegpu::TensorDescType tdescTy);
FailureOr<VectorType> getDistributedVectorType(VectorType originalType,
LayoutAttr layout);
-/// Return the attribute name for the OpOperand to attach LayoutAttr
+/// Return the attribute name for the OpOperand to attach DistributeLayoutAttr
std::string getLayoutName(const OpOperand &operand);
-/// Return the attribute name for the OpResult to attach LayoutAttr
+/// Return the attribute name for the OpResult to attach DistributeLayoutAttr
std::string getLayoutName(const OpResult result);
-/// Retrieves the LayoutAttr associated with a given Value. For TensorDescType
-/// values, the LayoutAttr is extracted from the TensorDescType itself. For
-/// other values, it is obtained from the attributes of the defining operation.
-/// Returns nullptr if no LayoutAttr is found.
-LayoutAttr getLayoutAttr(const Value value);
+/// Retrieves the DistributeLayoutAttr associated with a given Value. For
+/// TensorDescType values, the DistributeLayoutAttr is extracted from the
+/// TensorDescType itself. For other values, it is obtained from the attributes
+/// of the defining operation. Returns nullptr if no DistributeLayoutAttr is
+/// found.
+DistributeLayoutAttr getDistributeLayoutAttr(const Value value);
-/// Retrieves the LayoutAttr associated with a given OpOperand. It will
-/// first check the operand_layout_{id} of the owner operation. If not found,
-/// it will check the operand itself and its defining op.
-LayoutAttr getLayoutAttr(const OpOperand &opr);
+template <typename AttrTy>
+AttrTy getDistributeLayoutAttrOfType(const Value value) {
+ return dyn_cast_if_present<AttrTy>(getDistributeLayoutAttr(value));
+}
+
+/// Retrieves the DistributeLayoutAttr associated with a given OpOperand. It
+/// will first check the operand_layout_{id} of the owner operation. If not
+/// found, it will check the operand itself and its defining op.
+DistributeLayoutAttr getDistributeLayoutAttr(const OpOperand &opr);
+
+template <typename AttrTy>
+AttrTy getDistributeLayoutAttrOfType(const OpOperand &opr) {
+ return dyn_cast_if_present<AttrTy>(getDistributeLayoutAttr(opr));
+}
/// Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
template <typename T,
@@ -82,23 +96,24 @@ template <typename T,
std::is_same_v<T, OpResult>>>
void removeLayoutAttr(const T &operandOrResult);
-/// Removes the LayoutAttr for each OpOperand and OpResult of the given
-/// operation if they exist. If the operation contains regions, it is also
+/// Removes the DistributeLayoutAttr for each OpOperand and OpResult of the
+/// given operation if they exist. If the operation contains regions, it is also
/// applied recursively to the contained operations
void removeLayoutAttrs(Operation *op);
-/// Sets the LayoutAttr for a given OpOperand or OpResult by attaching
+/// Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching
/// it to the owner's dictionary attributes
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
std::is_same_v<T, OpResult>>>
-void setLayoutAttr(const T &operandOrResult, const LayoutAttr layout);
+void setDistributeLayoutAttr(const T &operandOrResult,
+ const DistributeLayoutAttr layout);
-/// Set the LayoutAttr for each OpOperand and OpResult of the given operation.
-/// If the operation contains regions, it is also applied recursively to the
-/// contained operations
-void setLayoutAttrs(Operation *op,
- function_ref<LayoutAttr(Value)> getLayoutImpl);
+/// Set the DistributeLayoutAttr for each OpOperand and OpResult of the given
+/// operation. If the operation contains regions, it is also applied recursively
+/// to the contained operations
+void setDistributeLayoutAttrs(
+ Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl);
/// Extract a set of small vectors from a value with a given shape using
/// vector.extract_stride_slice
@@ -123,6 +138,25 @@ Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
void doSCFStructuralTypeConversionWithTensorType(Operation *op,
TypeConverter converter);
+/// Retrieves the chip string from the XeVM target attribute of the parent
+/// GPU module operation. Returns the chip identifier if found, or nullopt
+/// if no GPU module parent or XeVM target attribute exists.
+std::optional<std::string> getChipStr(Operation *op);
+
+/// Generates element-wise addition ops of two arrays with automatic alignment.
+/// When the input arrays have different sizes, the shorter array is
+/// right-aligned with the longer array, and the unmatched leading elements from
+/// the longer array are preserved unchanged. This is commonly used for offset
+/// computation where higher-dimensional offsets need to be added to
+/// lower-dimensional adjustments.
+///
+/// Example:
+/// lhs = [l1, l2, l3], rhs = [r1, r2]
+/// Result: [11, l2+r1, l3+r2]
+SmallVector<OpFoldResult> addWithRightAligned(OpBuilder &builder, Location loc,
+ ArrayRef<OpFoldResult> lhs,
+ ArrayRef<OpFoldResult> rhs);
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
index 96ccebc..5bd71d6 100644
--- a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
+++ b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
@@ -227,6 +227,13 @@ public:
llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
symbolMap);
+ /// Initialize the ExecutionEngine. Global constructors specified by
+ /// `llvm.mlir.global_ctors` will be run. One common scenario is that kernel
+ /// binary compiled from `gpu.module` gets loaded during initialization. Make
+ /// sure all symbols are resolvable before initialization by calling
+ /// `registerSymbols` or including shared libraries.
+ void initialize();
+
private:
/// Ordering of llvmContext and jit is important for destruction purposes: the
/// jit must be destroyed before the context.
@@ -250,6 +257,8 @@ private:
/// Destroy functions in the libraries loaded by the ExecutionEngine that are
/// called when this ExecutionEngine is destructed.
SmallVector<LibraryDestroyFn> destroyFns;
+
+ bool isInitialized = false;
};
} // namespace mlir
diff --git a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
index 6e72f7c..d66d757 100644
--- a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
@@ -151,7 +151,7 @@ public:
AllocFunType allocFun = &::malloc,
std::function<void(StridedMemRefType<T, Rank>)> freeFun =
[](StridedMemRefType<T, Rank> descriptor) {
- ::free(descriptor.data);
+ ::free(descriptor.basePtr);
})
: freeFunc(freeFun) {
if (shapeAlloc.empty())
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index e486bb62..85ce66f 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -205,12 +205,14 @@ public:
}
/// Return an iterator range over the operation within this block excluding
- /// the terminator operation at the end.
+ /// the terminator operation at the end. If the block has no terminator,
+ /// return an iterator range over the entire block. If it is unknown if the
+ /// block has a terminator (i.e., last block operation is unregistered), also
+ /// return an iterator range over the entire block.
iterator_range<iterator> without_terminator() {
if (begin() == end())
return {begin(), end()};
- auto endIt = --end();
- return {begin(), endIt};
+ return without_terminator_impl();
}
//===--------------------------------------------------------------------===//
@@ -221,7 +223,8 @@ public:
/// the block might have a valid terminator operation.
Operation *getTerminator();
- /// Check whether this block might have a terminator.
+ /// Return "true" if this block might have a terminator. Return "true" if
+ /// the last operation is unregistered.
bool mightHaveTerminator();
//===--------------------------------------------------------------------===//
@@ -402,6 +405,9 @@ public:
void printAsOperand(raw_ostream &os, AsmState &state);
private:
+ /// Same as `without_terminator`, but assumes that the block is not empty.
+ iterator_range<iterator> without_terminator_impl();
+
/// Pair of the parent object that owns this block and a bit that signifies if
/// the operations within this block have a valid ordering.
llvm::PointerIntPair<Region *, /*IntBits=*/1, bool> parentValidOpOrderPair;
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 846547f..683e2fe 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -8,37 +8,36 @@ mlir_tablegen(OpAsmOpInterface.h.inc -gen-op-interface-decls)
mlir_tablegen(OpAsmOpInterface.cpp.inc -gen-op-interface-defs)
mlir_tablegen(OpAsmTypeInterface.h.inc -gen-type-interface-decls)
mlir_tablegen(OpAsmTypeInterface.cpp.inc -gen-type-interface-defs)
-add_public_tablegen_target(MLIROpAsmInterfaceIncGen)
-add_dependencies(mlir-generic-headers MLIROpAsmInterfaceIncGen)
+add_mlir_generic_tablegen_target(MLIROpAsmInterfaceIncGen)
set(LLVM_TARGET_DEFINITIONS BuiltinAttributes.td)
mlir_tablegen(BuiltinAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(BuiltinAttributes.cpp.inc -gen-attrdef-defs)
-add_public_tablegen_target(MLIRBuiltinAttributesIncGen)
+add_mlir_generic_tablegen_target(MLIRBuiltinAttributesIncGen)
set(LLVM_TARGET_DEFINITIONS BuiltinAttributeInterfaces.td)
mlir_tablegen(BuiltinAttributeInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(BuiltinAttributeInterfaces.cpp.inc -gen-attr-interface-defs)
-add_public_tablegen_target(MLIRBuiltinAttributeInterfacesIncGen)
+add_mlir_generic_tablegen_target(MLIRBuiltinAttributeInterfacesIncGen)
set(LLVM_TARGET_DEFINITIONS BuiltinDialect.td)
mlir_tablegen(BuiltinDialect.h.inc -gen-dialect-decls)
mlir_tablegen(BuiltinDialect.cpp.inc -gen-dialect-defs)
-add_public_tablegen_target(MLIRBuiltinDialectIncGen)
+add_mlir_generic_tablegen_target(MLIRBuiltinDialectIncGen)
set(LLVM_TARGET_DEFINITIONS BuiltinDialectBytecode.td)
mlir_tablegen(BuiltinDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Builtin")
-add_public_tablegen_target(MLIRBuiltinDialectBytecodeIncGen)
+add_mlir_generic_tablegen_target(MLIRBuiltinDialectBytecodeIncGen)
set(LLVM_TARGET_DEFINITIONS BuiltinLocationAttributes.td)
mlir_tablegen(BuiltinLocationAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(BuiltinLocationAttributes.cpp.inc -gen-attrdef-defs)
-add_public_tablegen_target(MLIRBuiltinLocationAttributesIncGen)
+add_mlir_generic_tablegen_target(MLIRBuiltinLocationAttributesIncGen)
set(LLVM_TARGET_DEFINITIONS BuiltinOps.td)
mlir_tablegen(BuiltinOps.h.inc -gen-op-decls)
mlir_tablegen(BuiltinOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRBuiltinOpsIncGen)
+add_mlir_generic_tablegen_target(MLIRBuiltinOpsIncGen)
set(LLVM_TARGET_DEFINITIONS BuiltinTypes.td)
mlir_tablegen(BuiltinTypes.h.inc -gen-typedef-decls)
@@ -46,17 +45,17 @@ mlir_tablegen(BuiltinTypes.cpp.inc -gen-typedef-defs)
add_public_tablegen_target(MLIRBuiltinTypesIncGen)
mlir_tablegen(BuiltinTypeConstraints.h.inc -gen-type-constraint-decls)
mlir_tablegen(BuiltinTypeConstraints.cpp.inc -gen-type-constraint-defs)
-add_public_tablegen_target(MLIRBuiltinTypeConstraintsIncGen)
+add_mlir_generic_tablegen_target(MLIRBuiltinTypeConstraintsIncGen)
set(LLVM_TARGET_DEFINITIONS BuiltinTypeInterfaces.td)
mlir_tablegen(BuiltinTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(BuiltinTypeInterfaces.cpp.inc -gen-type-interface-defs)
-add_public_tablegen_target(MLIRBuiltinTypeInterfacesIncGen)
+add_mlir_generic_tablegen_target(MLIRBuiltinTypeInterfacesIncGen)
set(LLVM_TARGET_DEFINITIONS TensorEncoding.td)
mlir_tablegen(TensorEncInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(TensorEncInterfaces.cpp.inc -gen-attr-interface-defs)
-add_public_tablegen_target(MLIRTensorEncodingIncGen)
+add_mlir_generic_tablegen_target(MLIRTensorEncodingIncGen)
add_mlir_doc(BuiltinAttributes BuiltinAttributes Dialects/ -gen-attrdef-doc)
add_mlir_doc(BuiltinLocationAttributes BuiltinLocationAttributes Dialects/ -gen-attrdef-doc)
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index 18da85a..e1869c1 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -800,6 +800,10 @@ def IntPowerOf2 : AttrConstraint<
CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getValue().isPowerOf2()">,
"whose value is a power of two > 0">;
+def IntPositivePowerOf2 : AllAttrOf<[IntPositive, IntPowerOf2]>;
+
+class IntValidAlignment<Attr attr>: ConfinedAttr<attr, [IntPositivePowerOf2]>;
+
class ArrayMaxCount<int n> : AttrConstraint<
CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() <= " # n>,
"with at most " # n # " elements">;
diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td
index ff6cec6..7eba68f 100644
--- a/mlir/include/mlir/IR/EnumAttr.td
+++ b/mlir/include/mlir/IR/EnumAttr.td
@@ -289,7 +289,7 @@ class IntEnum<string name, string summary, list<EnumCase> cases, int width>
class I32Enum<string name, string summary, list<EnumCase> cases>
: IntEnum<name, summary, cases, 32>;
class I64Enum<string name, string summary, list<EnumCase> cases>
- : IntEnum<name, summary, cases, 32>;
+ : IntEnum<name, summary, cases, 64>;
// An enum attribute backed by IntegerAttr.
//
diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index ef8dab8..9690029 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -34,6 +34,9 @@ class MLIRContextImpl;
class RegisteredOperationName;
class StorageUniquer;
class IRUnit;
+namespace remark::detail {
+class RemarkEngine;
+} // namespace remark::detail
/// MLIRContext is the top-level object for a collection of MLIR operations. It
/// holds immortal uniqued objects like types, and the tables used to unique
@@ -212,6 +215,13 @@ public:
/// Returns the diagnostic engine for this context.
DiagnosticEngine &getDiagEngine();
+ /// Returns the remark engine for this context, or nullptr if none has been
+ /// set.
+ remark::detail::RemarkEngine *getRemarkEngine();
+
+ /// Set the remark engine for this context.
+ void setRemarkEngine(std::unique_ptr<remark::detail::RemarkEngine> engine);
+
/// Returns the storage uniquer used for creating affine constructs.
StorageUniquer &getAffineUniquer();
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 9e5fb56..af8c072 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -603,6 +603,51 @@ class RangedTypesMatchWith<string summary, string lhsArg, string rhsArg,
string transform>
: TypesMatchWith<summary, lhsArg, rhsArg, transform, "llvm::equal">;
+// Checks that each inputArg has the same type as the corresponding entry
+// in allowedTypes
+class InputMatchesTypes<list<string> inputArgs, list<Type> allowedTypes> :
+ PredOpTrait<"operands {" # !interleave(inputArgs, ", ") # "} match expected types",
+ !foldl(TruePred, !range(!size(inputArgs)), acc, i,
+ And<[acc,
+ SubstLeaves<"$_self", "$" # inputArgs[i] # ".getType()",
+ allowedTypes[i].predicate>
+ ]>)> {
+ assert !eq(!size(inputArgs), !size(allowedTypes)),
+ "inputArgs and allowedTypes lists must have the same length";
+
+ list<string> inputArgList = inputArgs;
+ list<Type> allowedTypeList = allowedTypes;
+}
+
+// Checks that inputArgs match one of the allowed type combinations.
+// Each combination in allowedCombinations must have the same number of types
+// as there are inputArgs.
+class InputAddressIsCombinationOf<list<string> inputArgs,
+ list<list<Type>> allowedCombinations,
+ string description = ""> :
+ PredOpTrait<!if(!empty(description),
+ "operands {" # !interleave(inputArgs, ", ") # "} match one of the allowed type combinations",
+ description),
+ Or<!foreach(combination, allowedCombinations,
+ !foldl(TruePred, !range(!size(inputArgs)), acc, i,
+ And<[acc,
+ SubstLeaves<"$_self", "$" # inputArgs[i] # ".getType()",
+ combination[i].predicate>
+ ]>))>> {
+ assert !gt(!size(allowedCombinations), 0),
+ "allowedCombinations must not be empty";
+
+ // Validate that each combination has the same number of types as inputArgs
+ defvar inputArgSize = !size(inputArgs);
+ defvar validSizes = !foldl(1, allowedCombinations, acc, combination,
+ !and(acc, !eq(inputArgSize, !size(combination))));
+ assert validSizes,
+ "each combination in allowedCombinations must have the same length as inputArgs";
+
+ list<string> inputArgList = inputArgs;
+ list<list<Type>> allowedCombinationList = allowedCombinations;
+}
+
// Type Constraint operand `idx`'s Element type is `type`.
class TCopVTEtIs<int idx, Type type> : And<[
CPred<"$_op.getNumOperands() > " # idx>,
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 4f89f8b..5569392c 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -1106,6 +1106,8 @@ inline raw_ostream &operator<<(raw_ostream &os, const Operation &op) {
/// useful to act as a "stream modifier" to customize printing an operation
/// with a stream using the operator<< overload, e.g.:
/// llvm::dbgs() << OpWithFlags(op, OpPrintingFlags().skipRegions());
+/// This always prints the operation with the local scope, to avoid introducing
+/// spurious newlines in the stream.
class OpWithFlags {
public:
OpWithFlags(Operation *op, OpPrintingFlags flags = {})
@@ -1116,11 +1118,11 @@ public:
private:
Operation *op;
OpPrintingFlags theFlags;
- friend raw_ostream &operator<<(raw_ostream &os, const OpWithFlags &op);
+ friend raw_ostream &operator<<(raw_ostream &os, OpWithFlags op);
};
-inline raw_ostream &operator<<(raw_ostream &os,
- const OpWithFlags &opWithFlags) {
+inline raw_ostream &operator<<(raw_ostream &os, OpWithFlags opWithFlags) {
+ opWithFlags.flags().useLocalScope();
opWithFlags.op->print(os, opWithFlags.flags());
return os;
}
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index b5a93a0..57e73c1 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -311,14 +311,14 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
/// opposed to a raw Operation.
template <typename SourceOp>
struct OpRewritePattern
- : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+ : public mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp> {
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching and a list of generated
/// ops.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1,
ArrayRef<StringRef> generatedNames = {})
- : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
+ : mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp>(
SourceOp::getOperationName(), benefit, context, generatedNames) {}
};
@@ -327,10 +327,10 @@ struct OpRewritePattern
/// of a raw Operation.
template <typename SourceOp>
struct OpInterfaceRewritePattern
- : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+ : public mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp> {
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
- : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
+ : mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp>(
Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
benefit, context) {}
};
diff --git a/mlir/include/mlir/IR/Properties.td b/mlir/include/mlir/IR/Properties.td
index a6221f9..a7ade06 100644
--- a/mlir/include/mlir/IR/Properties.td
+++ b/mlir/include/mlir/IR/Properties.td
@@ -773,9 +773,10 @@ class OptionalProp<Property p, bit canDelegateParsing = 1>
}];
let writeToMlirBytecode = [{
$_writer.writeOwnedBool($_storage.has_value());
- if (!$_storage.has_value())
- return;
- }] # !subst("$_storage", "(*($_storage))", p.writeToMlirBytecode);
+ if ($_storage.has_value()) {
+ }] # !subst("$_storage", "(*($_storage))", p.writeToMlirBytecode) # [{
+ }
+ }];
let hashProperty = !if(!empty(p.hashProperty), p.hashProperty,
[{ hash_value($_storage.has_value() ? std::optional<::llvm::hash_code>{}] #
diff --git a/mlir/include/mlir/IR/Remarks.h b/mlir/include/mlir/IR/Remarks.h
new file mode 100644
index 0000000..26d6547
--- /dev/null
+++ b/mlir/include/mlir/IR/Remarks.h
@@ -0,0 +1,520 @@
+//===- Remarks.h - MLIR Optimization Remark ----------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines utilities for emitting optimization remarks.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_REMARKS_H
+#define MLIR_IR_REMARKS_H
+
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/IR/DiagnosticInfo.h"
+#include "llvm/Remarks/Remark.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/Regex.h"
+#include <optional>
+
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Value.h"
+
+namespace mlir::remark {
+
+/// Define an the set of categories to accept. By default none are, the provided
+/// regex matches against the category names for each kind of remark.
+struct RemarkCategories {
+ std::optional<std::string> passed, missed, analysis, failed;
+};
+
+/// Categories describe the outcome of an transformation, not the mechanics of
+/// emitting/serializing remarks.
+enum class RemarkKind {
+ RemarkUnknown = 0,
+
+ /// An optimization was applied.
+ RemarkPassed,
+
+ /// A profitable optimization opportunity was found but not applied.
+ RemarkMissed,
+
+ /// The compiler attempted the optimization but failed (e.g., legality
+ /// checks, or better opportunites).
+ RemarkFailure,
+
+ /// Informational context (e.g., analysis numbers) without a pass/fail
+ /// outcome.
+ RemarkAnalysis,
+};
+
+using namespace llvm;
+
+/// Options to create a Remark
+struct RemarkOpts {
+ StringRef remarkName; // Identifiable name
+ StringRef categoryName; // Category name (subject to regex filtering)
+ StringRef subCategoryName; // Subcategory name
+ StringRef functionName; // Function name if available
+
+ // Construct RemarkOpts from a remark name.
+ static constexpr RemarkOpts name(StringRef n) {
+ return RemarkOpts{n, {}, {}, {}};
+ }
+ /// Return a copy with the category set.
+ constexpr RemarkOpts category(StringRef v) const {
+ return {remarkName, v, subCategoryName, functionName};
+ }
+ /// Return a copy with the subcategory set.
+ constexpr RemarkOpts subCategory(StringRef v) const {
+ return {remarkName, categoryName, v, functionName};
+ }
+ /// Return a copy with the function name set.
+ constexpr RemarkOpts function(StringRef v) const {
+ return {remarkName, categoryName, subCategoryName, v};
+ }
+};
+
+} // namespace mlir::remark
+
+namespace mlir::remark::detail {
+//===----------------------------------------------------------------------===//
+// Remark Base Class
+//===----------------------------------------------------------------------===//
+class Remark {
+
+public:
+ Remark(RemarkKind remarkKind, DiagnosticSeverity severity, Location loc,
+ RemarkOpts opts)
+ : remarkKind(remarkKind), functionName(opts.functionName), loc(loc),
+ categoryName(opts.categoryName), subCategoryName(opts.subCategoryName),
+ remarkName(opts.remarkName) {
+ if (!categoryName.empty() && !subCategoryName.empty()) {
+ (llvm::Twine(categoryName) + ":" + subCategoryName)
+ .toStringRef(fullCategoryName);
+ }
+ }
+
+ // Remark argument that is a key-value pair that can be printed as machine
+ // parsable args.
+ struct Arg {
+ std::string key;
+ std::string val;
+ Arg(llvm::StringRef m) : key("Remark"), val(m) {}
+ Arg(llvm::StringRef k, llvm::StringRef v) : key(k), val(v) {}
+ Arg(llvm::StringRef k, std::string v) : key(k), val(std::move(v)) {}
+ Arg(llvm::StringRef k, const char *v) : Arg(k, llvm::StringRef(v)) {}
+ Arg(llvm::StringRef k, Value v);
+ Arg(llvm::StringRef k, Type t);
+ Arg(llvm::StringRef k, bool b) : key(k), val(b ? "true" : "false") {}
+
+ // One constructor for all arithmetic types except bool.
+ template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T> &&
+ !std::is_same_v<T, bool>>>
+ Arg(llvm::StringRef k, T v) : key(k) {
+ if constexpr (std::is_floating_point_v<T>) {
+ llvm::raw_string_ostream os(val);
+ os << v;
+ } else if constexpr (std::is_signed_v<T>) {
+ val = llvm::itostr(static_cast<long long>(v));
+ } else {
+ val = llvm::utostr(static_cast<unsigned long long>(v));
+ }
+ }
+ };
+
+ void insert(llvm::StringRef s);
+ void insert(Arg a);
+
+ void print(llvm::raw_ostream &os, bool printLocation = false) const;
+
+ Location getLocation() const { return loc; }
+ /// Diagnostic -> Remark
+ llvm::remarks::Remark generateRemark() const;
+
+ StringRef getFunction() const {
+ if (!functionName.empty())
+ return functionName;
+ return "<unknown function>";
+ }
+
+ llvm::StringRef getCategoryName() const { return categoryName; }
+
+ llvm::StringRef getFullCategoryName() const {
+ if (categoryName.empty() && subCategoryName.empty())
+ return {};
+ if (subCategoryName.empty())
+ return categoryName;
+ if (categoryName.empty())
+ return subCategoryName;
+ return fullCategoryName;
+ }
+
+ StringRef getRemarkName() const {
+ if (remarkName.empty())
+ return "<unknown remark name>";
+ return remarkName;
+ }
+
+ std::string getMsg() const;
+
+ ArrayRef<Arg> getArgs() const { return args; }
+
+ llvm::remarks::Type getRemarkType() const;
+
+ StringRef getRemarkTypeString() const;
+
+protected:
+ /// Keeps the MLIR diagnostic kind, which is used to determine the
+ /// diagnostic kind in the LLVM remark streamer.
+ RemarkKind remarkKind;
+ /// Name of the convering function like interface
+ StringRef functionName;
+
+ Location loc;
+ /// Sub category passname e.g., "Unroll" or "UnrollAndJam"
+ StringRef categoryName;
+
+ /// Sub category name "Loop Optimizer"
+ StringRef subCategoryName;
+
+ /// Combined name for category and sub-category
+ SmallString<64> fullCategoryName;
+
+ /// Remark identifier
+ StringRef remarkName;
+
+ /// Args collected via the streaming interface.
+ SmallVector<Arg, 4> args;
+
+private:
+ /// Convert the MLIR diagnostic severity to LLVM diagnostic severity.
+ static llvm::DiagnosticSeverity
+ makeLLVMSeverity(DiagnosticSeverity severity) {
+ switch (severity) {
+ case DiagnosticSeverity::Note:
+ return llvm::DiagnosticSeverity::DS_Note;
+ case DiagnosticSeverity::Warning:
+ return llvm::DiagnosticSeverity::DS_Warning;
+ case DiagnosticSeverity::Error:
+ return llvm::DiagnosticSeverity::DS_Error;
+ case DiagnosticSeverity::Remark:
+ return llvm::DiagnosticSeverity::DS_Remark;
+ }
+ llvm_unreachable("Unknown diagnostic severity");
+ }
+ /// Convert the MLIR remark kind to LLVM diagnostic kind.
+ static llvm::DiagnosticKind makeLLVMKind(RemarkKind remarkKind) {
+ switch (remarkKind) {
+ case RemarkKind::RemarkUnknown:
+ return llvm::DiagnosticKind::DK_Generic;
+ case RemarkKind::RemarkPassed:
+ return llvm::DiagnosticKind::DK_OptimizationRemark;
+ case RemarkKind::RemarkMissed:
+ return llvm::DiagnosticKind::DK_OptimizationRemarkMissed;
+ case RemarkKind::RemarkFailure:
+ return llvm::DiagnosticKind::DK_OptimizationFailure;
+ case RemarkKind::RemarkAnalysis:
+ return llvm::DiagnosticKind::DK_OptimizationRemarkAnalysis;
+ }
+ llvm_unreachable("Unknown diagnostic kind");
+ }
+};
+
+inline Remark &operator<<(Remark &r, StringRef s) {
+ r.insert(s);
+ return r;
+}
+inline Remark &&operator<<(Remark &&r, StringRef s) {
+ r.insert(s);
+ return std::move(r);
+}
+inline Remark &operator<<(Remark &r, const Remark::Arg &kv) {
+ r.insert(kv);
+ return r;
+}
+
+//===----------------------------------------------------------------------===//
+// Shorthand aliases for different kinds of remarks.
+//===----------------------------------------------------------------------===//
+
+template <RemarkKind K, DiagnosticSeverity S>
+class OptRemarkBase final : public Remark {
+public:
+ explicit OptRemarkBase(Location loc, RemarkOpts opts)
+ : Remark(K, S, loc, opts) {}
+};
+
+using OptRemarkAnalysis =
+ OptRemarkBase<RemarkKind::RemarkAnalysis, DiagnosticSeverity::Remark>;
+
+using OptRemarkPass =
+ OptRemarkBase<RemarkKind::RemarkPassed, DiagnosticSeverity::Remark>;
+
+using OptRemarkMissed =
+ OptRemarkBase<RemarkKind::RemarkMissed, DiagnosticSeverity::Remark>;
+
+using OptRemarkFailure =
+ OptRemarkBase<RemarkKind::RemarkFailure, DiagnosticSeverity::Remark>;
+
+class RemarkEngine;
+
+//===----------------------------------------------------------------------===//
+// InFlightRemark
+//===----------------------------------------------------------------------===//
+
+/// Lazy text building for zero cost string formatting.
+struct LazyTextBuild {
+ llvm::StringRef key;
+ std::function<std::string()> thunk;
+};
+
+/// InFlightRemark is a RAII class that holds a reference to a Remark
+/// instance and allows to build the remark using the << operator. The remark
+/// is emitted when the InFlightRemark instance is destroyed, which happens
+/// when the scope ends or when the InFlightRemark instance is moved.
+/// Similar to InFlightDiagnostic, but for remarks.
+class InFlightRemark {
+public:
+ explicit InFlightRemark(std::unique_ptr<Remark> diag)
+ : remark(std::move(diag)) {}
+
+ InFlightRemark(RemarkEngine &eng, std::unique_ptr<Remark> diag)
+ : owner(&eng), remark(std::move(diag)) {}
+
+ InFlightRemark() = default; // empty ctor
+
+ InFlightRemark &operator<<(const LazyTextBuild &l) {
+ if (remark)
+ *remark << Remark::Arg(l.key, l.thunk());
+ return *this;
+ }
+
+ // Generic path, but *not* for Lazy
+ template <typename T, typename = std::enable_if_t<
+ !std::is_same_v<std::decay_t<T>, LazyTextBuild>>>
+ InFlightRemark &operator<<(T &&arg) {
+ if (remark)
+ *remark << std::forward<T>(arg);
+ return *this;
+ }
+
+ explicit operator bool() const { return remark != nullptr; }
+
+ ~InFlightRemark();
+
+ InFlightRemark(const InFlightRemark &) = delete;
+ InFlightRemark &operator=(const InFlightRemark &) = delete;
+ InFlightRemark(InFlightRemark &&) = default;
+ InFlightRemark &operator=(InFlightRemark &&) = default;
+
+private:
+ RemarkEngine *owner{nullptr};
+ std::unique_ptr<Remark> remark;
+};
+
+//===----------------------------------------------------------------------===//
+// MLIR Remark Streamer
+//===----------------------------------------------------------------------===//
+
+/// Base class for MLIR remark streamers that is used to stream
+/// optimization remarks to the underlying remark streamer. The derived classes
+/// should implement the `streamOptimizationRemark` method to provide the
+/// actual streaming implementation.
+class MLIRRemarkStreamerBase {
+public:
+ virtual ~MLIRRemarkStreamerBase() = default;
+ /// Stream an optimization remark to the underlying remark streamer. It is
+ /// called by the RemarkEngine to stream the optimization remarks.
+ ///
+ /// It must be overridden by the derived classes to provide
+ /// the actual streaming implementation.
+ virtual void streamOptimizationRemark(const Remark &remark) = 0;
+
+ virtual void finalize() {} // optional
+};
+
+//===----------------------------------------------------------------------===//
+// Remark Engine (MLIR Context will own this class)
+//===----------------------------------------------------------------------===//
+
+class RemarkEngine {
+private:
+ /// Regex that filters missed optimization remarks: only matching one are
+ /// reported.
+ std::optional<llvm::Regex> missFilter;
+ /// The category for passed optimization remarks.
+ std::optional<llvm::Regex> passedFilter;
+ /// The category for analysis remarks.
+ std::optional<llvm::Regex> analysisFilter;
+ /// The category for failed optimization remarks.
+ std::optional<llvm::Regex> failedFilter;
+ /// The MLIR remark streamer that will be used to emit the remarks.
+ std::unique_ptr<MLIRRemarkStreamerBase> remarkStreamer;
+ /// When is enabled, engine also prints remarks as mlir::emitRemarks.
+ bool printAsEmitRemarks = false;
+
+ /// Return true if missed optimization remarks are enabled, override
+ /// to provide different implementation.
+ bool isMissedOptRemarkEnabled(StringRef categoryName) const;
+
+ /// Return true if passed optimization remarks are enabled, override
+ /// to provide different implementation.
+ bool isPassedOptRemarkEnabled(StringRef categoryName) const;
+
+ /// Return true if analysis optimization remarks are enabled, override
+ /// to provide different implementation.
+ bool isAnalysisOptRemarkEnabled(StringRef categoryName) const;
+
+ /// Return true if analysis optimization remarks are enabled, override
+ /// to provide different implementation.
+ bool isFailedOptRemarkEnabled(StringRef categoryName) const;
+
+ /// Return true if any type of remarks are enabled for this pass.
+ bool isAnyRemarkEnabled(StringRef categoryName) const {
+ return isMissedOptRemarkEnabled(categoryName) ||
+ isPassedOptRemarkEnabled(categoryName) ||
+ isFailedOptRemarkEnabled(categoryName) ||
+ isAnalysisOptRemarkEnabled(categoryName);
+ }
+
+ /// Emit a remark using the given maker function, which should return
+ /// a Remark instance. The remark will be emitted using the main
+ /// remark streamer.
+ template <typename RemarkT, typename... Args>
+ InFlightRemark makeRemark(Args &&...args);
+
+ template <typename RemarkT>
+ InFlightRemark emitIfEnabled(Location loc, RemarkOpts opts,
+ bool (RemarkEngine::*isEnabled)(StringRef)
+ const);
+
+public:
+ /// Default constructor is deleted, use the other constructor.
+ RemarkEngine() = delete;
+
+ /// Constructs Remark engine with optional category names. If a category
+ /// name is not provided, it is not enabled. The category names are used to
+ /// filter the remarks that are emitted.
+ RemarkEngine(bool printAsEmitRemarks, const RemarkCategories &cats);
+
+ /// Destructor that will close the output file and reset the
+ /// main remark streamer.
+ ~RemarkEngine();
+
+ /// Setup the remark engine with the given output path and format.
+ LogicalResult initialize(std::unique_ptr<MLIRRemarkStreamerBase> streamer,
+ std::string *errMsg);
+
+ /// Report a remark.
+ void report(const Remark &&remark);
+
+ /// Report a successful remark, this will create an InFlightRemark
+ /// that can be used to build the remark using the << operator.
+ InFlightRemark emitOptimizationRemark(Location loc, RemarkOpts opts);
+
+ /// Report a missed optimization remark
+ /// that can be used to build the remark using the << operator.
+ InFlightRemark emitOptimizationRemarkMiss(Location loc, RemarkOpts opts);
+
+ /// Report a failed optimization remark, this will create an InFlightRemark
+ /// that can be used to build the remark using the << operator.
+ InFlightRemark emitOptimizationRemarkFailure(Location loc, RemarkOpts opts);
+
+ /// Report an analysis remark, this will create an InFlightRemark
+ /// that can be used to build the remark using the << operator.
+ InFlightRemark emitOptimizationRemarkAnalysis(Location loc, RemarkOpts opts);
+};
+
+template <typename Fn, typename... Args>
+inline InFlightRemark withEngine(Fn fn, Location loc, Args &&...args) {
+ MLIRContext *ctx = loc->getContext();
+
+ RemarkEngine *enginePtr = ctx->getRemarkEngine();
+
+ if (LLVM_UNLIKELY(enginePtr))
+ return (enginePtr->*fn)(loc, std::forward<Args>(args)...);
+
+ return {};
+}
+
+} // namespace mlir::remark::detail
+
+namespace mlir::remark {
+
+/// Create a Reason with llvm::formatv formatting.
+template <class... Ts>
+inline detail::LazyTextBuild reason(const char *fmt, Ts &&...ts) {
+ return {"Reason", [=] { return llvm::formatv(fmt, ts...).str(); }};
+}
+
+/// Create a Suggestion with llvm::formatv formatting.
+template <class... Ts>
+inline detail::LazyTextBuild suggest(const char *fmt, Ts &&...ts) {
+ return {"Suggestion", [=] { return llvm::formatv(fmt, ts...).str(); }};
+}
+
+/// Create a Remark with llvm::formatv formatting.
+template <class... Ts>
+inline detail::LazyTextBuild add(const char *fmt, Ts &&...ts) {
+ return {"Remark", [=] { return llvm::formatv(fmt, ts...).str(); }};
+}
+
+template <class V>
+inline detail::LazyTextBuild metric(StringRef key, V &&v) {
+ using DV = std::decay_t<V>;
+ return {key, [key, vv = DV(std::forward<V>(v))]() mutable {
+ // Reuse Arg's formatting logic and return just the value string.
+ return detail::Remark::Arg(key, std::move(vv)).val;
+ }};
+}
+//===----------------------------------------------------------------------===//
+// Emitters
+//===----------------------------------------------------------------------===//
+
+/// Report an optimization remark that was passed.
+inline detail::InFlightRemark passed(Location loc, RemarkOpts opts) {
+ return withEngine(&detail::RemarkEngine::emitOptimizationRemark, loc, opts);
+}
+
+/// Report an optimization remark that was missed.
+inline detail::InFlightRemark missed(Location loc, RemarkOpts opts) {
+ return withEngine(&detail::RemarkEngine::emitOptimizationRemarkMiss, loc,
+ opts);
+}
+
+/// Report an optimization remark that failed.
+inline detail::InFlightRemark failed(Location loc, RemarkOpts opts) {
+ return withEngine(&detail::RemarkEngine::emitOptimizationRemarkFailure, loc,
+ opts);
+}
+
+/// Report an optimization analysis remark.
+inline detail::InFlightRemark analysis(Location loc, RemarkOpts opts) {
+ return withEngine(&detail::RemarkEngine::emitOptimizationRemarkAnalysis, loc,
+ opts);
+}
+
+//===----------------------------------------------------------------------===//
+// Setup
+//===----------------------------------------------------------------------===//
+
+/// Setup remarks for the context. This function will enable the remark engine
+/// and set the streamer to be used for optimization remarks. The remark
+/// categories are used to filter the remarks that will be emitted by the remark
+/// engine. If a category is not specified, it will not be emitted. If
+/// `printAsEmitRemarks` is true, the remarks will be printed as
+/// mlir::emitRemarks. 'streamer' must inherit from MLIRRemarkStreamerBase and
+/// will be used to stream the remarks.
+LogicalResult enableOptimizationRemarks(
+ MLIRContext &ctx,
+ std::unique_ptr<remark::detail::MLIRRemarkStreamerBase> streamer,
+ const remark::RemarkCategories &cats, bool printAsEmitRemarks = false);
+
+} // namespace mlir::remark
+
+#endif // MLIR_IR_REMARKS_H
diff --git a/mlir/include/mlir/InitAllTranslations.h b/mlir/include/mlir/InitAllTranslations.h
index 1ab80fb..622024d 100644
--- a/mlir/include/mlir/InitAllTranslations.h
+++ b/mlir/include/mlir/InitAllTranslations.h
@@ -20,6 +20,7 @@ namespace mlir {
void registerFromLLVMIRTranslation();
void registerFromSPIRVTranslation();
+void registerFromWasmTranslation();
void registerToCppTranslation();
void registerToLLVMIRTranslation();
void registerToSPIRVTranslation();
@@ -36,6 +37,7 @@ inline void registerAllTranslations() {
registerFromLLVMIRTranslation();
registerFromSPIRVTranslation();
registerIRDLToCppTranslation();
+ registerFromWasmTranslation();
registerToCppTranslation();
registerToLLVMIRTranslation();
registerToSPIRVTranslation();
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index 067e051..20cc267 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -24,8 +24,7 @@ mlir_tablegen(MemorySlotOpInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(MemorySlotOpInterfaces.cpp.inc -gen-op-interface-defs)
mlir_tablegen(MemorySlotTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(MemorySlotTypeInterfaces.cpp.inc -gen-type-interface-defs)
-add_public_tablegen_target(MLIRMemorySlotInterfacesIncGen)
-add_dependencies(mlir-generic-headers MLIRMemorySlotInterfacesIncGen)
+add_mlir_generic_tablegen_target(MLIRMemorySlotInterfacesIncGen)
set(LLVM_TARGET_DEFINITIONS DataLayoutInterfaces.td)
mlir_tablegen(DataLayoutAttrInterface.h.inc -gen-attr-interface-decls)
@@ -34,8 +33,7 @@ mlir_tablegen(DataLayoutOpInterface.h.inc -gen-op-interface-decls)
mlir_tablegen(DataLayoutOpInterface.cpp.inc -gen-op-interface-defs)
mlir_tablegen(DataLayoutTypeInterface.h.inc -gen-type-interface-decls)
mlir_tablegen(DataLayoutTypeInterface.cpp.inc -gen-type-interface-defs)
-add_public_tablegen_target(MLIRDataLayoutInterfacesIncGen)
-add_dependencies(mlir-generic-headers MLIRDataLayoutInterfacesIncGen)
+add_mlir_generic_tablegen_target(MLIRDataLayoutInterfacesIncGen)
add_mlir_doc(DataLayoutInterfaces
DataLayoutAttrInterface
diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
index aef7ec6..9de20f0c 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
@@ -383,33 +383,71 @@ struct Write : public Effect::Base<Write> {};
// SideEffect Utilities
//===----------------------------------------------------------------------===//
-/// Returns true if `op` has only an effect of type `EffectTy`.
+/// Return "true" if `op` has unknown effects. I.e., the effects of the
+/// operation itself are unknown and the operation does not derive its effects
+/// from its nested operations. (`HasRecursiveMemoryEffects` trait is not
+/// implemented or it is unknown whether it is implemented or not.)
+bool hasUnknownEffects(Operation *op);
+
+/// Returns "true" if `op` has only an effect of type `EffectTy`. Returns
+/// "false" if `op` has unknown effects or other/additional effects. Recursive
+/// effects are not taken into account.
template <typename EffectTy>
bool hasSingleEffect(Operation *op);
-/// Returns true if `op` has only an effect of type `EffectTy` (and of no other
-/// type) on `value`.
+/// Returns "true" if `op` has only an effect of type `EffectTy` on `value`.
+/// Returns "false" if `op` has unknown effects or other/additional effects.
+/// Recursive effects are not taken into account.
template <typename EffectTy>
bool hasSingleEffect(Operation *op, Value value);
-/// Returns true if `op` has only an effect of type `EffectTy` (and of no other
-/// type) on `value` of type `ValueTy`.
+/// Returns "true" if `op` has only an effect of type `EffectTy` on `value` of
+/// type `ValueTy`. Returns "false" if `op` has unknown effects or
+/// other/additional effects. Recursive effects are not taken into account.
template <typename ValueTy, typename EffectTy>
bool hasSingleEffect(Operation *op, ValueTy value);
-/// Returns true if `op` has an effect of type `EffectTy`.
+/// Returns "true" if `op` has an effect of type `EffectTy`. Returns "false" if
+/// `op` has unknown effects. Recursive effects are not taken into account.
template <typename... EffectTys>
bool hasEffect(Operation *op);
-/// Returns true if `op` has an effect of type `EffectTy` on `value`.
+/// Returns "true" if `op` has an effect of type `EffectTy` on `value`. Returns
+/// "false" if `op` has unknown effects. Recursive effects are not taken into
+/// account.
template <typename... EffectTys>
bool hasEffect(Operation *op, Value value);
-/// Returns true if `op` has an effect of type `EffectTy` on `value` of type
-/// `ValueTy`.
+/// Returns "true" if `op` has an effect of type `EffectTy` on `value` of type
+/// `ValueTy`. Returns "false" if `op` has unknown effects. Recursive effects
+/// are not taken into account.
template <typename ValueTy, typename... EffectTys>
bool hasEffect(Operation *op, ValueTy value);
+/// Returns "true" if `op` might have an effect of type `EffectTy`. Returns
+/// "true" if the op has unknown effects. Recursive effects are not taken into
+/// account.
+template <typename... EffectTys>
+bool mightHaveEffect(Operation *op) {
+ return hasUnknownEffects(op) || hasEffect<EffectTys...>(op);
+}
+
+/// Returns "true" if `op` might have an effect of type `EffectTy` on `value`.
+/// Returns "true" if the op has unknown effects. Recursive effects are not
+/// taken into account.
+template <typename... EffectTys>
+bool mightHaveEffect(Operation *op, Value value) {
+ return hasUnknownEffects(op) || hasEffect<EffectTys...>(op, value);
+}
+
+/// Returns "true" if `op` might have an effect of type `EffectTy` on `value`
+/// of type `ValueTy`. Returns "true" if the op has unknown effects. Recursive
+/// effects are not taken into account.
+template <typename ValueTy, typename... EffectTys>
+bool mightHaveEffect(Operation *op, ValueTy value) {
+ return hasUnknownEffects(op) || hasEffect<EffectTys...>(op, value);
+}
+
/// Return true if the given operation is unused, and has no side effects on
/// memory that prevent erasing.
bool isOpTriviallyDead(Operation *op);
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
index d1401c2..ed213bf 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
@@ -26,7 +26,17 @@ def ViewLikeOpInterface : OpInterface<"ViewLikeOpInterface"> {
let methods = [
InterfaceMethod<
"Returns the source buffer from which the view is created.",
- "::mlir::Value", "getViewSource">
+ "::mlir::Value", "getViewSource">,
+ InterfaceMethod<
+ /*desc=*/[{ Returns the buffer which the view created. }],
+ /*retTy=*/"::mlir::Value",
+ /*methodName=*/"getViewDest",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return $_op->getResult(0);
+ }]
+ >
];
}
diff --git a/mlir/include/mlir/Pass/PassOptions.h b/mlir/include/mlir/Pass/PassOptions.h
index e1f16c6..0c71f78 100644
--- a/mlir/include/mlir/Pass/PassOptions.h
+++ b/mlir/include/mlir/Pass/PassOptions.h
@@ -377,7 +377,7 @@ private:
/// ListOption<int> someListFlag{*this, "flag-name", llvm::cl::desc("...")};
/// };
template <typename T>
-class PassPipelineOptions : public detail::PassOptions {
+class PassPipelineOptions : public virtual detail::PassOptions {
public:
/// Factory that parses the provided options and returns a unique_ptr to the
/// struct.
diff --git a/mlir/include/mlir/Query/Matcher/SliceMatchers.h b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
index 7181648..43b6998 100644
--- a/mlir/include/mlir/Query/Matcher/SliceMatchers.h
+++ b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
@@ -36,7 +36,7 @@
/// 9
///
/// Assuming all local orders match the numbering order:
-/// {5, 7, 6, 8, 9}
+/// {1, 5, 6, 7, 8, 9}
namespace mlir::query::matcher {
template <typename Matcher>
diff --git a/mlir/include/mlir/Reducer/CMakeLists.txt b/mlir/include/mlir/Reducer/CMakeLists.txt
index 1d92d07..3d09f87 100644
--- a/mlir/include/mlir/Reducer/CMakeLists.txt
+++ b/mlir/include/mlir/Reducer/CMakeLists.txt
@@ -1,5 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Reducer)
-add_public_tablegen_target(MLIRReducerIncGen)
+add_mlir_generic_tablegen_target(MLIRReducerIncGen)
add_mlir_doc(Passes ReducerPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Remark/RemarkStreamer.h b/mlir/include/mlir/Remark/RemarkStreamer.h
new file mode 100644
index 0000000..8bfd176
--- /dev/null
+++ b/mlir/include/mlir/Remark/RemarkStreamer.h
@@ -0,0 +1,49 @@
+//===- RemarkStreamer.h - MLIR Optimization Remark ---------------*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines LLVMRemarkStreamer plugging class that uses LLVM's
+// streamer.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Remarks.h"
+
+#include "llvm/Remarks/RemarkStreamer.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+namespace mlir::remark::detail {
+
+/// Concrete streamer that writes LLVM optimization remarks to a file
+/// (YAML or Bitstream). Lives outside core.
+class LLVMRemarkStreamer final : public MLIRRemarkStreamerBase {
+public:
+ static FailureOr<std::unique_ptr<MLIRRemarkStreamerBase>>
+ createToFile(llvm::StringRef path, llvm::remarks::Format fmt);
+
+ void streamOptimizationRemark(const Remark &remark) override;
+ void finalize() override {}
+ ~LLVMRemarkStreamer() override;
+
+private:
+ LLVMRemarkStreamer() = default;
+
+ std::unique_ptr<class llvm::remarks::RemarkStreamer> remarkStreamer;
+ std::unique_ptr<class llvm::ToolOutputFile> file;
+};
+} // namespace mlir::remark::detail
+
+namespace mlir::remark {
+/// Enable optimization remarks to a file with the given path and format.
+/// The remark categories are used to filter the remarks that are emitted.
+/// If the printAsEmitRemarks flag is set, remarks will also be printed using
+/// mlir::emitRemarks.
+LogicalResult enableOptimizationRemarksWithLLVMStreamer(
+ MLIRContext &ctx, StringRef filePath, llvm::remarks::Format fmt,
+ const RemarkCategories &cat, bool printAsEmitRemarks = false);
+
+} // namespace mlir::remark
diff --git a/mlir/include/mlir/Target/CMakeLists.txt b/mlir/include/mlir/Target/CMakeLists.txt
new file mode 100644
index 0000000..39d31dc
--- /dev/null
+++ b/mlir/include/mlir/Target/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(LLVMIR)
diff --git a/mlir/include/mlir/Target/LLVM/XeVM/Target.h b/mlir/include/mlir/Target/LLVM/XeVM/Target.h
new file mode 100644
index 0000000..6aab15c
--- /dev/null
+++ b/mlir/include/mlir/Target/LLVM/XeVM/Target.h
@@ -0,0 +1,30 @@
+//===-- Target.h - MLIR XeVM target registration ----------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This provides registration calls for attaching the XeVM target interface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVM_XEVM_TARGET_H
+#define MLIR_TARGET_LLVM_XEVM_TARGET_H
+
+namespace mlir {
+class DialectRegistry;
+class MLIRContext;
+namespace xevm {
+/// Registers the `TargetAttrInterface` for the `#xevm.target` attribute in
+/// the given registry.
+void registerXeVMTargetInterfaceExternalModels(mlir::DialectRegistry &registry);
+
+/// Registers the `TargetAttrInterface` for the `#xevm.target` attribute in
+/// the registry associated with the given context.
+void registerXeVMTargetInterfaceExternalModels(mlir::MLIRContext &context);
+} // namespace xevm
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVM_XEVM_TARGET_H
diff --git a/mlir/include/mlir/Target/LLVM/XeVM/Utils.h b/mlir/include/mlir/Target/LLVM/XeVM/Utils.h
new file mode 100644
index 0000000..5d523f1
--- /dev/null
+++ b/mlir/include/mlir/Target/LLVM/XeVM/Utils.h
@@ -0,0 +1,63 @@
+//===-- Utils.h - MLIR XeVM target utils ------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This files declares XeVM target related utility classes and functions.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVM_XEVM_UTILS_H
+#define MLIR_TARGET_LLVM_XEVM_UTILS_H
+
+#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/Target/LLVM/ModuleToObject.h"
+
+namespace mlir {
+namespace xevm {
+
+/// Base class for all XeVM serializations from GPU modules into binary strings.
+/// By default this class serializes into LLVM bitcode.
+class SerializeGPUModuleBase : public LLVM::ModuleToObject {
+public:
+ SerializeGPUModuleBase(Operation &module, XeVMTargetAttr target,
+ const gpu::TargetOptions &targetOptions = {});
+
+ /// Returns the target attribute.
+ XeVMTargetAttr getTarget() const;
+
+ /// Loads the bitcode files in `librariesToLink`.
+ std::optional<SmallVector<std::unique_ptr<llvm::Module>>>
+ loadBitcodeFiles(llvm::Module &module) override;
+
+ /// Returns the gpu module being serialized.
+ gpu::GPUModuleOp getGPUModuleOp();
+
+ /// Compiles to native code using `ocloc`.
+ std::optional<SmallVector<char, 0>> compileToBinary(const std::string &asmStr,
+ StringRef inputFormat);
+
+protected:
+ /// XeVM Target attribute.
+ XeVMTargetAttr xeTarget;
+ /// List of LLVM bitcode to link into after translation to LLVM IR.
+ /// The attributes can be StringAttr pointing to a file path, or
+ /// a Resource blob pointing to the LLVM bitcode in-memory.
+ SmallVector<Attribute> librariesToLink;
+
+ /// Returns the path to the tool used for serialization.
+ std::optional<std::string> findTool(StringRef tool);
+
+ /// GPU compilation target options.
+ gpu::TargetOptions targetOptions;
+};
+} // namespace xevm
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVM_XEVM_UTILS_H
diff --git a/mlir/include/mlir/Target/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Target/LLVMIR/CMakeLists.txt
new file mode 100644
index 0000000..e31af32
--- /dev/null
+++ b/mlir/include/mlir/Target/LLVMIR/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Target/LLVMIR/DataLayoutImporter.h b/mlir/include/mlir/Target/LLVMIR/DataLayoutImporter.h
index 88ceaf1..4d432df 100644
--- a/mlir/lib/Target/LLVMIR/DataLayoutImporter.h
+++ b/mlir/include/mlir/Target/LLVMIR/DataLayoutImporter.h
@@ -11,8 +11,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_LIB_TARGET_LLVMIR_DATALAYOUTIMPORTER_H_
-#define MLIR_LIB_TARGET_LLVMIR_DATALAYOUTIMPORTER_H_
+#ifndef MLIR_TARGET_LLVMIR_DATALAYOUTIMPORTER_H
+#define MLIR_TARGET_LLVMIR_DATALAYOUTIMPORTER_H
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -38,23 +38,31 @@ namespace detail {
/// null if the bit width is not supported.
FloatType getFloatType(MLIRContext *context, unsigned width);
-/// Helper class that translates an LLVM data layout to an MLIR data layout
-/// specification. Only integer, float, pointer, alloca memory space, stack
-/// alignment, and endianness entries are translated. The class also returns all
-/// entries from the default data layout specification found in the language
-/// reference (https://llvm.org/docs/LangRef.html#data-layout) if they are not
-/// overwritten by the provided data layout.
+/// Helper class that translates an LLVM data layout string to an MLIR data
+/// layout specification. Only integer, float, pointer, alloca memory space,
+/// stack alignment, and endianness entries are translated. The class also
+/// returns all entries from the default data layout specification found in the
+/// language reference (https://llvm.org/docs/LangRef.html#data-layout) if they
+/// are not overwritten by the provided data layout.
class DataLayoutImporter {
public:
- DataLayoutImporter(MLIRContext *context,
- const llvm::DataLayout &llvmDataLayout)
- : context(context) {
- translateDataLayout(llvmDataLayout);
+ DataLayoutImporter(MLIRContext *context, StringRef dataLayoutStr)
+ : dataLayoutStr(dataLayoutStr), context(context) {
+ // Translate the `dataLayoutStr`. First, append the default data layout
+ // string specified in the language reference
+ // (https://llvm.org/docs/LangRef.html#data-layout) to the supplied string.
+ // The translation then parses the string and ignores the default value if a
+ // specific kind occurs in both strings. Additionally, the following default
+ // values exist:
+ // - non-default address space pointer specifications default to the default
+ // address space pointer specification
+ // - the alloca address space defaults to the default address space.
+ dataLayoutSpec = dataLayoutSpecFromDataLayoutStr();
}
/// Returns the MLIR data layout specification translated from the LLVM
/// data layout.
- DataLayoutSpecInterface getDataLayout() const { return dataLayout; }
+ DataLayoutSpecInterface getDataLayoutSpec() const { return dataLayoutSpec; }
/// Returns the last data layout token that has been processed before
/// the data layout translation failed.
@@ -65,8 +73,9 @@ public:
ArrayRef<StringRef> getUnhandledTokens() const { return unhandledTokens; }
private:
- /// Translates the LLVM `dataLayout` to an MLIR data layout specification.
- void translateDataLayout(const llvm::DataLayout &llvmDataLayout);
+ /// Translate the LLVM data layout string to an MLIR data layout
+ /// specification.
+ DataLayoutSpecInterface dataLayoutSpecFromDataLayoutStr();
/// Tries to parse the letter only prefix that identifies the specification
/// and removes the consumed characters from the beginning of the string.
@@ -116,17 +125,18 @@ private:
/// Adds legal int widths entry if there is none yet.
LogicalResult tryToEmplaceLegalIntWidthsEntry(StringRef token);
- std::string layoutStr = {};
+ std::string dataLayoutStr = {};
+ DataLayoutSpecInterface dataLayoutSpec;
+
StringRef lastToken = {};
SmallVector<StringRef> unhandledTokens;
llvm::MapVector<StringAttr, DataLayoutEntryInterface> keyEntries;
llvm::MapVector<TypeAttr, DataLayoutEntryInterface> typeEntries;
MLIRContext *context;
- DataLayoutSpecInterface dataLayout;
};
} // namespace detail
} // namespace LLVM
} // namespace mlir
-#endif // MLIR_LIB_TARGET_LLVMIR_DATALAYOUTIMPORTER_H_
+#endif // MLIR_TARGET_LLVMIR_DATALAYOUTIMPORTER_H
diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
index e4670cb..05b66ac 100644
--- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
+++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
@@ -25,6 +25,7 @@
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/SPIRV/SPIRVToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/VCIX/VCIXToLLVMIRTranslation.h"
@@ -45,6 +46,7 @@ static inline void registerAllToLLVMIRTranslations(DialectRegistry &registry) {
registerNVVMDialectTranslation(registry);
registerOpenACCDialectTranslation(registry);
registerOpenMPDialectTranslation(registry);
+ registerPtrDialectTranslation(registry);
registerROCDLDialectTranslation(registry);
registerSPIRVDialectTranslation(registry);
registerVCIXDialectTranslation(registry);
diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.h
new file mode 100644
index 0000000..5c81762
--- /dev/null
+++ b/mlir/include/mlir/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.h
@@ -0,0 +1,31 @@
+//===- PtrToLLVMIRTranslation.h - `ptr` to LLVM IR --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This provides registration calls for `ptr` dialect to LLVM IR translation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_DIALECT_PTR_PTRTOLLVMIRTRANSLATION_H
+#define MLIR_TARGET_LLVMIR_DIALECT_PTR_PTRTOLLVMIRTRANSLATION_H
+
+namespace mlir {
+
+class DialectRegistry;
+class MLIRContext;
+
+/// Register the `ptr` dialect and the translation from it to the LLVM IR in the
+/// given registry;
+void registerPtrDialectTranslation(DialectRegistry &registry);
+
+/// Register the `ptr` dialect and the translation from it in the registry
+/// associated with the given context.
+void registerPtrDialectTranslation(MLIRContext &context);
+
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVMIR_DIALECT_PTR_PTRTOLLVMIRTRANSLATION_H
diff --git a/mlir/include/mlir/Target/LLVMIR/Transforms/CMakeLists.txt b/mlir/include/mlir/Target/LLVMIR/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..662763b
--- /dev/null
+++ b/mlir/include/mlir/Target/LLVMIR/Transforms/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name TargetLLVMIRTransforms)
+add_mlir_dialect_tablegen_target(MLIRTargetLLVMIRTransformsIncGen)
+
+add_mlir_doc(Passes TargetLLVMIRTransforms ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Target/LLVMIR/Transforms/Passes.h b/mlir/include/mlir/Target/LLVMIR/Transforms/Passes.h
new file mode 100644
index 0000000..22e6fad
--- /dev/null
+++ b/mlir/include/mlir/Target/LLVMIR/Transforms/Passes.h
@@ -0,0 +1,26 @@
+//===- Passes.h - LLVM Target Pass Construction and Registration ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_TRANSFORMS_PASSES_H
+#define MLIR_TARGET_LLVMIR_TRANSFORMS_PASSES_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace LLVM {
+
+#define GEN_PASS_DECL
+#define GEN_PASS_REGISTRATION
+#include "mlir/Target/LLVMIR/Transforms/Passes.h.inc"
+
+void registerTargetLLVMPasses();
+
+} // namespace LLVM
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVMIR_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Target/LLVMIR/Transforms/Passes.td b/mlir/include/mlir/Target/LLVMIR/Transforms/Passes.td
new file mode 100644
index 0000000..c858e9f
--- /dev/null
+++ b/mlir/include/mlir/Target/LLVMIR/Transforms/Passes.td
@@ -0,0 +1,46 @@
+//===-- Passes.td - LLVM Target pass definition file -------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_TRANSFORMS_PASSES
+#define MLIR_TARGET_LLVMIR_TRANSFORMS_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def LLVMTargetToDataLayout : Pass<"llvm-target-to-data-layout"> {
+ let summary = "Derive data layout attributes from LLVM target attributes";
+ let dependentDialects = ["mlir::DLTIDialect"];
+ let description = [{
+ Derive a `DataLayoutSpecInterface`-implementing data layout attribute from
+ the LLVM-backend target specified by the `TargetAttrInterface`-implementing
+ attribute attached to the target op at the name `llvm.target`.
+ }];
+ let options = [
+ Option<"initializeLLVMTargets", "initialize-llvm-targets", "bool",
+ /*default=*/"true",
+ "Whether to pre-load all available target machines, that LLVM is "
+ "configured to support, into the TargetRegistry.">
+ ];
+}
+
+def LLVMTargetToTargetFeatures : Pass<"llvm-target-to-target-features"> {
+ let summary = "Update attached #llvm.target's features per the described target";
+ let description = [{
+ Obtain the TargetMachine specified by the attached #llvm.target's attributes
+ and obtain from it the full list of features of the selected target. Updates
+ the attached #llvm.target so that its features reflect the full list of
+ features.
+ }];
+ let options = [
+ Option<"initializeLLVMTargets", "initialize-llvm-targets", "bool",
+ /*default=*/"true",
+ "Whether to pre-load all available target machines, that LLVM is "
+ "configured to support, into the TargetRegistry.">
+ ];
+}
+
+#endif // MLIR_TARGET_LLVMIR_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Target/LLVMIR/Transforms/TargetUtils.h b/mlir/include/mlir/Target/LLVMIR/Transforms/TargetUtils.h
new file mode 100644
index 0000000..2930733
--- /dev/null
+++ b/mlir/include/mlir/Target/LLVMIR/Transforms/TargetUtils.h
@@ -0,0 +1,35 @@
+//===- TargetUtils.h - Utils to obtain LLVM's TargetMachine and DataLayout ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_TRANSFORMS_TARGETUTILS_H
+#define MLIR_TARGET_LLVMIR_TRANSFORMS_TARGETUTILS_H
+
+#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
+#include "llvm/Support/Threading.h"
+#include "llvm/Target/TargetMachine.h"
+
+namespace mlir {
+namespace LLVM {
+namespace detail {
+/// Idempotent helper to register/initialize all backends that LLVM has been
+/// configured to support. Only runs the first time it is called.
+void initializeBackendsOnce();
+
+/// Helper to obtain the TargetMachine specified by the properties of the
+/// TargetAttrInterface-implementing attribute.
+FailureOr<std::unique_ptr<llvm::TargetMachine>>
+getTargetMachine(mlir::LLVM::TargetAttrInterface attr);
+
+/// Helper to obtain the DataLayout of the target specified by the properties of
+/// the TargetAttrInterface-implementing attribute.
+FailureOr<llvm::DataLayout> getDataLayout(mlir::LLVM::TargetAttrInterface attr);
+} // namespace detail
+} // namespace LLVM
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVMIR_TRANSFORMS_TARGETUTILS_H
diff --git a/mlir/include/mlir/Target/SPIRV/Serialization.h b/mlir/include/mlir/Target/SPIRV/Serialization.h
index 225777e..e474101 100644
--- a/mlir/include/mlir/Target/SPIRV/Serialization.h
+++ b/mlir/include/mlir/Target/SPIRV/Serialization.h
@@ -15,6 +15,7 @@
#include "mlir/Support/LLVM.h"
#include <cstdint>
+#include <string>
namespace mlir {
class MLIRContext;
@@ -27,12 +28,33 @@ struct SerializationOptions {
bool emitSymbolName = true;
/// Whether to emit `OpLine` location information for SPIR-V ops.
bool emitDebugInfo = false;
+ /// Whether to store a module to an additional file during
+ /// serialization. This is used to store the SPIR-V module to the
+ /// file in addition to writing it to `os` passed from the calling
+ /// tool. This saved file is later used for validation.
+ bool saveModuleForValidation = false;
+ /// A prefix prepended to the file used when `saveModuleForValidation`
+ /// is set to `true`. This can either be a file prefix, or a relative or
+ /// or an absolute path followed by the prefix. For example:
+ ///
+ /// * "foo" - Create files with a `foo` prefix in the current working
+ /// directory. For example: `fooXYZ123.spv`, `fooABC456.spv` ...
+ /// `fooXXXXXX.spv`. The last 6 characters will be a unique combination
+ /// as generated by `llvm::sys::fs::createUniqueFile`.
+ ///
+ /// * "my/dir/foo" - Create files in `my/dir` with a `foo` prefix. The
+ /// `my/dir` need to exists. For example: `fooXYZ123.spv`,
+ /// `fooABC456.spv` ... `fooXXXXXX.spv` will be created and stored in
+ /// `/my/dir`. Filenames follow the same pattern as above.
+ ///
+ /// * "/home/user/my/dir" - Same as above but using an absolute path.
+ std::string validationFilePrefix = "";
};
-/// Serializes the given SPIR-V `module` and writes to `binary`. On failure,
+/// Serializes the given SPIR-V `moduleOp` and writes to `binary`. On failure,
/// reports errors to the error handler registered with the MLIR context for
-/// `module`.
-LogicalResult serialize(ModuleOp module, SmallVectorImpl<uint32_t> &binary,
+/// `moduleOp`.
+LogicalResult serialize(ModuleOp moduleOp, SmallVectorImpl<uint32_t> &binary,
const SerializationOptions &options = {});
} // namespace spirv
diff --git a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
new file mode 100644
index 0000000..21adde8
--- /dev/null
+++ b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
@@ -0,0 +1,143 @@
+//===- WasmBinaryEncoding.h - Byte encodings for Wasm binary format ===----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// Define various flags used to encode instructions, types, etc. in
+// WebAssembly binary format.
+//
+// These encodings are defined in the WebAssembly binary format specification.
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_TARGET_WASMBINARYENCODING
+#define MLIR_TARGET_WASMBINARYENCODING
+
+#include <cstddef>
+
+namespace mlir {
+struct WasmBinaryEncoding {
+ /// Byte encodings for Wasm instructions.
+ struct OpCode {
+ // Locals, globals, constants.
+ static constexpr std::byte localGet{0x20};
+ static constexpr std::byte localSet{0x21};
+ static constexpr std::byte localTee{0x22};
+ static constexpr std::byte globalGet{0x23};
+ static constexpr std::byte constI32{0x41};
+ static constexpr std::byte constI64{0x42};
+ static constexpr std::byte constFP32{0x43};
+ static constexpr std::byte constFP64{0x44};
+
+ // Numeric operations.
+ static constexpr std::byte clzI32{0x67};
+ static constexpr std::byte ctzI32{0x68};
+ static constexpr std::byte popcntI32{0x69};
+ static constexpr std::byte addI32{0x6A};
+ static constexpr std::byte subI32{0x6B};
+ static constexpr std::byte mulI32{0x6C};
+ static constexpr std::byte divSI32{0x6d};
+ static constexpr std::byte divUI32{0x6e};
+ static constexpr std::byte remSI32{0x6f};
+ static constexpr std::byte remUI32{0x70};
+ static constexpr std::byte andI32{0x71};
+ static constexpr std::byte orI32{0x72};
+ static constexpr std::byte xorI32{0x73};
+ static constexpr std::byte shlI32{0x74};
+ static constexpr std::byte shrSI32{0x75};
+ static constexpr std::byte shrUI32{0x76};
+ static constexpr std::byte rotlI32{0x77};
+ static constexpr std::byte rotrI32{0x78};
+ static constexpr std::byte clzI64{0x79};
+ static constexpr std::byte ctzI64{0x7A};
+ static constexpr std::byte popcntI64{0x7B};
+ static constexpr std::byte addI64{0x7C};
+ static constexpr std::byte subI64{0x7D};
+ static constexpr std::byte mulI64{0x7E};
+ static constexpr std::byte divSI64{0x7F};
+ static constexpr std::byte divUI64{0x80};
+ static constexpr std::byte remSI64{0x81};
+ static constexpr std::byte remUI64{0x82};
+ static constexpr std::byte andI64{0x83};
+ static constexpr std::byte orI64{0x84};
+ static constexpr std::byte xorI64{0x85};
+ static constexpr std::byte shlI64{0x86};
+ static constexpr std::byte shrSI64{0x87};
+ static constexpr std::byte shrUI64{0x88};
+ static constexpr std::byte rotlI64{0x89};
+ static constexpr std::byte rotrI64{0x8A};
+ static constexpr std::byte absF32{0x8B};
+ static constexpr std::byte negF32{0x8C};
+ static constexpr std::byte ceilF32{0x8D};
+ static constexpr std::byte floorF32{0x8E};
+ static constexpr std::byte truncF32{0x8F};
+ static constexpr std::byte sqrtF32{0x91};
+ static constexpr std::byte addF32{0x92};
+ static constexpr std::byte subF32{0x93};
+ static constexpr std::byte mulF32{0x94};
+ static constexpr std::byte divF32{0x95};
+ static constexpr std::byte minF32{0x96};
+ static constexpr std::byte maxF32{0x97};
+ static constexpr std::byte copysignF32{0x98};
+ static constexpr std::byte absF64{0x99};
+ static constexpr std::byte negF64{0x9A};
+ static constexpr std::byte ceilF64{0x9B};
+ static constexpr std::byte floorF64{0x9C};
+ static constexpr std::byte truncF64{0x9D};
+ static constexpr std::byte sqrtF64{0x9F};
+ static constexpr std::byte addF64{0xA0};
+ static constexpr std::byte subF64{0xA1};
+ static constexpr std::byte mulF64{0xA2};
+ static constexpr std::byte divF64{0xA3};
+ static constexpr std::byte minF64{0xA4};
+ static constexpr std::byte maxF64{0xA5};
+ static constexpr std::byte copysignF64{0xA6};
+ static constexpr std::byte wrap{0xA7};
+ };
+
+ /// Byte encodings of types in Wasm binaries
+ struct Type {
+ static constexpr std::byte emptyBlockType{0x40};
+ static constexpr std::byte funcType{0x60};
+ static constexpr std::byte externRef{0x6F};
+ static constexpr std::byte funcRef{0x70};
+ static constexpr std::byte v128{0x7B};
+ static constexpr std::byte f64{0x7C};
+ static constexpr std::byte f32{0x7D};
+ static constexpr std::byte i64{0x7E};
+ static constexpr std::byte i32{0x7F};
+ };
+
+ /// Byte encodings of Wasm imports.
+ struct Import {
+ static constexpr std::byte typeID{0x00};
+ static constexpr std::byte tableType{0x01};
+ static constexpr std::byte memType{0x02};
+ static constexpr std::byte globalType{0x03};
+ };
+
+ /// Byte encodings for Wasm limits.
+ struct LimitHeader {
+ static constexpr std::byte lowLimitOnly{0x00};
+ static constexpr std::byte bothLimits{0x01};
+ };
+
+ /// Byte encodings describing the mutability of globals.
+ struct GlobalMutability {
+ static constexpr std::byte isConst{0x00};
+ static constexpr std::byte isMutable{0x01};
+ };
+
+ /// Byte encodings describing Wasm exports.
+ struct Export {
+ static constexpr std::byte function{0x00};
+ static constexpr std::byte table{0x01};
+ static constexpr std::byte memory{0x02};
+ static constexpr std::byte global{0x03};
+ };
+
+ static constexpr std::byte endByte{0x0B};
+};
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Target/Wasm/WasmImporter.h b/mlir/include/mlir/Target/Wasm/WasmImporter.h
new file mode 100644
index 0000000..3f05bbe
--- /dev/null
+++ b/mlir/include/mlir/Target/Wasm/WasmImporter.h
@@ -0,0 +1,31 @@
+//===- WasmImporter.h - Helpers to create WebAssembly emitter ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines helpers to import WebAssembly code using the WebAssembly
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_WASM_WASMIMPORTER_H
+#define MLIR_TARGET_WASM_WASMIMPORTER_H
+
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OwningOpRef.h"
+#include "llvm/Support/SourceMgr.h"
+
+namespace mlir::wasm {
+
+/// If `source` contains a valid Wasm binary file, this function returns a
+/// a ModuleOp containing the representation of the Wasm module encoded in
+/// the source file in the `wasmssa` dialect.
+OwningOpRef<ModuleOp> importWebAssemblyToModule(llvm::SourceMgr &source,
+ MLIRContext *context);
+} // namespace mlir::wasm
+
+#endif // MLIR_TARGET_WASM_WASMIMPORTER_H
diff --git a/mlir/include/mlir/Transforms/CMakeLists.txt b/mlir/include/mlir/Transforms/CMakeLists.txt
index cf01899..5fa52b2 100644
--- a/mlir/include/mlir/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Transforms/CMakeLists.txt
@@ -3,6 +3,6 @@ set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Transforms)
mlir_tablegen(Transforms.capi.h.inc -gen-pass-capi-header --prefix Transforms)
mlir_tablegen(Transforms.capi.cpp.inc -gen-pass-capi-impl --prefix Transforms)
-add_public_tablegen_target(MLIRTransformsPassIncGen)
+add_mlir_dialect_tablegen_target(MLIRTransformsPassIncGen)
add_mlir_doc(Passes GeneralPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 4e651a0..14dfbf1 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -139,7 +139,8 @@ public:
};
/// Register a conversion function. A conversion function must be convertible
- /// to any of the following forms (where `T` is a class derived from `Type`):
+ /// to any of the following forms (where `T` is `Value` or a class derived
+ /// from `Type`, including `Type` itself):
///
/// * std::optional<Type>(T)
/// - This form represents a 1-1 type conversion. It should return nullptr
@@ -154,6 +155,14 @@ public:
/// `std::nullopt` is returned, the converter is allowed to try another
/// conversion function to perform the conversion.
///
+ /// Conversion functions that accept `Value` as the first argument are
+ /// context-aware. I.e., they can take into account IR when converting the
+ /// type of the given value. Context-unaware conversion functions accept
+ /// `Type` or a derived class as the first argument.
+ ///
+ /// Note: Context-unaware conversions are cached, but context-aware
+ /// conversions are not.
+ ///
/// Note: When attempting to convert a type, e.g. via 'convertType', the
/// mostly recently added conversions will be invoked first.
template <typename FnT, typename T = typename llvm::function_traits<
@@ -242,15 +251,28 @@ public:
wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
}
- /// Convert the given type. This function should return failure if no valid
+ /// Convert the given type. This function returns failure if no valid
/// conversion exists, success otherwise. If the new set of types is empty,
/// the type is removed and any usages of the existing value are expected to
/// be removed during conversion.
+ ///
+ /// Note: This overload invokes only context-unaware type conversion
+ /// functions. Users should call the other overload if possible.
LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;
+ /// Convert the type of the given value. This function returns failure if no
+ /// valid conversion exists, success otherwise. If the new set of types is
+ /// empty, the type is removed and any usages of the existing value are
+ /// expected to be removed during conversion.
+ ///
+ /// Note: This overload invokes both context-aware and context-unaware type
+ /// conversion functions.
+ LogicalResult convertType(Value v, SmallVectorImpl<Type> &results) const;
+
/// This hook simplifies defining 1-1 type conversions. This function returns
/// the type to convert to on success, and a null type on failure.
Type convertType(Type t) const;
+ Type convertType(Value v) const;
/// Attempts a 1-1 type conversion, expecting the result type to be
/// `TargetType`. Returns the converted type cast to `TargetType` on success,
@@ -259,25 +281,36 @@ public:
TargetType convertType(Type t) const {
return dyn_cast_or_null<TargetType>(convertType(t));
}
+ template <typename TargetType>
+ TargetType convertType(Value v) const {
+ return dyn_cast_or_null<TargetType>(convertType(v));
+ }
- /// Convert the given set of types, filling 'results' as necessary. This
- /// returns failure if the conversion of any of the types fails, success
+ /// Convert the given types, filling 'results' as necessary. This returns
+ /// "failure" if the conversion of any of the types fails, "success"
/// otherwise.
LogicalResult convertTypes(TypeRange types,
SmallVectorImpl<Type> &results) const;
+ /// Convert the types of the given values, filling 'results' as necessary.
+ /// This returns "failure" if the conversion of any of the types fails,
+ /// "success" otherwise.
+ LogicalResult convertTypes(ValueRange values,
+ SmallVectorImpl<Type> &results) const;
+
/// Return true if the given type is legal for this type converter, i.e. the
/// type converts to itself.
bool isLegal(Type type) const;
+ bool isLegal(Value value) const;
/// Return true if all of the given types are legal for this type converter.
- template <typename RangeT>
- std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
- !std::is_convertible<RangeT, Operation *>::value,
- bool>
- isLegal(RangeT &&range) const {
+ bool isLegal(TypeRange range) const {
return llvm::all_of(range, [this](Type type) { return isLegal(type); });
}
+ bool isLegal(ValueRange range) const {
+ return llvm::all_of(range, [this](Value value) { return isLegal(value); });
+ }
+
/// Return true if the given operation has legal operand and result types.
bool isLegal(Operation *op) const;
@@ -296,6 +329,11 @@ public:
LogicalResult convertSignatureArgs(TypeRange types,
SignatureConversion &result,
unsigned origInputOffset = 0) const;
+ LogicalResult convertSignatureArg(unsigned inputNo, Value value,
+ SignatureConversion &result) const;
+ LogicalResult convertSignatureArgs(ValueRange values,
+ SignatureConversion &result,
+ unsigned origInputOffset = 0) const;
/// This function converts the type signature of the given block, by invoking
/// 'convertSignatureArg' for each argument. This function should return a
@@ -329,7 +367,7 @@ private:
/// types is empty, the type is removed and any usages of the existing value
/// are expected to be removed during conversion.
using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
- Type, SmallVectorImpl<Type> &)>;
+ PointerUnion<Type, Value>, SmallVectorImpl<Type> &)>;
/// The signature of the callback used to materialize a source conversion.
///
@@ -349,13 +387,14 @@ private:
/// Generate a wrapper for the given callback. This allows for accepting
/// different callback forms, that all compose into a single version.
- /// With callback of form: `std::optional<Type>(T)`
+ /// With callback of form: `std::optional<Type>(T)`, where `T` can be a
+ /// `Value` or a `Type` (or a class derived from `Type`).
template <typename T, typename FnT>
std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
- wrapCallback(FnT &&callback) const {
+ wrapCallback(FnT &&callback) {
return wrapCallback<T>([callback = std::forward<FnT>(callback)](
- T type, SmallVectorImpl<Type> &results) {
- if (std::optional<Type> resultOpt = callback(type)) {
+ T typeOrValue, SmallVectorImpl<Type> &results) {
+ if (std::optional<Type> resultOpt = callback(typeOrValue)) {
bool wasSuccess = static_cast<bool>(*resultOpt);
if (wasSuccess)
results.push_back(*resultOpt);
@@ -365,20 +404,49 @@ private:
});
}
/// With callback of form: `std::optional<LogicalResult>(
- /// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
+ /// T, SmallVectorImpl<Type> &)`, where `T` is a type.
template <typename T, typename FnT>
- std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
+ std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
+ std::is_base_of_v<Type, T>,
ConversionCallbackFn>
wrapCallback(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
- Type type,
+ PointerUnion<Type, Value> typeOrValue,
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
- T derivedType = dyn_cast<T>(type);
+ T derivedType;
+ if (Type t = dyn_cast<Type>(typeOrValue)) {
+ derivedType = dyn_cast<T>(t);
+ } else if (Value v = dyn_cast<Value>(typeOrValue)) {
+ derivedType = dyn_cast<T>(v.getType());
+ } else {
+ llvm_unreachable("unexpected variant");
+ }
if (!derivedType)
return std::nullopt;
return callback(derivedType, results);
};
}
+ /// With callback of form: `std::optional<LogicalResult>(
+ /// T, SmallVectorImpl<Type>)`, where `T` is a `Value`.
+ template <typename T, typename FnT>
+ std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
+ std::is_same_v<T, Value>,
+ ConversionCallbackFn>
+ wrapCallback(FnT &&callback) {
+ hasContextAwareTypeConversions = true;
+ return [callback = std::forward<FnT>(callback)](
+ PointerUnion<Type, Value> typeOrValue,
+ SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
+ if (Type t = dyn_cast<Type>(typeOrValue)) {
+ // Context-aware type conversion was called with a type.
+ return std::nullopt;
+ } else if (Value v = dyn_cast<Value>(typeOrValue)) {
+ return callback(v, results);
+ }
+ llvm_unreachable("unexpected variant");
+ return std::nullopt;
+ };
+ }
/// Register a type conversion.
void registerConversion(ConversionCallbackFn callback) {
@@ -505,6 +573,12 @@ private:
mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
/// A mutex used for cache access
mutable llvm::sys::SmartRWMutex<true> cacheMutex;
+ /// Whether the type converter has context-aware type conversions. I.e.,
+ /// conversion rules that depend on the SSA value instead of just the type.
+ /// Type conversion caching is deactivated when there are context-aware
+ /// conversions because the type converter may return different results for
+ /// the same input type.
+ bool hasContextAwareTypeConversions = false;
};
//===----------------------------------------------------------------------===//
@@ -521,8 +595,8 @@ public:
/// Hook for derived classes to implement combined matching and rewriting.
/// This overload supports only 1:1 replacements. The 1:N overload is called
- /// by the driver. By default, it calls this 1:1 overload or reports a fatal
- /// error if 1:N replacements were found.
+ /// by the driver. By default, it calls this 1:1 overload or fails to match
+ /// if 1:N replacements were found.
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
@@ -534,7 +608,7 @@ public:
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const {
- return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+ return dispatchTo1To1(*this, op, operands, rewriter);
}
/// Attempt to match and rewrite the IR root at the specified operation.
@@ -567,11 +641,26 @@ protected:
/// try to extract the single value of each range to construct a the inputs
/// for a 1:1 adaptor.
///
- /// This function produces a fatal error if at least one range has 0 or
- /// more than 1 value: "pattern 'name' does not support 1:N conversion"
- SmallVector<Value>
+ /// Returns failure if at least one range has 0 or more than 1 value.
+ FailureOr<SmallVector<Value>>
getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands) const;
+ /// Overloaded method used to dispatch to the 1:1 'matchAndRewrite' method
+ /// if possible and emit diagnostic with a failure return value otherwise.
+ /// 'self' should be '*this' of the derived-pattern and is used to dispatch
+ /// to the correct 'matchAndRewrite' method in the derived pattern.
+ template <typename SelfPattern, typename SourceOp>
+ static LogicalResult dispatchTo1To1(const SelfPattern &self, SourceOp op,
+ ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter);
+
+ /// Same as above, but accepts an adaptor as operand.
+ template <typename SelfPattern, typename SourceOp>
+ static LogicalResult dispatchTo1To1(
+ const SelfPattern &self, SourceOp op,
+ typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> adaptor,
+ ConversionPatternRewriter &rewriter);
+
protected:
/// An optional type converter for use by this pattern.
const TypeConverter *typeConverter = nullptr;
@@ -620,9 +709,7 @@ public:
virtual LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- SmallVector<Value> oneToOneOperands =
- getOneToOneAdaptorOperands(adaptor.getOperands());
- return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+ return dispatchTo1To1(*this, op, adaptor, rewriter);
}
private:
@@ -666,7 +753,7 @@ public:
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const {
- return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+ return dispatchTo1To1(*this, op, operands, rewriter);
}
private:
@@ -769,6 +856,12 @@ public:
/// Replace all the uses of the block argument `from` with `to`. This
/// function supports both 1:1 and 1:N replacements.
+ ///
+ /// Note: If `allowPatternRollback` is set to "true", this function replaces
+ /// all current and future uses of the block argument. This same block
+ /// block argument must not be replaced multiple times. Uses are not replaced
+ /// immediately but in a delayed fashion. Patterns may still see the original
+ /// uses when inspecting IR.
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
/// Return the converted value of 'key' with a type defined by the type
@@ -865,6 +958,35 @@ private:
std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
};
+template <typename SelfPattern, typename SourceOp>
+LogicalResult
+ConversionPattern::dispatchTo1To1(const SelfPattern &self, SourceOp op,
+ ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) {
+ FailureOr<SmallVector<Value>> oneToOneOperands =
+ self.getOneToOneAdaptorOperands(operands);
+ if (failed(oneToOneOperands))
+ return rewriter.notifyMatchFailure(op,
+ "pattern '" + self.getDebugName() +
+ "' does not support 1:N conversion");
+ return self.matchAndRewrite(op, *oneToOneOperands, rewriter);
+}
+
+template <typename SelfPattern, typename SourceOp>
+LogicalResult ConversionPattern::dispatchTo1To1(
+ const SelfPattern &self, SourceOp op,
+ typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> adaptor,
+ ConversionPatternRewriter &rewriter) {
+ FailureOr<SmallVector<Value>> oneToOneOperands =
+ self.getOneToOneAdaptorOperands(adaptor.getOperands());
+ if (failed(oneToOneOperands))
+ return rewriter.notifyMatchFailure(op,
+ "pattern '" + self.getDebugName() +
+ "' does not support 1:N conversion");
+ return self.matchAndRewrite(
+ op, typename SourceOp::Adaptor(*oneToOneOperands, adaptor), rewriter);
+}
+
//===----------------------------------------------------------------------===//
// ConversionTarget
//===----------------------------------------------------------------------===//
@@ -1161,6 +1283,16 @@ public:
// ConversionConfig
//===----------------------------------------------------------------------===//
+/// An enum to control folding behavior during dialect conversion.
+enum class DialectConversionFoldingMode {
+ /// Never attempt to fold.
+ Never,
+ /// Only attempt to fold not legal operations before applying patterns.
+ BeforePatterns,
+ /// Only attempt to fold not legal operations after applying patterns.
+ AfterPatterns,
+};
+
/// Dialect conversion configuration.
struct ConversionConfig {
/// An optional callback used to notify about match failure diagnostics during
@@ -1231,18 +1363,29 @@ struct ConversionConfig {
/// 2. Pattern produces IR (in-place modification or new IR) that is illegal
/// and cannot be legalized by subsequent foldings / pattern applications.
///
- /// If set to "false", the conversion driver will produce an LLVM fatal error
- /// instead of rolling back IR modifications. Moreover, in case of a failed
- /// conversion, the original IR is not restored. The resulting IR may be a
- /// mix of original and rewritten IR. (Same as a failed greedy pattern
- /// rewrite.)
+ /// Experimental: If set to "false", the conversion driver will produce an
+ /// LLVM fatal error instead of rolling back IR modifications. Moreover, in
+ /// case of a failed conversion, the original IR is not restored. The
+ /// resulting IR may be a mix of original and rewritten IR. (Same as a failed
+ /// greedy pattern rewrite.) Use the cmake build option
+ /// `-DMLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON` (ideally together with
+ /// ASAN) to detect invalid pattern API usage.
///
- /// Note: This flag was added in preparation of the One-Shot Dialect
- /// Conversion refactoring, which will remove the ability to roll back IR
- /// modifications from the conversion driver. Use this flag to ensure that
- /// your patterns do not trigger any IR rollbacks. For details, see
+ /// When pattern rollback is disabled, the conversion driver has to maintain
+ /// less internal state. This is more efficient, but not supported by all
+ /// lowering patterns. For details, see
/// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083.
bool allowPatternRollback = true;
+
+ /// The folding mode to use during conversion.
+ DialectConversionFoldingMode foldingMode =
+ DialectConversionFoldingMode::BeforePatterns;
+
+ /// If set to "true", the materialization kind ("source" or "target") will be
+ /// attached to "builtin.unrealized_conversion_cast" ops. This flag is useful
+ /// for debugging, to find out what kind of materialization rule may be
+ /// missing.
+ bool attachDebugMaterializationKind = false;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
index 6d62ae3d..7d5c1d5 100644
--- a/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
@@ -27,6 +27,8 @@ namespace mlir {
/// This is intended as the simplest and most lightweight pattern rewriter in
/// cases when a simple walk gets the job done.
///
+/// The driver will skip unreachable blocks.
+///
/// Note: Does not apply patterns to the given operation itself.
void walkAndApplyPatterns(Operation *op,
const FrozenRewritePatternSet &patterns,
diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
index 6cece46..8062b474 100644
--- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
+++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
@@ -127,9 +127,12 @@ static void collectUnderlyingAddressValues(OpResult result, unsigned maxDepth,
Operation *op = result.getOwner();
// If this is a view, unwrap to the source.
- if (ViewLikeOpInterface view = dyn_cast<ViewLikeOpInterface>(op))
- return collectUnderlyingAddressValues(view.getViewSource(), maxDepth,
- visited, output);
+ if (ViewLikeOpInterface view = dyn_cast<ViewLikeOpInterface>(op)) {
+ if (result == view.getViewDest()) {
+ return collectUnderlyingAddressValues(view.getViewSource(), maxDepth,
+ visited, output);
+ }
+ }
// Check to see if we can reason about the control flow of this op.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
return collectUnderlyingAddressValues(branch, /*region=*/nullptr, result,
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index 10874fd..9424eff 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -15,6 +15,7 @@
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
@@ -78,9 +79,17 @@ void Executable::onUpdate(DataFlowSolver *solver) const {
void PredecessorState::print(raw_ostream &os) const {
if (allPredecessorsKnown())
os << "(all) ";
- os << "predecessors:\n";
- for (Operation *op : getKnownPredecessors())
- os << " " << *op << "\n";
+ os << "predecessors:";
+ if (getKnownPredecessors().empty())
+ os << " (none)";
+ else
+ os << "\n";
+ llvm::interleave(
+ getKnownPredecessors(), os,
+ [&](Operation *op) {
+ os << " " << OpWithFlags(op, OpPrintingFlags().skipRegions());
+ },
+ "\n");
}
ChangeResult PredecessorState::join(Operation *predecessor) {
@@ -127,7 +136,7 @@ DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver)
LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
LDBG() << "Initializing DeadCodeAnalysis for top-level op: "
- << top->getName();
+ << OpWithFlags(top, OpPrintingFlags().skipRegions());
// Mark the top-level blocks as executable.
for (Region &region : top->getRegions()) {
if (region.empty())
@@ -135,7 +144,8 @@ LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
auto *state =
getOrCreate<Executable>(getProgramPointBefore(&region.front()));
propagateIfChanged(state, state->setToLive());
- LDBG() << "Marked entry block live for region in op: " << top->getName();
+ LDBG() << "Marked entry block live for region in op: "
+ << OpWithFlags(top, OpPrintingFlags().skipRegions());
}
// Mark as overdefined the predecessors of symbol callables with potentially
@@ -147,17 +157,19 @@ LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
LDBG() << "[init] Entering initializeSymbolCallables for top-level op: "
- << top->getName();
+ << OpWithFlags(top, OpPrintingFlags().skipRegions());
analysisScope = top;
auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
- LDBG() << "[init] Processing symbol table op: " << symTable->getName();
+ LDBG() << "[init] Processing symbol table op: "
+ << OpWithFlags(symTable, OpPrintingFlags().skipRegions());
Region &symbolTableRegion = symTable->getRegion(0);
Block *symbolTableBlock = &symbolTableRegion.front();
bool foundSymbolCallable = false;
for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
LDBG() << "[init] Found CallableOpInterface: "
- << callable.getOperation()->getName();
+ << OpWithFlags(callable.getOperation(),
+ OpPrintingFlags().skipRegions());
Region *callableRegion = callable.getCallableRegion();
if (!callableRegion)
continue;
@@ -172,7 +184,8 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
getOrCreate<PredecessorState>(getProgramPointAfter(callable));
propagateIfChanged(state, state->setHasUnknownPredecessors());
LDBG() << "[init] Marked callable as having unknown predecessors: "
- << callable.getOperation()->getName();
+ << OpWithFlags(callable.getOperation(),
+ OpPrintingFlags().skipRegions());
}
foundSymbolCallable = true;
}
@@ -195,7 +208,8 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
propagateIfChanged(state, state->setHasUnknownPredecessors());
LDBG() << "[init] Marked nested callable as "
"having unknown predecessors: "
- << callable.getOperation()->getName();
+ << OpWithFlags(callable.getOperation(),
+ OpPrintingFlags().skipRegions());
});
}
@@ -211,13 +225,13 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
propagateIfChanged(state, state->setHasUnknownPredecessors());
LDBG() << "[init] Found non-call use for symbol, "
"marked as having unknown predecessors: "
- << symbol->getName();
+ << OpWithFlags(symbol, OpPrintingFlags().skipRegions());
}
};
SymbolTable::walkSymbolTables(top, /*allSymUsesVisible=*/!top->getBlock(),
walkFn);
LDBG() << "[init] Finished initializeSymbolCallables for top-level op: "
- << top->getName();
+ << OpWithFlags(top, OpPrintingFlags().skipRegions());
}
/// Returns true if the operation is a returning terminator in region
@@ -229,12 +243,13 @@ static bool isRegionOrCallableReturn(Operation *op) {
}
LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
- LDBG() << "[init] Entering initializeRecursively for op: " << op->getName()
- << " at " << op;
+ LDBG() << "[init] Entering initializeRecursively for op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
// Initialize the analysis by visiting every op with control-flow semantics.
if (op->getNumRegions() || op->getNumSuccessors() ||
isRegionOrCallableReturn(op) || isa<CallOpInterface>(op)) {
- LDBG() << "[init] Visiting op with control-flow semantics: " << *op;
+ LDBG() << "[init] Visiting op with control-flow semantics: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
// When the liveness of the parent block changes, make sure to
// re-invoke the analysis on the op.
if (op->getBlock())
@@ -246,16 +261,17 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
}
// Recurse on nested operations.
for (Region &region : op->getRegions()) {
- LDBG() << "[init] Recursing into region of op: " << op->getName();
+ LDBG() << "[init] Recursing into region of op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
for (Operation &nestedOp : region.getOps()) {
- LDBG() << "[init] Recursing into nested op: " << nestedOp.getName()
- << " at " << &nestedOp;
+ LDBG() << "[init] Recursing into nested op: "
+ << OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions());
if (failed(initializeRecursively(&nestedOp)))
return failure();
}
}
- LDBG() << "[init] Finished initializeRecursively for op: " << op->getName()
- << " at " << op;
+ LDBG() << "[init] Finished initializeRecursively for op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
return success();
}
@@ -269,35 +285,40 @@ void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
}
void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
- LDBG() << "Marking entry blocks live for op: " << op->getName();
+ LDBG() << "Marking entry blocks live for op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
for (Region &region : op->getRegions()) {
if (region.empty())
continue;
auto *state =
getOrCreate<Executable>(getProgramPointBefore(&region.front()));
propagateIfChanged(state, state->setToLive());
- LDBG() << "Marked entry block live for region in op: " << op->getName();
+ LDBG() << "Marked entry block live for region in op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
}
}
LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
- LDBG() << "Visiting program point: " << point << " " << *point;
+ LDBG() << "Visiting program point: " << *point;
if (point->isBlockStart())
return success();
Operation *op = point->getPrevOp();
- LDBG() << "Visiting operation: " << *op;
+ LDBG() << "Visiting operation: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
// If the parent block is not executable, there is nothing to do.
if (op->getBlock() != nullptr &&
!getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))
->isLive()) {
- LDBG() << "Parent block not live, skipping op: " << *op;
+ LDBG() << "Parent block not live, skipping op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
return success();
}
// We have a live call op. Add this as a live predecessor of the callee.
if (auto call = dyn_cast<CallOpInterface>(op)) {
- LDBG() << "Visiting call operation: " << *op;
+ LDBG() << "Visiting call operation: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
visitCallOperation(call);
}
@@ -305,12 +326,14 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
if (op->getNumRegions()) {
// Check if we can reason about the region control-flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
- LDBG() << "Visiting region branch operation: " << *op;
+ LDBG() << "Visiting region branch operation: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
visitRegionBranchOperation(branch);
// Check if this is a callable operation.
} else if (auto callable = dyn_cast<CallableOpInterface>(op)) {
- LDBG() << "Visiting callable operation: " << *op;
+ LDBG() << "Visiting callable operation: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
const auto *callsites = getOrCreateFor<PredecessorState>(
getProgramPointAfter(op), getProgramPointAfter(callable));
@@ -322,19 +345,22 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
// Otherwise, conservatively mark all entry blocks as executable.
} else {
- LDBG() << "Marking all entry blocks live for op: " << *op;
+ LDBG() << "Marking all entry blocks live for op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
markEntryBlocksLive(op);
}
}
if (isRegionOrCallableReturn(op)) {
if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
- LDBG() << "Visiting region terminator: " << *op;
+ LDBG() << "Visiting region terminator: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
// Visit the exiting terminator of a region.
visitRegionTerminator(op, branch);
} else if (auto callable =
dyn_cast<CallableOpInterface>(op->getParentOp())) {
- LDBG() << "Visiting callable terminator: " << *op;
+ LDBG() << "Visiting callable terminator: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
// Visit the exiting terminator of a callable.
visitCallableTerminator(op, callable);
}
@@ -343,12 +369,14 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
if (op->getNumSuccessors()) {
// Check if we can reason about the control-flow.
if (auto branch = dyn_cast<BranchOpInterface>(op)) {
- LDBG() << "Visiting branch operation: " << *op;
+ LDBG() << "Visiting branch operation: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
visitBranchOperation(branch);
// Otherwise, conservatively mark all successors as exectuable.
} else {
- LDBG() << "Marking all successors live for op: " << *op;
+ LDBG() << "Marking all successors live for op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
for (Block *successor : op->getSuccessors())
markEdgeLive(op->getBlock(), successor);
}
@@ -358,7 +386,8 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
}
void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
- LDBG() << "visitCallOperation: " << call.getOperation()->getName();
+ LDBG() << "visitCallOperation: "
+ << OpWithFlags(call.getOperation(), OpPrintingFlags().skipRegions());
Operation *callableOp = call.resolveCallableInTable(&symbolTable);
// A call to a externally-defined callable has unknown predecessors.
@@ -382,14 +411,14 @@ void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
getOrCreate<PredecessorState>(getProgramPointAfter(callableOp));
propagateIfChanged(callsites, callsites->join(call));
LDBG() << "Added callsite as predecessor for callable: "
- << callableOp->getName();
+ << OpWithFlags(callableOp, OpPrintingFlags().skipRegions());
} else {
// Mark this call op's predecessors as overdefined.
auto *predecessors =
getOrCreate<PredecessorState>(getProgramPointAfter(call));
propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors());
LDBG() << "Marked call op's predecessors as unknown for: "
- << call.getOperation()->getName();
+ << OpWithFlags(call.getOperation(), OpPrintingFlags().skipRegions());
}
}
@@ -421,7 +450,8 @@ DeadCodeAnalysis::getOperandValues(Operation *op) {
}
void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
- LDBG() << "visitBranchOperation: " << branch.getOperation()->getName();
+ LDBG() << "visitBranchOperation: "
+ << OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions());
// Try to deduce a single successor for the branch.
std::optional<SmallVector<Attribute>> operands = getOperandValues(branch);
if (!operands)
@@ -440,7 +470,8 @@ void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
void DeadCodeAnalysis::visitRegionBranchOperation(
RegionBranchOpInterface branch) {
- LDBG() << "visitRegionBranchOperation: " << branch.getOperation()->getName();
+ LDBG() << "visitRegionBranchOperation: "
+ << OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions());
// Try to deduce which regions are executable.
std::optional<SmallVector<Attribute>> operands = getOperandValues(branch);
if (!operands)
@@ -517,14 +548,14 @@ void DeadCodeAnalysis::visitCallableTerminator(Operation *op,
if (canResolve) {
propagateIfChanged(predecessors, predecessors->join(op));
LDBG() << "Added callable terminator as predecessor for callsite: "
- << predecessor->getName();
+ << OpWithFlags(predecessor, OpPrintingFlags().skipRegions());
} else {
// If the terminator is not a return-like, then conservatively assume we
// can't resolve the predecessor.
propagateIfChanged(predecessors,
predecessors->setHasUnknownPredecessors());
LDBG() << "Could not resolve callable terminator for callsite: "
- << predecessor->getName();
+ << OpWithFlags(predecessor, OpPrintingFlags().skipRegions());
}
}
}
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index c7a950d..e79f6a8 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -19,6 +19,8 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
@@ -28,6 +30,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include <cassert>
#include <optional>
#include <utility>
@@ -87,7 +90,8 @@ LogicalResult IntegerRangeAnalysis::visitOperation(
return success();
}
- LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
+ LDBG() << "Inferring ranges for "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
auto argRanges = llvm::map_to_vector(
operands, [](const IntegerValueRangeLattice *lattice) {
return lattice->getValue();
@@ -99,7 +103,7 @@ LogicalResult IntegerRangeAnalysis::visitOperation(
return;
assert(llvm::is_contained(op->getResults(), result));
- LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
+ LDBG() << "Inferred range " << attrs;
IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
IntegerValueRange oldRange = lattice->getValue();
@@ -114,7 +118,7 @@ LogicalResult IntegerRangeAnalysis::visitOperation(
});
if (isYieldedResult && !oldRange.isUninitialized() &&
!(lattice->getValue() == oldRange)) {
- LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
+ LDBG() << "Loop variant loop result detected";
changed |= lattice->join(IntegerValueRange::getMaxRange(v));
}
propagateIfChanged(lattice, changed);
@@ -128,7 +132,8 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
Operation *op, const RegionSuccessor &successor,
ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
- LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
+ LDBG() << "Inferring ranges for "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) {
return getLatticeElementFor(getProgramPointAfter(op), value)->getValue();
@@ -141,7 +146,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg))
return;
- LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
+ LDBG() << "Inferred range " << attrs;
IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
IntegerValueRange oldRange = lattice->getValue();
@@ -156,7 +161,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
});
if (isYieldedValue && !oldRange.isUninitialized() &&
!(lattice->getValue() == oldRange)) {
- LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
+ LDBG() << "Loop variant loop result detected";
changed |= lattice->join(IntegerValueRange::getMaxRange(v));
}
propagateIfChanged(lattice, changed);
diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
index 509f520..65df355 100644
--- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
@@ -294,7 +294,34 @@ RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) {
solver.load<LivenessAnalysis>(symbolTable);
LDBG() << "Initializing and running solver";
(void)solver.initializeAndRun(op);
- LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName();
+ LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName()
+ << " check on unreachable code now:";
+ // The framework doesn't visit operations in dead blocks, so we need to
+ // explicitly mark them as dead.
+ op->walk([&](Operation *op) {
+ if (op->getNumResults() == 0)
+ return;
+ for (auto result : llvm::enumerate(op->getResults())) {
+ if (getLiveness(result.value()))
+ continue;
+ LDBG() << "Result: " << result.index() << " of "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " has no liveness info (unreachable), mark dead";
+ solver.getOrCreateState<Liveness>(result.value());
+ }
+ for (auto &region : op->getRegions()) {
+ for (auto &block : region) {
+ for (auto blockArg : llvm::enumerate(block.getArguments())) {
+ if (getLiveness(blockArg.value()))
+ continue;
+ LDBG() << "Block argument: " << blockArg.index() << " of "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " has no liveness info, mark dead";
+ solver.getOrCreateState<Liveness>(blockArg.value());
+ }
+ }
+ }
+ });
}
const Liveness *RunLivenessAnalysis::getLiveness(Value val) {
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index e625f62..13a3e14 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -19,12 +19,15 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
#include <cassert>
#include <optional>
using namespace mlir;
using namespace mlir::dataflow;
+#define DEBUG_TYPE "dataflow"
+
//===----------------------------------------------------------------------===//
// AbstractSparseLattice
//===----------------------------------------------------------------------===//
@@ -64,22 +67,36 @@ AbstractSparseForwardDataFlowAnalysis::initialize(Operation *top) {
LogicalResult
AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
+ LDBG() << "Initializing recursively for operation: " << op->getName();
+
// Initialize the analysis by visiting every owner of an SSA value (all
// operations and blocks).
- if (failed(visitOperation(op)))
+ if (failed(visitOperation(op))) {
+ LDBG() << "Failed to visit operation: " << op->getName();
return failure();
+ }
for (Region &region : op->getRegions()) {
+ LDBG() << "Processing region with " << region.getBlocks().size()
+ << " blocks";
for (Block &block : region) {
+ LDBG() << "Processing block with " << block.getNumArguments()
+ << " arguments";
getOrCreate<Executable>(getProgramPointBefore(&block))
->blockContentSubscribe(this);
visitBlock(&block);
- for (Operation &op : block)
- if (failed(initializeRecursively(&op)))
+ for (Operation &op : block) {
+ LDBG() << "Recursively initializing nested operation: " << op.getName();
+ if (failed(initializeRecursively(&op))) {
+ LDBG() << "Failed to initialize nested operation: " << op.getName();
return failure();
+ }
+ }
}
}
+ LDBG() << "Successfully completed recursive initialization for operation: "
+ << op->getName();
return success();
}
@@ -409,11 +426,20 @@ static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {
LogicalResult
AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
+ LDBG() << "Visiting operation: " << op->getName() << " with "
+ << op->getNumOperands() << " operands and " << op->getNumResults()
+ << " results";
+
// If we're in a dead block, bail out.
if (op->getBlock() != nullptr &&
- !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive())
+ !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))
+ ->isLive()) {
+ LDBG() << "Operation is in dead block, bailing out";
return success();
+ }
+ LDBG() << "Creating lattice elements for " << op->getNumOperands()
+ << " operands and " << op->getNumResults() << " results";
SmallVector<AbstractSparseLattice *> operandLattices =
getLatticeElements(op->getOperands());
SmallVector<const AbstractSparseLattice *> resultLattices =
@@ -422,11 +448,15 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
// Block arguments of region branch operations flow back into the operands
// of the parent op
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
+ LDBG() << "Processing RegionBranchOpInterface operation";
visitRegionSuccessors(branch, operandLattices);
return success();
}
if (auto branch = dyn_cast<BranchOpInterface>(op)) {
+ LDBG() << "Processing BranchOpInterface operation with "
+ << op->getNumSuccessors() << " successors";
+
// Block arguments of successor blocks flow back into our operands.
// We remember all operands not forwarded to any block in a BitVector.
@@ -463,6 +493,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
// For function calls, connect the arguments of the entry blocks to the
// operands of the call op that are forwarded to these arguments.
if (auto call = dyn_cast<CallOpInterface>(op)) {
+ LDBG() << "Processing CallOpInterface operation";
Operation *callableOp = call.resolveCallableInTable(&symbolTable);
if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
// Not all operands of a call op forward to arguments. Such operands are
@@ -513,6 +544,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
// of this op itself and the operands of the terminators of the regions of
// this op.
if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
+ LDBG() << "Processing RegionBranchTerminatorOpInterface operation";
if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
visitRegionSuccessorsFromTerminator(terminator, branch);
return success();
@@ -520,12 +552,16 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
}
if (op->hasTrait<OpTrait::ReturnLike>()) {
+ LDBG() << "Processing ReturnLike operation";
// Going backwards, the operands of the return are derived from the
// results of all CallOps calling this CallableOp.
- if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp()))
+ if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
+ LDBG() << "Callable parent found, visiting callable operation";
return visitCallableOperation(op, callable, operandLattices);
+ }
}
+ LDBG() << "Using default visitOperationImpl for operation: " << op->getName();
return visitOperationImpl(op, operandLattices, resultLattices);
}
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index 16f7033..7e1b405 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -45,7 +45,7 @@ void AnalysisState::addDependency(ProgramPoint *dependent,
DATAFLOW_DEBUG({
if (inserted) {
LDBG() << "Creating dependency between " << debugName << " of " << anchor
- << "\nand " << debugName << " on " << dependent;
+ << "\nand " << debugName << " on " << *dependent;
}
});
}
@@ -62,11 +62,12 @@ void ProgramPoint::print(raw_ostream &os) const {
return;
}
if (!isBlockStart()) {
- os << "<after operation>:";
- return getPrevOp()->print(os, OpPrintingFlags().skipRegions());
+ os << "<after operation>:"
+ << OpWithFlags(getPrevOp(), OpPrintingFlags().skipRegions());
+ return;
}
- os << "<before operation>:";
- return getNextOp()->print(os, OpPrintingFlags().skipRegions());
+ os << "<before operation>:"
+ << OpWithFlags(getNextOp(), OpPrintingFlags().skipRegions());
}
//===----------------------------------------------------------------------===//
@@ -78,8 +79,8 @@ void LatticeAnchor::print(raw_ostream &os) const {
os << "<NULL POINT>";
return;
}
- if (auto *LatticeAnchor = llvm::dyn_cast<GenericLatticeAnchor *>(*this))
- return LatticeAnchor->print(os);
+ if (auto *latticeAnchor = llvm::dyn_cast<GenericLatticeAnchor *>(*this))
+ return latticeAnchor->print(os);
if (auto value = llvm::dyn_cast<Value>(*this)) {
return value.print(os, OpPrintingFlags().skipRegions());
}
@@ -88,8 +89,8 @@ void LatticeAnchor::print(raw_ostream &os) const {
}
Location LatticeAnchor::getLoc() const {
- if (auto *LatticeAnchor = llvm::dyn_cast<GenericLatticeAnchor *>(*this))
- return LatticeAnchor->getLoc();
+ if (auto *latticeAnchor = llvm::dyn_cast<GenericLatticeAnchor *>(*this))
+ return latticeAnchor->getLoc();
if (auto value = llvm::dyn_cast<Value>(*this))
return value.getLoc();
@@ -128,7 +129,7 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
worklist.pop();
DATAFLOW_DEBUG(LDBG() << "Invoking '" << analysis->debugName
- << "' on: " << point);
+ << "' on: " << *point);
if (failed(analysis->visit(point)))
return failure();
}
diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index f4b02b4..30ce1fb 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -60,7 +60,7 @@ private:
AffineExpr localExpr) override {
SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr);
// Update localVarCst.
- localVarCst.addLocalFloorDiv(dividend, divisor);
+ (void)localVarCst.addLocalFloorDiv(dividend, divisor);
}
LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs,
diff --git a/mlir/lib/Analysis/Presburger/Barvinok.cpp b/mlir/lib/Analysis/Presburger/Barvinok.cpp
index 4546e49..75d592e 100644
--- a/mlir/lib/Analysis/Presburger/Barvinok.cpp
+++ b/mlir/lib/Analysis/Presburger/Barvinok.cpp
@@ -554,7 +554,7 @@ QuasiPolynomial mlir::presburger::detail::getCoefficientInRationalFunction(
/// t^num / \prod_j (1 - t^dens[j]).
/// v represents the affine functions whose floors are multiplied by the
/// generators, and ds represents the list of generators.
-std::pair<QuasiPolynomial, std::vector<Fraction>>
+static std::pair<QuasiPolynomial, std::vector<Fraction>>
substituteMuInTerm(unsigned numParams, const ParamPoint &v,
const std::vector<Point> &ds, const Point &mu) {
unsigned numDims = mu.size();
@@ -606,8 +606,8 @@ substituteMuInTerm(unsigned numParams, const ParamPoint &v,
/// Here, sign = ± 1,
/// num is a QuasiPolynomial, and
/// each dens[j] is a Fraction.
-void normalizeDenominatorExponents(int &sign, QuasiPolynomial &num,
- std::vector<Fraction> &dens) {
+static void normalizeDenominatorExponents(int &sign, QuasiPolynomial &num,
+ std::vector<Fraction> &dens) {
// We track the number of exponents that are negative in the
// denominator, and convert them to their absolute values.
unsigned numNegExps = 0;
@@ -634,8 +634,8 @@ void normalizeDenominatorExponents(int &sign, QuasiPolynomial &num,
/// Compute the binomial coefficients nCi for 0 ≤ i ≤ r,
/// where n is a QuasiPolynomial.
-std::vector<QuasiPolynomial> getBinomialCoefficients(const QuasiPolynomial &n,
- unsigned r) {
+static std::vector<QuasiPolynomial>
+getBinomialCoefficients(const QuasiPolynomial &n, unsigned r) {
unsigned numParams = n.getNumInputs();
std::vector<QuasiPolynomial> coefficients;
coefficients.reserve(r + 1);
@@ -651,8 +651,8 @@ std::vector<QuasiPolynomial> getBinomialCoefficients(const QuasiPolynomial &n,
/// Compute the binomial coefficients nCi for 0 ≤ i ≤ r,
/// where n is a QuasiPolynomial.
-std::vector<Fraction> getBinomialCoefficients(const Fraction &n,
- const Fraction &r) {
+static std::vector<Fraction> getBinomialCoefficients(const Fraction &n,
+ const Fraction &r) {
std::vector<Fraction> coefficients;
coefficients.reserve((int64_t)floor(r));
coefficients.emplace_back(1);
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 5c4d4d1..0dcdd5b 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -1500,12 +1500,13 @@ void IntegerRelation::addBound(BoundType type, ArrayRef<DynamicAPInt> expr,
/// respect to a positive constant 'divisor'. Two constraints are added to the
/// system to capture equivalence with the floordiv.
/// q = expr floordiv c <=> c*q <= expr <= c*q + c - 1.
-void IntegerRelation::addLocalFloorDiv(ArrayRef<DynamicAPInt> dividend,
- const DynamicAPInt &divisor) {
+/// Returns the column position of the new local variable.
+unsigned IntegerRelation::addLocalFloorDiv(ArrayRef<DynamicAPInt> dividend,
+ const DynamicAPInt &divisor) {
assert(dividend.size() == getNumCols() && "incorrect dividend size");
assert(divisor > 0 && "positive divisor expected");
- appendVar(VarKind::Local);
+ unsigned newVar = appendVar(VarKind::Local);
SmallVector<DynamicAPInt, 8> dividendCopy(dividend);
dividendCopy.insert(dividendCopy.end() - 1, DynamicAPInt(0));
@@ -1513,6 +1514,28 @@ void IntegerRelation::addLocalFloorDiv(ArrayRef<DynamicAPInt> dividend,
getDivLowerBound(dividendCopy, divisor, dividendCopy.size() - 2));
addInequality(
getDivUpperBound(dividendCopy, divisor, dividendCopy.size() - 2));
+ return newVar;
+}
+
+unsigned IntegerRelation::addLocalModulo(ArrayRef<DynamicAPInt> exprs,
+ const DynamicAPInt &modulus) {
+ assert(exprs.size() == getNumCols() && "incorrect exprs size");
+ assert(modulus > 0 && "positive modulus expected");
+
+ /// Add a local variable for q = expr floordiv modulus
+ addLocalFloorDiv(exprs, modulus);
+
+ /// Add a local var to represent the result
+ auto resultIndex = appendVar(VarKind::Local);
+
+ SmallVector<DynamicAPInt, 8> exprsCopy(exprs);
+ /// Insert the two new locals before the constant
+ /// Add locals that correspond to `q` and `result` to compute
+ /// 0 = (expr - modulus * q) - result
+ exprsCopy.insert(exprsCopy.end() - 1,
+ {DynamicAPInt(-modulus), DynamicAPInt(-1)});
+ addEquality(exprsCopy);
+ return resultIndex;
}
int IntegerRelation::findEqualityToConstant(unsigned pos, bool symbolic) const {
diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp
index 9fc6205..bb60564 100644
--- a/mlir/lib/Analysis/Presburger/Matrix.cpp
+++ b/mlir/lib/Analysis/Presburger/Matrix.cpp
@@ -402,10 +402,10 @@ void Matrix<T>::print(raw_ostream &os) const {
for (unsigned row = 0; row < nRows; ++row)
for (unsigned column = 0; column < nColumns; ++column)
updatePrintMetrics<T>(at(row, column), ptm);
- unsigned MIN_SPACING = 1;
+ unsigned minSpacing = 1;
for (unsigned row = 0; row < nRows; ++row) {
for (unsigned column = 0; column < nColumns; ++column) {
- printWithPrintMetrics<T>(os, at(row, column), MIN_SPACING, ptm);
+ printWithPrintMetrics<T>(os, at(row, column), minSpacing, ptm);
}
os << "\n";
}
@@ -721,7 +721,7 @@ FracMatrix FracMatrix::gramSchmidt() const {
// Otherwise, we swap b_k and b_{k-1} and decrement k.
//
// We repeat this until k = n and return.
-void FracMatrix::LLL(Fraction delta) {
+void FracMatrix::LLL(const Fraction &delta) {
DynamicAPInt nearest;
Fraction mu;
diff --git a/mlir/lib/Analysis/Presburger/QuasiPolynomial.cpp b/mlir/lib/Analysis/Presburger/QuasiPolynomial.cpp
index 84d885f..4e374d0 100644
--- a/mlir/lib/Analysis/Presburger/QuasiPolynomial.cpp
+++ b/mlir/lib/Analysis/Presburger/QuasiPolynomial.cpp
@@ -112,14 +112,14 @@ QuasiPolynomial QuasiPolynomial::simplify() {
// A term is zero if its coefficient is zero, or
if (coefficients[i] == Fraction(0, 1))
continue;
- bool product_is_zero =
+ bool productIsZero =
// if any of the affine functions in the product
- llvm::any_of(affine[i], [](const SmallVector<Fraction> &affine_ij) {
+ llvm::any_of(affine[i], [](const SmallVector<Fraction> &affineIj) {
// has all its coefficients as zero.
- return llvm::all_of(affine_ij,
+ return llvm::all_of(affineIj,
[](const Fraction &f) { return f == 0; });
});
- if (product_is_zero)
+ if (productIsZero)
continue;
// Now, we know the term is nonzero.
diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp
index 08290db..a1cbe29 100644
--- a/mlir/lib/Analysis/Presburger/Simplex.cpp
+++ b/mlir/lib/Analysis/Presburger/Simplex.cpp
@@ -433,7 +433,7 @@ LogicalResult SymbolicLexSimplex::addSymbolicCut(unsigned row) {
normalizeDiv(divCoeffs, divDenom);
domainSimplex.addDivisionVariable(divCoeffs, divDenom);
- domainPoly.addLocalFloorDiv(divCoeffs, divDenom);
+ (void)domainPoly.addLocalFloorDiv(divCoeffs, divDenom);
// Update `this` to account for the additional symbol we just added.
appendSymbol();
@@ -1663,7 +1663,7 @@ public:
/// First pushes a snapshot for the current simplex state to the stack so
/// that this can be rolled back later.
void addEqualityForDirection(ArrayRef<DynamicAPInt> dir) {
- assert(llvm::any_of(dir, [](const DynamicAPInt &x) { return x != 0; }) &&
+ assert(llvm::any_of(dir, [](const DynamicAPInt &X) { return X != 0; }) &&
"Direction passed is the zero vector!");
snapshotStack.emplace_back(simplex.getSnapshot());
simplex.addEquality(getCoeffsForDirection(dir));
@@ -2156,10 +2156,10 @@ void SimplexBase::print(raw_ostream &os) const {
for (unsigned row = 0, numRows = getNumRows(); row < numRows; ++row)
for (unsigned col = 0, numCols = getNumColumns(); col < numCols; ++col)
updatePrintMetrics<DynamicAPInt>(tableau(row, col), ptm);
- unsigned MIN_SPACING = 1;
+ unsigned minSpacing = 1;
for (unsigned row = 0, numRows = getNumRows(); row < numRows; ++row) {
for (unsigned col = 0, numCols = getNumColumns(); col < numCols; ++col) {
- printWithPrintMetrics<DynamicAPInt>(os, tableau(row, col), MIN_SPACING,
+ printWithPrintMetrics<DynamicAPInt>(os, tableau(row, col), minSpacing,
ptm);
}
os << '\n';
diff --git a/mlir/lib/Analysis/TopologicalSortUtils.cpp b/mlir/lib/Analysis/TopologicalSortUtils.cpp
index a2fd149..99546e7 100644
--- a/mlir/lib/Analysis/TopologicalSortUtils.cpp
+++ b/mlir/lib/Analysis/TopologicalSortUtils.cpp
@@ -101,12 +101,7 @@ bool mlir::sortTopologically(
bool mlir::sortTopologically(
Block *block, function_ref<bool(Value, Operation *)> isOperandReady) {
- if (block->empty())
- return true;
- if (block->back().hasTrait<OpTrait::IsTerminator>())
- return sortTopologically(block, block->without_terminator(),
- isOperandReady);
- return sortTopologically(block, *block, isOperandReady);
+ return sortTopologically(block, block->without_terminator(), isOperandReady);
}
bool mlir::computeTopologicalSorting(
diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp
index e5045cf..a21176fff 100644
--- a/mlir/lib/Bindings/Python/DialectGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectGPU.cpp
@@ -9,8 +9,8 @@
#include "mlir-c/Dialect/GPU.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace nanobind::literals;
@@ -34,7 +34,7 @@ NB_MODULE(_mlirDialectsGPU, m) {
mlirGPUAsyncTokenType.def_classmethod(
"get",
- [](nb::object cls, MlirContext ctx) {
+ [](const nb::object &cls, MlirContext ctx) {
return cls(mlirGPUAsyncTokenTypeGet(ctx));
},
"Gets an instance of AsyncTokenType in the same context", nb::arg("cls"),
@@ -47,8 +47,9 @@ NB_MODULE(_mlirDialectsGPU, m) {
mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr)
.def_classmethod(
"get",
- [](nb::object cls, MlirAttribute target, uint32_t format,
- nb::bytes object, std::optional<MlirAttribute> mlirObjectProps,
+ [](const nb::object &cls, MlirAttribute target, uint32_t format,
+ const nb::bytes &object,
+ std::optional<MlirAttribute> mlirObjectProps,
std::optional<MlirAttribute> mlirKernelsAttr) {
MlirStringRef objectStrRef = mlirStringRefCreate(
static_cast<char *>(const_cast<void *>(object.data())),
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index f211e76..ee106c0 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -12,8 +12,8 @@
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Diagnostics.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
@@ -24,7 +24,7 @@ using namespace mlir;
using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
-void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
+static void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
//===--------------------------------------------------------------------===//
// StructType
@@ -35,8 +35,8 @@ void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
llvmStructType.def_classmethod(
"get_literal",
- [](nb::object cls, const std::vector<MlirType> &elements, bool packed,
- MlirLocation loc) {
+ [](const nb::object &cls, const std::vector<MlirType> &elements,
+ bool packed, MlirLocation loc) {
CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
MlirType type = mlirLLVMStructTypeLiteralGetChecked(
@@ -51,7 +51,7 @@ void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
llvmStructType.def_classmethod(
"get_identified",
- [](nb::object cls, const std::string &name, MlirContext context) {
+ [](const nb::object &cls, const std::string &name, MlirContext context) {
return cls(mlirLLVMStructTypeIdentifiedGet(
context, mlirStringRefCreate(name.data(), name.size())));
},
@@ -59,7 +59,7 @@ void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
llvmStructType.def_classmethod(
"get_opaque",
- [](nb::object cls, const std::string &name, MlirContext context) {
+ [](const nb::object &cls, const std::string &name, MlirContext context) {
return cls(mlirLLVMStructTypeOpaqueGet(
context, mlirStringRefCreate(name.data(), name.size())));
},
@@ -79,7 +79,7 @@ void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
llvmStructType.def_classmethod(
"new_identified",
- [](nb::object cls, const std::string &name,
+ [](const nb::object &cls, const std::string &name,
const std::vector<MlirType> &elements, bool packed, MlirContext ctx) {
return cls(mlirLLVMStructTypeIdentifiedNewGet(
ctx, mlirStringRefCreate(name.data(), name.length()),
@@ -123,7 +123,7 @@ void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType)
.def_classmethod(
"get",
- [](nb::object cls, std::optional<unsigned> addressSpace,
+ [](const nb::object &cls, std::optional<unsigned> addressSpace,
MlirContext context) {
CollectDiagnosticsToStringScope scope(context);
MlirType type = mlirLLVMPointerTypeGet(
diff --git a/mlir/lib/Bindings/Python/DialectNVGPU.cpp b/mlir/lib/Bindings/Python/DialectNVGPU.cpp
index a0d6a4b..bb3f519c 100644
--- a/mlir/lib/Bindings/Python/DialectNVGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectNVGPU.cpp
@@ -8,8 +8,8 @@
#include "mlir-c/Dialect/NVGPU.h"
#include "mlir-c/IR.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace llvm;
@@ -23,8 +23,8 @@ static void populateDialectNVGPUSubmodule(const nb::module_ &m) {
nvgpuTensorMapDescriptorType.def_classmethod(
"get",
- [](nb::object cls, MlirType tensorMemrefType, int swizzle, int l2promo,
- int oobFill, int interleave, MlirContext ctx) {
+ [](const nb::object &cls, MlirType tensorMemrefType, int swizzle,
+ int l2promo, int oobFill, int interleave, MlirContext ctx) {
return cls(mlirNVGPUTensorMapDescriptorTypeGet(
ctx, tensorMemrefType, swizzle, l2promo, oobFill, interleave));
},
diff --git a/mlir/lib/Bindings/Python/DialectPDL.cpp b/mlir/lib/Bindings/Python/DialectPDL.cpp
index bcc6ff4..2acedbc 100644
--- a/mlir/lib/Bindings/Python/DialectPDL.cpp
+++ b/mlir/lib/Bindings/Python/DialectPDL.cpp
@@ -8,8 +8,8 @@
#include "mlir-c/Dialect/PDL.h"
#include "mlir-c/IR.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace llvm;
@@ -17,7 +17,7 @@ using namespace mlir;
using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
-void populateDialectPDLSubmodule(const nanobind::module_ &m) {
+static void populateDialectPDLSubmodule(const nanobind::module_ &m) {
//===-------------------------------------------------------------------===//
// PDLType
//===-------------------------------------------------------------------===//
@@ -32,7 +32,7 @@ void populateDialectPDLSubmodule(const nanobind::module_ &m) {
mlir_type_subclass(m, "AttributeType", mlirTypeIsAPDLAttributeType);
attributeType.def_classmethod(
"get",
- [](nb::object cls, MlirContext ctx) {
+ [](const nb::object &cls, MlirContext ctx) {
return cls(mlirPDLAttributeTypeGet(ctx));
},
"Get an instance of AttributeType in given context.", nb::arg("cls"),
@@ -46,7 +46,7 @@ void populateDialectPDLSubmodule(const nanobind::module_ &m) {
mlir_type_subclass(m, "OperationType", mlirTypeIsAPDLOperationType);
operationType.def_classmethod(
"get",
- [](nb::object cls, MlirContext ctx) {
+ [](const nb::object &cls, MlirContext ctx) {
return cls(mlirPDLOperationTypeGet(ctx));
},
"Get an instance of OperationType in given context.", nb::arg("cls"),
@@ -59,7 +59,7 @@ void populateDialectPDLSubmodule(const nanobind::module_ &m) {
auto rangeType = mlir_type_subclass(m, "RangeType", mlirTypeIsAPDLRangeType);
rangeType.def_classmethod(
"get",
- [](nb::object cls, MlirType elementType) {
+ [](const nb::object &cls, MlirType elementType) {
return cls(mlirPDLRangeTypeGet(elementType));
},
"Gets an instance of RangeType in the same context as the provided "
@@ -77,7 +77,7 @@ void populateDialectPDLSubmodule(const nanobind::module_ &m) {
auto typeType = mlir_type_subclass(m, "TypeType", mlirTypeIsAPDLTypeType);
typeType.def_classmethod(
"get",
- [](nb::object cls, MlirContext ctx) {
+ [](const nb::object &cls, MlirContext ctx) {
return cls(mlirPDLTypeTypeGet(ctx));
},
"Get an instance of TypeType in given context.", nb::arg("cls"),
@@ -90,7 +90,7 @@ void populateDialectPDLSubmodule(const nanobind::module_ &m) {
auto valueType = mlir_type_subclass(m, "ValueType", mlirTypeIsAPDLValueType);
valueType.def_classmethod(
"get",
- [](nb::object cls, MlirContext ctx) {
+ [](const nb::object &cls, MlirContext ctx) {
return cls(mlirPDLValueTypeGet(ctx));
},
"Get an instance of TypeType in given context.", nb::arg("cls"),
diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp
index 55571cd..a5220fc 100644
--- a/mlir/lib/Bindings/Python/DialectQuant.cpp
+++ b/mlir/lib/Bindings/Python/DialectQuant.cpp
@@ -165,7 +165,7 @@ static void populateDialectQuantSubmodule(const nb::module_ &m) {
quantizedType.get_class());
anyQuantizedType.def_classmethod(
"get",
- [](nb::object cls, unsigned flags, MlirType storageType,
+ [](const nb::object &cls, unsigned flags, MlirType storageType,
MlirType expressedType, int64_t storageTypeMin,
int64_t storageTypeMax) {
return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType,
@@ -186,7 +186,7 @@ static void populateDialectQuantSubmodule(const nb::module_ &m) {
quantizedType.get_class());
uniformQuantizedType.def_classmethod(
"get",
- [](nb::object cls, unsigned flags, MlirType storageType,
+ [](const nb::object &cls, unsigned flags, MlirType storageType,
MlirType expressedType, double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax) {
return cls(mlirUniformQuantizedTypeGet(flags, storageType,
@@ -221,7 +221,7 @@ static void populateDialectQuantSubmodule(const nb::module_ &m) {
quantizedType.get_class());
uniformQuantizedPerAxisType.def_classmethod(
"get",
- [](nb::object cls, unsigned flags, MlirType storageType,
+ [](const nb::object &cls, unsigned flags, MlirType storageType,
MlirType expressedType, std::vector<double> scales,
std::vector<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax) {
@@ -293,7 +293,7 @@ static void populateDialectQuantSubmodule(const nb::module_ &m) {
mlirTypeIsAUniformQuantizedSubChannelType, quantizedType.get_class());
uniformQuantizedSubChannelType.def_classmethod(
"get",
- [](nb::object cls, unsigned flags, MlirType storageType,
+ [](const nb::object &cls, unsigned flags, MlirType storageType,
MlirType expressedType, MlirAttribute scales, MlirAttribute zeroPoints,
std::vector<int32_t> quantizedDimensions,
std::vector<int64_t> blockSizes, int64_t storageTypeMin,
@@ -367,7 +367,8 @@ static void populateDialectQuantSubmodule(const nb::module_ &m) {
quantizedType.get_class());
calibratedQuantizedType.def_classmethod(
"get",
- [](nb::object cls, MlirType expressedType, double min, double max) {
+ [](const nb::object &cls, MlirType expressedType, double min,
+ double max) {
return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max));
},
"Gets an instance of CalibratedQuantizedType in the same context as the "
diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp
index 4e76477..cab4219 100644
--- a/mlir/lib/Bindings/Python/DialectSMT.cpp
+++ b/mlir/lib/Bindings/Python/DialectSMT.cpp
@@ -24,7 +24,7 @@ using namespace mlir;
using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
-void populateDialectSMTSubmodule(nanobind::module_ &m) {
+static void populateDialectSMTSubmodule(nanobind::module_ &m) {
auto smtBoolType = mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool)
.def_classmethod(
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 97cebcc..9d7dc11 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -12,8 +12,8 @@
#include "mlir-c/AffineMap.h"
#include "mlir-c/Dialect/SparseTensor.h"
#include "mlir-c/IR.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace llvm;
@@ -38,7 +38,8 @@ static void populateDialectSparseTensorSubmodule(const nb::module_ &m) {
mlirAttributeIsASparseTensorEncodingAttr)
.def_classmethod(
"get",
- [](nb::object cls, std::vector<MlirSparseTensorLevelType> lvlTypes,
+ [](const nb::object &cls,
+ std::vector<MlirSparseTensorLevelType> lvlTypes,
std::optional<MlirAffineMap> dimToLvl,
std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
std::optional<MlirAttribute> explicitVal,
@@ -58,7 +59,7 @@ static void populateDialectSparseTensorSubmodule(const nb::module_ &m) {
"Gets a sparse_tensor.encoding from parameters.")
.def_classmethod(
"build_level_type",
- [](nb::object cls, MlirSparseTensorLevelFormat lvlFmt,
+ [](const nb::object &cls, MlirSparseTensorLevelFormat lvlFmt,
const std::vector<MlirSparseTensorLevelPropertyNondefault>
&properties,
unsigned n, unsigned m) {
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 59a030a..1a62b06 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -11,15 +11,15 @@
#include "mlir-c/Dialect/Transform.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace mlir;
using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;
-void populateDialectTransformSubmodule(const nb::module_ &m) {
+static void populateDialectTransformSubmodule(const nb::module_ &m) {
//===-------------------------------------------------------------------===//
// AnyOpType
//===-------------------------------------------------------------------===//
@@ -29,7 +29,7 @@ void populateDialectTransformSubmodule(const nb::module_ &m) {
mlirTransformAnyOpTypeGetTypeID);
anyOpType.def_classmethod(
"get",
- [](nb::object cls, MlirContext ctx) {
+ [](const nb::object &cls, MlirContext ctx) {
return cls(mlirTransformAnyOpTypeGet(ctx));
},
"Get an instance of AnyOpType in the given context.", nb::arg("cls"),
@@ -44,7 +44,7 @@ void populateDialectTransformSubmodule(const nb::module_ &m) {
mlirTransformAnyParamTypeGetTypeID);
anyParamType.def_classmethod(
"get",
- [](nb::object cls, MlirContext ctx) {
+ [](const nb::object &cls, MlirContext ctx) {
return cls(mlirTransformAnyParamTypeGet(ctx));
},
"Get an instance of AnyParamType in the given context.", nb::arg("cls"),
@@ -59,7 +59,7 @@ void populateDialectTransformSubmodule(const nb::module_ &m) {
mlirTransformAnyValueTypeGetTypeID);
anyValueType.def_classmethod(
"get",
- [](nb::object cls, MlirContext ctx) {
+ [](const nb::object &cls, MlirContext ctx) {
return cls(mlirTransformAnyValueTypeGet(ctx));
},
"Get an instance of AnyValueType in the given context.", nb::arg("cls"),
@@ -74,7 +74,8 @@ void populateDialectTransformSubmodule(const nb::module_ &m) {
mlirTransformOperationTypeGetTypeID);
operationType.def_classmethod(
"get",
- [](nb::object cls, const std::string &operationName, MlirContext ctx) {
+ [](const nb::object &cls, const std::string &operationName,
+ MlirContext ctx) {
MlirStringRef cOperationName =
mlirStringRefCreate(operationName.data(), operationName.size());
return cls(mlirTransformOperationTypeGet(ctx, cOperationName));
@@ -101,7 +102,7 @@ void populateDialectTransformSubmodule(const nb::module_ &m) {
mlirTransformParamTypeGetTypeID);
paramType.def_classmethod(
"get",
- [](nb::object cls, MlirType type, MlirContext ctx) {
+ [](const nb::object &cls, MlirType type, MlirContext ctx) {
return cls(mlirTransformParamTypeGet(ctx, type));
},
"Get an instance of ParamType for the given type in the given context.",
diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
index 81dada3..8bb493e 100644
--- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
+++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
@@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir-c/ExecutionEngine.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace mlir;
@@ -45,7 +45,7 @@ public:
referencedObjects.push_back(obj);
}
- static nb::object createFromCapsule(nb::object capsule) {
+ static nb::object createFromCapsule(const nb::object &capsule) {
MlirExecutionEngine rawPm =
mlirPythonCapsuleToExecutionEngine(capsule.ptr());
if (mlirExecutionEngineIsNull(rawPm))
@@ -113,7 +113,7 @@ NB_MODULE(_mlirExecutionEngine, m) {
.def(
"raw_register_runtime",
[](PyExecutionEngine &executionEngine, const std::string &name,
- nb::object callbackObj) {
+ const nb::object &callbackObj) {
executionEngine.addReferencedObject(callbackObj);
uintptr_t rawSym =
nb::cast<uintptr_t>(nb::getattr(callbackObj, "value"));
@@ -125,6 +125,17 @@ NB_MODULE(_mlirExecutionEngine, m) {
nb::arg("name"), nb::arg("callback"),
"Register `callback` as the runtime symbol `name`.")
.def(
+ "initialize",
+ [](PyExecutionEngine &executionEngine) {
+ mlirExecutionEngineInitialize(executionEngine.get());
+ },
+ "Initialize the ExecutionEngine. Global constructors specified by "
+ "`llvm.mlir.global_ctors` will be run. One common scenario is that "
+ "kernel binary compiled from `gpu.module` gets loaded during "
+ "initialization. Make sure all symbols are resolvable before "
+ "initialization by calling `register_runtime` or including "
+ "shared libraries.")
+ .def(
"dump_to_object_file",
[](PyExecutionEngine &executionEngine, const std::string &fileName) {
mlirExecutionEngineDumpToObjectFile(
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 826a34a..71a051c 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -10,15 +10,19 @@
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
#include <optional>
+#include <regex>
#include <string>
+#include <unordered_set>
#include <vector>
#include "NanobindUtils.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/Support.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/Regex.h"
namespace mlir {
namespace python {
@@ -114,6 +118,39 @@ public:
std::optional<nanobind::object>
lookupOperationClass(llvm::StringRef operationName);
+ class TracebackLoc {
+ public:
+ bool locTracebacksEnabled();
+
+ void setLocTracebacksEnabled(bool value);
+
+ size_t locTracebackFramesLimit();
+
+ void setLocTracebackFramesLimit(size_t value);
+
+ void registerTracebackFileInclusion(const std::string &file);
+
+ void registerTracebackFileExclusion(const std::string &file);
+
+ bool isUserTracebackFilename(llvm::StringRef file);
+
+ static constexpr size_t kMaxFrames = 512;
+
+ private:
+ nanobind::ft_mutex mutex;
+ bool locTracebackEnabled_ = false;
+ size_t locTracebackFramesLimit_ = 10;
+ std::unordered_set<std::string> userTracebackIncludeFiles;
+ std::unordered_set<std::string> userTracebackExcludeFiles;
+ std::regex userTracebackIncludeRegex;
+ bool rebuildUserTracebackIncludeRegex = false;
+ std::regex userTracebackExcludeRegex;
+ bool rebuildUserTracebackExcludeRegex = false;
+ llvm::StringMap<bool> isUserTracebackFilenameCache;
+ };
+
+ TracebackLoc &getTracebackLoc() { return tracebackLoc; }
+
private:
static PyGlobals *instance;
@@ -134,6 +171,8 @@ private:
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
llvm::StringSet<> loadedDialectModules;
+
+ TracebackLoc tracebackLoc;
};
} // namespace python
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 50f2a4f..a6499c9 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -17,9 +17,9 @@
#include "NanobindUtils.h"
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
+#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
#include "mlir-c/IntegerSet.h"
#include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/SmallVector.h"
@@ -64,7 +64,7 @@ static void pyListToVector(const nb::list &list,
}
template <typename PermutationTy>
-static bool isPermutation(std::vector<PermutationTy> permutation) {
+static bool isPermutation(const std::vector<PermutationTy> &permutation) {
llvm::SmallVector<bool, 8> seen(permutation.size(), false);
for (auto val : permutation) {
if (val < permutation.size()) {
@@ -366,7 +366,7 @@ nb::object PyAffineExpr::getCapsule() {
return nb::steal<nb::object>(mlirPythonAffineExprToCapsule(*this));
}
-PyAffineExpr PyAffineExpr::createFromCapsule(nb::object capsule) {
+PyAffineExpr PyAffineExpr::createFromCapsule(const nb::object &capsule) {
MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
if (mlirAffineExprIsNull(rawAffineExpr))
throw nb::python_error();
@@ -424,7 +424,7 @@ nb::object PyAffineMap::getCapsule() {
return nb::steal<nb::object>(mlirPythonAffineMapToCapsule(*this));
}
-PyAffineMap PyAffineMap::createFromCapsule(nb::object capsule) {
+PyAffineMap PyAffineMap::createFromCapsule(const nb::object &capsule) {
MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
if (mlirAffineMapIsNull(rawAffineMap))
throw nb::python_error();
@@ -500,7 +500,7 @@ nb::object PyIntegerSet::getCapsule() {
return nb::steal<nb::object>(mlirPythonIntegerSetToCapsule(*this));
}
-PyIntegerSet PyIntegerSet::createFromCapsule(nb::object capsule) {
+PyIntegerSet PyIntegerSet::createFromCapsule(const nb::object &capsule) {
MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
if (mlirIntegerSetIsNull(rawIntegerSet))
throw nb::python_error();
@@ -708,7 +708,8 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
return static_cast<size_t>(llvm::hash_value(self.get().ptr));
})
.def_static("compress_unused_symbols",
- [](nb::list affineMaps, DefaultingPyMlirContext context) {
+ [](const nb::list &affineMaps,
+ DefaultingPyMlirContext context) {
SmallVector<MlirAffineMap> maps;
pyListToVector<PyAffineMap, MlirAffineMap>(
affineMaps, maps, "attempting to create an AffineMap");
@@ -734,7 +735,7 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
kDumpDocstring)
.def_static(
"get",
- [](intptr_t dimCount, intptr_t symbolCount, nb::list exprs,
+ [](intptr_t dimCount, intptr_t symbolCount, const nb::list &exprs,
DefaultingPyMlirContext context) {
SmallVector<MlirAffineExpr> affineExprs;
pyListToVector<PyAffineExpr, MlirAffineExpr>(
@@ -869,7 +870,8 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
.def("__eq__", [](PyIntegerSet &self,
PyIntegerSet &other) { return self == other; })
- .def("__eq__", [](PyIntegerSet &self, nb::object other) { return false; })
+ .def("__eq__",
+ [](PyIntegerSet &self, const nb::object &other) { return false; })
.def("__str__",
[](PyIntegerSet &self) {
PyPrintAccumulator printAccum;
@@ -898,7 +900,7 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
kDumpDocstring)
.def_static(
"get",
- [](intptr_t numDims, intptr_t numSymbols, nb::list exprs,
+ [](intptr_t numDims, intptr_t numSymbols, const nb::list &exprs,
std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
if (exprs.size() != eqFlags.size())
throw nb::value_error(
@@ -934,8 +936,9 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
nb::arg("context").none() = nb::none())
.def(
"get_replaced",
- [](PyIntegerSet &self, nb::list dimExprs, nb::list symbolExprs,
- intptr_t numResultDims, intptr_t numResultSymbols) {
+ [](PyIntegerSet &self, const nb::list &dimExprs,
+ const nb::list &symbolExprs, intptr_t numResultDims,
+ intptr_t numResultSymbols) {
if (static_cast<intptr_t>(dimExprs.size()) !=
mlirIntegerSetGetNumDims(self))
throw nb::value_error(
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index db84ee1..f2eafa7 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -505,7 +505,7 @@ public:
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](nb::list attributes, DefaultingPyMlirContext context) {
+ [](const nb::list &attributes, DefaultingPyMlirContext context) {
SmallVector<MlirAttribute> mlirAttributes;
mlirAttributes.reserve(nb::len(attributes));
for (auto attribute : attributes) {
@@ -530,7 +530,7 @@ public:
.def("__iter__", [](const PyArrayAttribute &arr) {
return PyArrayAttributeIterator(arr);
});
- c.def("__add__", [](PyArrayAttribute arr, nb::list extras) {
+ c.def("__add__", [](PyArrayAttribute arr, const nb::list &extras) {
std::vector<MlirAttribute> attributes;
intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
attributes.reserve(numOldElements + nb::len(extras));
@@ -708,7 +708,7 @@ public:
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](std::string value, DefaultingPyMlirContext context) {
+ [](const std::string &value, DefaultingPyMlirContext context) {
MlirAttribute attr =
mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
return PyFlatSymbolRefAttribute(context->getRef(), attr);
@@ -736,8 +736,8 @@ public:
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](std::string dialectNamespace, nb_buffer buffer, PyType &type,
- DefaultingPyMlirContext context) {
+ [](const std::string &dialectNamespace, const nb_buffer &buffer,
+ PyType &type, DefaultingPyMlirContext context) {
const nb_buffer_info bufferInfo = buffer.request();
intptr_t bufferSize = bufferInfo.size;
MlirAttribute attr = mlirOpaqueAttrGet(
@@ -775,7 +775,7 @@ public:
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](std::string value, DefaultingPyMlirContext context) {
+ [](const std::string &value, DefaultingPyMlirContext context) {
MlirAttribute attr =
mlirStringAttrGet(context->get(), toMlirStringRef(value));
return PyStringAttribute(context->getRef(), attr);
@@ -784,7 +784,7 @@ public:
"Gets a uniqued string attribute");
c.def_static(
"get",
- [](nb::bytes value, DefaultingPyMlirContext context) {
+ [](const nb::bytes &value, DefaultingPyMlirContext context) {
MlirAttribute attr =
mlirStringAttrGet(context->get(), toMlirStringRef(value));
return PyStringAttribute(context->getRef(), attr);
@@ -793,7 +793,7 @@ public:
"Gets a uniqued string attribute");
c.def_static(
"get_typed",
- [](PyType &type, std::string value) {
+ [](PyType &type, const std::string &value) {
MlirAttribute attr =
mlirStringAttrTypedGet(type, toMlirStringRef(value));
return PyStringAttribute(type.getContext(), attr);
@@ -826,7 +826,7 @@ public:
using PyConcreteAttribute::PyConcreteAttribute;
static PyDenseElementsAttribute
- getFromList(nb::list attributes, std::optional<PyType> explicitType,
+ getFromList(const nb::list &attributes, std::optional<PyType> explicitType,
DefaultingPyMlirContext contextWrapper) {
const size_t numAttributes = nb::len(attributes);
if (numAttributes == 0)
@@ -878,8 +878,8 @@ public:
}
static PyDenseElementsAttribute
- getFromBuffer(nb_buffer array, bool signless,
- std::optional<PyType> explicitType,
+ getFromBuffer(const nb_buffer &array, bool signless,
+ const std::optional<PyType> &explicitType,
std::optional<std::vector<int64_t>> explicitShape,
DefaultingPyMlirContext contextWrapper) {
// Request a contiguous view. In exotic cases, this will cause a copy.
@@ -894,8 +894,8 @@ public:
auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
MlirContext context = contextWrapper->get();
- MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType,
- explicitShape, context);
+ MlirAttribute attr = getAttributeFromBuffer(
+ view, signless, explicitType, std::move(explicitShape), context);
if (mlirAttributeIsNull(attr)) {
throw std::invalid_argument(
"DenseElementsAttr could not be constructed from the given buffer. "
@@ -1092,16 +1092,16 @@ private:
"when the type is not a shaped type.");
}
return *bulkLoadElementType;
- } else {
- MlirAttribute encodingAttr = mlirAttributeGetNull();
- return mlirRankedTensorTypeGet(shape.size(), shape.data(),
- *bulkLoadElementType, encodingAttr);
}
+ MlirAttribute encodingAttr = mlirAttributeGetNull();
+ return mlirRankedTensorTypeGet(shape.size(), shape.data(),
+ *bulkLoadElementType, encodingAttr);
}
static MlirAttribute getAttributeFromBuffer(
Py_buffer &view, bool signless, std::optional<PyType> explicitType,
- std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) {
+ const std::optional<std::vector<int64_t>> &explicitShape,
+ MlirContext &context) {
// Detect format codes that are suitable for bulk loading. This includes
// all byte aligned integer and floating point types up to 8 bytes.
// Notably, this excludes exotics types which do not have a direct
@@ -1125,7 +1125,7 @@ private:
bulkLoadElementType = mlirF16TypeGet(context);
} else if (format == "?") {
// i1
- // The i1 type needs to be bit-packed, so we will handle it seperately
+ // The i1 type needs to be bit-packed, so we will handle it separately
return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
context);
} else if (isSignedIntegerFormat(format)) {
@@ -1205,8 +1205,8 @@ private:
packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little");
nb_buffer_info pythonBuffer = nb::cast<nb_buffer>(packedBooleans).request();
- MlirType bitpackedType =
- getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
+ MlirType bitpackedType = getShapedType(mlirIntegerTypeGet(context, 1),
+ std::move(explicitShape), view);
assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
// Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
// packedBooleans, hence the MlirAttribute will remain valid even when
@@ -1443,9 +1443,9 @@ public:
using PyConcreteAttribute::PyConcreteAttribute;
static PyDenseResourceElementsAttribute
- getFromBuffer(nb_buffer buffer, const std::string &name, const PyType &type,
- std::optional<size_t> alignment, bool isMutable,
- DefaultingPyMlirContext contextWrapper) {
+ getFromBuffer(const nb_buffer &buffer, const std::string &name,
+ const PyType &type, std::optional<size_t> alignment,
+ bool isMutable, DefaultingPyMlirContext contextWrapper) {
if (!mlirTypeIsAShaped(type)) {
throw std::invalid_argument(
"Constructing a DenseResourceElementsAttr requires a ShapedType.");
@@ -1534,7 +1534,7 @@ public:
c.def("__len__", &PyDictAttribute::dunderLen);
c.def_static(
"get",
- [](nb::dict attributes, DefaultingPyMlirContext context) {
+ [](const nb::dict &attributes, DefaultingPyMlirContext context) {
SmallVector<MlirNamedAttribute> mlirNamedAttributes;
mlirNamedAttributes.reserve(attributes.size());
for (std::pair<nb::handle, nb::handle> it : attributes) {
@@ -1618,7 +1618,7 @@ public:
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](PyType value, DefaultingPyMlirContext context) {
+ [](const PyType &value, DefaultingPyMlirContext context) {
MlirAttribute attr = mlirTypeAttrGet(value.get());
return PyTypeAttribute(context->getRef(), attr);
},
@@ -1663,7 +1663,7 @@ public:
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](int64_t offset, const std::vector<int64_t> strides,
+ [](int64_t offset, const std::vector<int64_t> &strides,
DefaultingPyMlirContext ctx) {
MlirAttribute attr = mlirStridedLayoutAttrGet(
ctx->get(), offset, strides.size(), strides.data());
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 5feed95..2df2a73 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -20,11 +20,8 @@
#include "nanobind/nanobind.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/raw_ostream.h"
#include <optional>
-#include <system_error>
-#include <utility>
namespace nb = nanobind;
using namespace nb::literals;
@@ -70,6 +67,12 @@ Returns a new MlirModule or raises an MLIRError if the parsing fails.
See also: https://mlir.llvm.org/docs/LangRef/
)";
+static const char kModuleCAPICreate[] =
+ R"(Creates a Module from a MlirModule wrapped by a capsule (i.e. module._CAPIPtr).
+Note this returns a new object BUT _clear_mlir_module(module) must be called to
+prevent double-frees (of the underlying mlir::Module).
+)";
+
static const char kOperationCreateDocstring[] =
R"(Creates a new operation.
@@ -199,7 +202,7 @@ operations.
/// Helper for creating an @classmethod.
template <class Func, typename... Args>
-nb::object classmethod(Func f, Args... args) {
+static nb::object classmethod(Func f, Args... args) {
nb::object cf = nb::cpp_function(f, args...);
return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr())));
}
@@ -705,84 +708,6 @@ size_t PyMlirContext::getLiveCount() {
return getLiveContexts().size();
}
-size_t PyMlirContext::getLiveOperationCount() {
- nb::ft_lock_guard lock(liveOperationsMutex);
- return liveOperations.size();
-}
-
-std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
- std::vector<PyOperation *> liveObjects;
- nb::ft_lock_guard lock(liveOperationsMutex);
- for (auto &entry : liveOperations)
- liveObjects.push_back(entry.second.second);
- return liveObjects;
-}
-
-size_t PyMlirContext::clearLiveOperations() {
-
- LiveOperationMap operations;
- {
- nb::ft_lock_guard lock(liveOperationsMutex);
- std::swap(operations, liveOperations);
- }
- for (auto &op : operations)
- op.second.second->setInvalid();
- size_t numInvalidated = operations.size();
- return numInvalidated;
-}
-
-void PyMlirContext::clearOperation(MlirOperation op) {
- PyOperation *py_op;
- {
- nb::ft_lock_guard lock(liveOperationsMutex);
- auto it = liveOperations.find(op.ptr);
- if (it == liveOperations.end()) {
- return;
- }
- py_op = it->second.second;
- liveOperations.erase(it);
- }
- py_op->setInvalid();
-}
-
-void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
- typedef struct {
- PyOperation &rootOp;
- bool rootSeen;
- } callBackData;
- callBackData data{op.getOperation(), false};
- // Mark all ops below the op that the passmanager will be rooted
- // at (but not op itself - note the preorder) as invalid.
- MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
- void *userData) {
- callBackData *data = static_cast<callBackData *>(userData);
- if (LLVM_LIKELY(data->rootSeen))
- data->rootOp.getOperation().getContext()->clearOperation(op);
- else
- data->rootSeen = true;
- return MlirWalkResult::MlirWalkResultAdvance;
- };
- mlirOperationWalk(op.getOperation(), invalidatingCallback,
- static_cast<void *>(&data), MlirWalkPreOrder);
-}
-void PyMlirContext::clearOperationsInside(MlirOperation op) {
- PyOperationRef opRef = PyOperation::forOperation(getRef(), op);
- clearOperationsInside(opRef->getOperation());
-}
-
-void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
- MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
- void *userData) {
- PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
- contextRef->clearOperation(op);
- return MlirWalkResult::MlirWalkResultAdvance;
- };
- mlirOperationWalk(op.getOperation(), invalidatingCallback,
- &op.getOperation().getContext(), MlirWalkPreOrder);
-}
-
-size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
-
nb::object PyMlirContext::contextEnter(nb::object context) {
return PyThreadContextEntry::pushContext(context);
}
@@ -1154,38 +1079,23 @@ PyLocation &DefaultingPyLocation::resolve() {
PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
: BaseContextObject(std::move(contextRef)), module(module) {}
-PyModule::~PyModule() {
- nb::gil_scoped_acquire acquire;
- auto &liveModules = getContext()->liveModules;
- assert(liveModules.count(module.ptr) == 1 &&
- "destroying module not in live map");
- liveModules.erase(module.ptr);
- mlirModuleDestroy(module);
-}
+PyModule::~PyModule() { mlirModuleDestroy(module); }
PyModuleRef PyModule::forModule(MlirModule module) {
MlirContext context = mlirModuleGetContext(module);
PyMlirContextRef contextRef = PyMlirContext::forContext(context);
- nb::gil_scoped_acquire acquire;
- auto &liveModules = contextRef->liveModules;
- auto it = liveModules.find(module.ptr);
- if (it == liveModules.end()) {
- // Create.
- PyModule *unownedModule = new PyModule(std::move(contextRef), module);
- // Note that the default return value policy on cast is automatic_reference,
- // which does not take ownership (delete will not be called).
- // Just be explicit.
- nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
- unownedModule->handle = pyRef;
- liveModules[module.ptr] =
- std::make_pair(unownedModule->handle, unownedModule);
- return PyModuleRef(unownedModule, std::move(pyRef));
- }
- // Use existing.
- PyModule *existing = it->second.second;
- nb::object pyRef = nb::borrow<nb::object>(it->second.first);
- return PyModuleRef(existing, std::move(pyRef));
+ // Create.
+ PyModule *unownedModule = new PyModule(std::move(contextRef), module);
+ // Note that the default return value policy on cast is `automatic_reference`,
+ // which means "does not take ownership, does not call delete/dtor".
+ // We use `take_ownership`, which means "Python will call the C++ destructor
+ // and delete operator when the Python wrapper is garbage collected", because
+ // MlirModule actually wraps OwningOpRef<ModuleOp> (see mlirModuleCreateParse
+ // etc).
+ nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
+ unownedModule->handle = pyRef;
+ return PyModuleRef(unownedModule, std::move(pyRef));
}
nb::object PyModule::createFromCapsule(nb::object capsule) {
@@ -1210,15 +1120,11 @@ PyOperation::~PyOperation() {
// If the operation has already been invalidated there is nothing to do.
if (!valid)
return;
-
- // Otherwise, invalidate the operation and remove it from live map when it is
- // attached.
- if (isAttached()) {
- getContext()->clearOperation(*this);
- } else {
- // And destroy it when it is detached, i.e. owned by Python, in which case
- // all nested operations must be invalidated at removed from the live map as
- // well.
+ // Otherwise, invalidate the operation when it is attached.
+ if (isAttached())
+ setInvalid();
+ else {
+ // And destroy it when it is detached, i.e. owned by Python.
erase();
}
}
@@ -1255,35 +1161,15 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
MlirOperation operation,
nb::object parentKeepAlive) {
- nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
- auto &liveOperations = contextRef->liveOperations;
- auto it = liveOperations.find(operation.ptr);
- if (it == liveOperations.end()) {
- // Create.
- PyOperationRef result = createInstance(std::move(contextRef), operation,
- std::move(parentKeepAlive));
- liveOperations[operation.ptr] =
- std::make_pair(result.getObject(), result.get());
- return result;
- }
- // Use existing.
- PyOperation *existing = it->second.second;
- nb::object pyRef = nb::borrow<nb::object>(it->second.first);
- return PyOperationRef(existing, std::move(pyRef));
+ return createInstance(std::move(contextRef), operation,
+ std::move(parentKeepAlive));
}
PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
MlirOperation operation,
nb::object parentKeepAlive) {
- nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
- auto &liveOperations = contextRef->liveOperations;
- assert(liveOperations.count(operation.ptr) == 0 &&
- "cannot create detached operation that already exists");
- (void)liveOperations;
PyOperationRef created = createInstance(std::move(contextRef), operation,
std::move(parentKeepAlive));
- liveOperations[operation.ptr] =
- std::make_pair(created.getObject(), created.get());
created->attached = false;
return created;
}
@@ -1523,7 +1409,7 @@ nb::object PyOperation::create(std::string_view name,
llvm::ArrayRef<MlirValue> operands,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
- int regions, DefaultingPyLocation location,
+ int regions, PyLocation &location,
const nb::object &maybeIp, bool inferType) {
llvm::SmallVector<MlirType, 4> mlirResults;
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
@@ -1627,7 +1513,7 @@ nb::object PyOperation::create(std::string_view name,
if (!operation.ptr)
throw nb::value_error("Operation creation failed");
PyOperationRef created =
- PyOperation::createDetached(location->getContext(), operation);
+ PyOperation::createDetached(location.getContext(), operation);
maybeInsertOperation(created, maybeIp);
return created.getObject();
@@ -1655,7 +1541,7 @@ nb::object PyOperation::createOpView() {
void PyOperation::erase() {
checkValid();
- getContext()->clearOperationAndInside(*this);
+ setInvalid();
mlirOperationDestroy(operation);
}
@@ -1937,9 +1823,9 @@ nb::object PyOpView::buildGeneric(
std::optional<nb::list> resultTypeList, nb::list operandList,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
- std::optional<int> regions, DefaultingPyLocation location,
+ std::optional<int> regions, PyLocation &location,
const nb::object &maybeIp) {
- PyMlirContextRef context = location->getContext();
+ PyMlirContextRef context = location.getContext();
// Class level operation construction metadata.
// Operand and result segment specs are either none, which does no
@@ -2108,7 +1994,7 @@ nb::object PyOpView::buildGeneric(
// Delegate to create.
return PyOperation::create(name,
/*results=*/std::move(resultTypes),
- /*operands=*/std::move(operands),
+ /*operands=*/operands,
/*attributes=*/std::move(attributes),
/*successors=*/std::move(successors),
/*regions=*/*regions, location, maybeIp,
@@ -2789,6 +2675,156 @@ private:
PyOperationRef operation;
};
+// see
+// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h
+
+#ifndef _Py_CAST
+#define _Py_CAST(type, expr) ((type)(expr))
+#endif
+
+// Static inline functions should use _Py_NULL rather than using directly NULL
+// to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer,
+// _Py_NULL is defined as nullptr.
+#ifndef _Py_NULL
+#if (defined(__STDC_VERSION__) && __STDC_VERSION__ > 201710L) || \
+ (defined(__cplusplus) && __cplusplus >= 201103)
+#define _Py_NULL nullptr
+#else
+#define _Py_NULL NULL
+#endif
+#endif
+
+// Python 3.10.0a3
+#if PY_VERSION_HEX < 0x030A00A3
+
+// bpo-42262 added Py_XNewRef()
+#if !defined(Py_XNewRef)
+[[maybe_unused]] PyObject *_Py_XNewRef(PyObject *obj) {
+ Py_XINCREF(obj);
+ return obj;
+}
+#define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj))
+#endif
+
+// bpo-42262 added Py_NewRef()
+#if !defined(Py_NewRef)
+[[maybe_unused]] PyObject *_Py_NewRef(PyObject *obj) {
+ Py_INCREF(obj);
+ return obj;
+}
+#define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj))
+#endif
+
+#endif // Python 3.10.0a3
+
+// Python 3.9.0b1
+#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION)
+
+// bpo-40429 added PyThreadState_GetFrame()
+PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) {
+ assert(tstate != _Py_NULL && "expected tstate != _Py_NULL");
+ return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame));
+}
+
+// bpo-40421 added PyFrame_GetBack()
+PyFrameObject *PyFrame_GetBack(PyFrameObject *frame) {
+ assert(frame != _Py_NULL && "expected frame != _Py_NULL");
+ return _Py_CAST(PyFrameObject *, Py_XNewRef(frame->f_back));
+}
+
+// bpo-40421 added PyFrame_GetCode()
+PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) {
+ assert(frame != _Py_NULL && "expected frame != _Py_NULL");
+ assert(frame->f_code != _Py_NULL && "expected frame->f_code != _Py_NULL");
+ return _Py_CAST(PyCodeObject *, Py_NewRef(frame->f_code));
+}
+
+#endif // Python 3.9.0b1
+
+MlirLocation tracebackToLocation(MlirContext ctx) {
+ size_t framesLimit =
+ PyGlobals::get().getTracebackLoc().locTracebackFramesLimit();
+ // Use a thread_local here to avoid requiring a large amount of space.
+ thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames>
+ frames;
+ size_t count = 0;
+
+ nb::gil_scoped_acquire acquire;
+ PyThreadState *tstate = PyThreadState_GET();
+ PyFrameObject *next;
+ PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
+ // In the increment expression:
+ // 1. get the next prev frame;
+ // 2. decrement the ref count on the current frame (in order that it can get
+ // gc'd, along with any objects in its closure and etc);
+ // 3. set current = next.
+ for (; pyFrame != nullptr && count < framesLimit;
+ next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
+ PyCodeObject *code = PyFrame_GetCode(pyFrame);
+ auto fileNameStr =
+ nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
+ llvm::StringRef fileName(fileNameStr);
+ if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
+ continue;
+
+ // co_qualname and PyCode_Addr2Location added in py3.11
+#if PY_VERSION_HEX < 0x030B00F0
+ std::string name =
+ nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
+ llvm::StringRef funcName(name);
+ int startLine = PyFrame_GetLineNumber(pyFrame);
+ MlirLocation loc =
+ mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0);
+#else
+ std::string name =
+ nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
+ llvm::StringRef funcName(name);
+ int startLine, startCol, endLine, endCol;
+ int lasti = PyFrame_GetLasti(pyFrame);
+ if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
+ &endCol)) {
+ throw nb::python_error();
+ }
+ MlirLocation loc = mlirLocationFileLineColRangeGet(
+ ctx, wrap(fileName), startLine, startCol, endLine, endCol);
+#endif
+
+ frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc);
+ ++count;
+ }
+ // When the loop breaks (after the last iter), current frame (if non-null)
+ // is leaked without this.
+ Py_XDECREF(pyFrame);
+
+ if (count == 0)
+ return mlirLocationUnknownGet(ctx);
+
+ MlirLocation callee = frames[0];
+ assert(!mlirLocationIsNull(callee) && "expected non-null callee location");
+ if (count == 1)
+ return callee;
+
+ MlirLocation caller = frames[count - 1];
+ assert(!mlirLocationIsNull(caller) && "expected non-null caller location");
+ for (int i = count - 2; i >= 1; i--)
+ caller = mlirLocationCallSiteGet(frames[i], caller);
+
+ return mlirLocationCallSiteGet(callee, caller);
+}
+
+PyLocation
+maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
+ if (location.has_value())
+ return location.value();
+ if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled())
+ return DefaultingPyLocation::resolve();
+
+ PyMlirContext &ctx = DefaultingPyMlirContext::resolve();
+ MlirLocation mlirLoc = tracebackToLocation(ctx.get());
+ PyMlirContextRef ref = PyMlirContext::forContext(ctx.get());
+ return {ref, mlirLoc};
+}
+
} // namespace
//------------------------------------------------------------------------------
@@ -2876,14 +2912,6 @@ void mlir::python::populateIRCore(nb::module_ &m) {
PyMlirContextRef ref = PyMlirContext::forContext(self.get());
return ref.releaseObject();
})
- .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
- .def("_get_live_operation_objects",
- &PyMlirContext::getLiveOperationObjects)
- .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
- .def("_clear_live_operations_inside",
- nb::overload_cast<MlirOperation>(
- &PyMlirContext::clearOperationsInside))
- .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
.def("__enter__", &PyMlirContext::contextEnter)
@@ -3052,10 +3080,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
.def("__eq__", [](PyLocation &self, nb::object other) { return false; })
.def_prop_ro_static(
"current",
- [](nb::object & /*class*/) {
+ [](nb::object & /*class*/) -> std::optional<PyLocation *> {
auto *loc = PyThreadContextEntry::getDefaultLocation();
if (!loc)
- throw nb::value_error("No current Location");
+ return std::nullopt;
return loc;
},
"Gets the Location bound to the current thread or raises ValueError")
@@ -3201,7 +3229,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
//----------------------------------------------------------------------------
nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
- .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
+ .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
+ kModuleCAPICreate)
+ .def("_clear_mlir_module", &PyModule::clearMlirModule)
.def_static(
"parse",
[](const std::string &moduleAsm, DefaultingPyMlirContext context) {
@@ -3240,8 +3270,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
kModuleParseDocstring)
.def_static(
"create",
- [](DefaultingPyLocation loc) {
- MlirModule module = mlirModuleCreateEmpty(loc);
+ [](const std::optional<PyLocation> &loc) {
+ PyLocation pyLoc = maybeGetTracebackLocation(loc);
+ MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
return PyModule::forModule(module).releaseObject();
},
nb::arg("loc").none() = nb::none(), "Creates an empty module")
@@ -3280,7 +3311,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
// Defer to the operation's __str__.
return self.attr("operation").attr("__str__")();
},
- kOperationStrDunderDocstring);
+ kOperationStrDunderDocstring)
+ .def(
+ "__eq__",
+ [](PyModule &self, PyModule &other) {
+ return mlirModuleEqual(self.get(), other.get());
+ },
+ "other"_a);
//----------------------------------------------------------------------------
// Mapping of Operation.
@@ -3292,7 +3329,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
})
.def("__eq__",
[](PyOperationBase &self, PyOperationBase &other) {
- return &self.getOperation() == &other.getOperation();
+ return mlirOperationEqual(self.getOperation().get(),
+ other.getOperation().get());
})
.def("__eq__",
[](PyOperationBase &self, nb::object other) { return false; })
@@ -3442,6 +3480,14 @@ void mlir::python::populateIRCore(nb::module_ &m) {
return operation.createOpView();
},
"Detaches the operation from its parent block.")
+ .def_prop_ro(
+ "attached",
+ [](PyOperationBase &self) {
+ PyOperation &operation = self.getOperation();
+ operation.checkValid();
+ return operation.isAttached();
+ },
+ "Reports if the operation is attached to its parent block.")
.def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
.def("walk", &PyOperationBase::walk, nb::arg("callback"),
nb::arg("walk_order") = MlirWalkPostOrder);
@@ -3454,8 +3500,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
std::optional<std::vector<PyValue *>> operands,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors, int regions,
- DefaultingPyLocation location, const nb::object &maybeIp,
- bool inferType) {
+ const std::optional<PyLocation> &location,
+ const nb::object &maybeIp, bool inferType) {
// Unpack/validate operands.
llvm::SmallVector<MlirValue, 4> mlirOperands;
if (operands) {
@@ -3467,8 +3513,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
}
}
+ PyLocation pyLoc = maybeGetTracebackLocation(location);
return PyOperation::create(name, results, mlirOperands, attributes,
- successors, regions, location, maybeIp,
+ successors, regions, pyLoc, maybeIp,
inferType);
},
nb::arg("name"), nb::arg("results").none() = nb::none(),
@@ -3498,7 +3545,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
[](PyOperationBase &self) {
return PyOpSuccessors(self.getOperation().getRef());
},
- "Returns the list of Operation successors.");
+ "Returns the list of Operation successors.")
+ .def("_set_invalid", &PyOperation::setInvalid,
+ "Invalidate the operation.");
auto opViewClass =
nb::class_<PyOpView, PyOperationBase>(m, "OpView")
@@ -3512,12 +3561,14 @@ void mlir::python::populateIRCore(nb::module_ &m) {
std::optional<nb::list> resultTypeList, nb::list operandList,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
- std::optional<int> regions, DefaultingPyLocation location,
+ std::optional<int> regions,
+ const std::optional<PyLocation> &location,
const nb::object &maybeIp) {
+ PyLocation pyLoc = maybeGetTracebackLocation(location);
new (self) PyOpView(PyOpView::buildGeneric(
name, opRegionSpec, operandSegmentSpecObj,
resultSegmentSpecObj, resultTypeList, operandList,
- attributes, successors, regions, location, maybeIp));
+ attributes, successors, regions, pyLoc, maybeIp));
},
nb::arg("name"), nb::arg("opRegionSpec"),
nb::arg("operandSegmentSpecObj").none() = nb::none(),
@@ -3540,7 +3591,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
[](PyOperationBase &self) {
return PyOpSuccessors(self.getOperation().getRef());
},
- "Returns the list of Operation successors.");
+ "Returns the list of Operation successors.")
+ .def(
+ "_set_invalid",
+ [](PyOpView &self) { self.getOperation().setInvalid(); },
+ "Invalidate the operation.");
opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
@@ -3551,17 +3606,18 @@ void mlir::python::populateIRCore(nb::module_ &m) {
[](nb::handle cls, std::optional<nb::list> resultTypeList,
nb::list operandList, std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
- std::optional<int> regions, DefaultingPyLocation location,
+ std::optional<int> regions, std::optional<PyLocation> location,
const nb::object &maybeIp) {
std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
std::tuple<int, bool> opRegionSpec =
nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
+ PyLocation pyLoc = maybeGetTracebackLocation(location);
return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
resultSegmentSpec, resultTypeList,
operandList, attributes, successors,
- regions, location, maybeIp);
+ regions, pyLoc, maybeIp);
},
nb::arg("cls"), nb::arg("results").none() = nb::none(),
nb::arg("operands").none() = nb::none(),
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index e600f1b..0de2f17 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -13,9 +13,9 @@
#include "Globals.h"
#include "NanobindUtils.h"
+#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
namespace nb = nanobind;
using namespace mlir;
@@ -197,3 +197,71 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
// Not found and loading did not yield a registration.
return std::nullopt;
}
+
+bool PyGlobals::TracebackLoc::locTracebacksEnabled() {
+ nanobind::ft_lock_guard lock(mutex);
+ return locTracebackEnabled_;
+}
+
+void PyGlobals::TracebackLoc::setLocTracebacksEnabled(bool value) {
+ nanobind::ft_lock_guard lock(mutex);
+ locTracebackEnabled_ = value;
+}
+
+size_t PyGlobals::TracebackLoc::locTracebackFramesLimit() {
+ nanobind::ft_lock_guard lock(mutex);
+ return locTracebackFramesLimit_;
+}
+
+void PyGlobals::TracebackLoc::setLocTracebackFramesLimit(size_t value) {
+ nanobind::ft_lock_guard lock(mutex);
+ locTracebackFramesLimit_ = std::min(value, kMaxFrames);
+}
+
+void PyGlobals::TracebackLoc::registerTracebackFileInclusion(
+ const std::string &file) {
+ nanobind::ft_lock_guard lock(mutex);
+ auto reg = "^" + llvm::Regex::escape(file);
+ if (userTracebackIncludeFiles.insert(reg).second)
+ rebuildUserTracebackIncludeRegex = true;
+ if (userTracebackExcludeFiles.count(reg)) {
+ if (userTracebackExcludeFiles.erase(reg))
+ rebuildUserTracebackExcludeRegex = true;
+ }
+}
+
+void PyGlobals::TracebackLoc::registerTracebackFileExclusion(
+ const std::string &file) {
+ nanobind::ft_lock_guard lock(mutex);
+ auto reg = "^" + llvm::Regex::escape(file);
+ if (userTracebackExcludeFiles.insert(reg).second)
+ rebuildUserTracebackExcludeRegex = true;
+ if (userTracebackIncludeFiles.count(reg)) {
+ if (userTracebackIncludeFiles.erase(reg))
+ rebuildUserTracebackIncludeRegex = true;
+ }
+}
+
+bool PyGlobals::TracebackLoc::isUserTracebackFilename(
+ const llvm::StringRef file) {
+ nanobind::ft_lock_guard lock(mutex);
+ if (rebuildUserTracebackIncludeRegex) {
+ userTracebackIncludeRegex.assign(
+ llvm::join(userTracebackIncludeFiles, "|"));
+ rebuildUserTracebackIncludeRegex = false;
+ isUserTracebackFilenameCache.clear();
+ }
+ if (rebuildUserTracebackExcludeRegex) {
+ userTracebackExcludeRegex.assign(
+ llvm::join(userTracebackExcludeFiles, "|"));
+ rebuildUserTracebackExcludeRegex = false;
+ isUserTracebackFilenameCache.clear();
+ }
+ if (!isUserTracebackFilenameCache.contains(file)) {
+ std::string fileStr = file.str();
+ bool include = std::regex_search(fileStr, userTracebackIncludeRegex);
+ bool exclude = std::regex_search(fileStr, userTracebackExcludeRegex);
+ isUserTracebackFilenameCache[file] = include || !exclude;
+ }
+ return isUserTracebackFilenameCache[file];
+}
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 9c22dea..0cc0459 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -192,16 +192,6 @@ public:
PyMlirContext(const PyMlirContext &) = delete;
PyMlirContext(PyMlirContext &&) = delete;
- /// For the case of a python __init__ (nanobind::init) method, pybind11 is
- /// quite strict about needing to return a pointer that is not yet associated
- /// to an nanobind::object. Since the forContext() method acts like a pool,
- /// possibly returning a recycled context, it does not satisfy this need. The
- /// usual way in python to accomplish such a thing is to override __new__, but
- /// that is also not supported by pybind11. Instead, we use this entry
- /// point which always constructs a fresh context (which cannot alias an
- /// existing one because it is fresh).
- static PyMlirContext *createNewContextForInit();
-
/// Returns a context reference for the singleton PyMlirContext wrapper for
/// the given context.
static PyMlirContextRef forContext(MlirContext context);
@@ -228,40 +218,6 @@ public:
/// Gets the count of live context objects. Used for testing.
static size_t getLiveCount();
- /// Get a list of Python objects which are still in the live context map.
- std::vector<PyOperation *> getLiveOperationObjects();
-
- /// Gets the count of live operations associated with this context.
- /// Used for testing.
- size_t getLiveOperationCount();
-
- /// Clears the live operations map, returning the number of entries which were
- /// invalidated. To be used as a safety mechanism so that API end-users can't
- /// corrupt by holding references they shouldn't have accessed in the first
- /// place.
- size_t clearLiveOperations();
-
- /// Removes an operation from the live operations map and sets it invalid.
- /// This is useful for when some non-bindings code destroys the operation and
- /// the bindings need to made aware. For example, in the case when pass
- /// manager is run.
- ///
- /// Note that this does *NOT* clear the nested operations.
- void clearOperation(MlirOperation op);
-
- /// Clears all operations nested inside the given op using
- /// `clearOperation(MlirOperation)`.
- void clearOperationsInside(PyOperationBase &op);
- void clearOperationsInside(MlirOperation op);
-
- /// Clears the operaiton _and_ all operations inside using
- /// `clearOperation(MlirOperation)`.
- void clearOperationAndInside(PyOperationBase &op);
-
- /// Gets the count of live modules associated with this context.
- /// Used for testing.
- size_t getLiveModuleCount();
-
/// Enter and exit the context manager.
static nanobind::object contextEnter(nanobind::object context);
void contextExit(const nanobind::object &excType,
@@ -288,25 +244,6 @@ private:
static nanobind::ft_mutex live_contexts_mutex;
static LiveContextMap &getLiveContexts();
- // Interns all live modules associated with this context. Modules tracked
- // in this map are valid. When a module is invalidated, it is removed
- // from this map, and while it still exists as an instance, any
- // attempt to access it will raise an error.
- using LiveModuleMap =
- llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>;
- LiveModuleMap liveModules;
-
- // Interns all live operations associated with this context. Operations
- // tracked in this map are valid. When an operation is invalidated, it is
- // removed from this map, and while it still exists as an instance, any
- // attempt to access it will raise an error.
- using LiveOperationMap =
- llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>;
- nanobind::ft_mutex liveOperationsMutex;
-
- // Guarded by liveOperationsMutex in free-threading mode.
- LiveOperationMap liveOperations;
-
bool emitErrorDiagnostics = false;
MlirContext context;
@@ -558,8 +495,8 @@ class PyModule;
using PyModuleRef = PyObjectRef<PyModule>;
class PyModule : public BaseContextObject {
public:
- /// Returns a PyModule reference for the given MlirModule. This may return
- /// a pre-existing or new object.
+ /// Returns a PyModule reference for the given MlirModule. This always returns
+ /// a new object.
static PyModuleRef forModule(MlirModule module);
PyModule(PyModule &) = delete;
PyModule(PyMlirContext &&) = delete;
@@ -580,11 +517,12 @@ public:
nanobind::object getCapsule();
/// Creates a PyModule from the MlirModule wrapped by a capsule.
- /// Note that PyModule instances are uniqued, so the returned object
- /// may be a pre-existing object. Ownership of the underlying MlirModule
- /// is taken by calling this function.
+ /// Note this returns a new object BUT clearMlirModule() must be called to
+ /// prevent double-frees (of the underlying mlir::Module).
static nanobind::object createFromCapsule(nanobind::object capsule);
+ void clearMlirModule() { module = {nullptr}; }
+
private:
PyModule(PyMlirContextRef contextRef, MlirModule module);
MlirModule module;
@@ -722,8 +660,7 @@ public:
llvm::ArrayRef<MlirValue> operands,
std::optional<nanobind::dict> attributes,
std::optional<std::vector<PyBlock *>> successors, int regions,
- DefaultingPyLocation location, const nanobind::object &ip,
- bool inferType);
+ PyLocation &location, const nanobind::object &ip, bool inferType);
/// Creates an OpView suitable for this operation.
nanobind::object createOpView();
@@ -781,7 +718,7 @@ public:
nanobind::list operandList,
std::optional<nanobind::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
- std::optional<int> regions, DefaultingPyLocation location,
+ std::optional<int> regions, PyLocation &location,
const nanobind::object &maybeIp);
/// Construct an instance of a class deriving from OpView, bypassing its
@@ -1227,7 +1164,7 @@ public:
/// Note that PyAffineExpr instances are uniqued, so the returned object
/// may be a pre-existing object. Ownership of the underlying MlirAffineExpr
/// is taken by calling this function.
- static PyAffineExpr createFromCapsule(nanobind::object capsule);
+ static PyAffineExpr createFromCapsule(const nanobind::object &capsule);
PyAffineExpr add(const PyAffineExpr &other) const;
PyAffineExpr mul(const PyAffineExpr &other) const;
@@ -1254,7 +1191,7 @@ public:
/// Note that PyAffineMap instances are uniqued, so the returned object
/// may be a pre-existing object. Ownership of the underlying MlirAffineMap
/// is taken by calling this function.
- static PyAffineMap createFromCapsule(nanobind::object capsule);
+ static PyAffineMap createFromCapsule(const nanobind::object &capsule);
private:
MlirAffineMap affineMap;
@@ -1274,7 +1211,7 @@ public:
/// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule.
/// Note that PyIntegerSet instances may be uniqued, so the returned object
/// may be a pre-existing object. Integer sets are owned by the context.
- static PyIntegerSet createFromCapsule(nanobind::object capsule);
+ static PyIntegerSet createFromCapsule(const nanobind::object &capsule);
private:
MlirIntegerSet integerSet;
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index b11e3f7..a9b1259 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -963,7 +963,7 @@ public:
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](std::string dialectNamespace, std::string typeData,
+ [](const std::string &dialectNamespace, const std::string &typeData,
DefaultingPyMlirContext context) {
MlirType type = mlirOpaqueTypeGet(context->get(),
toMlirStringRef(dialectNamespace),
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 6f49431..278847e 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -6,7 +6,6 @@
//
//===----------------------------------------------------------------------===//
-
#include "Globals.h"
#include "IRModule.h"
#include "NanobindUtils.h"
@@ -44,7 +43,27 @@ NB_MODULE(_mlir, m) {
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
"operation_name"_a, "operation_class"_a, nb::kw_only(),
"replace"_a = false,
- "Testing hook for directly registering an operation");
+ "Testing hook for directly registering an operation")
+ .def("loc_tracebacks_enabled",
+ [](PyGlobals &self) {
+ return self.getTracebackLoc().locTracebacksEnabled();
+ })
+ .def("set_loc_tracebacks_enabled",
+ [](PyGlobals &self, bool enabled) {
+ self.getTracebackLoc().setLocTracebacksEnabled(enabled);
+ })
+ .def("set_loc_tracebacks_frame_limit",
+ [](PyGlobals &self, int n) {
+ self.getTracebackLoc().setLocTracebackFramesLimit(n);
+ })
+ .def("register_traceback_file_inclusion",
+ [](PyGlobals &self, const std::string &filename) {
+ self.getTracebackLoc().registerTracebackFileInclusion(filename);
+ })
+ .def("register_traceback_file_exclusion",
+ [](PyGlobals &self, const std::string &filename) {
+ self.getTracebackLoc().registerTracebackFileExclusion(filename);
+ });
// Aside from making the globals accessible to python, having python manage
// it is necessary to make sure it is destroyed (and releases its python
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 20017e2..88e28dc 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -39,7 +39,7 @@ public:
return nb::steal<nb::object>(mlirPythonPassManagerToCapsule(get()));
}
- static nb::object createFromCapsule(nb::object capsule) {
+ static nb::object createFromCapsule(const nb::object &capsule) {
MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
if (mlirPassManagerIsNull(rawPm))
throw nb::python_error();
@@ -159,11 +159,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
"ValueError if the pipeline can't be parsed.")
.def(
"run",
- [](PyPassManager &passManager, PyOperationBase &op,
- bool invalidateOps) {
- if (invalidateOps) {
- op.getOperation().getContext()->clearOperationsInside(op);
- }
+ [](PyPassManager &passManager, PyOperationBase &op) {
// Actually run the pass manager.
PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
MlirLogicalResult status = mlirPassManagerRunOnOp(
@@ -172,7 +168,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
throw MLIRError("Failure while executing pass pipeline",
errors.take());
},
- "operation"_a, "invalidate_ops"_a = true,
+ "operation"_a,
"Run the pass manager on the provided operation, raising an "
"MLIRError on failure.")
.def(
diff --git a/mlir/lib/Bindings/Python/RegisterEverything.cpp b/mlir/lib/Bindings/Python/RegisterEverything.cpp
index 3ba42be..3edcb09 100644
--- a/mlir/lib/Bindings/Python/RegisterEverything.cpp
+++ b/mlir/lib/Bindings/Python/RegisterEverything.cpp
@@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir-c/RegisterEverything.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
NB_MODULE(_mlirRegisterEverything, m) {
m.doc() = "MLIR All Upstream Dialects, Translations and Passes Registration";
diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
index f9b0fed..920bca8 100644
--- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp
+++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
@@ -67,7 +67,6 @@ static void populateTransformInterpreterSubmodule(nb::module_ &m) {
// root. This is awkward, but we don't have access to PyMlirContext
// object here otherwise.
nb::object obj = nb::cast(payloadRoot);
- obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot);
MlirLogicalResult result = mlirTransformApplyNamedSequence(
payloadRoot, transformRoot, transformModule, options.options);
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 6ebeac5..eacb936 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -19,6 +19,7 @@
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
@@ -150,8 +151,7 @@ public:
/// Backpatch a byte in the result buffer at the given offset.
void patchByte(uint64_t offset, uint8_t value, StringLiteral desc) {
- LLVM_DEBUG(llvm::dbgs() << "patchByte(" << offset << ',' << uint64_t(value)
- << ")\t" << desc << '\n');
+ LDBG() << "patchByte(" << offset << ',' << uint64_t(value) << ")\t" << desc;
assert(offset < size() && offset >= prevResultSize &&
"cannot patch previously emitted data");
currentResult[offset - prevResultSize] = value;
@@ -160,8 +160,7 @@ public:
/// Emit the provided blob of data, which is owned by the caller and is
/// guaranteed to not die before the end of the bytecode process.
void emitOwnedBlob(ArrayRef<uint8_t> data, StringLiteral desc) {
- LLVM_DEBUG(llvm::dbgs()
- << "emitOwnedBlob(" << data.size() << "b)\t" << desc << '\n');
+ LDBG() << "emitOwnedBlob(" << data.size() << "b)\t" << desc;
// Push the current buffer before adding the provided data.
appendResult(std::move(currentResult));
appendOwnedResult(data);
@@ -209,15 +208,13 @@ public:
/// Emit a single byte.
template <typename T>
void emitByte(T byte, StringLiteral desc) {
- LLVM_DEBUG(llvm::dbgs()
- << "emitByte(" << uint64_t(byte) << ")\t" << desc << '\n');
+ LDBG() << "emitByte(" << uint64_t(byte) << ")\t" << desc;
currentResult.push_back(static_cast<uint8_t>(byte));
}
/// Emit a range of bytes.
void emitBytes(ArrayRef<uint8_t> bytes, StringLiteral desc) {
- LLVM_DEBUG(llvm::dbgs()
- << "emitBytes(" << bytes.size() << "b)\t" << desc << '\n');
+ LDBG() << "emitBytes(" << bytes.size() << "b)\t" << desc;
llvm::append_range(currentResult, bytes);
}
@@ -229,7 +226,7 @@ public:
/// additional bytes, provide the value of the integer encoded in
/// little-endian order.
void emitVarInt(uint64_t value, StringLiteral desc) {
- LLVM_DEBUG(llvm::dbgs() << "emitVarInt(" << value << ")\t" << desc << '\n');
+ LDBG() << "emitVarInt(" << value << ")\t" << desc;
// In the most common case, the value can be represented in a single byte.
// Given how hot this case is, explicitly handle that here.
diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
index 306cebd..2dbb993 100644
--- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
+++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
@@ -68,6 +68,10 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths,
return wrap(jitOrError->release());
}
+extern "C" void mlirExecutionEngineInitialize(MlirExecutionEngine jit) {
+ unwrap(jit)->initialize();
+}
+
extern "C" void mlirExecutionEngineDestroy(MlirExecutionEngine jit) {
delete (unwrap(jit));
}
@@ -106,9 +110,8 @@ extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit,
void *sym) {
unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
llvm::orc::SymbolMap symbolMap;
- symbolMap[interner(unwrap(name))] =
- { llvm::orc::ExecutorAddr::fromPtr(sym),
- llvm::JITSymbolFlags::Exported };
+ symbolMap[interner(unwrap(name))] = {llvm::orc::ExecutorAddr::fromPtr(sym),
+ llvm::JITSymbolFlags::Exported};
return symbolMap;
});
}
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 9d8554a..f5f4ed3 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -465,10 +465,6 @@ MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc,
return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType)));
}
-MlirType mlirUnrankedTensorTypeGetElementType(MlirType type) {
- return wrap(llvm::cast<UnrankedTensorType>(unwrap(type)).getElementType());
-}
-
//===----------------------------------------------------------------------===//
// Ranked / Unranked MemRef type.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 8491553..c7069f0 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -465,6 +465,10 @@ MlirModule mlirModuleFromOperation(MlirOperation op) {
return wrap(dyn_cast<ModuleOp>(unwrap(op)));
}
+bool mlirModuleEqual(MlirModule lhs, MlirModule rhs) {
+ return unwrap(lhs) == unwrap(rhs);
+}
+
//===----------------------------------------------------------------------===//
// Operation state API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt
index 191b5ab6..91ed05f 100644
--- a/mlir/lib/CMakeLists.txt
+++ b/mlir/lib/CMakeLists.txt
@@ -13,6 +13,7 @@ add_subdirectory(Parser)
add_subdirectory(Pass)
add_subdirectory(Query)
add_subdirectory(Reducer)
+add_subdirectory(Remark)
add_subdirectory(Rewrite)
add_subdirectory(Support)
add_subdirectory(TableGen)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 64720bf..203790e 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
@@ -1876,6 +1877,54 @@ struct AMDGPUSwizzleBitModeLowering
}
};
+struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (chipset < kGfx950)
+ return op->emitOpError("permlane_swap is only supported on gfx950+");
+
+ Location loc = op.getLoc();
+ Type i32 = rewriter.getI32Type();
+ Value src = adaptor.getSrc();
+ unsigned rowLength = op.getRowLength();
+ bool fi = op.getFetchInactive();
+ bool boundctrl = op.getBoundCtrl();
+
+ SmallVector<Value> decomposed =
+ LLVM::decomposeValue(rewriter, loc, src, i32);
+
+ SmallVector<Value> permuted;
+ for (Value v : decomposed) {
+ Value res;
+ Type i32pair = LLVM::LLVMStructType::getLiteral(
+ rewriter.getContext(), {v.getType(), v.getType()});
+
+ if (rowLength == 16)
+ res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
+ boundctrl);
+ else if (rowLength == 32)
+ res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
+ boundctrl);
+ else
+ llvm_unreachable("unsupported row length");
+
+ Value vdstNew = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
+ permuted.emplace_back(vdstNew);
+ }
+
+ Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
struct ConvertAMDGPUToROCDLPass
: public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
using Base::Base;
@@ -1944,6 +1993,6 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
- TransposeLoadOpLowering>(converter, chipset);
+ TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 515fe5c..b68933d 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -610,16 +610,19 @@ public:
? rewriter.getIntegerAttr(arithmeticType, 0)
: rewriter.getIndexAttr(0)));
- emitc::ExpressionOp ternary = emitc::ExpressionOp::create(
- rewriter, op.getLoc(), arithmeticType, /*do_not_inline=*/false);
- Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
+ emitc::ExpressionOp ternary =
+ emitc::ExpressionOp::create(rewriter, op.getLoc(), arithmeticType,
+ ValueRange({lhs, rhs, excessCheck, poison}),
+ /*do_not_inline=*/false);
+ Block &bodyBlock = ternary.createBody();
auto currentPoint = rewriter.getInsertionPoint();
rewriter.setInsertionPointToStart(&bodyBlock);
Value arithmeticResult =
- EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
- Value resultOrPoison =
- emitc::ConditionalOp::create(rewriter, op.getLoc(), arithmeticType,
- excessCheck, arithmeticResult, poison);
+ EmitCOp::create(rewriter, op.getLoc(), arithmeticType,
+ bodyBlock.getArgument(0), bodyBlock.getArgument(1));
+ Value resultOrPoison = emitc::ConditionalOp::create(
+ rewriter, op.getLoc(), arithmeticType, bodyBlock.getArgument(2),
+ arithmeticResult, bodyBlock.getArgument(3));
emitc::YieldOp::create(rewriter, op.getLoc(), resultOrPoison);
rewriter.setInsertionPoint(op->getBlock(), currentPoint);
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 18e857c..cb0c829 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -238,6 +238,16 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
ConversionPatternRewriter &rewriter) const override;
};
+struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ using Adaptor =
+ typename ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor;
+
+ LogicalResult
+ matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -480,6 +490,32 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
}
//===----------------------------------------------------------------------===//
+// SelectOpOneToNLowering
+//===----------------------------------------------------------------------===//
+
+/// Pattern for arith.select where the true/false values lower to multiple
+/// SSA values (1:N conversion). This pattern generates multiple arith.select
+/// than can be lowered by the 1:1 arith.select pattern.
+LogicalResult SelectOpOneToNLowering::matchAndRewrite(
+ arith::SelectOp op, Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // In case of a 1:1 conversion, the 1:1 pattern will match.
+ if (llvm::hasSingleElement(adaptor.getTrueValue()))
+ return rewriter.notifyMatchFailure(
+ op, "not a 1:N conversion, 1:1 pattern will match");
+ if (!op.getCondition().getType().isInteger(1))
+ return rewriter.notifyMatchFailure(op,
+ "non-i1 conditions are not supported");
+ SmallVector<Value> results;
+ for (auto [trueValue, falseValue] :
+ llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue()))
+ results.push_back(arith::SelectOp::create(
+ rewriter, op.getLoc(), op.getCondition(), trueValue, falseValue));
+ rewriter.replaceOpWithMultiple(op, {results});
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
@@ -587,6 +623,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
RemSIOpLowering,
RemUIOpLowering,
SelectOpLowering,
+ SelectOpOneToNLowering,
ShLIOpLowering,
ShRSIOpLowering,
ShRUIOpLowering,
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index e28d5122..c69ede9 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -333,7 +333,7 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType,
tileLoadOp.getBase(),
memrefIndices, maskOp1D,
- /*passthru=*/pad1DOp);
+ /*passthrough=*/pad1DOp);
// Create 'arm_sme.insert_tile_slice' to insert slice into tile.
auto insertSlice = arm_sme::InsertTileSliceOp::create(
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 785cb82..71986f8 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -50,6 +50,7 @@ add_subdirectory(NVVMToLLVM)
add_subdirectory(OpenACCToSCF)
add_subdirectory(OpenMPToLLVM)
add_subdirectory(PDLToPDLInterp)
+add_subdirectory(PtrToLLVM)
add_subdirectory(ReconcileUnrealizedCasts)
add_subdirectory(SCFToControlFlow)
add_subdirectory(SCFToEmitC)
@@ -68,6 +69,7 @@ add_subdirectory(TosaToSCF)
add_subdirectory(TosaToTensor)
add_subdirectory(UBToLLVM)
add_subdirectory(UBToSPIRV)
+add_subdirectory(VectorToAMX)
add_subdirectory(VectorToArmSME)
add_subdirectory(VectorToGPU)
add_subdirectory(VectorToLLVM)
@@ -75,3 +77,4 @@ add_subdirectory(VectorToSCF)
add_subdirectory(VectorToSPIRV)
add_subdirectory(VectorToXeGPU)
add_subdirectory(XeVMToLLVM)
+add_subdirectory(XeGPUToXeVM)
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 35ad99c..7a3a7fd 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -56,22 +56,30 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> {
private:
std::string funcName;
};
+
+// Rewrite complex.pow(z, w) -> complex.exp(w * complex.log(z))
+struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
+ using OpRewritePattern<complex::PowOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(complex::PowOp op,
+ PatternRewriter &rewriter) const final {
+ Location loc = op.getLoc();
+ Value logBase = rewriter.create<complex::LogOp>(loc, op.getLhs());
+ Value mul = rewriter.create<complex::MulOp>(loc, op.getRhs(), logBase);
+ Value exp = rewriter.create<complex::ExpOp>(loc, mul);
+ rewriter.replaceOp(op, exp);
+ return success();
+ }
+};
} // namespace
void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
RewritePatternSet &patterns) {
+ patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext());
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
patterns.getContext(), "__ocml_cabs_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
patterns.getContext(), "__ocml_cabs_f64");
- patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>(
- patterns.getContext(), "__ocml_carg_f32");
- patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>(
- patterns.getContext(), "__ocml_carg_f64");
- patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>(
- patterns.getContext(), "__ocml_conj_f32");
- patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>(
- patterns.getContext(), "__ocml_conj_f64");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>(
patterns.getContext(), "__ocml_ccos_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>(
@@ -84,10 +92,6 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
patterns.getContext(), "__ocml_clog_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>(
patterns.getContext(), "__ocml_clog_f64");
- patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>(
- patterns.getContext(), "__ocml_cpow_f32");
- patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>(
- patterns.getContext(), "__ocml_cpow_f64");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>(
patterns.getContext(), "__ocml_csin_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>(
@@ -122,10 +126,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<func::FuncDialect>();
- target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp,
- complex::CosOp, complex::ExpOp, complex::LogOp,
- complex::PowOp, complex::SinOp, complex::SqrtOp,
- complex::TanOp, complex::TanhOp>();
+ target.addLegalOp<complex::MulOp>();
+ target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
+ complex::LogOp, complex::PowOp, complex::SinOp,
+ complex::SqrtOp, complex::TanOp, complex::TanhOp>();
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index ff6d369..798d8b0 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -125,22 +125,33 @@ static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
return rewriter.applySignatureConversion(block, *conversion, converter);
}
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+ SmallVector<Value> result;
+ for (const ValueRange &vals : values)
+ llvm::append_range(result, vals);
+ return result;
+}
+
/// Convert the destination block signature (if necessary) and lower the branch
/// op to llvm.br.
struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
+ using Adaptor =
+ typename ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor;
LogicalResult
- matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
+ matchAndRewrite(cf::BranchOp op, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
FailureOr<Block *> convertedBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
- TypeRange(adaptor.getOperands()));
+ TypeRange(ValueRange(flattenedAdaptor)));
if (failed(convertedBlock))
return failure();
DictionaryAttr attrs = op->getAttrDictionary();
Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
- op, adaptor.getOperands(), *convertedBlock);
+ op, flattenedAdaptor, *convertedBlock);
// TODO: We should not just forward all attributes like that. But there are
// existing Flang tests that depend on this behavior.
newOp->setAttrs(attrs);
@@ -152,29 +163,37 @@ struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
/// branch op to llvm.cond_br.
struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
+ using Adaptor =
+ typename ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor;
LogicalResult
- matchAndRewrite(cf::CondBranchOp op,
- typename cf::CondBranchOp::Adaptor adaptor,
+ matchAndRewrite(cf::CondBranchOp op, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> flattenedAdaptorTrue =
+ flattenValues(adaptor.getTrueDestOperands());
+ SmallVector<Value> flattenedAdaptorFalse =
+ flattenValues(adaptor.getFalseDestOperands());
+ if (!llvm::hasSingleElement(adaptor.getCondition()))
+ return rewriter.notifyMatchFailure(op,
+ "expected single element condition");
FailureOr<Block *> convertedTrueBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
- TypeRange(adaptor.getTrueDestOperands()));
+ TypeRange(ValueRange(flattenedAdaptorTrue)));
if (failed(convertedTrueBlock))
return failure();
FailureOr<Block *> convertedFalseBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
- TypeRange(adaptor.getFalseDestOperands()));
+ TypeRange(ValueRange(flattenedAdaptorFalse)));
if (failed(convertedFalseBlock))
return failure();
- DictionaryAttr attrs = op->getAttrDictionary();
+ DictionaryAttr attrs = op->getDiscardableAttrDictionary();
auto newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
- op, adaptor.getCondition(), adaptor.getTrueDestOperands(),
- adaptor.getFalseDestOperands(), op.getBranchWeightsAttr(),
+ op, llvm::getSingleElement(adaptor.getCondition()),
+ flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
*convertedTrueBlock, *convertedFalseBlock);
// TODO: We should not just forward all attributes like that. But there are
// existing Flang tests that depend on this behavior.
- newOp->setAttrs(attrs);
+ newOp->setDiscardableAttrs(attrs);
return success();
}
};
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
index ed5d6d4..764ad2e 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
@@ -14,6 +14,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/DebugLog.h"
#include <memory>
#define DEBUG_TYPE "convert-to-llvm"
@@ -31,7 +32,8 @@ namespace {
class ConvertToLLVMPassInterface {
public:
ConvertToLLVMPassInterface(MLIRContext *context,
- ArrayRef<std::string> filterDialects);
+ ArrayRef<std::string> filterDialects,
+ bool allowPatternRollback = true);
virtual ~ConvertToLLVMPassInterface() = default;
/// Get the dependent dialects used by `convert-to-llvm`.
@@ -60,6 +62,9 @@ protected:
MLIRContext *context;
/// List of dialects names to use as filters.
ArrayRef<std::string> filterDialects;
+ /// An experimental flag to disallow pattern rollback. This is more efficient
+ /// but not supported by all lowering patterns.
+ bool allowPatternRollback;
};
/// This DialectExtension can be attached to the context, which will invoke the
@@ -75,13 +80,13 @@ public:
void apply(MLIRContext *context,
MutableArrayRef<Dialect *> dialects) const final {
- LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM extension load\n");
+ LDBG() << "Convert to LLVM extension load";
for (Dialect *dialect : dialects) {
auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
if (!iface)
continue;
- LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM found dialect interface for "
- << dialect->getNamespace() << "\n");
+ LDBG() << "Convert to LLVM found dialect interface for "
+ << dialect->getNamespace();
iface->loadDependentDialects(context);
}
}
@@ -128,7 +133,9 @@ struct StaticConvertToLLVM : public ConvertToLLVMPassInterface {
/// Apply the conversion driver.
LogicalResult transform(Operation *op, AnalysisManager manager) const final {
- if (failed(applyPartialConversion(op, *target, *patterns)))
+ ConversionConfig config;
+ config.allowPatternRollback = allowPatternRollback;
+ if (failed(applyPartialConversion(op, *target, *patterns, config)))
return failure();
return success();
}
@@ -179,7 +186,9 @@ struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface {
patterns);
// Apply the conversion.
- if (failed(applyPartialConversion(op, target, std::move(patterns))))
+ ConversionConfig config;
+ config.allowPatternRollback = allowPatternRollback;
+ if (failed(applyPartialConversion(op, target, std::move(patterns), config)))
return failure();
return success();
}
@@ -206,9 +215,11 @@ public:
std::shared_ptr<ConvertToLLVMPassInterface> impl;
// Choose the pass implementation.
if (useDynamic)
- impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects);
+ impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects,
+ allowPatternRollback);
else
- impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects);
+ impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects,
+ allowPatternRollback);
if (failed(impl->initialize()))
return failure();
this->impl = impl;
@@ -228,8 +239,10 @@ public:
//===----------------------------------------------------------------------===//
ConvertToLLVMPassInterface::ConvertToLLVMPassInterface(
- MLIRContext *context, ArrayRef<std::string> filterDialects)
- : context(context), filterDialects(filterDialects) {}
+ MLIRContext *context, ArrayRef<std::string> filterDialects,
+ bool allowPatternRollback)
+ : context(context), filterDialects(filterDialects),
+ allowPatternRollback(allowPatternRollback) {}
void ConvertToLLVMPassInterface::getDependentDialects(
DialectRegistry &registry) {
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 67bb1c1..42c76ed 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -527,19 +527,21 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
using Super = CallOpInterfaceLowering<CallOpType>;
using Base = ConvertOpToLLVMPattern<CallOpType>;
+ using Adaptor = typename ConvertOpToLLVMPattern<CallOpType>::OneToNOpAdaptor;
- LogicalResult matchAndRewriteImpl(CallOpType callOp,
- typename CallOpType::Adaptor adaptor,
+ LogicalResult matchAndRewriteImpl(CallOpType callOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter,
bool useBarePtrCallConv = false) const {
// Pack the result types into a struct.
Type packedResult = nullptr;
+ SmallVector<SmallVector<Type>> groupedResultTypes;
unsigned numResults = callOp.getNumResults();
auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
-
+ int64_t numConvertedTypes = 0;
if (numResults != 0) {
if (!(packedResult = this->getTypeConverter()->packFunctionResults(
- resultTypes, useBarePtrCallConv)))
+ resultTypes, useBarePtrCallConv, &groupedResultTypes,
+ &numConvertedTypes)))
return failure();
}
@@ -565,34 +567,64 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
static_cast<int32_t>(promoted.size()), 0};
newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
- SmallVector<Value, 4> results;
- if (numResults < 2) {
- // If < 2 results, packing did not do anything and we can just return.
- results.append(newOp.result_begin(), newOp.result_end());
- } else {
- // Otherwise, it had been converted to an operation producing a structure.
- // Extract individual results from the structure and return them as list.
- results.reserve(numResults);
- for (unsigned i = 0; i < numResults; ++i) {
- results.push_back(LLVM::ExtractValueOp::create(
- rewriter, callOp.getLoc(), newOp->getResult(0), i));
+ // Helper function that extracts an individual result from the return value
+ // of the new call op. llvm.call ops support only 0 or 1 result. In case of
+ // 2 or more results, the results are packed into a structure.
+ //
+ // The new call op may have more than 2 results because:
+ // a. The original call op has more than 2 results.
+ // b. An original op result type-converted to more than 1 result.
+ auto getUnpackedResult = [&](unsigned i) -> Value {
+ assert(numConvertedTypes > 0 && "convert op has no results");
+ if (numConvertedTypes == 1) {
+ assert(i == 0 && "out of bounds: converted op has only one result");
+ return newOp->getResult(0);
}
+ // Results have been converted to a structure. Extract individual results
+ // from the structure.
+ return LLVM::ExtractValueOp::create(rewriter, callOp.getLoc(),
+ newOp->getResult(0), i);
+ };
+
+ // Group the results into a vector of vectors, such that it is clear which
+ // original op result is replaced with which range of values. (In case of a
+ // 1:N conversion, there can be multiple replacements for a single result.)
+ SmallVector<SmallVector<Value>> results;
+ results.reserve(numResults);
+ unsigned counter = 0;
+ for (unsigned i = 0; i < numResults; ++i) {
+ SmallVector<Value> &group = results.emplace_back();
+ for (unsigned j = 0, e = groupedResultTypes[i].size(); j < e; ++j)
+ group.push_back(getUnpackedResult(counter++));
}
- if (useBarePtrCallConv) {
- // For the bare-ptr calling convention, promote memref results to
- // descriptors.
- assert(results.size() == resultTypes.size() &&
- "The number of arguments and types doesn't match");
- this->getTypeConverter()->promoteBarePtrsToDescriptors(
- rewriter, callOp.getLoc(), resultTypes, results);
- } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
- resultTypes, results,
- /*toDynamic=*/false))) {
- return failure();
+ // Special handling for MemRef types.
+ for (unsigned i = 0; i < numResults; ++i) {
+ Type origType = resultTypes[i];
+ auto memrefType = dyn_cast<MemRefType>(origType);
+ auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(origType);
+ if (useBarePtrCallConv && memrefType) {
+ // For the bare-ptr calling convention, promote memref results to
+ // descriptors.
+ assert(results[i].size() == 1 && "expected one converted result");
+ results[i].front() = MemRefDescriptor::fromStaticShape(
+ rewriter, callOp.getLoc(), *this->getTypeConverter(), memrefType,
+ results[i].front());
+ }
+ if (unrankedMemrefType) {
+ assert(!useBarePtrCallConv && "unranked memref is not supported in the "
+ "bare-ptr calling convention");
+ assert(results[i].size() == 1 && "expected one converted result");
+ Value desc = this->copyUnrankedDescriptor(
+ rewriter, callOp.getLoc(), unrankedMemrefType, results[i].front(),
+ /*toDynamic=*/false);
+ if (!desc)
+ return failure();
+ results[i].front() = desc;
+ }
}
- rewriter.replaceOp(callOp, results);
+ rewriter.replaceOpWithMultiple(callOp, results);
return success();
}
};
@@ -606,7 +638,7 @@ public:
symbolTables(symbolTables) {}
LogicalResult
- matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
+ matchAndRewrite(func::CallOp callOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
bool useBarePtrCallConv = false;
if (getTypeConverter()->getOptions().useBarePtrCallConv) {
@@ -636,7 +668,7 @@ struct CallIndirectOpLowering
using Super::Super;
LogicalResult
- matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
+ matchAndRewrite(func::CallIndirectOp callIndirectOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
}
@@ -679,41 +711,50 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
+ matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- unsigned numArguments = op.getNumOperands();
SmallVector<Value, 4> updatedOperands;
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
bool useBarePtrCallConv =
shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
- if (useBarePtrCallConv) {
- // For the bare-ptr calling convention, extract the aligned pointer to
- // be returned from the memref descriptor.
- for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
- Type oldTy = std::get<0>(it).getType();
- Value newOperand = std::get<1>(it);
- if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
- cast<BaseMemRefType>(oldTy))) {
- MemRefDescriptor memrefDesc(newOperand);
- newOperand = memrefDesc.allocatedPtr(rewriter, loc);
- } else if (isa<UnrankedMemRefType>(oldTy)) {
+
+ for (auto [oldOperand, newOperands] :
+ llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
+ Type oldTy = oldOperand.getType();
+ if (auto memRefType = dyn_cast<MemRefType>(oldTy)) {
+ assert(newOperands.size() == 1 && "expected one converted result");
+ if (useBarePtrCallConv &&
+ getTypeConverter()->canConvertToBarePtr(memRefType)) {
+ // For the bare-ptr calling convention, extract the aligned pointer to
+ // be returned from the memref descriptor.
+ MemRefDescriptor memrefDesc(newOperands.front());
+ updatedOperands.push_back(memrefDesc.allocatedPtr(rewriter, loc));
+ continue;
+ }
+ } else if (auto unrankedMemRefType =
+ dyn_cast<UnrankedMemRefType>(oldTy)) {
+ assert(newOperands.size() == 1 && "expected one converted result");
+ if (useBarePtrCallConv) {
// Unranked memref is not supported in the bare pointer calling
// convention.
return failure();
}
- updatedOperands.push_back(newOperand);
+ Value updatedDesc =
+ copyUnrankedDescriptor(rewriter, loc, unrankedMemRefType,
+ newOperands.front(), /*toDynamic=*/true);
+ if (!updatedDesc)
+ return failure();
+ updatedOperands.push_back(updatedDesc);
+ continue;
}
- } else {
- updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
- (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
- updatedOperands,
- /*toDynamic=*/true);
+
+ llvm::append_range(updatedOperands, newOperands);
}
// If ReturnOp has 0 or 1 operand, create it and return immediately.
- if (numArguments <= 1) {
+ if (updatedOperands.size() <= 1) {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
op, TypeRange(), updatedOperands, op->getAttrs());
return success();
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 3cfbd89..e516118 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -532,6 +532,9 @@ void GpuToLLVMConversionPass::runOnOperation() {
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
vector::populateVectorTransferLoweringPatterns(patterns,
/*maxTransferRank=*/1);
+ // Transform N-D vector.from_elements to 1-D vector.from_elements before
+ // conversion.
+ vector::populateVectorFromElementsLoweringPatterns(patterns);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 317bfc2..93e370d 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -27,6 +27,7 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -369,6 +370,9 @@ struct LowerGpuOpsToNVVMOpsPass final
{
RewritePatternSet patterns(m.getContext());
populateGpuRewritePatterns(patterns);
+ // Transform N-D vector.from_elements to 1-D vector.from_elements before
+ // conversion.
+ vector::populateVectorFromElementsLoweringPatterns(patterns);
if (failed(applyPatternsGreedily(m, std::move(patterns))))
return signalPassFailure();
}
@@ -394,7 +398,7 @@ struct LowerGpuOpsToNVVMOpsPass final
if (!allowedDialectsSet.empty() && !allowed)
continue;
- auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
+ auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
if (!iface) {
// Error out if dialect was explicily specified but doesn't implement
// conversion interface.
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index d22364e..8994905 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -79,17 +79,30 @@ static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
return canBeBare;
}
-static Value getLaneId(ConversionPatternRewriter &rewriter, Location loc,
- const unsigned indexBitwidth) {
+static Value getLaneId(RewriterBase &rewriter, Location loc) {
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
- Value mbcntLo = ROCDL::MbcntLoOp::create(rewriter, loc, int32Type,
- ValueRange{minus1, zero});
- Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, int32Type,
- ValueRange{minus1, mbcntLo});
+ NamedAttribute noundef = rewriter.getNamedAttr(
+ LLVM::LLVMDialect::getNoUndefAttrName(), rewriter.getUnitAttr());
+ NamedAttribute lowRange = rewriter.getNamedAttr(
+ LLVM::LLVMDialect::getRangeAttrName(),
+ LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
+ APInt(32, 32)));
+ NamedAttribute highRange = rewriter.getNamedAttr(
+ LLVM::LLVMDialect::getRangeAttrName(),
+ LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
+ APInt(32, 64)));
+ Value mbcntLo = ROCDL::MbcntLoOp::create(
+ rewriter, loc, int32Type, minus1, zero, /*arg_attrs=*/{},
+ /*res_attrs=*/
+ rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, lowRange})));
+ Value laneId = ROCDL::MbcntHiOp::create(
+ rewriter, loc, int32Type, minus1, mbcntLo, /*arg_attrs=*/{},
+ rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, highRange})));
return laneId;
}
+
static constexpr StringLiteral amdgcnDataLayout =
"e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32"
"-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:"
@@ -104,18 +117,16 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
LogicalResult
matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
+ Location loc = op.getLoc();
MLIRContext *context = rewriter.getContext();
- // convert to: %mlo = call @llvm.amdgcn.mbcnt.lo(-1, 0)
- // followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo)
-
- Type intTy = IntegerType::get(context, 32);
- Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
- Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
- Value mbcntLo = ROCDL::MbcntLoOp::create(rewriter, loc, intTy,
- ValueRange{minus1, zero});
- Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, intTy,
- ValueRange{minus1, mbcntLo});
+ // convert to:
+ // %mlo = call noundef range(i32 0, 32)
+ // @llvm.amdgcn.mbcnt.lo(-1, 0)
+ // followed by:
+ // %lid = call noundef range(i32 0, 64)
+ // @llvm.amdgcn.mbcnt.hi(-1, %mlo)
+
+ Value laneId = getLaneId(rewriter, loc);
// Truncate or extend the result depending on the index bitwidth specified
// by the LLVMTypeConverter options.
const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
@@ -160,6 +171,38 @@ struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> {
const amdgpu::Chipset chipset;
};
+static bool isSupportedReadLaneType(Type type) {
+ // read(first)lane also supports some vector types, but limit it for scalars
+ // for now.
+ return type.isInteger(16) || type.isInteger(32) || type.isInteger(64) ||
+ isa<Float16Type, BFloat16Type, Float32Type, Float64Type,
+ LLVM::LLVMPointerType>(type);
+}
+
+struct GPUSubgroupBroadcastOpToROCDL
+ : public ConvertOpToLLVMPattern<gpu::SubgroupBroadcastOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Value src = adaptor.getSrc();
+ if (!isSupportedReadLaneType(src.getType()))
+ return rewriter.notifyMatchFailure(op, "unsupported readlane type");
+
+ if (adaptor.getBroadcastType() == gpu::BroadcastType::specific_lane) {
+ rewriter.replaceOpWithNewOp<ROCDL::ReadlaneOp>(op, src.getType(), src,
+ adaptor.getLane());
+ } else { // first_active_lane or any_lane
+ // any_lane is lowered to readfirstlane too, to force value into scalar
+ // register.
+ rewriter.replaceOpWithNewOp<ROCDL::ReadfirstlaneOp>(op, src.getType(),
+ src);
+ }
+ return success();
+ }
+};
+
struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
@@ -185,8 +228,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
Location loc = op->getLoc();
Value initShflValue = adaptor.getValue();
- const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
- Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);
+ Value srcLaneId = getLaneId(rewriter, loc);
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
Value width = adaptor.getWidth();
@@ -317,7 +359,7 @@ struct LowerGpuOpsToROCDLOpsPass final
{
RewritePatternSet patterns(ctx);
populateGpuRewritePatterns(patterns);
- populateGpuPromoteShuffleToAMDGPUPatterns(patterns);
+ populateGpuPromoteShuffleToAMDGPUPatterns(patterns, maybeChipset);
(void)applyPatternsGreedily(m, std::move(patterns));
}
@@ -453,7 +495,8 @@ void mlir::populateGpuToROCDLConversionPatterns(
// TODO: Add alignment for workgroup memory
patterns.add<GPUDynamicSharedMemoryOpLowering>(converter);
- patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
+ patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL,
+ GPUSubgroupBroadcastOpToROCDL>(converter);
patterns.add<GPUSubgroupSizeOpToROCDL>(converter, chipset);
populateMathToROCDLConversionPatterns(converter, patterns);
diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
index fce7a3f..522e914 100644
--- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -353,14 +353,9 @@ void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
results.push_back(d.memRefDescPtr(builder, loc));
}
-void UnrankedMemRefDescriptor::computeSizes(
+Value UnrankedMemRefDescriptor::computeSize(
OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter,
- ArrayRef<UnrankedMemRefDescriptor> values, ArrayRef<unsigned> addressSpaces,
- SmallVectorImpl<Value> &sizes) {
- if (values.empty())
- return;
- assert(values.size() == addressSpaces.size() &&
- "must provide address space for each descriptor");
+ UnrankedMemRefDescriptor desc, unsigned addressSpace) {
// Cache the index type.
Type indexType = typeConverter.getIndexType();
@@ -371,34 +366,31 @@ void UnrankedMemRefDescriptor::computeSizes(
builder, loc, indexType,
llvm::divideCeil(typeConverter.getIndexTypeBitwidth(), 8));
- sizes.reserve(sizes.size() + values.size());
- for (auto [desc, addressSpace] : llvm::zip(values, addressSpaces)) {
- // Emit IR computing the memory necessary to store the descriptor. This
- // assumes the descriptor to be
- // { type*, type*, index, index[rank], index[rank] }
- // and densely packed, so the total size is
- // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
- // TODO: consider including the actual size (including eventual padding due
- // to data layout) into the unranked descriptor.
- Value pointerSize = createIndexAttrConstant(
- builder, loc, indexType,
- llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
- Value doublePointerSize =
- LLVM::MulOp::create(builder, loc, indexType, two, pointerSize);
-
- // (1 + 2 * rank) * sizeof(index)
- Value rank = desc.rank(builder, loc);
- Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank);
- Value doubleRankIncremented =
- LLVM::AddOp::create(builder, loc, indexType, doubleRank, one);
- Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType,
- doubleRankIncremented, indexSize);
-
- // Total allocation size.
- Value allocationSize = LLVM::AddOp::create(
- builder, loc, indexType, doublePointerSize, rankIndexSize);
- sizes.push_back(allocationSize);
- }
+ // Emit IR computing the memory necessary to store the descriptor. This
+ // assumes the descriptor to be
+ // { type*, type*, index, index[rank], index[rank] }
+ // and densely packed, so the total size is
+ // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
+ // TODO: consider including the actual size (including eventual padding due
+ // to data layout) into the unranked descriptor.
+ Value pointerSize = createIndexAttrConstant(
+ builder, loc, indexType,
+ llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
+ Value doublePointerSize =
+ LLVM::MulOp::create(builder, loc, indexType, two, pointerSize);
+
+ // (1 + 2 * rank) * sizeof(index)
+ Value rank = desc.rank(builder, loc);
+ Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank);
+ Value doubleRankIncremented =
+ LLVM::AddOp::create(builder, loc, indexType, doubleRank, one);
+ Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType,
+ doubleRankIncremented, indexSize);
+
+ // Total allocation size.
+ Value allocationSize = LLVM::AddOp::create(builder, loc, indexType,
+ doublePointerSize, rankIndexSize);
+ return allocationSize;
}
Value UnrankedMemRefDescriptor::allocatedPtr(
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 2568044..48a0319 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -216,34 +216,14 @@ MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
return memRefDescriptor;
}
-LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
- OpBuilder &builder, Location loc, TypeRange origTypes,
- SmallVectorImpl<Value> &operands, bool toDynamic) const {
- assert(origTypes.size() == operands.size() &&
- "expected as may original types as operands");
-
- // Find operands of unranked memref type and store them.
- SmallVector<UnrankedMemRefDescriptor> unrankedMemrefs;
- SmallVector<unsigned> unrankedAddressSpaces;
- for (unsigned i = 0, e = operands.size(); i < e; ++i) {
- if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
- unrankedMemrefs.emplace_back(operands[i]);
- FailureOr<unsigned> addressSpace =
- getTypeConverter()->getMemRefAddressSpace(memRefType);
- if (failed(addressSpace))
- return failure();
- unrankedAddressSpaces.emplace_back(*addressSpace);
- }
- }
-
- if (unrankedMemrefs.empty())
- return success();
-
- // Compute allocation sizes.
- SmallVector<Value> sizes;
- UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(),
- unrankedMemrefs, unrankedAddressSpaces,
- sizes);
+Value ConvertToLLVMPattern::copyUnrankedDescriptor(
+ OpBuilder &builder, Location loc, UnrankedMemRefType memRefType,
+ Value operand, bool toDynamic) const {
+ // Convert memory space.
+ FailureOr<unsigned> addressSpace =
+ getTypeConverter()->getMemRefAddressSpace(memRefType);
+ if (failed(addressSpace))
+ return {};
// Get frequently used types.
Type indexType = getTypeConverter()->getIndexType();
@@ -254,52 +234,61 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
if (toDynamic) {
mallocFunc = LLVM::lookupOrCreateMallocFn(builder, module, indexType);
if (failed(mallocFunc))
- return failure();
+ return {};
}
if (!toDynamic) {
freeFunc = LLVM::lookupOrCreateFreeFn(builder, module);
if (failed(freeFunc))
- return failure();
+ return {};
}
- unsigned unrankedMemrefPos = 0;
- for (unsigned i = 0, e = operands.size(); i < e; ++i) {
- Type type = origTypes[i];
- if (!isa<UnrankedMemRefType>(type))
- continue;
- Value allocationSize = sizes[unrankedMemrefPos++];
- UnrankedMemRefDescriptor desc(operands[i]);
-
- // Allocate memory, copy, and free the source if necessary.
- Value memory =
- toDynamic ? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
- allocationSize)
- .getResult()
- : LLVM::AllocaOp::create(builder, loc, getPtrType(),
- IntegerType::get(getContext(), 8),
- allocationSize,
- /*alignment=*/0);
- Value source = desc.memRefDescPtr(builder, loc);
- LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false);
- if (!toDynamic)
- LLVM::CallOp::create(builder, loc, freeFunc.value(), source);
-
- // Create a new descriptor. The same descriptor can be returned multiple
- // times, attempting to modify its pointer can lead to memory leaks
- // (allocated twice and overwritten) or double frees (the caller does not
- // know if the descriptor points to the same memory).
- Type descriptorType = getTypeConverter()->convertType(type);
- if (!descriptorType)
- return failure();
- auto updatedDesc =
- UnrankedMemRefDescriptor::poison(builder, loc, descriptorType);
- Value rank = desc.rank(builder, loc);
- updatedDesc.setRank(builder, loc, rank);
- updatedDesc.setMemRefDescPtr(builder, loc, memory);
+ UnrankedMemRefDescriptor desc(operand);
+ Value allocationSize = UnrankedMemRefDescriptor::computeSize(
+ builder, loc, *getTypeConverter(), desc, *addressSpace);
+
+ // Allocate memory, copy, and free the source if necessary.
+ Value memory = toDynamic
+ ? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
+ allocationSize)
+ .getResult()
+ : LLVM::AllocaOp::create(builder, loc, getPtrType(),
+ IntegerType::get(getContext(), 8),
+ allocationSize,
+ /*alignment=*/0);
+ Value source = desc.memRefDescPtr(builder, loc);
+ LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false);
+ if (!toDynamic)
+ LLVM::CallOp::create(builder, loc, freeFunc.value(), source);
+
+ // Create a new descriptor. The same descriptor can be returned multiple
+ // times, attempting to modify its pointer can lead to memory leaks
+ // (allocated twice and overwritten) or double frees (the caller does not
+ // know if the descriptor points to the same memory).
+ Type descriptorType = getTypeConverter()->convertType(memRefType);
+ if (!descriptorType)
+ return {};
+ auto updatedDesc =
+ UnrankedMemRefDescriptor::poison(builder, loc, descriptorType);
+ Value rank = desc.rank(builder, loc);
+ updatedDesc.setRank(builder, loc, rank);
+ updatedDesc.setMemRefDescPtr(builder, loc, memory);
+ return updatedDesc;
+}
- operands[i] = updatedDesc;
+LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
+ OpBuilder &builder, Location loc, TypeRange origTypes,
+ SmallVectorImpl<Value> &operands, bool toDynamic) const {
+ assert(origTypes.size() == operands.size() &&
+ "expected as may original types as operands");
+ for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+ if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
+ Value updatedDesc = copyUnrankedDescriptor(builder, loc, memRefType,
+ operands[i], toDynamic);
+ if (!updatedDesc)
+ return failure();
+ operands[i] = updatedDesc;
+ }
}
-
return success();
}
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 1a9bf56..cb9dea1 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -365,6 +365,7 @@ Type LLVMTypeConverter::convertFunctionSignatureImpl(
useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
: structFuncArgTypeConverter;
+
// Convert argument types one by one and check for errors.
for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
SmallVector<Type, 8> converted;
@@ -658,27 +659,19 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
/// UnrankedMemRefType, are converted following the specific rules for the
/// calling convention. Calling convention independent types are converted
/// following the default LLVM type conversions.
-Type LLVMTypeConverter::convertCallingConventionType(
- Type type, bool useBarePtrCallConv) const {
- if (useBarePtrCallConv)
- if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
- return convertMemRefToBarePtr(memrefTy);
-
- return convertType(type);
-}
+LogicalResult LLVMTypeConverter::convertCallingConventionType(
+ Type type, SmallVectorImpl<Type> &result, bool useBarePtrCallConv) const {
+ if (useBarePtrCallConv) {
+ if (auto memrefTy = dyn_cast<BaseMemRefType>(type)) {
+ Type converted = convertMemRefToBarePtr(memrefTy);
+ if (!converted)
+ return failure();
+ result.push_back(converted);
+ return success();
+ }
+ }
-/// Promote the bare pointers in 'values' that resulted from memrefs to
-/// descriptors. 'stdTypes' holds they types of 'values' before the conversion
-/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
-void LLVMTypeConverter::promoteBarePtrsToDescriptors(
- ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
- SmallVectorImpl<Value> &values) const {
- assert(stdTypes.size() == values.size() &&
- "The number of types and values doesn't match");
- for (unsigned i = 0, end = values.size(); i < end; ++i)
- if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
- values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
- memrefTy, values[i]);
+ return convertType(type, result);
}
/// Convert a non-empty list of types of values produced by an operation into an
@@ -706,23 +699,35 @@ Type LLVMTypeConverter::packOperationResults(TypeRange types) const {
/// LLVM-compatible type. In particular, if more than one value is returned,
/// create an LLVM dialect structure type with elements that correspond to each
/// of the types converted with `convertCallingConventionType`.
-Type LLVMTypeConverter::packFunctionResults(TypeRange types,
- bool useBarePtrCallConv) const {
+Type LLVMTypeConverter::packFunctionResults(
+ TypeRange types, bool useBarePtrCallConv,
+ SmallVector<SmallVector<Type>> *groupedTypes,
+ int64_t *numConvertedTypes) const {
assert(!types.empty() && "expected non-empty list of type");
+ assert((!groupedTypes || groupedTypes->empty()) &&
+ "expected groupedTypes to be empty");
useBarePtrCallConv |= options.useBarePtrCallConv;
- if (types.size() == 1)
- return convertCallingConventionType(types.front(), useBarePtrCallConv);
-
SmallVector<Type> resultTypes;
resultTypes.reserve(types.size());
+ size_t sizeBefore = 0;
for (auto t : types) {
- auto converted = convertCallingConventionType(t, useBarePtrCallConv);
- if (!converted || !LLVM::isCompatibleType(converted))
+ if (failed(
+ convertCallingConventionType(t, resultTypes, useBarePtrCallConv)))
return {};
- resultTypes.push_back(converted);
+ if (groupedTypes) {
+ SmallVector<Type> &group = groupedTypes->emplace_back();
+ llvm::append_range(group, ArrayRef(resultTypes).drop_front(sizeBefore));
+ }
+ sizeBefore = resultTypes.size();
}
+ if (numConvertedTypes)
+ *numConvertedTypes = resultTypes.size();
+ if (resultTypes.size() == 1)
+ return resultTypes.front();
+ if (resultTypes.empty())
+ return {};
return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
}
@@ -740,40 +745,50 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
return allocated;
}
-SmallVector<Value, 4>
-LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
- ValueRange operands, OpBuilder &builder,
- bool useBarePtrCallConv) const {
+SmallVector<Value, 4> LLVMTypeConverter::promoteOperands(
+ Location loc, ValueRange opOperands, ValueRange adaptorOperands,
+ OpBuilder &builder, bool useBarePtrCallConv) const {
+ SmallVector<ValueRange> ranges;
+ for (size_t i = 0, e = adaptorOperands.size(); i < e; i++)
+ ranges.push_back(adaptorOperands.slice(i, 1));
+ return promoteOperands(loc, opOperands, ranges, builder, useBarePtrCallConv);
+}
+
+SmallVector<Value, 4> LLVMTypeConverter::promoteOperands(
+ Location loc, ValueRange opOperands, ArrayRef<ValueRange> adaptorOperands,
+ OpBuilder &builder, bool useBarePtrCallConv) const {
SmallVector<Value, 4> promotedOperands;
- promotedOperands.reserve(operands.size());
+ promotedOperands.reserve(adaptorOperands.size());
useBarePtrCallConv |= options.useBarePtrCallConv;
- for (auto it : llvm::zip(opOperands, operands)) {
- auto operand = std::get<0>(it);
- auto llvmOperand = std::get<1>(it);
-
+ for (auto [operand, llvmOperand] :
+ llvm::zip_equal(opOperands, adaptorOperands)) {
if (useBarePtrCallConv) {
// For the bare-ptr calling convention, we only have to extract the
// aligned pointer of a memref.
if (isa<MemRefType>(operand.getType())) {
- MemRefDescriptor desc(llvmOperand);
- llvmOperand = desc.alignedPtr(builder, loc);
+ assert(llvmOperand.size() == 1 && "Expected a single operand");
+ MemRefDescriptor desc(llvmOperand.front());
+ promotedOperands.push_back(desc.alignedPtr(builder, loc));
+ continue;
} else if (isa<UnrankedMemRefType>(operand.getType())) {
llvm_unreachable("Unranked memrefs are not supported");
}
} else {
if (isa<UnrankedMemRefType>(operand.getType())) {
- UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
+ assert(llvmOperand.size() == 1 && "Expected a single operand");
+ UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand.front(),
promotedOperands);
continue;
}
if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
- MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType,
+ assert(llvmOperand.size() == 1 && "Expected a single operand");
+ MemRefDescriptor::unpack(builder, loc, llvmOperand.front(), memrefType,
promotedOperands);
continue;
}
}
- promotedOperands.push_back(llvmOperand);
+ llvm::append_range(promotedOperands, llvmOperand);
}
return promotedOperands;
}
@@ -802,11 +817,7 @@ mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
result.append(converted.begin(), converted.end());
return success();
}
- auto converted = converter.convertType(type);
- if (!converted)
- return failure();
- result.push_back(converted);
- return success();
+ return converter.convertType(type, result);
}
/// Callback to convert function argument types. It converts MemRef function
@@ -814,11 +825,7 @@ mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
LogicalResult
mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
SmallVectorImpl<Type> &result) {
- auto llvmTy = converter.convertCallingConventionType(
- type, /*useBarePointerCallConv=*/true);
- if (!llvmTy)
- return failure();
-
- result.push_back(llvmTy);
- return success();
+ return converter.convertCallingConventionType(
+ type, result,
+ /*useBarePointerCallConv=*/true);
}
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 6bd0e2d..2b7bdc9 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -17,11 +17,13 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cstdint>
+#include <numeric>
using namespace mlir;
@@ -97,6 +99,48 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
return resultTy;
}
+static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
+ OpBuilder &builder) {
+ assert(isMemRefTypeLegalForEmitC(memrefType) &&
+ "incompatible memref type for EmitC conversion");
+ emitc::CallOpaqueOp elementSize = emitc::CallOpaqueOp::create(
+ builder, loc, emitc::SizeTType::get(builder.getContext()),
+ builder.getStringAttr("sizeof"), ValueRange{},
+ ArrayAttr::get(builder.getContext(),
+ {TypeAttr::get(memrefType.getElementType())}));
+
+ IndexType indexType = builder.getIndexType();
+ int64_t numElements = std::accumulate(memrefType.getShape().begin(),
+ memrefType.getShape().end(), int64_t{1},
+ std::multiplies<int64_t>());
+ emitc::ConstantOp numElementsValue = emitc::ConstantOp::create(
+ builder, loc, indexType, builder.getIndexAttr(numElements));
+
+ Type sizeTType = emitc::SizeTType::get(builder.getContext());
+ emitc::MulOp totalSizeBytes = emitc::MulOp::create(
+ builder, loc, sizeTType, elementSize.getResult(0), numElementsValue);
+
+ return totalSizeBytes.getResult();
+}
+
+static emitc::ApplyOp
+createPointerFromEmitcArray(Location loc, OpBuilder &builder,
+ TypedValue<emitc::ArrayType> arrayValue) {
+
+ emitc::ConstantOp zeroIndex = emitc::ConstantOp::create(
+ builder, loc, builder.getIndexType(), builder.getIndexAttr(0));
+
+ emitc::ArrayType arrayType = arrayValue.getType();
+ llvm::SmallVector<mlir::Value> indices(arrayType.getRank(), zeroIndex);
+ emitc::SubscriptOp subPtr =
+ emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices));
+ emitc::ApplyOp ptr = emitc::ApplyOp::create(
+ builder, loc, emitc::PointerType::get(arrayType.getElementType()),
+ builder.getStringAttr("&"), subPtr);
+
+ return ptr;
+}
+
struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@@ -112,19 +156,21 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
Type elementType = memrefType.getElementType();
IndexType indexType = rewriter.getIndexType();
- emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>(
- loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{},
+ emitc::CallOpaqueOp sizeofElementOp = emitc::CallOpaqueOp::create(
+ rewriter, loc, sizeTType, rewriter.getStringAttr("sizeof"),
+ ValueRange{},
ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)}));
int64_t numElements = 1;
for (int64_t dimSize : memrefType.getShape()) {
numElements *= dimSize;
}
- Value numElementsValue = rewriter.create<emitc::ConstantOp>(
- loc, indexType, rewriter.getIndexAttr(numElements));
+ Value numElementsValue = emitc::ConstantOp::create(
+ rewriter, loc, indexType, rewriter.getIndexAttr(numElements));
- Value totalSizeBytes = rewriter.create<emitc::MulOp>(
- loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue);
+ Value totalSizeBytes =
+ emitc::MulOp::create(rewriter, loc, sizeTType,
+ sizeofElementOp.getResult(0), numElementsValue);
emitc::CallOpaqueOp allocCall;
StringAttr allocFunctionName;
@@ -132,8 +178,8 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
SmallVector<Value, 2> argsVec;
if (allocOp.getAlignment()) {
allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName);
- alignmentValue = rewriter.create<emitc::ConstantOp>(
- loc, sizeTType,
+ alignmentValue = emitc::ConstantOp::create(
+ rewriter, loc, sizeTType,
rewriter.getIntegerAttr(indexType,
allocOp.getAlignment().value_or(0)));
argsVec.push_back(alignmentValue);
@@ -144,21 +190,62 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
argsVec.push_back(totalSizeBytes);
ValueRange args(argsVec);
- allocCall = rewriter.create<emitc::CallOpaqueOp>(
- loc,
+ allocCall = emitc::CallOpaqueOp::create(
+ rewriter, loc,
emitc::PointerType::get(
emitc::OpaqueType::get(rewriter.getContext(), "void")),
allocFunctionName, args);
emitc::PointerType targetPointerType = emitc::PointerType::get(elementType);
- emitc::CastOp castOp = rewriter.create<emitc::CastOp>(
- loc, targetPointerType, allocCall.getResult(0));
+ emitc::CastOp castOp = emitc::CastOp::create(
+ rewriter, loc, targetPointerType, allocCall.getResult(0));
rewriter.replaceOp(allocOp, castOp);
return success();
}
};
+struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = copyOp.getLoc();
+ MemRefType srcMemrefType = cast<MemRefType>(copyOp.getSource().getType());
+ MemRefType targetMemrefType =
+ cast<MemRefType>(copyOp.getTarget().getType());
+
+ if (!isMemRefTypeLegalForEmitC(srcMemrefType))
+ return rewriter.notifyMatchFailure(
+ loc, "incompatible source memref type for EmitC conversion");
+
+ if (!isMemRefTypeLegalForEmitC(targetMemrefType))
+ return rewriter.notifyMatchFailure(
+ loc, "incompatible target memref type for EmitC conversion");
+
+ auto srcArrayValue =
+ cast<TypedValue<emitc::ArrayType>>(operands.getSource());
+ emitc::ApplyOp srcPtr =
+ createPointerFromEmitcArray(loc, rewriter, srcArrayValue);
+
+ auto targetArrayValue =
+ cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
+ emitc::ApplyOp targetPtr =
+ createPointerFromEmitcArray(loc, rewriter, targetArrayValue);
+
+ emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create(
+ rewriter, loc, TypeRange{}, "memcpy",
+ ValueRange{
+ targetPtr.getResult(), srcPtr.getResult(),
+ calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)});
+
+ rewriter.replaceOp(copyOp, memCpyCall.getResults());
+
+ return success();
+ }
+};
+
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
using OpConversionPattern::OpConversionPattern;
@@ -320,6 +407,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
void mlir::populateMemRefToEmitCConversionPatterns(
RewritePatternSet &patterns, const TypeConverter &converter) {
- patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
- ConvertLoad, ConvertStore>(converter, patterns.getContext());
+ patterns.add<ConvertAlloca, ConvertAlloc, ConvertCopy, ConvertGlobal,
+ ConvertGetGlobal, ConvertLoad, ConvertStore>(
+ converter, patterns.getContext());
}
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index e78dd76..a073a9a 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -18,6 +18,8 @@
#include "mlir/IR/Attributes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/StringRef.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
@@ -27,6 +29,15 @@ namespace mlir {
using namespace mlir;
namespace {
+
+emitc::IncludeOp addStandardHeader(OpBuilder &builder, ModuleOp module,
+ StringRef headerName) {
+ StringAttr includeAttr = builder.getStringAttr(headerName);
+ return emitc::IncludeOp::create(
+ builder, module.getLoc(), includeAttr,
+ /*is_standard_include=*/builder.getUnitAttr());
+}
+
struct ConvertMemRefToEmitCPass
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
using Base::Base;
@@ -55,34 +66,29 @@ struct ConvertMemRefToEmitCPass
return signalPassFailure();
mlir::ModuleOp module = getOperation();
+ llvm::SmallSet<StringRef, 4> existingHeaders;
+ mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
+ module.walk([&](mlir::emitc::IncludeOp includeOp) {
+ if (includeOp.getIsStandardInclude())
+ existingHeaders.insert(includeOp.getInclude());
+ });
+
module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
- if (callOp.getCallee() != alignedAllocFunctionName &&
- callOp.getCallee() != mallocFunctionName) {
+ StringRef expectedHeader;
+ if (callOp.getCallee() == alignedAllocFunctionName ||
+ callOp.getCallee() == mallocFunctionName)
+ expectedHeader = options.lowerToCpp ? cppStandardLibraryHeader
+ : cStandardLibraryHeader;
+ else if (callOp.getCallee() == memcpyFunctionName)
+ expectedHeader =
+ options.lowerToCpp ? cppStringLibraryHeader : cStringLibraryHeader;
+ else
return mlir::WalkResult::advance();
+ if (!existingHeaders.contains(expectedHeader)) {
+ addStandardHeader(builder, module, expectedHeader);
+ existingHeaders.insert(expectedHeader);
}
-
- for (auto &op : *module.getBody()) {
- emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
- if (!includeOp) {
- continue;
- }
- if (includeOp.getIsStandardInclude() &&
- ((options.lowerToCpp &&
- includeOp.getInclude() == cppStandardLibraryHeader) ||
- (!options.lowerToCpp &&
- includeOp.getInclude() == cStandardLibraryHeader))) {
- return mlir::WalkResult::interrupt();
- }
- }
-
- mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
- StringAttr includeAttr =
- builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader
- : cStandardLibraryHeader);
- builder.create<mlir::emitc::IncludeOp>(
- module.getLoc(), includeAttr,
- /*is_standard_include=*/builder.getUnitAttr());
- return mlir::WalkResult::interrupt();
+ return mlir::WalkResult::advance();
});
}
};
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index d6bdd34..262e0e7 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1246,10 +1246,8 @@ struct MemorySpaceCastOpLowering
auto result = UnrankedMemRefDescriptor::poison(
rewriter, loc, typeConverter->convertType(resultTypeU));
result.setRank(rewriter, loc, rank);
- SmallVector<Value, 1> sizes;
- UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
- result, resultAddrSpace, sizes);
- Value resultUnderlyingSize = sizes.front();
+ Value resultUnderlyingSize = UnrankedMemRefDescriptor::computeSize(
+ rewriter, loc, *getTypeConverter(), result, resultAddrSpace);
Value resultUnderlyingDesc =
LLVM::AllocaOp::create(rewriter, loc, getPtrType(),
rewriter.getI8Type(), resultUnderlyingSize);
@@ -1530,12 +1528,11 @@ private:
auto targetDesc = UnrankedMemRefDescriptor::poison(
rewriter, loc, typeConverter->convertType(targetType));
targetDesc.setRank(rewriter, loc, resultRank);
- SmallVector<Value, 4> sizes;
- UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
- targetDesc, addressSpace, sizes);
+ Value allocationSize = UnrankedMemRefDescriptor::computeSize(
+ rewriter, loc, *getTypeConverter(), targetDesc, addressSpace);
Value underlyingDescPtr = LLVM::AllocaOp::create(
rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8),
- sizes.front());
+ allocationSize);
targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
// Extract pointers and offset from the source memref.
@@ -1872,6 +1869,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
return LLVM::AtomicBinOp::umin;
case arith::AtomicRMWKind::ori:
return LLVM::AtomicBinOp::_or;
+ case arith::AtomicRMWKind::xori:
+ return LLVM::AtomicBinOp::_xor;
case arith::AtomicRMWKind::andi:
return LLVM::AtomicBinOp::_and;
default:
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index c7ecd83..2e00b42 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Visitors.h"
#include <cassert>
+#include <limits>
#include <optional>
#define DEBUG_TYPE "memref-to-spirv-pattern"
@@ -475,7 +476,12 @@ struct MemoryRequirements {
/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
/// any.
static FailureOr<MemoryRequirements>
-calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
+calculateMemoryRequirements(Value accessedPtr, bool isNontemporal,
+ uint64_t preferredAlignment) {
+ if (preferredAlignment >= std::numeric_limits<uint32_t>::max()) {
+ return failure();
+ }
+
MLIRContext *ctx = accessedPtr.getContext();
auto memoryAccess = spirv::MemoryAccess::None;
@@ -484,7 +490,10 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
}
auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
- if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
+ bool mayOmitAlignment =
+ !preferredAlignment &&
+ ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer;
+ if (mayOmitAlignment) {
if (memoryAccess == spirv::MemoryAccess::None) {
return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
}
@@ -493,6 +502,7 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
}
// PhysicalStorageBuffers require the `Aligned` attribute.
+ // Other storage types may show an `Aligned` attribute.
auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
if (!pointeeType)
return failure();
@@ -504,7 +514,8 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
- auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
+ auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes;
+ auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue);
return MemoryRequirements{memAccessAttr, alignment};
}
@@ -518,16 +529,9 @@ calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
"Must be called on either memref::LoadOp or memref::StoreOp");
- Operation *memrefAccessOp = loadOrStoreOp.getOperation();
- auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
- spirv::attributeName<spirv::MemoryAccess>());
- auto memrefAlignment =
- memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
- if (memrefMemAccess && memrefAlignment)
- return MemoryRequirements{memrefMemAccess, memrefAlignment};
-
return calculateMemoryRequirements(accessedPtr,
- loadOrStoreOp.getNontemporal());
+ loadOrStoreOp.getNontemporal(),
+ loadOrStoreOp.getAlignment().value_or(0));
}
LogicalResult
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 2549a9c..37d12ba 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -283,11 +283,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
Value srcPtr =
getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
adaptor.getSrcMemref(), adaptor.getIndices());
+ auto shape = NVVM::LdStMatrixShapeAttr::get(rewriter.getContext(), 8, 8);
Value ldMatrixResult = NVVM::LdMatrixOp::create(
b, ldMatrixResultType, srcPtr,
/*num=*/op.getNumTiles(),
/*layout=*/op.getTranspose() ? NVVM::MMALayout::col
- : NVVM::MMALayout::row);
+ : NVVM::MMALayout::row,
+ /*shape=*/shape, /*eltType=*/NVVM::LdStMatrixEltType::B16);
// The ldmatrix operation returns either a single i32 value or a struct of
// i32 values. Here we unpack those values and cast them back to their
@@ -394,11 +396,6 @@ struct ConvertNVGPUToNVVMPass
: public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
using Base::Base;
- void getDependentDialects(DialectRegistry &registry) const override {
- registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
- arith::ArithDialect>();
- }
-
void runOnOperation() override {
LowerToLLVMOptions options(&getContext());
RewritePatternSet patterns(&getContext());
@@ -1029,8 +1026,10 @@ struct NVGPUTmaAsyncStoreOpLowering
coords[index] = truncToI32(b, value);
}
+ // TODO: Enhance the NVGPU Op for other modes too
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
- op, adaptor.getTensorMapDescriptor(), dest, coords,
+ op, adaptor.getTensorMapDescriptor(), dest, coords, Value{},
+ NVVM::TMAStoreMode::TILE, // default is TILE mode
adaptor.getPredicate());
return success();
}
@@ -1104,12 +1103,10 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
// // [0,14) start_address
dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
- LDBG() << "Generating warpgroup.descriptor: "
- << "leading_off:" << leadDimVal << "\t"
- << "stride_off :" << strideDimVal << "\t"
- << "base_offset:" << offsetVal << "\t"
- << "layout_type:" << swizzle << " ("
- << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
+ LDBG() << "Generating warpgroup.descriptor: " << "leading_off:"
+ << leadDimVal << "\t" << "stride_off :" << strideDimVal << "\t"
+ << "base_offset:" << offsetVal << "\t" << "layout_type:" << swizzle
+ << " (" << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
<< ")\n start_addr : " << baseAddr;
rewriter.replaceOp(op, dsc);
@@ -1399,14 +1396,12 @@ struct NVGPUWarpgroupMmaOpLowering
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
/// descriptors and arranges them based on induction variables: i, j, and k.
Value generateWgmma(int i, int j, int k, Value matrixC) {
- LDBG() << "\t wgmma."
- << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK << "(A["
- << (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM
- << "][" << (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "] * "
- << " B[" << (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN
- << "])";
+ LDBG() << "\t wgmma." << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
+ << "(A[" << (iterationM * wgmmaM) << ":"
+ << (iterationM * wgmmaM) + wgmmaM << "][" << (iterationK * wgmmaK)
+ << ":" << (iterationK * wgmmaK + wgmmaK) << "] * " << " B["
+ << (iterationK * wgmmaK) << ":" << (iterationK * wgmmaK + wgmmaK)
+ << "][" << 0 << ":" << wgmmaN << "])";
Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
@@ -1700,8 +1695,10 @@ struct NVGPUTmaPrefetchOpLowering
LogicalResult
matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>(
- op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
+ rewriter.replaceOpWithNewOp<NVVM::PrefetchOp>(
+ op, /* CacheLevel */ nullptr, /* Cache Eviction Priority */ nullptr,
+ adaptor.getTensorMapDescriptor(), adaptor.getPredicate(),
+ /* Tensormap UnitAttr */ mlir::UnitAttr::get(op.getContext()));
return success();
}
};
diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index 91788f9..314cbed 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -26,6 +26,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "nvvm-to-llvm"
@@ -57,12 +58,13 @@ struct PtxLowering
SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
LDBG() << op.getPtx();
- PtxBuilder generator(op, rewriter);
- op.getAsmValues(rewriter, asmValues);
+ bool needsManualMapping = op.getAsmValues(rewriter, asmValues);
+ PtxBuilder generator(op, rewriter, needsManualMapping);
for (auto &[asmValue, modifier] : asmValues) {
- LDBG() << asmValue << "\t Modifier : " << &modifier;
- generator.insertValue(asmValue, modifier);
+ LDBG() << asmValue << "\t Modifier : " << modifier;
+ if (failed(generator.insertValue(asmValue, modifier)))
+ return failure();
}
generator.buildAndReplaceOp();
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
index 5bd1d49..d57926ec 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
@@ -15,6 +15,7 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include <queue>
#define DEBUG_TYPE "pdl-predicate-tree"
@@ -544,7 +545,7 @@ static void visitUpward(std::vector<PositionalPredicate> &predList,
Value value = opIndex.parent;
TypeSwitch<Operation *>(value.getDefiningOp())
.Case<pdl::OperationOp>([&](auto operationOp) {
- LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
+ LDBG() << " * Value: " << value;
// Get users and iterate over them.
Position *usersPos = builder.getUsers(pos, /*useRepresentative=*/true);
@@ -618,19 +619,15 @@ static Value buildPredicateList(pdl::PatternOp pattern,
RootOrderingGraph graph;
ParentMaps parentMaps;
buildCostGraph(roots, graph, parentMaps);
- LLVM_DEBUG({
- llvm::dbgs() << "Graph:\n";
- for (auto &target : graph) {
- llvm::dbgs() << " * " << target.first.getLoc() << " " << target.first
- << "\n";
- for (auto &source : target.second) {
- RootOrderingEntry &entry = source.second;
- llvm::dbgs() << " <- " << source.first << ": " << entry.cost.first
- << ":" << entry.cost.second << " via "
- << entry.connector.getLoc() << "\n";
- }
+ LDBG() << "Graph:";
+ for (auto &target : graph) {
+ LDBG() << " * " << target.first.getLoc() << " " << target.first;
+ for (auto &source : target.second) {
+ RootOrderingEntry &entry = source.second;
+ LDBG() << " <- " << source.first << ": " << entry.cost.first << ":"
+ << entry.cost.second << " via " << entry.connector.getLoc();
}
- });
+ }
// Solve the optimal branching problem for each candidate root, or use the
// provided one.
@@ -638,11 +635,11 @@ static Value buildPredicateList(pdl::PatternOp pattern,
OptimalBranching::EdgeList bestEdges;
if (!bestRoot) {
unsigned bestCost = 0;
- LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n");
+ LDBG() << "Candidate roots:";
for (Value root : roots) {
OptimalBranching solver(graph, root);
unsigned cost = solver.solve();
- LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n");
+ LDBG() << " * " << root << ": " << cost;
if (!bestRoot || bestCost > cost) {
bestCost = cost;
bestRoot = root;
@@ -656,18 +653,15 @@ static Value buildPredicateList(pdl::PatternOp pattern,
}
// Print the best solution.
- LLVM_DEBUG({
- llvm::dbgs() << "Best tree:\n";
- for (const std::pair<Value, Value> &edge : bestEdges) {
- llvm::dbgs() << " * " << edge.first;
- if (edge.second)
- llvm::dbgs() << " <- " << edge.second;
- llvm::dbgs() << "\n";
- }
- });
+ LDBG() << "Best tree:";
+ for (const std::pair<Value, Value> &edge : bestEdges) {
+ if (edge.second)
+ LDBG() << " * " << edge.first << " <- " << edge.second;
+ else
+ LDBG() << " * " << edge.first;
+ }
- LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n");
- LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n");
+ LDBG() << "Calling key getTreePredicates (Value: " << bestRoot << ")";
// The best root is the starting point for the traversal. Get the tree
// predicates for the DAG rooted at bestRoot.
@@ -691,7 +685,7 @@ static Value buildPredicateList(pdl::PatternOp pattern,
// Determine the connector.
Value connector = graph[target][source].connector;
assert(connector && "invalid edge");
- LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n");
+ LDBG() << " * Connector: " << connector.getLoc();
DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target);
Position *pos = valueToPosition.lookup(connector);
assert(pos && "connector has not been traversed yet");
@@ -806,9 +800,9 @@ static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
/// Get or insert a child matcher for the given parent switch node, given a
/// predicate and parent pattern.
-std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
- OrderedPredicate *predicate,
- pdl::PatternOp pattern) {
+static std::unique_ptr<MatcherNode> &
+getOrCreateChild(SwitchNode *node, OrderedPredicate *predicate,
+ pdl::PatternOp pattern) {
assert(isSamePredicate(node, predicate) &&
"expected matcher to equal the given predicate");
diff --git a/mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt b/mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt
new file mode 100644
index 0000000..2d416be
--- /dev/null
+++ b/mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_conversion_library(MLIRPtrToLLVM
+ PtrToLLVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/PtrToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRPtrDialect
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ )
diff --git a/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp b/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp
new file mode 100644
index 0000000..a0758aa
--- /dev/null
+++ b/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp
@@ -0,0 +1,440 @@
+//===- PtrToLLVM.cpp - Ptr to LLVM dialect conversion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/PtrToLLVM/PtrToLLVM.h"
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/Ptr/IR/PtrOps.h"
+#include "mlir/IR/TypeUtilities.h"
+#include <type_traits>
+
+using namespace mlir;
+
+namespace {
+//===----------------------------------------------------------------------===//
+// FromPtrOpConversion
+//===----------------------------------------------------------------------===//
+struct FromPtrOpConversion : public ConvertOpToLLVMPattern<ptr::FromPtrOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(ptr::FromPtrOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+//===----------------------------------------------------------------------===//
+// GetMetadataOpConversion
+//===----------------------------------------------------------------------===//
+struct GetMetadataOpConversion
+ : public ConvertOpToLLVMPattern<ptr::GetMetadataOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(ptr::GetMetadataOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+//===----------------------------------------------------------------------===//
+// PtrAddOpConversion
+//===----------------------------------------------------------------------===//
+struct PtrAddOpConversion : public ConvertOpToLLVMPattern<ptr::PtrAddOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(ptr::PtrAddOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+//===----------------------------------------------------------------------===//
+// ToPtrOpConversion
+//===----------------------------------------------------------------------===//
+struct ToPtrOpConversion : public ConvertOpToLLVMPattern<ptr::ToPtrOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(ptr::ToPtrOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+//===----------------------------------------------------------------------===//
+// TypeOffsetOpConversion
+//===----------------------------------------------------------------------===//
+struct TypeOffsetOpConversion
+ : public ConvertOpToLLVMPattern<ptr::TypeOffsetOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(ptr::TypeOffsetOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Internal functions
+//===----------------------------------------------------------------------===//
+
+// Function to create an LLVM struct type representing a memref metadata.
+static FailureOr<LLVM::LLVMStructType>
+createMemRefMetadataType(MemRefType type,
+ const LLVMTypeConverter &typeConverter) {
+ MLIRContext *context = type.getContext();
+ // Get the address space.
+ FailureOr<unsigned> addressSpace = typeConverter.getMemRefAddressSpace(type);
+ if (failed(addressSpace))
+ return failure();
+
+ // Get pointer type (using address space 0 by default)
+ auto ptrType = LLVM::LLVMPointerType::get(context, *addressSpace);
+
+ // Get the strides offsets and shape.
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ if (failed(type.getStridesAndOffset(strides, offset)))
+ return failure();
+ ArrayRef<int64_t> shape = type.getShape();
+
+ // Use index type from the type converter for the descriptor elements
+ Type indexType = typeConverter.getIndexType();
+
+ // For a ranked memref, the descriptor contains:
+ // 1. The pointer to the allocated data
+ // 2. The pointer to the aligned data
+ // 3. The dynamic offset?
+ // 4. The dynamic sizes?
+ // 5. The dynamic strides?
+ SmallVector<Type, 5> elements;
+
+ // Allocated pointer.
+ elements.push_back(ptrType);
+
+ // Potentially add the dynamic offset.
+ if (offset == ShapedType::kDynamic)
+ elements.push_back(indexType);
+
+ // Potentially add the dynamic sizes.
+ for (int64_t dim : shape) {
+ if (dim == ShapedType::kDynamic)
+ elements.push_back(indexType);
+ }
+
+ // Potentially add the dynamic strides.
+ for (int64_t stride : strides) {
+ if (stride == ShapedType::kDynamic)
+ elements.push_back(indexType);
+ }
+ return LLVM::LLVMStructType::getLiteral(context, elements);
+}
+
+//===----------------------------------------------------------------------===//
+// FromPtrOpConversion
+//===----------------------------------------------------------------------===//
+
+LogicalResult FromPtrOpConversion::matchAndRewrite(
+ ptr::FromPtrOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // Get the target memref type
+ auto mTy = dyn_cast<MemRefType>(op.getResult().getType());
+ if (!mTy)
+ return rewriter.notifyMatchFailure(op, "Expected memref result type");
+
+ if (!op.getMetadata() && op.getType().hasPtrMetadata()) {
+ return rewriter.notifyMatchFailure(
+ op, "Can convert only memrefs with metadata");
+ }
+
+ // Convert the result type
+ Type descriptorTy = getTypeConverter()->convertType(mTy);
+ if (!descriptorTy)
+ return rewriter.notifyMatchFailure(op, "Failed to convert result type");
+
+ // Get the strides, offsets and shape.
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ if (failed(mTy.getStridesAndOffset(strides, offset))) {
+ return rewriter.notifyMatchFailure(op,
+ "Failed to get the strides and offset");
+ }
+ ArrayRef<int64_t> shape = mTy.getShape();
+
+ // Create a new memref descriptor
+ Location loc = op.getLoc();
+ auto desc = MemRefDescriptor::poison(rewriter, loc, descriptorTy);
+
+ // Set the allocated and aligned pointers.
+ desc.setAllocatedPtr(
+ rewriter, loc,
+ rewriter.create<LLVM::ExtractValueOp>(loc, adaptor.getMetadata(), 0));
+ desc.setAlignedPtr(rewriter, loc, adaptor.getPtr());
+
+ // Extract metadata from the passed struct.
+ unsigned fieldIdx = 1;
+
+ // Set dynamic offset if needed.
+ if (offset == ShapedType::kDynamic) {
+ Value offsetValue = rewriter.create<LLVM::ExtractValueOp>(
+ loc, adaptor.getMetadata(), fieldIdx++);
+ desc.setOffset(rewriter, loc, offsetValue);
+ } else {
+ desc.setConstantOffset(rewriter, loc, offset);
+ }
+
+ // Set dynamic sizes if needed.
+ for (auto [i, dim] : llvm::enumerate(shape)) {
+ if (dim == ShapedType::kDynamic) {
+ Value sizeValue = rewriter.create<LLVM::ExtractValueOp>(
+ loc, adaptor.getMetadata(), fieldIdx++);
+ desc.setSize(rewriter, loc, i, sizeValue);
+ } else {
+ desc.setConstantSize(rewriter, loc, i, dim);
+ }
+ }
+
+ // Set dynamic strides if needed.
+ for (auto [i, stride] : llvm::enumerate(strides)) {
+ if (stride == ShapedType::kDynamic) {
+ Value strideValue = rewriter.create<LLVM::ExtractValueOp>(
+ loc, adaptor.getMetadata(), fieldIdx++);
+ desc.setStride(rewriter, loc, i, strideValue);
+ } else {
+ desc.setConstantStride(rewriter, loc, i, stride);
+ }
+ }
+
+ rewriter.replaceOp(op, static_cast<Value>(desc));
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// GetMetadataOpConversion
+//===----------------------------------------------------------------------===//
+
+LogicalResult GetMetadataOpConversion::matchAndRewrite(
+ ptr::GetMetadataOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ auto mTy = dyn_cast<MemRefType>(op.getPtr().getType());
+ if (!mTy)
+ return rewriter.notifyMatchFailure(op, "Only memref metadata is supported");
+
+ // Get the metadata type.
+ FailureOr<LLVM::LLVMStructType> mdTy =
+ createMemRefMetadataType(mTy, *getTypeConverter());
+ if (failed(mdTy)) {
+ return rewriter.notifyMatchFailure(op,
+ "Failed to create the metadata type");
+ }
+
+ // Get the memref descriptor.
+ MemRefDescriptor descriptor(adaptor.getPtr());
+
+ // Get the strides offsets and shape.
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ if (failed(mTy.getStridesAndOffset(strides, offset))) {
+ return rewriter.notifyMatchFailure(op,
+ "Failed to get the strides and offset");
+ }
+ ArrayRef<int64_t> shape = mTy.getShape();
+
+ // Create a new LLVM struct to hold the metadata
+ Location loc = op.getLoc();
+ Value sV = rewriter.create<LLVM::UndefOp>(loc, *mdTy);
+
+ // First element is the allocated pointer.
+ sV = rewriter.create<LLVM::InsertValueOp>(
+ loc, sV, descriptor.allocatedPtr(rewriter, loc), 0);
+
+ // Track the current field index.
+ unsigned fieldIdx = 1;
+
+ // Add dynamic offset if needed.
+ if (offset == ShapedType::kDynamic) {
+ sV = rewriter.create<LLVM::InsertValueOp>(
+ loc, sV, descriptor.offset(rewriter, loc), fieldIdx++);
+ }
+
+ // Add dynamic sizes if needed.
+ for (auto [i, dim] : llvm::enumerate(shape)) {
+ if (dim != ShapedType::kDynamic)
+ continue;
+ sV = rewriter.create<LLVM::InsertValueOp>(
+ loc, sV, descriptor.size(rewriter, loc, i), fieldIdx++);
+ }
+
+ // Add dynamic strides if needed
+ for (auto [i, stride] : llvm::enumerate(strides)) {
+ if (stride != ShapedType::kDynamic)
+ continue;
+ sV = rewriter.create<LLVM::InsertValueOp>(
+ loc, sV, descriptor.stride(rewriter, loc, i), fieldIdx++);
+ }
+ rewriter.replaceOp(op, sV);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// PtrAddOpConversion
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+PtrAddOpConversion::matchAndRewrite(ptr::PtrAddOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // Get and check the base.
+ Value base = adaptor.getBase();
+ if (!isa<LLVM::LLVMPointerType>(base.getType()))
+ return rewriter.notifyMatchFailure(op, "Incompatible pointer type");
+
+ // Get the offset.
+ Value offset = adaptor.getOffset();
+
+ // Ptr assumes the offset is in bytes.
+ Type elementType = IntegerType::get(rewriter.getContext(), 8);
+
+ // Convert the `ptradd` flags.
+ LLVM::GEPNoWrapFlags flags;
+ switch (op.getFlags()) {
+ case ptr::PtrAddFlags::none:
+ flags = LLVM::GEPNoWrapFlags::none;
+ break;
+ case ptr::PtrAddFlags::nusw:
+ flags = LLVM::GEPNoWrapFlags::nusw;
+ break;
+ case ptr::PtrAddFlags::nuw:
+ flags = LLVM::GEPNoWrapFlags::nuw;
+ break;
+ case ptr::PtrAddFlags::inbounds:
+ flags = LLVM::GEPNoWrapFlags::inbounds;
+ break;
+ }
+
+ // Create the GEP operation with appropriate arguments
+ rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, base.getType(), elementType,
+ base, ValueRange{offset}, flags);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ToPtrOpConversion
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ToPtrOpConversion::matchAndRewrite(ptr::ToPtrOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // Bail if it's not a memref.
+ if (!isa<MemRefType>(op.getPtr().getType()))
+ return rewriter.notifyMatchFailure(op, "Expected a memref input");
+
+ // Extract the aligned pointer from the memref descriptor.
+ rewriter.replaceOp(
+ op, MemRefDescriptor(adaptor.getPtr()).alignedPtr(rewriter, op.getLoc()));
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TypeOffsetOpConversion
+//===----------------------------------------------------------------------===//
+
+LogicalResult TypeOffsetOpConversion::matchAndRewrite(
+ ptr::TypeOffsetOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // Convert the type attribute.
+ Type type = getTypeConverter()->convertType(op.getElementType());
+ if (!type)
+ return rewriter.notifyMatchFailure(op, "Couldn't convert the type");
+
+ // Convert the result type.
+ Type rTy = getTypeConverter()->convertType(op.getResult().getType());
+ if (!rTy)
+ return rewriter.notifyMatchFailure(op, "Couldn't convert the result type");
+
+ // TODO: Use MLIR's data layout. We don't use it because overall support is
+ // still flaky.
+
+ // Create an LLVM pointer type for the GEP operation.
+ auto ptrTy = LLVM::LLVMPointerType::get(getContext());
+
+ // Create a GEP operation to compute the offset of the type.
+ auto offset =
+ LLVM::GEPOp::create(rewriter, op.getLoc(), ptrTy, type,
+ LLVM::ZeroOp::create(rewriter, op.getLoc(), ptrTy),
+ ArrayRef<LLVM::GEPArg>({LLVM::GEPArg(1)}));
+
+ // Replace the original op with a PtrToIntOp using the computed offset.
+ rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, rTy, offset.getRes());
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ConvertToLLVMPatternInterface implementation
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Implement the interface to convert Ptr to LLVM.
+struct PtrToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
+ using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
+ void loadDependentDialects(MLIRContext *context) const final {
+ context->loadDialect<LLVM::LLVMDialect>();
+ }
+
+ /// Hook for derived dialect interface to provide conversion patterns
+ /// and mark dialect legal for the conversion target.
+ void populateConvertToLLVMConversionPatterns(
+ ConversionTarget &target, LLVMTypeConverter &converter,
+ RewritePatternSet &patterns) const final {
+ ptr::populatePtrToLLVMConversionPatterns(converter, patterns);
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// API
+//===----------------------------------------------------------------------===//
+
+void mlir::ptr::populatePtrToLLVMConversionPatterns(
+ LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+ // Add address space conversions.
+ converter.addTypeAttributeConversion(
+ [&](PtrLikeTypeInterface type, ptr::GenericSpaceAttr memorySpace)
+ -> TypeConverter::AttributeConversionResult {
+ if (type.getMemorySpace() != memorySpace)
+ return TypeConverter::AttributeConversionResult::na();
+ return IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0);
+ });
+
+ // Add type conversions.
+ converter.addConversion([&](ptr::PtrType type) -> Type {
+ std::optional<Attribute> maybeAttr =
+ converter.convertTypeAttribute(type, type.getMemorySpace());
+ auto memSpace =
+ maybeAttr ? dyn_cast_or_null<IntegerAttr>(*maybeAttr) : IntegerAttr();
+ if (!memSpace)
+ return {};
+ return LLVM::LLVMPointerType::get(type.getContext(),
+ memSpace.getValue().getSExtValue());
+ });
+
+ // Convert ptr metadata of memref type.
+ converter.addConversion([&](ptr::PtrMetadataType type) -> Type {
+ auto mTy = dyn_cast<MemRefType>(type.getType());
+ if (!mTy)
+ return {};
+ FailureOr<LLVM::LLVMStructType> res =
+ createMemRefMetadataType(mTy, converter);
+ return failed(res) ? Type() : res.value();
+ });
+
+ // Add conversion patterns.
+ patterns.add<FromPtrOpConversion, GetMetadataOpConversion, PtrAddOpConversion,
+ ToPtrOpConversion, TypeOffsetOpConversion>(converter);
+}
+
+void mlir::ptr::registerConvertPtrToLLVMInterface(DialectRegistry &registry) {
+ registry.addExtension(+[](MLIRContext *ctx, ptr::PtrDialect *dialect) {
+ dialect->addInterfaces<PtrToLLVMDialectInterface>();
+ });
+}
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index ba448e4..37cfc9f 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -382,8 +382,11 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
// With the body block done, we can fill in the condition block.
rewriter.setInsertionPointToEnd(conditionBlock);
- auto comparison = arith::CmpIOp::create(
- rewriter, loc, arith::CmpIPredicate::slt, iv, upperBound);
+ arith::CmpIPredicate predicate = forOp.getUnsignedCmp()
+ ? arith::CmpIPredicate::ult
+ : arith::CmpIPredicate::slt;
+ auto comparison =
+ arith::CmpIOp::create(rewriter, loc, predicate, iv, upperBound);
cf::CondBranchOp::create(rewriter, loc, comparison, firstBodyBlock,
ArrayRef<Value>(), endBlock, ArrayRef<Value>());
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
index 84cbd86..1f239aa 100644
--- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
+++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
@@ -154,6 +154,10 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = forOp.getLoc();
+ if (forOp.getUnsignedCmp())
+ return rewriter.notifyMatchFailure(forOp,
+ "unsigned loops are not supported");
+
// Create an emitc::variable op for each result. These variables will be
// assigned to by emitc::assign ops within the loop body.
SmallVector<Value> resultVariables;
diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
index badd2f6..7d0a236 100644
--- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
+++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
@@ -27,7 +27,7 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/RegionUtils.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include <optional>
#define DEBUG_TYPE "loops-to-gpu"
@@ -134,7 +134,7 @@ static LogicalResult checkAffineLoopNestMappable(AffineForOp forOp,
unsigned numBlockDims,
unsigned numThreadDims) {
if (numBlockDims < 1 || numThreadDims < 1) {
- LLVM_DEBUG(llvm::dbgs() << "nothing to map");
+ LDBG() << "nothing to map";
return success();
}
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 34f372a..c4a9fc2 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -22,7 +22,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
@@ -538,15 +538,18 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/// Applies the conversion patterns in the given function.
static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
- ConversionTarget target(*module.getContext());
- target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
- target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
- memref::MemRefDialect>();
-
RewritePatternSet patterns(module.getContext());
patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
FrozenRewritePatternSet frozen(std::move(patterns));
- return applyPartialConversion(module, target, frozen);
+ walkAndApplyPatterns(module, frozen);
+ auto status = module.walk([](Operation *op) {
+ if (isa<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(op)) {
+ op->emitError("unconverted operation found");
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ return failure(status.wasInterrupted());
}
/// A pass converting SCF operations to OpenMP operations.
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index dc92367f..55ed31e 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -178,8 +178,14 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
// Generate the rest of the loop header.
rewriter.setInsertionPointToEnd(header);
auto *mergeBlock = loopOp.getMergeBlock();
- auto cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(),
- newIndVar, adaptor.getUpperBound());
+ Value cmpOp;
+ if (forOp.getUnsignedCmp()) {
+ cmpOp = spirv::ULessThanOp::create(rewriter, loc, rewriter.getI1Type(),
+ newIndVar, adaptor.getUpperBound());
+ } else {
+ cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(),
+ newIndVar, adaptor.getUpperBound());
+ }
spirv::BranchConditionalOp::create(rewriter, loc, cmpOp, body,
ArrayRef<Value>(), mergeBlock,
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index fa9e544..398ab88 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -301,7 +301,7 @@ public:
ConversionPatternRewriter &rewriter) const override {
// Create mpi::CommRankOp
Location loc = op.getLoc();
- auto ctx = op.getContext();
+ auto *ctx = op.getContext();
Value commWorld =
mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx));
auto rank = mpi::CommRankOp::create(
@@ -520,7 +520,7 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
};
static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
- auto ctx = kind.getContext();
+ auto *ctx = kind.getContext();
auto getReductionOp = [ctx](mpi::MPI_ReductionOpEnum redOp) {
return mpi::MPI_ReductionOpEnumAttr::get(ctx, redOp);
};
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index 044b725..e568660 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -64,8 +64,9 @@ public:
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
PatternRewriter &rewriter) const final {
- StringRef roundingMode = op.getRoundingMode();
- if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+ RoundingMode roundingMode = op.getRoundingMode();
+ if (roundingMode != RoundingMode::DOUBLE_ROUND &&
+ roundingMode != RoundingMode::SINGLE_ROUND) {
return failure();
}
@@ -100,7 +101,7 @@ public:
multiply64 = arith::AddIOp::create(rewriter, loc, multiply64, round);
// Apply double rounding if necessary.
- if (op.getRoundingMode() == "DOUBLE_ROUND") {
+ if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND) {
int64_t roundInt = 1 << 30;
Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
@@ -129,8 +130,9 @@ public:
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
PatternRewriter &rewriter) const final {
- StringRef roundingMode = op.getRoundingMode();
- if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+ RoundingMode roundingMode = op.getRoundingMode();
+ if (roundingMode != RoundingMode::DOUBLE_ROUND &&
+ roundingMode != RoundingMode::SINGLE_ROUND) {
return failure();
}
@@ -179,7 +181,7 @@ public:
arith::SelectOp::create(rewriter, loc, shiftOver32, shiftHighR, zero32);
// Conditionally perform our double round.
- if (op.getRoundingMode() == "DOUBLE_ROUND") {
+ if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND) {
Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
Value valuePositive = arith::CmpIOp::create(
rewriter, loc, arith::CmpIPredicate::sge, value32, zero32);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 0e3de06..d0a431b 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -65,7 +65,7 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
return result;
auto nanMode = op.getNanMode();
- if (nanMode == "PROPAGATE")
+ if (nanMode == NanPropagationMode::PROPAGATE)
return result;
// Unordered comparison of NaN against itself will always return true.
@@ -126,12 +126,12 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::MulOp>(op)) {
auto shiftVal = cast<tosa::MulOp>(op).getShift();
DenseElementsAttr shiftElem;
- if (!matchPattern(shiftVal, m_Constant(&shiftElem))) {
- (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
- return nullptr;
- }
-
- int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
+ bool shiftIsConstant = true;
+ int32_t shift = 0;
+ if (matchPattern(shiftVal, m_Constant(&shiftElem)))
+ shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
+ else
+ shiftIsConstant = false;
if (isa<FloatType>(elementTy)) {
if (shift != 0) {
@@ -147,23 +147,26 @@ static Value createLinalgBodyCalculationForElementwiseOp(
Value a = args[0];
Value b = args[1];
- if (shift > 0) {
- auto shiftConst =
- arith::ConstantIntOp::create(rewriter, loc, shift, /*bitwidth=*/8);
+ if (shift > 0 || !shiftIsConstant) {
+ Value shiftConst;
+ if (shiftIsConstant)
+ shiftConst =
+ rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
+
if (!a.getType().isInteger(32))
a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a);
if (!b.getType().isInteger(32))
b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);
- auto result = tosa::ApplyScaleOp::create(
- rewriter, loc, rewriter.getI32Type(), a, b, shiftConst,
- rewriter.getStringAttr("SINGLE_ROUND"));
-
- if (elementTy.isInteger(32))
- return result;
+ auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
+ auto roundingAttr = RoundingModeAttr::get(rewriter.getContext(),
+ RoundingMode::SINGLE_ROUND);
+ auto result =
+ tosa::ApplyScaleOp::create(rewriter, loc, rewriter.getI32Type(), a,
+ b, shiftAmount, roundingAttr);
- return arith::TruncIOp::create(rewriter, loc, elementTy, result);
+ return result;
}
int aWidth = a.getType().getIntOrFloatBitWidth();
@@ -464,7 +467,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// In the case of "PROPAGATE" semantics no compare and selection is
// required.
- if (nanMode == "PROPAGATE")
+ if (nanMode == NanPropagationMode::PROPAGATE)
return result;
// In the case of "IGNORE" semantics materialize a comparison
@@ -918,6 +921,18 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
if (operands.size() == 1)
return operands;
+ // No need to broadcast for static shape
+ bool hasDynamic = false;
+ for (auto op : operands) {
+ const auto tType = dyn_cast<RankedTensorType>(op.getType());
+ if (tType && !tType.hasStaticShape()) {
+ hasDynamic = true;
+ break;
+ }
+ }
+ if (!hasDynamic)
+ return operands;
+
// Broadcast dynamic dimensions operand by operand
return llvm::map_to_vector(operands, [&](Value operand) {
return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
@@ -990,8 +1005,14 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
static ValueRange getBroadcastableOperands(Operation *operation,
ValueRange operands) {
// Shift cannot broadcast
- if (isa<tosa::MulOp>(operation))
- return operands.take_front(2);
+ if (isa<tosa::MulOp>(operation)) {
+ DenseElementsAttr shiftElems;
+ // Shift cannot broadcast when it is constant
+ if (matchPattern(operation->getOperand(2), m_Constant(&shiftElems)))
+ return operands.take_front(2);
+ else
+ return operands.take_front(3);
+ }
// Input1_zp and output_zp cannot broadcast
if (isa<tosa::NegateOp>(operation))
return operands.take_front(1);
@@ -1173,7 +1194,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
// NaN propagation has no meaning for non floating point types.
- if (isa<FloatType>(elementTy) && op.getNanMode() == "IGNORE") {
+ if (isa<FloatType>(elementTy) &&
+ op.getNanMode() == NanPropagationMode::IGNORE) {
isNanIgnoreMode = true;
// Because the TOSA spec requires the result be NaN iff all elements in
// the reduction are NaN we can't simply perform a compare and select.
@@ -1336,11 +1358,11 @@ public:
unsigned rank = inputTy.getRank();
// This is an illegal configuration. terminate and log an error
- if (op.getRoundingMode() == "INEXACT_ROUND")
+ if (op.getRoundingMode() == RoundingMode::INEXACT_ROUND)
return rewriter.notifyMatchFailure(
op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
"currently supported");
- if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())
+ if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && !op.getScale32())
return rewriter.notifyMatchFailure(
op, "tosa.rescale requires scale32 for double_round to be true");
@@ -1386,11 +1408,10 @@ public:
// is ever true.
bool doubleRound =
- op.getRoundingMode() == "DOUBLE_ROUND" &&
+ op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
- StringAttr roundingMode = doubleRound
- ? rewriter.getStringAttr("DOUBLE_ROUND")
- : rewriter.getStringAttr("SINGLE_ROUND");
+ RoundingMode roundingMode =
+ doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(rank)};
@@ -1573,7 +1594,7 @@ public:
auto input = op.getInput();
auto inputTy = cast<RankedTensorType>(input.getType());
auto resultTy = cast<RankedTensorType>(op.getType());
- const bool isBilinear = op.getMode() == "BILINEAR";
+ const bool isBilinear = op.getMode() == ResizeMode::BILINEAR;
auto inputH = inputTy.getDimSize(1);
auto inputW = inputTy.getDimSize(2);
@@ -1584,8 +1605,8 @@ public:
return rewriter.notifyMatchFailure(
op, "tosa.resize is not a pure 1x1->1x1 image operation");
- // TODO(suderman): These string values should be declared the TOSA dialect.
- if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
+ if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
+ op.getMode() != ResizeMode::BILINEAR)
return rewriter.notifyMatchFailure(
op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
@@ -1785,7 +1806,8 @@ public:
return rewriter.notifyMatchFailure(
op, "unable to get dynamic dimensions of tosa.resize");
- if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
+ if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
+ op.getMode() != ResizeMode::BILINEAR)
return rewriter.notifyMatchFailure(
op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
@@ -1890,7 +1912,7 @@ public:
getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
}
- if (op.getMode() == "NEAREST_NEIGHBOR") {
+ if (op.getMode() == ResizeMode::NEAREST_NEIGHBOR) {
auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
@@ -1926,7 +1948,7 @@ public:
linalg::YieldOp::create(b, result);
} else {
// The mode here must be BILINEAR.
- assert(op.getMode() == "BILINEAR");
+ assert(op.getMode() == ResizeMode::BILINEAR);
auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
@@ -2291,7 +2313,7 @@ public:
Value predicate;
if (isa<FloatType>(inElementTy)) {
- if (argmaxOp.getNanMode() == "IGNORE") {
+ if (argmaxOp.getNanMode() == NanPropagationMode::IGNORE) {
// Only update index & max value for non NaN values. If all
// values are NaNs, the initial index will be return which is 0.
predicate = arith::CmpFOp::create(rewriter, nestedLoc,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 12d85ca..6f28849 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -803,7 +803,7 @@ public:
dilationAttr);
rewriter.setInsertionPointAfter(op);
- StringRef nanMode = op.getNanMode();
+ NanPropagationMode nanMode = op.getNanMode();
rewriter.replaceOp(op, resultOp);
// NaN propagation has no meaning for non floating point types.
@@ -817,7 +817,7 @@ public:
// we've already produced a named op we will just take its body and modify
// it to include the appropriate checks. If the current value is NaN the
// old value of pool will be taken otherwise we use the result.
- if (nanMode == "IGNORE") {
+ if (nanMode == NanPropagationMode::IGNORE) {
auto genericOp = linalg::GenericOp::create(
rewriter, loc, resultOp.getType(0), resultOp.getInputs(),
resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
@@ -1040,11 +1040,13 @@ public:
rewriter, loc, rewriter.getI8IntegerAttr(30));
Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
- auto scaled =
- tosa::ApplyScaleOp::create(
- rewriter, loc, rewriter.getI32Type(), poolVal, multiplier,
- shift, rewriter.getStringAttr("SINGLE_ROUND"))
- .getResult();
+ auto roundingAttr = RoundingModeAttr::get(
+ rewriter.getContext(), RoundingMode::SINGLE_ROUND);
+
+ auto scaled = tosa::ApplyScaleOp::create(
+ rewriter, loc, rewriter.getI32Type(), poolVal,
+ multiplier, shift, roundingAttr)
+ .getResult();
// If we have quantization information we need to apply output
// zeropoint.
diff --git a/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt b/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt
new file mode 100644
index 0000000..2d4b2b6
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRVectorToAMX
+ VectorToAMX.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToAMX
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRAMXDialect
+ MLIRAffineUtils
+ MLIRArithDialect
+ MLIRLinalgUtils
+ MLIRMemRefDialect
+ MLIRSCFDialect
+ MLIRTransforms
+ MLIRVectorDialect
+ )
diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
new file mode 100644
index 0000000..7b9ed1d
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
@@ -0,0 +1,429 @@
+//===- VectorToAMX.cpp - Convert vector to AMX dialect ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/VectorToAMX/VectorToAMX.h"
+
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "llvm/Support/DebugLog.h"
+
+#include <numeric>
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTVECTORTOAMX
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "vector-to-amx"
+
+namespace {
+
+/// Return true if vector shape is compatible with AMX tiles.
+/// The validation accounts for VNNI packing.
+static bool verifyAmxShape(VectorType vec) {
+ // Check overall shape:
+ // - 2D for plain layout input or output
+ // - 3D for VNNI packed input
+ if (vec.getRank() != 2 && vec.getRank() != 3)
+ return false;
+
+ ArrayRef<int64_t> shape = vec.getShape();
+ int64_t rows = shape[0];
+ int64_t cols = shape[1];
+ unsigned elemBitWidth = vec.getElementType().getIntOrFloatBitWidth();
+
+ // 3D shape indicates VNNI packed layout.
+ if (vec.getRank() == 3) {
+ int64_t vnniFactor = 32 / elemBitWidth;
+ if (shape.back() != vnniFactor) {
+ LDBG() << "invalid VNNI packing factor";
+ return false;
+ }
+ cols *= vnniFactor;
+ }
+
+ // AMX tile supports up to 16 rows of 64 bytes each.
+ constexpr unsigned maxRows = 16;
+ constexpr unsigned maxBitsPerRow = 64 * 8;
+ return rows <= maxRows && (cols * elemBitWidth) <= maxBitsPerRow;
+}
+
+/// Check if contraction operands are in AMX-compatible packed VNNI layout.
+static LogicalResult isAmxVnniLayout(PatternRewriter &rewriter,
+ vector::ContractionOp contractOp) {
+ VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
+ if (!accType || accType.getRank() != 2)
+ return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
+
+ // Expect 3D inputs for VNNI packed data.
+ VectorType lhsType = contractOp.getLhs().getType();
+ VectorType rhsType = contractOp.getRhs().getType();
+ if (lhsType.getRank() != 3 || rhsType.getRank() != 3)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects lhs and rhs 3D vectors");
+
+ // Check if shapes are compatible with AMX tile.
+ if (!verifyAmxShape(lhsType) || !verifyAmxShape(rhsType) ||
+ !verifyAmxShape(accType))
+ return rewriter.notifyMatchFailure(contractOp, "Invalid operand shape");
+
+ // Validate affine maps.
+ //
+ // Iterators can be ordered arbitrarily. Indexing map positions are based on
+ // operands' target shapes.
+ // The matrix layouts must match the following:
+ // - matrix A - [M]x[K/vnniFactor]x[vnniFactor]
+ // - matrix B - [K/vnniFactor]x[N]x[vnniFactor]
+ // - matrix C - [M]x[N]
+ SmallVector<AffineMap, 4> indexingMaps = contractOp.getIndexingMapsArray();
+ AffineMap mapA = indexingMaps[0];
+ AffineMap mapB = indexingMaps[1];
+ if (mapA.getNumInputs() != 4 || mapA.getNumResults() != 3 ||
+ mapB.getNumResults() != 3)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Invalid input indexing maps");
+ FailureOr<linalg::ContractionDimensions> dims =
+ linalg::inferContractionDims(indexingMaps);
+ if (failed(dims))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Failed to infer contraction dims");
+ // Two reduction dimensions are expected:
+ // - one for the K dimension
+ // - one for the VNNI factor
+ if (dims->k.size() != 2)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expected two reduction dims");
+ assert(dims->m.size() == 1 && dims->n.size() == 1 &&
+ "Invalid parallel contraction dims");
+
+ SmallVector<vector::IteratorType> iteratorTypes =
+ contractOp.getIteratorTypesArray();
+ // Check VNNI dim maps - the innermost dim for A and B inputs.
+ auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(2));
+ auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(2));
+ if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
+ iteratorTypes[vnniDimA.getPosition()] != vector::IteratorType::reduction)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI dim map");
+ // Check K dim maps - non-transposed row-major layout.
+ auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(1));
+ auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(0));
+ if (!redDimA || !redDimB || redDimA != redDimB ||
+ iteratorTypes[redDimA.getPosition()] != vector::IteratorType::reduction)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid K dim map");
+ // Check M and N dim maps - map to non-transposed output.
+ AffineMap mapC = indexingMaps[2];
+ auto mDimC = dyn_cast<AffineDimExpr>(mapC.getResult(0));
+ auto nDimC = dyn_cast<AffineDimExpr>(mapC.getResult(1));
+ if (!mDimC || !nDimC)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid acc maps");
+ auto parallelDimA = dyn_cast<AffineDimExpr>(mapA.getResult(0));
+ if (!parallelDimA ||
+ iteratorTypes[parallelDimA.getPosition()] !=
+ vector::IteratorType::parallel ||
+ parallelDimA != mDimC)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid M dim map");
+ auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(1));
+ if (!parallelDimB ||
+ iteratorTypes[parallelDimB.getPosition()] !=
+ vector::IteratorType::parallel ||
+ parallelDimB != nDimC)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid N dim map");
+
+ return success();
+}
+
+/// Validate contraction operands for AMX lowering.
+static LogicalResult validateOperands(PatternRewriter &rewriter,
+ vector::ContractionOp contractOp) {
+ VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
+ if (!accType)
+ return rewriter.notifyMatchFailure(contractOp, "Expects vector acc");
+
+ // Check if operand types are compatible with AMX compute ops.
+ bool validElemTypes = false;
+ Type lhsElemType = contractOp.getLhs().getType().getElementType();
+ Type rhsElemType = contractOp.getRhs().getType().getElementType();
+ Type accElemType = accType.getElementType();
+ if (accElemType.isInteger(32)) {
+ validElemTypes = lhsElemType.isInteger(8) && rhsElemType.isInteger(8);
+ } else if (accElemType.isF32()) {
+ validElemTypes = (lhsElemType.isF16() && rhsElemType.isF16()) ||
+ (lhsElemType.isBF16() && rhsElemType.isBF16());
+ }
+ if (!validElemTypes)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Invalid combination of operand types");
+
+ if (failed(isAmxVnniLayout(rewriter, contractOp)))
+ return failure();
+
+ return success();
+}
+
+/// Collapse the two innermost dimensions together.
+static TypedValue<MemRefType> collapseLastDim(PatternRewriter &rewriter,
+ TypedValue<MemRefType> memref) {
+ int64_t rank = memref.getType().getRank();
+ SmallVector<ReassociationIndices> reassocIndices;
+ for (auto i : llvm::seq<int64_t>(0, rank - 2))
+ reassocIndices.push_back({i});
+ reassocIndices.push_back({rank - 2, rank - 1});
+ return memref::CollapseShapeOp::create(rewriter, memref.getLoc(), memref,
+ reassocIndices);
+}
+
+/// Attempt to create an AMX tile load/store operation equivalent to the given
+/// vector transfer `xfer` op.
+/// This approach allows to skip longer route through registers and a temporary
+/// buffer otherwise required to move data to/from an AMX tile.
+static Operation *
+loadStoreFromTransfer(PatternRewriter &rewriter,
+ VectorTransferOpInterface xferOp, bool isPacked,
+ TypedValue<amx::TileType> tileToStore = nullptr) {
+ if (!xferOp || !isa<vector::TransferReadOp, vector::TransferWriteOp>(xferOp))
+ return nullptr;
+ if (xferOp.hasOutOfBoundsDim() ||
+ !xferOp.getPermutationMap().isMinorIdentity())
+ return nullptr;
+
+ // Extra checks in case of a write op.
+ // Stores must not be packed.
+ if (isa<vector::TransferWriteOp>(xferOp) &&
+ (!tileToStore || isPacked ||
+ tileToStore.getType().getShape() != xferOp.getVectorType().getShape()))
+ return nullptr;
+
+ // Check for a memref source buffer.
+ // AMX data transfer requires at least 2D shape to correctly
+ // infer stride between rows.
+ Value base = xferOp.getBase();
+ auto memTy = dyn_cast<MemRefType>(base.getType());
+ int64_t memRank = memTy.getRank();
+ if (!memTy || memRank < 2)
+ return nullptr;
+
+ // Check that the source buffer has enough contiguous elements to load whole
+ // AMX tile row.
+ //
+ // To ensure correctness, the validation is conservative and expects the
+ // buffer's innermost dimensions to be statically known, equal to or larger
+ // than the vector row length, and equal to the VNNI dimension if applicable.
+ //
+ // This check could be relaxed to accept more arbitrarily shaped buffers as
+ // long as there are enough contiguous elements to load a whole row.
+ if (!memTy.areTrailingDimsContiguous(isPacked ? 2 : 1))
+ return nullptr;
+ VectorType vecTy = xferOp.getVectorType();
+ ArrayRef<int64_t> vecShape = vecTy.getShape();
+ ArrayRef<int64_t> memShape = memTy.getShape();
+ if (memShape.back() == ShapedType::kDynamic ||
+ memShape.back() < vecShape.back())
+ return nullptr;
+ if (isPacked &&
+ (memShape.back() != vecShape.back() ||
+ memShape[memShape.size() - 2] == ShapedType::kDynamic ||
+ memShape[memShape.size() - 2] < vecShape[vecShape.size() - 2]))
+ return nullptr;
+
+ // Load values directly from the buffer to an AMX tile.
+ PatternRewriter::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(xferOp);
+ Location loc = xferOp.getLoc();
+
+ // Create a subview of the source buffer based on the transfer op to resolve
+ // offsets.
+ SmallVector<OpFoldResult> strides(memRank, rewriter.getIndexAttr(1));
+ int64_t vecRank = vecTy.getRank();
+ assert(memRank >= vecRank &&
+ "Expects buffer to be the same or greater rank than vector");
+ SmallVector<int64_t> shape(memRank - vecRank, 1);
+ shape.append(vecShape.begin(), vecShape.end());
+ TypedValue<MemRefType> src =
+ memref::SubViewOp::create(
+ rewriter, loc, base, getAsOpFoldResult(xferOp.getIndices()),
+ getAsOpFoldResult(rewriter.getI64ArrayAttr(shape)), strides)
+ .getResult();
+
+ // Collapse the VNNI dimension in case of packing.
+ if (isPacked)
+ src = collapseLastDim(rewriter, src);
+ int64_t rows = vecShape[0];
+ int64_t cols = std::accumulate(vecShape.begin() + 1, vecShape.end(), 1,
+ std::multiplies<int64_t>());
+ auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
+
+ Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+ SmallVector<Value> tileIndicides(src.getType().getRank(), zeroIndex);
+
+ Operation *amxTileOp = nullptr;
+ if (isa<vector::TransferReadOp>(xferOp)) {
+ amxTileOp =
+ amx::TileLoadOp::create(rewriter, loc, tileType, src, tileIndicides);
+ } else if (isa<vector::TransferWriteOp>(xferOp)) {
+ amxTileOp = amx::TileStoreOp::create(rewriter, loc, src, tileIndicides,
+ tileToStore);
+ } else {
+ llvm_unreachable("unsupported vector transfer op");
+ }
+
+ return amxTileOp;
+}
+
+/// Attempt to create an AMX tile load operation equivalent to the given
+/// vector transfer `readOp`.
+/// Returns loaded AMX tile if successful.
+static FailureOr<TypedValue<amx::TileType>>
+loadFromTransfer(PatternRewriter &rewriter, vector::TransferReadOp readOp,
+ bool isPacked) {
+ amx::TileLoadOp loadOp = dyn_cast_if_present<amx::TileLoadOp>(
+ loadStoreFromTransfer(rewriter, readOp, isPacked));
+ if (!loadOp)
+ return failure();
+ return loadOp.getRes();
+}
+
+/// Attempt to create an AMX tile store operation equivalent to the given
+/// vector transfer `writeOp`.
+static LogicalResult storeFromTransfer(PatternRewriter &rewriter,
+ vector::TransferWriteOp writeOp,
+ TypedValue<amx::TileType> tileToStore) {
+ return success(loadStoreFromTransfer(rewriter, writeOp, /*isPacked=*/false,
+ tileToStore));
+}
+
+/// Load vector values to an AMX tile.
+static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
+ TypedValue<VectorType> vec) {
+ Location loc = vec.getLoc();
+
+ VectorType vecTy = vec.getType();
+ bool isPacked = vecTy.getRank() == 3;
+
+ // Try to load tile directly from vector producer's buffer.
+ auto readOp = vec.getDefiningOp<vector::TransferReadOp>();
+ FailureOr<TypedValue<amx::TileType>> tile =
+ loadFromTransfer(rewriter, readOp, isPacked);
+ if (succeeded(tile))
+ return *tile;
+
+ // Transfer the vector to a tile through an intermediate buffer.
+ Value buf = memref::AllocaOp::create(
+ rewriter, loc, MemRefType::get(vecTy.getShape(), vecTy.getElementType()));
+ Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+ SmallVector<Value> indices(vecTy.getRank(), zeroIndex);
+ vector::TransferWriteOp::create(rewriter, loc, vec, buf, indices);
+
+ // Collapse the VNNI dimension in case of packing.
+ if (isPacked)
+ buf = collapseLastDim(rewriter, cast<TypedValue<MemRefType>>(buf));
+
+ ArrayRef<int64_t> shape = vecTy.getShape();
+ int64_t rows = shape[0];
+ int64_t cols = std::accumulate(shape.begin() + 1, shape.end(), 1,
+ std::multiplies<int64_t>());
+ auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
+
+ return amx::TileLoadOp::create(rewriter, loc, tileType, buf,
+ {zeroIndex, zeroIndex});
+}
+
+/// Store an AMX tile in a vector.
+static TypedValue<VectorType> storeTile(PatternRewriter &rewriter,
+ TypedValue<amx::TileType> tile) {
+ Location loc = tile.getLoc();
+
+ // Transfer the tile to a vector through an intermediate buffer.
+ amx::TileType tileTy = tile.getType();
+ Value buf = memref::AllocaOp::create(
+ rewriter, loc,
+ MemRefType::get(tileTy.getShape(), tileTy.getElementType()));
+ Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+ SmallVector<Value> indices(2, zeroIndex);
+ amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);
+
+ auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType());
+ return vector::TransferReadOp::create(rewriter, loc, vecTy, buf, indices, {});
+}
+
+struct ContractionToAMX : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = contractOp.getLoc();
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind");
+ if (failed(validateOperands(rewriter, contractOp)))
+ return failure();
+
+ TypedValue<amx::TileType> lhsTile = loadTile(rewriter, contractOp.getLhs());
+ TypedValue<amx::TileType> rhsTile = loadTile(rewriter, contractOp.getRhs());
+ auto acc = dyn_cast<TypedValue<VectorType>>(contractOp.getAcc());
+ assert(acc && "Invalid accumulator type");
+ TypedValue<amx::TileType> accTile = loadTile(rewriter, acc);
+
+ TypedValue<amx::TileType> tileMul;
+ if (acc.getType().getElementType().isFloat()) {
+ tileMul = amx::TileMulFOp::create(rewriter, loc, accTile.getType(),
+ lhsTile, rhsTile, accTile);
+ } else {
+ tileMul = amx::TileMulIOp::create(rewriter, loc, accTile.getType(),
+ lhsTile, rhsTile, accTile);
+ }
+
+ // If the contraction result is only written back to memory, try to replace
+ // the vector op with an AMX store directly.
+ Value res = contractOp.getResult();
+ if (res.hasOneUse()) {
+ auto writeOp = dyn_cast<vector::TransferWriteOp>(*res.getUsers().begin());
+ LogicalResult storeRes = storeFromTransfer(rewriter, writeOp, tileMul);
+ if (succeeded(storeRes)) {
+ rewriter.eraseOp(writeOp);
+ rewriter.eraseOp(contractOp);
+ return success();
+ }
+ }
+
+ // Load the result back into a vector.
+ Value newResult = storeTile(rewriter, tileMul);
+ rewriter.replaceOp(contractOp, newResult);
+
+ return success();
+ }
+};
+
+struct ConvertVectorToAMXPass
+ : public impl::ConvertVectorToAMXBase<ConvertVectorToAMXPass> {
+ void runOnOperation() override {
+ MLIRContext &ctx = getContext();
+ RewritePatternSet patterns(&ctx);
+ populateVectorToAMXConversionPatterns(patterns);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+
+} // namespace
+
+void mlir::populateVectorToAMXConversionPatterns(RewritePatternSet &patterns) {
+ patterns.add<ContractionToAMX>(patterns.getContext());
+}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 17a79e3..1ff7d5d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -247,8 +247,9 @@ public:
MemRefType memRefTy = loadOrStoreOp.getMemRefType();
// Resolve alignment.
- unsigned align;
- if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
+ unsigned align = loadOrStoreOp.getAlignment().value_or(0);
+ if (!align &&
+ failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
memRefTy, align, useVectorAlignment)))
return rewriter.notifyMatchFailure(loadOrStoreOp,
"could not resolve alignment");
@@ -305,11 +306,11 @@ public:
// Resolve address.
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
- adaptor.getBase(), adaptor.getIndices());
+ adaptor.getBase(), adaptor.getOffsets());
Value base = adaptor.getBase();
Value ptrs =
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
- base, ptr, adaptor.getIndexVec(), vType);
+ base, ptr, adaptor.getIndices(), vType);
// Replace with the gather intrinsic.
rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
@@ -361,10 +362,10 @@ public:
// Resolve address.
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
- adaptor.getBase(), adaptor.getIndices());
+ adaptor.getBase(), adaptor.getOffsets());
Value ptrs =
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
- adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
+ adaptor.getBase(), ptr, adaptor.getIndices(), vType);
// Replace with the scatter intrinsic.
rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
@@ -1890,15 +1891,21 @@ struct VectorFromElementsLowering
ConversionPatternRewriter &rewriter) const override {
Location loc = fromElementsOp.getLoc();
VectorType vectorType = fromElementsOp.getType();
- // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
- // Such ops should be handled in the same way as vector.insert.
+ // Only support 1-D vectors. Multi-dimensional vectors should have been
+ // transformed to 1-D vectors by the vector-to-vector transformations before
+ // this.
if (vectorType.getRank() > 1)
return rewriter.notifyMatchFailure(fromElementsOp,
"rank > 1 vectors are not supported");
Type llvmType = typeConverter->convertType(vectorType);
+ Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
- for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
- result = vector::InsertOp::create(rewriter, loc, val, result, idx);
+ for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) {
+ auto constIdx =
+ LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, idx);
+ result = LLVM::InsertElementOp::create(rewriter, loc, llvmType, result,
+ val, constIdx);
+ }
rewriter.replaceOp(fromElementsOp, result);
return success();
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index cf10869..9852df6 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -94,6 +94,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
populateVectorGatherLoweringPatterns(patterns);
+ populateVectorFromElementsLoweringPatterns(patterns);
if (armI8MM) {
if (armNeon)
arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index a4be7d4..036cbad 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -743,6 +743,22 @@ struct VectorLoadOpConverter final
auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
+ std::optional<uint64_t> alignment = loadOp.getAlignment();
+ if (alignment > std::numeric_limits<uint32_t>::max()) {
+ return rewriter.notifyMatchFailure(loadOp,
+ "invalid alignment requirement");
+ }
+
+ auto memoryAccess = spirv::MemoryAccess::None;
+ spirv::MemoryAccessAttr memoryAccessAttr;
+ IntegerAttr alignmentAttr;
+ if (alignment.has_value()) {
+ memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
+ memoryAccessAttr =
+ spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
+ alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
+ }
+
// For single element vectors, we don't need to bitcast the access chain to
// the original vector type. Both is going to be the same, a pointer
// to a scalar.
@@ -753,7 +769,8 @@ struct VectorLoadOpConverter final
accessChain);
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
- castedAccessChain);
+ castedAccessChain,
+ memoryAccessAttr, alignmentAttr);
return success();
}
@@ -782,6 +799,12 @@ struct VectorStoreOpConverter final
return rewriter.notifyMatchFailure(
storeOp, "failed to get memref element pointer");
+ std::optional<uint64_t> alignment = storeOp.getAlignment();
+ if (alignment > std::numeric_limits<uint32_t>::max()) {
+ return rewriter.notifyMatchFailure(storeOp,
+ "invalid alignment requirement");
+ }
+
spirv::StorageClass storageClass = attr.getValue();
auto vectorType = storeOp.getVectorType();
auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
@@ -795,8 +818,19 @@ struct VectorStoreOpConverter final
: spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
accessChain);
- rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
- adaptor.getValueToStore());
+ auto memoryAccess = spirv::MemoryAccess::None;
+ spirv::MemoryAccessAttr memoryAccessAttr;
+ IntegerAttr alignmentAttr;
+ if (alignment.has_value()) {
+ memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
+ memoryAccessAttr =
+ spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
+ alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
+ }
+
+ rewriter.replaceOpWithNewOp<spirv::StoreOp>(
+ storeOp, castedAccessChain, adaptor.getValueToStore(), memoryAccessAttr,
+ alignmentAttr);
return success();
}
diff --git a/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt
index 567083d..e9ad67c5 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt
@@ -13,4 +13,5 @@ add_mlir_conversion_library(MLIRVectorToXeGPU
MLIRTransforms
MLIRVectorDialect
MLIRXeGPUDialect
+ MLIRXeGPUUtils
)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 8010755..819c2e5 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -14,9 +14,11 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -68,11 +70,6 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
if (!srcTy)
return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
- // Perform common data transfer checks.
- VectorType vecTy = xferOp.getVectorType();
- if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy)))
- return failure();
-
// Validate further transfer op semantics.
SmallVector<int64_t> strides;
int64_t offset;
@@ -80,6 +77,7 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
return rewriter.notifyMatchFailure(
xferOp, "Buffer must be contiguous in the innermost dimension");
+ VectorType vecTy = xferOp.getVectorType();
unsigned vecRank = vecTy.getRank();
if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
return rewriter.notifyMatchFailure(
@@ -155,6 +153,277 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
return ndDesc;
}
+// Adjusts the strides of a memref according to a given permutation map for
+// vector operations.
+//
+// This function updates the innermost strides in the `strides` array to
+// reflect the permutation specified by `permMap`. The permutation is computed
+// using the inverse and broadcasting-aware version of the permutation map,
+// and is applied to the relevant strides. This ensures that memory accesses
+// are consistent with the logical permutation of vector elements.
+//
+// Example:
+// Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]`.
+// If the permutation map swaps the last two dimensions (e.g., [0, 1] -> [1,
+// 0]), then after calling this function, the last two strides will be
+// swapped:
+// Original strides: [s0, s1, s2, s3]
+// After permutation: [s0, s1, s3, s2]
+//
+static void adjustStridesForPermutation(AffineMap permMap,
+ SmallVectorImpl<Value> &strides) {
+
+ AffineMap invMap = inverseAndBroadcastProjectedPermutation(permMap);
+ SmallVector<unsigned> perms;
+ invMap.isPermutationOfMinorIdentityWithBroadcasting(perms);
+ SmallVector<int64_t> perms64(perms.begin(), perms.end());
+ strides = applyPermutation(strides, perms64);
+}
+
+// Computes memory strides for vector transfer operations, handling both
+// static and dynamic memrefs while applying permutation transformations
+// for XeGPU lowering.
+static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
+ PatternRewriter &rewriter) {
+ SmallVector<Value> strides;
+ Value baseMemref = xferOp.getBase();
+ AffineMap permMap = xferOp.getPermutationMap();
+ MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
+
+ Location loc = xferOp.getLoc();
+ if (memrefType.hasStaticShape()) {
+ int64_t offset;
+ SmallVector<int64_t> intStrides;
+ if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
+ return {};
+ // Wrap static strides as MLIR values
+ for (int64_t s : intStrides)
+ strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
+ } else {
+ // For dynamic shape memref, use memref.extract_strided_metadata to get
+ // stride values
+ unsigned rank = memrefType.getRank();
+ Type indexType = rewriter.getIndexType();
+
+ // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
+ // size0, size1, ..., sizeN-1]
+ SmallVector<Type> resultTypes;
+ resultTypes.push_back(MemRefType::get(
+ {}, memrefType.getElementType())); // base memref (unranked)
+ resultTypes.push_back(indexType); // offset
+
+ for (unsigned i = 0; i < rank; ++i)
+ resultTypes.push_back(indexType); // strides
+
+ for (unsigned i = 0; i < rank; ++i)
+ resultTypes.push_back(indexType); // sizes
+
+ auto meta = memref::ExtractStridedMetadataOp::create(
+ rewriter, loc, resultTypes, baseMemref);
+ strides.append(meta.getStrides().begin(), meta.getStrides().end());
+ }
+ // Adjust strides according to the permutation map (e.g., for transpose)
+ adjustStridesForPermutation(permMap, strides);
+ return strides;
+}
+
+// This function compute the vectors of localOffsets for scattered load/stores.
+// It is used in the lowering of vector.transfer_read/write to
+// load_gather/store_scatter Example:
+// %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0],
+// %cst {in_bounds = [true, true, true, true]}>} :
+// memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
+//
+// %6 = vector.step: vector<4xindex>
+// %7 = vector.step: vector<2xindex>
+// %8 = vector.step: vector<6xindex>
+// %9 = vector.step: vector<32xindex>
+// %10 = arith.mul %6, 384
+// %11 = arith.mul %7, 192
+// %12 = arith.mul %8, 32
+// %13 = arith.mul %9, 1
+// %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16>
+// %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16>
+// %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16>
+// %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16>
+// %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex>
+// %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex>
+// %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex>
+// %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex>
+// %22 = arith.add %18, %19
+// %23 = arith.add %20, %21
+// %local_offsets = arith.add %22, %23
+// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
+// %offsets = orig_offset + local_offsets
+static Value computeOffsets(VectorTransferOpInterface xferOp,
+ PatternRewriter &rewriter,
+ ArrayRef<Value> strides) {
+ Location loc = xferOp.getLoc();
+ VectorType vectorType = xferOp.getVectorType();
+ SmallVector<Value> indices(xferOp.getIndices().begin(),
+ xferOp.getIndices().end());
+ ArrayRef<int64_t> vectorShape = vectorType.getShape();
+
+ // Create vector.step operations for each dimension
+ SmallVector<Value> stepVectors;
+ llvm::map_to_vector(vectorShape, [&](int64_t dim) {
+ auto stepType = VectorType::get({dim}, rewriter.getIndexType());
+ auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
+ stepVectors.push_back(stepOp);
+ return stepOp;
+ });
+
+ // Multiply step vectors by corresponding strides
+ size_t memrefRank = strides.size();
+ size_t vectorRank = vectorShape.size();
+ SmallVector<Value> strideMultiplied;
+ for (size_t i = 0; i < vectorRank; ++i) {
+ size_t memrefDim = memrefRank - vectorRank + i;
+ Value strideValue = strides[memrefDim];
+ auto mulType = dyn_cast<VectorType>(stepVectors[i].getType());
+ auto bcastOp =
+ vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
+ auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
+ strideMultiplied.push_back(mulOp);
+ }
+
+ // Shape cast each multiplied vector to add singleton dimensions
+ SmallVector<Value> shapeCasted;
+ for (size_t i = 0; i < vectorRank; ++i) {
+ SmallVector<int64_t> newShape(vectorRank, 1);
+ newShape[i] = vectorShape[i];
+ auto newType = VectorType::get(newShape, rewriter.getIndexType());
+ auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
+ strideMultiplied[i]);
+ shapeCasted.push_back(castOp);
+ }
+
+ // Broadcast each shape-casted vector to full vector shape
+ SmallVector<Value> broadcasted;
+ auto fullIndexVectorType =
+ VectorType::get(vectorShape, rewriter.getIndexType());
+ for (Value shapeCastVal : shapeCasted) {
+ auto broadcastOp = vector::BroadcastOp::create(
+ rewriter, loc, fullIndexVectorType, shapeCastVal);
+ broadcasted.push_back(broadcastOp);
+ }
+
+ // Add all broadcasted vectors together to compute local offsets
+ Value localOffsets = broadcasted[0];
+ for (size_t i = 1; i < broadcasted.size(); ++i)
+ localOffsets =
+ arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
+
+ // Compute base offset from transfer read indices
+ Value baseOffset = nullptr;
+ if (!indices.empty()) {
+ baseOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ for (size_t i = 0; i < indices.size(); ++i) {
+ Value strideVal = strides[i];
+ Value offsetContrib =
+ arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
+ baseOffset =
+ arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
+ }
+ // Broadcast base offset to match vector shape
+ Value bcastBase = vector::BroadcastOp::create(
+ rewriter, loc, fullIndexVectorType, baseOffset);
+ localOffsets =
+ arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
+ }
+ return localOffsets;
+}
+
+// Collapse memref shape to 1D
+static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp,
+ PatternRewriter &rewriter) {
+ Location loc = xferOp.getLoc();
+
+ Value baseMemref = xferOp.getBase();
+ MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
+ Type elementType = memrefType.getElementType();
+
+ // Compute the total number of elements in the memref
+ MemRefType flatMemrefType;
+ if (memrefType.hasStaticShape()) {
+ auto totalElements = memrefType.getNumElements();
+ flatMemrefType = MemRefType::get({totalElements}, elementType);
+ } else {
+ flatMemrefType = MemRefType::get({ShapedType::kDynamic}, elementType);
+ }
+
+ SmallVector<ReassociationIndices> reassociation;
+ ReassociationIndices allDims =
+ llvm::to_vector(llvm::seq<int64_t>(0, memrefType.getRank()));
+ reassociation.push_back(allDims);
+
+ auto collapseOp = memref::CollapseShapeOp::create(
+ rewriter, loc, flatMemrefType, baseMemref, reassociation);
+ return collapseOp;
+}
+
+static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
+ PatternRewriter &rewriter) {
+
+ Location loc = readOp.getLoc();
+ VectorType vectorType = readOp.getVectorType();
+ ArrayRef<int64_t> vectorShape = vectorType.getShape();
+ auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
+ if (!memrefType)
+ return rewriter.notifyMatchFailure(readOp, "Expected memref source");
+
+ SmallVector<Value> strides = computeStrides(readOp, rewriter);
+ if (strides.empty())
+ return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
+
+ Value localOffsets = computeOffsets(readOp, rewriter, strides);
+
+ Value flatMemref = collapseMemrefTo1D(readOp, rewriter);
+
+ Value mask = vector::ConstantMaskOp::create(
+ rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
+ vectorShape);
+ auto gatherOp = xegpu::LoadGatherOp::create(
+ rewriter, loc, vectorType, flatMemref, localOffsets, mask,
+ /*chunk_size=*/IntegerAttr{},
+ /*l1_hint=*/xegpu::CachePolicyAttr{},
+ /*l2_hint=*/xegpu::CachePolicyAttr{},
+ /*l3_hint=*/xegpu::CachePolicyAttr{});
+
+ rewriter.replaceOp(readOp, gatherOp.getResult());
+ return success();
+}
+
+static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
+ PatternRewriter &rewriter) {
+
+ Location loc = writeOp.getLoc();
+ VectorType vectorType = writeOp.getVectorType();
+ ArrayRef<int64_t> vectorShape = vectorType.getShape();
+
+ auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
+ if (!memrefType)
+ return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
+
+ SmallVector<Value> strides = computeStrides(writeOp, rewriter);
+
+ Value localOffsets = computeOffsets(writeOp, rewriter, strides);
+
+ Value flatMemref = collapseMemrefTo1D(writeOp, rewriter);
+
+ Value mask = vector::ConstantMaskOp::create(
+ rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
+ vectorShape);
+ xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
+ localOffsets, mask,
+ /*chunk_size=*/IntegerAttr{},
+ /*l1_hint=*/xegpu::CachePolicyAttr{},
+ /*l2_hint=*/xegpu::CachePolicyAttr{},
+ /*l3_hint=*/xegpu::CachePolicyAttr{});
+ rewriter.eraseOp(writeOp);
+ return success();
+}
+
struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
@@ -165,6 +434,22 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
if (failed(transferPreconditions(rewriter, readOp)))
return failure();
+ // TODO:This check needs to be replaced with proper uArch capability check
+ auto chip = xegpu::getChipStr(readOp);
+ if (chip != "pvc" && chip != "bmg") {
+ // lower to scattered load Op if the target HW doesn't have 2d block load
+ // support
+ // TODO: add support for OutOfBound access
+ if (readOp.hasOutOfBoundsDim())
+ return failure();
+ return lowerToScatteredLoadOp(readOp, rewriter);
+ }
+
+ // Perform common data transfer checks.
+ VectorType vecTy = readOp.getVectorType();
+ if (failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
+ return failure();
+
bool isOutOfBounds = readOp.hasOutOfBoundsDim();
if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
return rewriter.notifyMatchFailure(
@@ -173,7 +458,6 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
AffineMap readMap = readOp.getPermutationMap();
bool isTransposeLoad = !readMap.isMinorIdentity();
- VectorType vecTy = readOp.getVectorType();
Type elementType = vecTy.getElementType();
unsigned minTransposeBitWidth = 32;
if (isTransposeLoad &&
@@ -221,11 +505,26 @@ struct TransferWriteLowering
if (failed(transferPreconditions(rewriter, writeOp)))
return failure();
+ // TODO:This check needs to be replaced with proper uArch capability check
+ auto chip = xegpu::getChipStr(writeOp);
+ if (chip != "pvc" && chip != "bmg") {
+ // lower to scattered store Op if the target HW doesn't have 2d block
+ // store support
+ // TODO: add support for OutOfBound access
+ if (writeOp.hasOutOfBoundsDim())
+ return failure();
+ return lowerToScatteredStoreOp(writeOp, rewriter);
+ }
+
+ // Perform common data transfer checks.
+ VectorType vecTy = writeOp.getVectorType();
+ if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy)))
+ return failure();
+
AffineMap map = writeOp.getPermutationMap();
if (!map.isMinorIdentity())
return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
- VectorType vecTy = writeOp.getVectorType();
auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(),
/*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
new file mode 100644
index 0000000..84b2580
--- /dev/null
+++ b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
@@ -0,0 +1,27 @@
+add_mlir_conversion_library(MLIRXeGPUToXeVM
+ XeGPUToXeVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/XeGPUToXeVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRFuncDialect
+ MLIRGPUDialect
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRXeVMDialect
+ MLIRVectorDialect
+ MLIRArithDialect
+ MLIRIndexDialect
+ MLIRSCFDialect
+ MLIRXeGPUDialect
+ MLIRPass
+ MLIRTransforms
+ MLIRSCFTransforms
+)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
new file mode 100644
index 0000000..a7f2dc2
--- /dev/null
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -0,0 +1,1026 @@
+//===-- XeGPUToXeVM.cpp - XeGPU to XeVM dialect conversion ------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+
+#include <numeric>
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+// TODO: Below are uArch dependent values, should move away from hardcoding
+static constexpr int32_t systolicDepth{8};
+static constexpr int32_t executionSize{16};
+
+// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
+enum class NdTdescOffset : uint32_t {
+ BasePtr = 0, // Base pointer (i64)
+ BaseShapeW = 2, // Base shape width (i32)
+ BaseShapeH = 3, // Base shape height (i32)
+ TensorOffsetW = 4, // Tensor offset W (i32)
+ TensorOffsetH = 5 // Tensor offset H (i32)
+};
+
+static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
+ switch (xeGpuMemspace) {
+ case xegpu::MemorySpace::Global:
+ return static_cast<int>(xevm::AddrSpace::GLOBAL);
+ case xegpu::MemorySpace::SLM:
+ return static_cast<int>(xevm::AddrSpace::SHARED);
+ }
+}
+
+// Get same bitwidth flat vector type of new element type.
+static VectorType encodeVectorTypeTo(VectorType currentVecType,
+ Type toElemType) {
+ auto elemType = currentVecType.getElementType();
+ auto currentBitWidth = elemType.getIntOrFloatBitWidth();
+ auto newBitWidth = toElemType.getIntOrFloatBitWidth();
+ const int size =
+ currentVecType.getNumElements() * currentBitWidth / newBitWidth;
+ return VectorType::get(size, toElemType);
+}
+
+static xevm::LoadCacheControl
+translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
+ std::optional<xegpu::CachePolicy> L3hint) {
+ auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
+ auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
+ switch (L1hintVal) {
+ case xegpu::CachePolicy::CACHED:
+ if (L3hintVal == xegpu::CachePolicy::CACHED)
+ return xevm::LoadCacheControl::L1C_L2UC_L3C;
+ else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::LoadCacheControl::L1C_L2UC_L3UC;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::UNCACHED:
+ if (L3hintVal == xegpu::CachePolicy::CACHED)
+ return xevm::LoadCacheControl::L1UC_L2UC_L3C;
+ else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::STREAMING:
+ if (L3hintVal == xegpu::CachePolicy::CACHED)
+ return xevm::LoadCacheControl::L1S_L2UC_L3C;
+ else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::LoadCacheControl::L1S_L2UC_L3UC;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::READ_INVALIDATE:
+ return xevm::LoadCacheControl::INVALIDATE_READ;
+ default:
+ llvm_unreachable("Unsupported cache control.");
+ }
+}
+
+static xevm::StoreCacheControl
+translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
+ std::optional<xegpu::CachePolicy> L3hint) {
+ auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
+ auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
+ switch (L1hintVal) {
+ case xegpu::CachePolicy::UNCACHED:
+ if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
+ else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+ return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::STREAMING:
+ if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::StoreCacheControl::L1S_L2UC_L3UC;
+ else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+ return xevm::StoreCacheControl::L1S_L2UC_L3WB;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::WRITE_BACK:
+ if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
+ else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+ return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::WRITE_THROUGH:
+ if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
+ else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+ return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ default:
+ llvm_unreachable("Unsupported cache control.");
+ }
+}
+
+class CreateNdDescToXeVMPattern
+ : public OpConversionPattern<xegpu::CreateNdDescOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::CreateNdDescOp op,
+ xegpu::CreateNdDescOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto source = op.getSource();
+ // Op is lowered to a code sequence that populates payload.
+ // Payload is a 8xi32 vector. Offset to individual fields are defined in
+ // NdTdescOffset enum.
+ Type payloadElemTy = rewriter.getI32Type();
+ VectorType payloadTy = VectorType::get(8, payloadElemTy);
+ Type i64Ty = rewriter.getI64Type();
+ // 4xi64 view is used for inserting the base pointer.
+ VectorType payloadI64Ty = VectorType::get(4, i64Ty);
+ // Initialize payload to zero.
+ Value payload = arith::ConstantOp::create(
+ rewriter, loc,
+ DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
+
+ Value baseAddr;
+ Value baseShapeW;
+ Value baseShapeH;
+ Value offsetW;
+ Value offsetH;
+
+ // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
+ SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
+ SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
+ // Descriptor shape is expected to be 2D.
+ int64_t rank = mixedSizes.size();
+ if (rank != 2)
+ return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
+ auto sourceTy = source.getType();
+ auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
+ // If source is a memref, we need to extract the aligned pointer as index.
+ // Pointer type is passed as i32 or i64 by type converter.
+ if (sourceMemrefTy) {
+ if (!sourceMemrefTy.hasStaticShape()) {
+ return rewriter.notifyMatchFailure(op, "Expected static memref shape.");
+ }
+ baseAddr =
+ memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
+ } else {
+ baseAddr = adaptor.getSource();
+ }
+ // Utility for creating offset values from op fold result.
+ auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
+ unsigned idx) -> Value {
+ Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]);
+ val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
+ return val;
+ };
+ // Offsets can be either 2D or not provided (0 is used).
+ if (mixedOffsets.size() == 2) {
+ offsetW = createOffset(mixedOffsets, 1);
+ offsetH = createOffset(mixedOffsets, 0);
+ } else if (mixedOffsets.size() == 0) {
+ offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
+ offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
+ } else {
+ return rewriter.notifyMatchFailure(op,
+ "Expected 2D offsets or no offsets.");
+ }
+ // Get shape values from op fold results.
+ baseShapeW = createOffset(mixedSizes, 1);
+ baseShapeH = createOffset(mixedSizes, 0);
+ if (sourceMemrefTy) {
+ // Cast index to i64.
+ baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
+ } else if (baseAddr.getType() != i64Ty) {
+ // Pointer type may be i32. Cast to i64 if needed.
+ baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
+ }
+ // Populate payload.
+ Value payLoadAsI64 =
+ vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
+ payLoadAsI64 =
+ vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
+ static_cast<int>(NdTdescOffset::BasePtr));
+ payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
+ payload =
+ vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
+ static_cast<int>(NdTdescOffset::BaseShapeW));
+ payload =
+ vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
+ static_cast<int>(NdTdescOffset::BaseShapeH));
+ payload = vector::InsertOp::create(
+ rewriter, loc, offsetW, payload,
+ static_cast<int>(NdTdescOffset::TensorOffsetW));
+ payload = vector::InsertOp::create(
+ rewriter, loc, offsetH, payload,
+ static_cast<int>(NdTdescOffset::TensorOffsetH));
+ rewriter.replaceOp(op, payload);
+ return success();
+ }
+};
+
+class UpdateNdOffsetToXeVMPattern
+ : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::UpdateNdOffsetOp op,
+ xegpu::UpdateNdOffsetOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto mixedOffsets = op.getMixedOffsets();
+ // Only 2D offsets are supported for now.
+ if (mixedOffsets.size() != 2)
+ return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
+ auto payload = adaptor.getTensorDesc();
+ // Utility for updating payload offset values from op fold result.
+ auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
+ Value offset =
+ getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]);
+ offset = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI32Type(), offset);
+ Value oldOffset =
+ vector::ExtractOp::create(rewriter, loc, payload, payloadPos);
+ Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
+ return vector::InsertOp::create(rewriter, loc, newOffset, payload,
+ payloadPos);
+ };
+ // Update offsets in the payload.
+ payload = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
+ payload = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
+ rewriter.replaceOp(op, payload);
+ return success();
+ }
+};
+
+template <
+ typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
+class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto ctxt = rewriter.getContext();
+
+ auto tdesc = adaptor.getTensorDesc();
+ auto tdescTy = op.getTensorDescType();
+ if (tdescTy.getRank() != 2)
+ return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
+ auto elemType = tdescTy.getElementType();
+ auto elemBitSize = elemType.getIntOrFloatBitWidth();
+ if (elemBitSize % 8 != 0)
+ return rewriter.notifyMatchFailure(
+ op, "Expected element type bit width to be multiple of 8.");
+
+ VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
+ Value payLoadAsI64 =
+ vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
+ Value basePtr = vector::ExtractOp::create(
+ rewriter, loc, payLoadAsI64, static_cast<int>(NdTdescOffset::BasePtr));
+ Value baseShapeW = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
+ Value baseShapeH = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
+ // Offsets provided in two ways:
+ // 1. Offsets are extracted from the tensor descriptor.
+ // 2. (Mixed) offsets which are provided by the op.
+ Value offsetW;
+ Value offsetH;
+ auto mixedOffsets = op.getMixedOffsets();
+ int64_t opOffsetsSize = mixedOffsets.size();
+ if (opOffsetsSize != 0 && opOffsetsSize != 2)
+ return rewriter.notifyMatchFailure(op,
+ "Expected 2D offsets or no offsets.");
+ if (opOffsetsSize) {
+ // If mixed offsets are provided by the op convert them to i32.
+ offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
+ offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI32Type(), offsetW);
+ offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
+ offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI32Type(), offsetH);
+ } else {
+ // If offsets are not available, we need to extract them from the tensor
+ // descriptor.
+ offsetW = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetW));
+ offsetH = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetH));
+ }
+ // Get address space from tensor descriptor memory space.
+ auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+ // Convert base pointer (i64) to LLVM pointer type.
+ Value basePtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
+ // Compute element byte size and surface width in bytes.
+ Value elemByteSize = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
+ Value surfaceW =
+ arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
+
+ // Get tile sizes and vblocks from the tensor descriptor type.
+ auto tileW = tdescTy.getDimSize(1);
+ auto tileH = tdescTy.getDimSize(0);
+ int32_t vblocks = tdescTy.getArrayLength();
+ if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
+ Value src = adaptor.getValue();
+ // If store value is a scalar, get value from op instead of adaptor.
+ // Adaptor might have optimized away single element vector
+ if (src.getType().isIntOrFloat()) {
+ src = op.getValue();
+ }
+ VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
+ if (!srcVecTy)
+ return rewriter.notifyMatchFailure(
+ op, "Expected store value to be a vector type.");
+ // Get flat vector type of integer type with matching element bit size.
+ VectorType newSrcVecTy =
+ encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
+ if (srcVecTy != newSrcVecTy)
+ src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
+ auto storeCacheControl =
+ translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ xevm::BlockStore2dOp::create(
+ rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
+ offsetH, elemBitSize, tileW, tileH, src,
+ xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
+ rewriter.eraseOp(op);
+ } else {
+ auto loadCacheControl =
+ translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
+ xevm::BlockPrefetch2dOp::create(
+ rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
+ offsetH, elemBitSize, tileW, tileH, vblocks,
+ xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+ rewriter.eraseOp(op);
+ } else {
+ VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
+ const bool vnni = op.getPacked().value_or(false);
+ auto transposeValue = op.getTranspose();
+ bool transpose =
+ transposeValue.has_value() && transposeValue.value()[0] == 1;
+ VectorType loadedTy = encodeVectorTypeTo(
+ dstVecTy, vnni ? rewriter.getI32Type()
+ : rewriter.getIntegerType(elemBitSize));
+
+ Value resultFlatVec = xevm::BlockLoad2dOp::create(
+ rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
+ surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
+ transpose, vnni,
+ xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+ resultFlatVec = vector::BitCastOp::create(
+ rewriter, loc,
+ encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
+ resultFlatVec);
+ rewriter.replaceOp(op, resultFlatVec);
+ }
+ }
+ return success();
+ }
+};
+
+// Add a builder that creates
+// offset * elemByteSize + baseAddr
+static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
+ Value baseAddr, Value offset, int64_t elemByteSize) {
+ Value byteSize = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI64Type(), elemByteSize);
+ Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
+ Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
+ return newAddr;
+}
+
+class CreateDescToXeVMPattern
+ : public OpConversionPattern<xegpu::CreateDescOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto eTy = op.getTensorDescType().getElementType();
+ auto eBw = eTy.getIntOrFloatBitWidth();
+ if (eBw % 8 != 0)
+ return rewriter.notifyMatchFailure(
+ op, "Expected element type bit width to be multiple of 8.");
+ auto loc = op.getLoc();
+ // Offsets are provided as scalar i64 by type converter.
+ auto offsets = adaptor.getOffsets();
+ // Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32).
+ // But type converter will convert them to integer types.
+ Value addr = adaptor.getSource();
+ // ui32 or i32 are passed as i32 so they need to be casted to i64.
+ if (addr.getType() != rewriter.getI64Type())
+ addr = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), addr);
+ auto laneAddr = addOffset(rewriter, loc, addr, offsets, eBw / 8);
+ rewriter.replaceOp(op, laneAddr);
+ return success();
+ }
+};
+
+class UpdateOffsetToXeVMPattern
+ : public OpConversionPattern<xegpu::UpdateOffsetOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::UpdateOffsetOp op,
+ xegpu::UpdateOffsetOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto eTy = op.getTensorDescType().getElementType();
+ auto eBw = eTy.getIntOrFloatBitWidth();
+ if (eBw % 8 != 0)
+ return rewriter.notifyMatchFailure(
+ op, "Expected element type bit width to be multiple of 8.");
+ auto loc = op.getLoc();
+ // Scatter descriptor is provided as scalar i64 by type converter.
+ // Offsets are provided as scalar i64 by type converter.
+ Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(),
+ adaptor.getOffsets(), eBw / 8);
+ rewriter.replaceOp(op, newOffset);
+ return success();
+ }
+};
+
+template <typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
+class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto ctxt = rewriter.getContext();
+ auto tdescTy = op.getTensorDescType();
+ Value basePtrI64;
+ // Load result or Store valye Type can be vector or scalar.
+ Type valOrResTy;
+ if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
+ valOrResTy = op.getResult().getType();
+ else
+ valOrResTy = adaptor.getValue().getType();
+ VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
+ bool hasScalarVal = !valOrResVecTy;
+ int64_t elemBitWidth =
+ hasScalarVal ? valOrResTy.getIntOrFloatBitWidth()
+ : valOrResVecTy.getElementType().getIntOrFloatBitWidth();
+ // Element type must be multiple of 8 bits.
+ if (elemBitWidth % 8 != 0)
+ return rewriter.notifyMatchFailure(
+ op, "Expected element type bit width to be multiple of 8.");
+ int64_t elemByteSize = elemBitWidth / 8;
+ // Default memory space is global.
+ LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
+ // If tensor descriptor is available, we use its memory space.
+ if (tdescTy)
+ ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+ // Base pointer can come from source (load) or dest (store).
+ // If they are memrefs, we use their memory space.
+ if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
+ basePtrI64 = adaptor.getSource();
+ if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
+ auto addrSpace = memRefTy.getMemorySpaceAsInt();
+ if (addrSpace != 0)
+ ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
+ }
+ } else {
+ basePtrI64 = adaptor.getDest();
+ if (auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
+ auto addrSpace = memRefTy.getMemorySpaceAsInt();
+ if (addrSpace != 0)
+ ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
+ }
+ }
+ // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
+ if (basePtrI64.getType() != rewriter.getI64Type()) {
+ basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
+ basePtrI64);
+ }
+ Value offsets = adaptor.getOffsets();
+ Value mask = adaptor.getMask();
+ if (offsets) {
+ if (dyn_cast<VectorType>(offsets.getType())) {
+ // Offset needs be scalar. Single element vector is converted to scalar
+ // by type converter.
+ return rewriter.notifyMatchFailure(op,
+ "Expected offsets to be a scalar.");
+ } else {
+ // If offsets are provided, we add them to the base pointer.
+ // Offsets are in number of elements, we need to multiply by
+ // element byte size.
+ basePtrI64 =
+ addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
+ }
+ }
+ // Convert base pointer (i64) to LLVM pointer type.
+ Value basePtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+
+ Value maskForLane;
+ VectorType maskVecTy = dyn_cast<VectorType>(mask.getType());
+ if (maskVecTy) {
+ // Mask needs be scalar. Single element vector is converted to scalar by
+ // type converter.
+ return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar.");
+ } else
+ maskForLane = mask;
+ if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
+ scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
+ maskForLane, true, true);
+ // If mask is true,- then clause - load from memory and yield.
+ rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ if (!hasScalarVal)
+ valOrResTy = VectorType::get({valOrResVecTy.getNumElements()},
+ valOrResVecTy.getElementType());
+ Value loaded =
+ LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
+ // Set cache control attribute on the load operation.
+ loaded.getDefiningOp()->setAttr(
+ "cache_control", xevm::LoadCacheControlAttr::get(
+ ctxt, translateLoadXeGPUCacheHint(
+ op.getL1Hint(), op.getL3Hint())));
+ scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
+ rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ // If mask is false - else clause -yield a vector of zeros.
+ auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
+ TypedAttr eVal;
+ if (eTy.isFloat())
+ eVal = FloatAttr::get(eTy, 0.0);
+ else
+ eVal = IntegerAttr::get(eTy, 0);
+ if (hasScalarVal)
+ loaded = arith::ConstantOp::create(rewriter, loc, eVal);
+ else
+ loaded = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(valOrResVecTy, eVal));
+ scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
+ rewriter.replaceOp(op, ifOp.getResult(0));
+ } else {
+ // If mask is true, perform the store.
+ scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false);
+ auto body = ifOp.getBody();
+ rewriter.setInsertionPointToStart(body);
+ auto storeOp =
+ LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
+ // Set cache control attribute on the store operation.
+ storeOp.getOperation()->setAttr(
+ "cache_control", xevm::StoreCacheControlAttr::get(
+ ctxt, translateStoreXeGPUCacheHint(
+ op.getL1Hint(), op.getL3Hint())));
+ rewriter.eraseOp(op);
+ }
+ return success();
+ }
+};
+
+class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto ctxt = rewriter.getContext();
+ auto tdescTy = op.getTensorDescType();
+ Value basePtrI64 = adaptor.getSource();
+ // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
+ if (basePtrI64.getType() != rewriter.getI64Type())
+ basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
+ basePtrI64);
+ Value offsets = adaptor.getOffsets();
+ if (offsets) {
+ VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
+ if (offsetsVecTy) {
+ // Offset needs be scalar.
+ return rewriter.notifyMatchFailure(op,
+ "Expected offsets to be a scalar.");
+ } else {
+ int64_t elemBitWidth{0};
+ int64_t elemByteSize;
+ // Element byte size can come from three sources:
+ if (tdescTy) {
+ // If tensor descriptor is available, we use its element type to
+ // determine element byte size.
+ elemBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
+ } else if (auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
+ // If memref is available, we use its element type to
+ // determine element byte size.
+ elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
+ } else {
+ // Otherwise, we use the provided offset byte alignment.
+ elemByteSize = *op.getOffsetAlignByte();
+ }
+ if (elemBitWidth != 0) {
+ if (elemBitWidth % 8 != 0)
+ return rewriter.notifyMatchFailure(
+ op, "Expected element type bit width to be multiple of 8.");
+ elemByteSize = elemBitWidth / 8;
+ }
+ basePtrI64 =
+ addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
+ }
+ }
+ // Default memory space is global.
+ LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
+ // If tensor descriptor is available, we use its memory space.
+ if (tdescTy)
+ ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+ // If source is a memref, we use its memory space.
+ if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
+ auto addrSpace = memRefTy.getMemorySpaceAsInt();
+ if (addrSpace != 0)
+ ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
+ }
+ // Convert base pointer (i64) to LLVM pointer type.
+ Value ptrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+ // Create the prefetch op with cache control attribute.
+ xevm::PrefetchOp::create(
+ rewriter, loc, ptrLLVM,
+ xevm::LoadCacheControlAttr::get(
+ ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+class FenceToXeVMPattern : public OpConversionPattern<xegpu::FenceOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
+ switch (op.getFenceScope()) {
+ case xegpu::FenceScope::Workgroup:
+ memScope = xevm::MemScope::WORKGROUP;
+ break;
+ case xegpu::FenceScope::GPU:
+ memScope = xevm::MemScope::DEVICE;
+ break;
+ }
+ xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
+ switch (op.getMemoryKind()) {
+ case xegpu::MemorySpace::Global:
+ addrSpace = xevm::AddrSpace::GLOBAL;
+ break;
+ case xegpu::MemorySpace::SLM:
+ addrSpace = xevm::AddrSpace::SHARED;
+ break;
+ }
+ xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto ctxt = rewriter.getContext();
+ auto aTy = cast<VectorType>(op.getLhs().getType());
+ auto bTy = cast<VectorType>(op.getRhs().getType());
+ auto resultType = cast<VectorType>(op.getResultType());
+
+ auto encodePrecision = [&](Type type) -> xevm::ElemType {
+ if (type == rewriter.getBF16Type())
+ return xevm::ElemType::BF16;
+ else if (type == rewriter.getF16Type())
+ return xevm::ElemType::F16;
+ else if (type == rewriter.getTF32Type())
+ return xevm::ElemType::TF32;
+ else if (type.isInteger(8)) {
+ if (type.isUnsignedInteger())
+ return xevm::ElemType::U8;
+ return xevm::ElemType::S8;
+ } else if (type == rewriter.getF32Type())
+ return xevm::ElemType::F32;
+ else if (type.isInteger(32))
+ return xevm::ElemType::S32;
+ llvm_unreachable("add more support for ElemType");
+ };
+ xevm::ElemType precATy = encodePrecision(aTy.getElementType());
+ xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
+ Value c = op.getAcc();
+ if (!c) {
+ auto elementTy = resultType.getElementType();
+ Attribute initValueAttr;
+ if (isa<FloatType>(elementTy))
+ initValueAttr = FloatAttr::get(elementTy, 0.0);
+ else
+ initValueAttr = IntegerAttr::get(elementTy, 0);
+ c = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(resultType, initValueAttr));
+ }
+
+ Value aVec = op.getLhs();
+ Value bVec = op.getRhs();
+ auto cvecty = cast<VectorType>(c.getType());
+ xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
+ xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
+ VectorType cNty =
+ VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
+ if (cvecty != cNty)
+ c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
+ Value dpasRes = xevm::MMAOp::create(
+ rewriter, loc, cNty, aVec, bVec, c,
+ xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize,
+ systolicDepth *
+ getNumOperandsPerDword(precATy)),
+ xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
+ if (cvecty != cNty)
+ dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
+ rewriter.replaceOp(op, dpasRes);
+ return success();
+ }
+
+private:
+ static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
+ switch (pTy) {
+ case xevm::ElemType::TF32:
+ return 1;
+ case xevm::ElemType::BF16:
+ case xevm::ElemType::F16:
+ return 2;
+ case xevm::ElemType::U8:
+ case xevm::ElemType::S8:
+ return 4;
+ default:
+ llvm_unreachable("unsupported xevm::ElemType");
+ }
+ }
+};
+
+static std::optional<LLVM::AtomicBinOp>
+matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
+ switch (arithKind) {
+ case arith::AtomicRMWKind::addf:
+ return LLVM::AtomicBinOp::fadd;
+ case arith::AtomicRMWKind::addi:
+ return LLVM::AtomicBinOp::add;
+ case arith::AtomicRMWKind::assign:
+ return LLVM::AtomicBinOp::xchg;
+ case arith::AtomicRMWKind::maximumf:
+ return LLVM::AtomicBinOp::fmax;
+ case arith::AtomicRMWKind::maxs:
+ return LLVM::AtomicBinOp::max;
+ case arith::AtomicRMWKind::maxu:
+ return LLVM::AtomicBinOp::umax;
+ case arith::AtomicRMWKind::minimumf:
+ return LLVM::AtomicBinOp::fmin;
+ case arith::AtomicRMWKind::mins:
+ return LLVM::AtomicBinOp::min;
+ case arith::AtomicRMWKind::minu:
+ return LLVM::AtomicBinOp::umin;
+ case arith::AtomicRMWKind::ori:
+ return LLVM::AtomicBinOp::_or;
+ case arith::AtomicRMWKind::andi:
+ return LLVM::AtomicBinOp::_and;
+ default:
+ return std::nullopt;
+ }
+}
+
+class AtomicRMWToXeVMPattern : public OpConversionPattern<xegpu::AtomicRMWOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto ctxt = rewriter.getContext();
+ auto tdesc = op.getTensorDesc().getType();
+ auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
+ Value basePtrI64 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc());
+ Value basePtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+ VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
+ VectorType srcOrDstFlatVecTy = VectorType::get(
+ srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
+ Value srcFlatVec = vector::ShapeCastOp::create(
+ rewriter, loc, srcOrDstFlatVecTy, op.getValue());
+ auto atomicKind = matchSimpleAtomicOp(op.getKind());
+ assert(atomicKind.has_value());
+ Value resVec = srcFlatVec;
+ for (int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
+ auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
+ Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
+ rewriter.getIndexAttr(i));
+ Value currPtr =
+ LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
+ srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
+ Value newVal =
+ LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
+ val, LLVM::AtomicOrdering::seq_cst);
+ resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
+ }
+ rewriter.replaceOp(op, resVec);
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
+struct ConvertXeGPUToXeVMPass
+ : public impl::ConvertXeGPUToXeVMPassBase<ConvertXeGPUToXeVMPass> {
+ using Base::Base;
+
+ void runOnOperation() override {
+ LLVMTypeConverter typeConverter(&getContext());
+ typeConverter.addConversion([&](VectorType type) -> Type {
+ unsigned rank = type.getRank();
+ auto elemType = type.getElementType();
+ // If the element type is index, convert it to i64.
+ if (llvm::isa<IndexType>(elemType))
+ elemType = IntegerType::get(&getContext(), 64);
+ // If the vector is a scalar or has a single element, return the element
+ if (rank < 1 || type.getNumElements() == 1)
+ return elemType;
+ // Otherwise, convert the vector to a flat vector type.
+ int64_t sum =
+ std::accumulate(type.getShape().begin(), type.getShape().end(),
+ int64_t{1}, std::multiplies<int64_t>());
+ return VectorType::get(sum, elemType);
+ });
+ typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
+ if (type.isScattered())
+ return IntegerType::get(&getContext(), 64);
+ auto i32Type = IntegerType::get(&getContext(), 32);
+ return VectorType::get(8, i32Type);
+ });
+ typeConverter.addConversion([&](MemRefType type) -> Type {
+ // Convert MemRefType to i64 type.
+ return IntegerType::get(&getContext(), 64);
+ });
+
+ // LLVM type converter puts unrealized casts for the following cases:
+ // add materialization casts to handle them.
+
+ // Materialization to convert memref to i64
+ auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1)
+ return {};
+ auto input = inputs.front();
+ if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
+
+ Value addr =
+ memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input);
+ return arith::IndexCastUIOp::create(builder, loc, type, addr)
+ .getResult();
+ }
+ return {};
+ };
+
+ // Materialization to convert ui64 to i64
+ auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1)
+ return {};
+ auto input = inputs.front();
+ if (input.getType() == builder.getIntegerType(64, false)) {
+ Value cast =
+ index::CastUOp::create(builder, loc, builder.getIndexType(), input)
+ .getResult();
+ return arith::IndexCastUIOp::create(builder, loc, type, cast)
+ .getResult();
+ }
+ return {};
+ };
+
+ // Materialization to convert ui32 to i32
+ auto ui32MaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1)
+ return {};
+ auto input = inputs.front();
+ if (input.getType() == builder.getIntegerType(32, false)) {
+ Value cast =
+ index::CastUOp::create(builder, loc, builder.getIndexType(), input)
+ .getResult();
+ return arith::IndexCastUIOp::create(builder, loc, type, cast)
+ .getResult();
+ }
+ return {};
+ };
+
+ // Materialization to convert
+ // - single element 1D vector to scalar
+ // - bitcast vector of same rank
+ // - shape vector of different rank but same element type
+ auto vectorMaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1)
+ return {};
+ auto input = inputs.front();
+ if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
+ if (vecTy.getNumElements() == 1) {
+ // If the vector has a single element, return the element type.
+ Value cast =
+ vector::ExtractOp::create(builder, loc, input, 0).getResult();
+ if (vecTy.getElementType() == builder.getIndexType())
+ cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
+ .getResult();
+ return cast;
+ } else if (auto targetVecTy = dyn_cast<VectorType>(type)) {
+ // If the target type is a vector of same rank,
+ // bitcast to the target type.
+ if (targetVecTy.getRank() == vecTy.getRank())
+ return vector::BitCastOp::create(builder, loc, targetVecTy, input)
+ .getResult();
+ else if (targetVecTy.getElementType() == vecTy.getElementType()) {
+ // If the target type is a vector of different rank but same element
+ // type, reshape to the target type.
+ return vector::ShapeCastOp::create(builder, loc, targetVecTy, input)
+ .getResult();
+ }
+ }
+ }
+ return {};
+ };
+ typeConverter.addSourceMaterialization(memrefMaterializationCast);
+ typeConverter.addSourceMaterialization(ui64MaterializationCast);
+ typeConverter.addSourceMaterialization(ui32MaterializationCast);
+ typeConverter.addSourceMaterialization(vectorMaterializationCast);
+ typeConverter.addTargetMaterialization(memrefMaterializationCast);
+ typeConverter.addTargetMaterialization(ui32MaterializationCast);
+ typeConverter.addTargetMaterialization(ui64MaterializationCast);
+ typeConverter.addTargetMaterialization(vectorMaterializationCast);
+ ConversionTarget target(getContext());
+ target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
+ vector::VectorDialect, arith::ArithDialect,
+ memref::MemRefDialect, gpu::GPUDialect,
+ index::IndexDialect>();
+ target.addIllegalDialect<xegpu::XeGPUDialect>();
+
+ RewritePatternSet patterns(&getContext());
+ populateXeGPUToXeVMConversionPatterns(typeConverter, patterns);
+ scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter,
+ patterns, target);
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern Population
+//===----------------------------------------------------------------------===//
+void mlir::populateXeGPUToXeVMConversionPatterns(
+ const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
+ patterns.add<CreateNdDescToXeVMPattern, UpdateNdOffsetToXeVMPattern,
+ LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
+ LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
+ LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
+ typeConverter, patterns.getContext());
+ patterns.add<CreateDescToXeVMPattern, UpdateOffsetToXeVMPattern,
+ AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
+ LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
+ LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
+ typeConverter, patterns.getContext());
+ patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
+ patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 9a0a230..11a40d6 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -511,6 +511,18 @@ LogicalResult DPPOp::verify() {
}
//===----------------------------------------------------------------------===//
+// PermlaneSwapOp
+//===----------------------------------------------------------------------===//
+LogicalResult PermlaneSwapOp::verify() {
+ unsigned rowLength = getRowLength();
+
+ if (rowLength != 16 && rowLength != 32)
+ return emitOpError("row_length attribute must either be 16 or 32.");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// GatherToLDSOp
//===----------------------------------------------------------------------===//
@@ -518,8 +530,8 @@ LogicalResult GatherToLDSOp::verify() {
MemRefType srcType = cast<MemRefType>(getSrc().getType());
MemRefType dstType = cast<MemRefType>(getDst().getType());
- if (!dstType.areTrailingDimsContiguous(dstType.getRank()))
- return emitOpError("destination types must be contiguous");
+ if (!dstType.areTrailingDimsContiguous(1))
+ return emitOpError("destination type inner most dim must be contiguous");
auto elemType = srcType.getElementType();
// Check $src and $dst element types are the same.
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index 729e3da..d35853b 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -5,7 +5,7 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms
ResolveStridedMetadata.cpp
ADDITIONAL_HEADER_DIRS
- {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
DEPENDS
MLIRAMDGPUTransformsIncGen
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
index a3fdc7e..d547510 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
@@ -28,62 +28,79 @@ struct AmdgpuFoldMemRefOpsPass final
}
};
+static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc,
+ Value view, mlir::OperandRange indices,
+ SmallVectorImpl<Value> &resolvedIndices,
+ Value &memrefBase, StringRef role) {
+ Operation *defOp = view.getDefiningOp();
+ if (!defOp) {
+ return failure();
+ }
+ return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
+ .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
+ mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+ rewriter, loc, subviewOp.getMixedOffsets(),
+ subviewOp.getMixedStrides(), subviewOp.getDroppedDims(), indices,
+ resolvedIndices);
+ memrefBase = subviewOp.getSource();
+ return success();
+ })
+ .Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
+ if (failed(mlir::memref::resolveSourceIndicesExpandShape(
+ loc, rewriter, expandShapeOp, indices, resolvedIndices,
+ false))) {
+ return failure();
+ }
+ memrefBase = expandShapeOp.getViewSource();
+ return success();
+ })
+ .Case<memref::CollapseShapeOp>(
+ [&](memref::CollapseShapeOp collapseShapeOp) {
+ if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
+ loc, rewriter, collapseShapeOp, indices,
+ resolvedIndices))) {
+ return failure();
+ }
+ memrefBase = collapseShapeOp.getViewSource();
+ return success();
+ })
+ .Default([&](Operation *op) {
+ return rewriter.notifyMatchFailure(
+ op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or "
+ "CollapseShapeOp")
+ .str());
+ });
+}
+
struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(GatherToLDSOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- Value memrefSource;
- SmallVector<Value> sourceIndices;
- auto foldResult =
- llvm::TypeSwitch<Operation *, LogicalResult>(
- op.getSrc().getDefiningOp())
- .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
- // If the source is a SubViewOp, we can directly rewrite the
- // GatherToLDSOp.
- mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
- rewriter, loc, subviewOp.getMixedOffsets(),
- subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
- op.getSrcIndices(), sourceIndices);
- memrefSource = subviewOp.getSource();
- return success();
- })
- .Case<memref::ExpandShapeOp>(
- [&](memref::ExpandShapeOp expandShapeOp) {
- if (failed(mlir::memref::resolveSourceIndicesExpandShape(
- loc, rewriter, expandShapeOp, op.getSrcIndices(),
- sourceIndices, false))) {
- return failure();
- }
- memrefSource = expandShapeOp.getViewSource();
- return success();
- })
- .Case<memref::CollapseShapeOp>(
- [&](memref::CollapseShapeOp collapseShapeOp) {
- if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
- loc, rewriter, collapseShapeOp, op.getSrcIndices(),
- sourceIndices))) {
- return failure();
- }
- memrefSource = collapseShapeOp.getViewSource();
- return success();
- })
- .Default([&](Operation *op) {
- // If the source is not a SubViewOp, ExpandShapeOp, or
- // CollapseShapeOp, we cannot fold the GatherToLDSOp.
- return rewriter.notifyMatchFailure(
- op,
- "source producer is not one of SubViewOp, ExpandShapeOp, or "
- "CollapseShapeOp");
- });
+ SmallVector<Value> sourceIndices, destIndices;
+ Value memrefSource, memrefDest;
+
+ auto foldSrcResult =
+ foldMemrefViewOp(rewriter, loc, op.getSrc(), op.getSrcIndices(),
+ sourceIndices, memrefSource, "source");
+
+ if (failed(foldSrcResult)) {
+ memrefSource = op.getSrc();
+ sourceIndices = op.getSrcIndices();
+ }
+
+ auto foldDstResult =
+ foldMemrefViewOp(rewriter, loc, op.getDst(), op.getDstIndices(),
+ destIndices, memrefDest, "destination");
- if (failed(foldResult)) {
- return failure();
+ if (failed(foldDstResult)) {
+ memrefDest = op.getDst();
+ destIndices = op.getDstIndices();
}
rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
- op.getDst(), op.getDstIndices(),
+ memrefDest, destIndices,
op.getTransferType());
return success();
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index 6f3110c..68990ef 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -271,7 +271,9 @@ Type amx::TileType::parse(AsmParser &parser) {
if (parser.parseGreater())
return nullptr;
- return TileType::get(shape, elementType);
+ return TileType::getChecked(
+ [&] { return parser.emitError(parser.getNameLoc()); }, shape,
+ elementType);
}
void amx::TileType::print(AsmPrinter &os) const {
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
index 86edc2b..b405ec2 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
@@ -93,13 +93,13 @@ FlatAffineValueConstraints::addAffineForOpDomain(AffineForOp forOp) {
int64_t lb = forOp.getConstantLowerBound();
dividend[pos] = 1;
dividend.back() -= lb;
- addLocalFloorDiv(dividend, step);
+ unsigned qPos = addLocalFloorDiv(dividend, step);
// Second constraint: (iv - lb) - step * q = 0.
SmallVector<int64_t, 8> eq(getNumCols(), 0);
eq[pos] = 1;
eq.back() -= lb;
// For the local var just added above.
- eq[getNumCols() - 2] = -step;
+ eq[qPos] = -step;
addEquality(eq);
}
}
diff --git a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
index 2f85e0b..166d39e 100644
--- a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
@@ -21,6 +21,7 @@
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include <numeric>
#include <optional>
@@ -548,19 +549,19 @@ bool mlir::affine::isTilingValid(ArrayRef<AffineForOp> loops) {
// Check whether there is any negative direction vector in the
// dependence components found above, which means that dependence is
// violated by the default hyper-rect tiling method.
- LLVM_DEBUG(llvm::dbgs() << "Checking whether tiling legality violated "
- "for dependence at depth: "
- << Twine(d) << " between:\n";);
- LLVM_DEBUG(srcAccess.opInst->dump());
- LLVM_DEBUG(dstAccess.opInst->dump());
+ LDBG() << "Checking whether tiling legality violated "
+ << "for dependence at depth: " << Twine(d) << " between:"
+ << OpWithFlags(srcAccess.opInst, OpPrintingFlags().skipRegions())
+ << "\nand:\n"
+ << OpWithFlags(dstAccess.opInst,
+ OpPrintingFlags().skipRegions());
for (const DependenceComponent &depComp : depComps) {
if (depComp.lb.has_value() && depComp.ub.has_value() &&
*depComp.lb < *depComp.ub && *depComp.ub < 0) {
- LLVM_DEBUG(llvm::dbgs()
- << "Dependence component lb = " << Twine(*depComp.lb)
- << " ub = " << Twine(*depComp.ub)
- << " is negative at depth: " << Twine(d)
- << " and thus violates the legality rule.\n");
+ LDBG() << "Dependence component lb = " << Twine(*depComp.lb)
+ << " ub = " << Twine(*depComp.ub)
+ << " is negative at depth: " << Twine(d)
+ << " and thus violates the legality rule.";
return false;
}
}
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index a89c1ae..99ea20b 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/IntegerSet.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
@@ -241,7 +242,7 @@ addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
}
bool MemRefDependenceGraph::init() {
- LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
+ LDBG() << "--- Initializing MDG ---";
// Map from a memref to the set of ids of the nodes that have ops accessing
// the memref.
DenseMap<Value, SetVector<unsigned>> memrefAccesses;
@@ -288,8 +289,7 @@ bool MemRefDependenceGraph::init() {
// Return false if non-handled/unknown region-holding ops are found. We
// won't know what such ops do or what its regions mean; for e.g., it may
// not be an imperative op.
- LLVM_DEBUG(llvm::dbgs()
- << "MDG init failed; unknown region-holding op found!\n");
+ LDBG() << "MDG init failed; unknown region-holding op found!";
return false;
}
// We aren't creating nodes for memory-effect free ops either with no
@@ -297,7 +297,7 @@ bool MemRefDependenceGraph::init() {
// interface.
}
- LLVM_DEBUG(llvm::dbgs() << "Created " << nodes.size() << " nodes\n");
+ LDBG() << "Created " << nodes.size() << " nodes";
// Add dependence edges between nodes which produce SSA values and their
// users. Load ops can be considered as the ones producing SSA values.
@@ -556,9 +556,8 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
gatherDefiningNodes(dstId, definingNodes);
if (llvm::any_of(definingNodes,
[&](unsigned id) { return hasDependencePath(srcId, id); })) {
- LLVM_DEBUG(llvm::dbgs()
- << "Can't fuse: a defining op with a user in the dst "
- "loop has dependence from the src loop\n");
+ LDBG() << "Can't fuse: a defining op with a user in the dst "
+ << "loop has dependence from the src loop";
return nullptr;
}
@@ -957,20 +956,20 @@ std::optional<bool> ComputationSliceState::isSliceValid() const {
FlatAffineValueConstraints srcConstraints;
// TODO: Store the source's domain to avoid computation at each depth.
if (failed(getSourceAsConstraints(srcConstraints))) {
- LLVM_DEBUG(llvm::dbgs() << "Unable to compute source's domain\n");
+ LDBG() << "Unable to compute source's domain";
return std::nullopt;
}
// As the set difference utility currently cannot handle symbols in its
// operands, validity of the slice cannot be determined.
if (srcConstraints.getNumSymbolVars() > 0) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot handle symbols in source domain\n");
+ LDBG() << "Cannot handle symbols in source domain";
return std::nullopt;
}
// TODO: Handle local vars in the source domains while using the 'projectOut'
// utility below. Currently, aligning is not done assuming that there will be
// no local vars in the source domain.
if (srcConstraints.getNumLocalVars() != 0) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot handle locals in source domain\n");
+ LDBG() << "Cannot handle locals in source domain";
return std::nullopt;
}
@@ -978,7 +977,7 @@ std::optional<bool> ComputationSliceState::isSliceValid() const {
// fusion succeeds.
FlatAffineValueConstraints sliceConstraints;
if (failed(getAsConstraints(&sliceConstraints))) {
- LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice's domain\n");
+ LDBG() << "Unable to compute slice's domain";
return std::nullopt;
}
@@ -987,11 +986,11 @@ std::optional<bool> ComputationSliceState::isSliceValid() const {
sliceConstraints.projectOut(ivs.size(),
sliceConstraints.getNumVars() - ivs.size());
- LLVM_DEBUG(llvm::dbgs() << "Domain of the source of the slice:\n");
- LLVM_DEBUG(srcConstraints.dump());
- LLVM_DEBUG(llvm::dbgs() << "Domain of the slice if this fusion succeeds "
- "(expressed in terms of its source's IVs):\n");
- LLVM_DEBUG(sliceConstraints.dump());
+ LDBG() << "Domain of the source of the slice:\n"
+ << "Source constraints:" << srcConstraints
+ << "\nDomain of the slice if this fusion succeeds "
+ << "(expressed in terms of its source's IVs):\n"
+ << "Slice constraints:" << sliceConstraints;
// TODO: Store 'srcSet' to avoid recalculating for each depth.
PresburgerSet srcSet(srcConstraints);
@@ -999,7 +998,7 @@ std::optional<bool> ComputationSliceState::isSliceValid() const {
PresburgerSet diffSet = sliceSet.subtract(srcSet);
if (!diffSet.isIntegerEmpty()) {
- LLVM_DEBUG(llvm::dbgs() << "Incorrect slice\n");
+ LDBG() << "Incorrect slice";
return false;
}
return true;
@@ -1172,8 +1171,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
unsigned rank = access.getRank();
- LLVM_DEBUG(llvm::dbgs() << "MemRefRegion::compute: " << *op
- << "\ndepth: " << loopDepth << "\n";);
+ LDBG() << "MemRefRegion::compute: " << *op << " depth: " << loopDepth;
// 0-d memrefs.
if (rank == 0) {
@@ -1236,7 +1234,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
if (auto constVal = getConstantIntValue(symbol))
cst.addBound(BoundType::EQ, symbol, constVal.value());
} else {
- LLVM_DEBUG(llvm::dbgs() << "unknown affine dimensional value");
+ LDBG() << "unknown affine dimensional value";
return failure();
}
}
@@ -1260,7 +1258,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
// Add access function equalities to connect loop IVs to data dimensions.
if (failed(cst.composeMap(&accessValueMap))) {
op->emitError("getMemRefRegion: compose affine map failed");
- LLVM_DEBUG(accessValueMap.getAffineMap().dump());
+ LDBG() << "Access map: " << accessValueMap.getAffineMap();
return failure();
}
@@ -1317,8 +1315,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
}
cst.removeTrivialRedundancy();
- LLVM_DEBUG(llvm::dbgs() << "Memory region:\n");
- LLVM_DEBUG(cst.dump());
+ LDBG() << "Memory region: " << cst;
return success();
}
@@ -1346,14 +1343,14 @@ std::optional<int64_t> MemRefRegion::getRegionSize() {
auto memRefType = cast<MemRefType>(memref.getType());
if (!memRefType.getLayout().isIdentity()) {
- LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
+ LDBG() << "Non-identity layout map not yet supported";
return false;
}
// Compute the extents of the buffer.
std::optional<int64_t> numElements = getConstantBoundingSizeAndShape();
if (!numElements) {
- LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n");
+ LDBG() << "Dynamic shapes not yet supported";
return std::nullopt;
}
auto eltSize = getMemRefIntOrFloatEltSizeInBytes(memRefType);
@@ -1397,8 +1394,7 @@ LogicalResult mlir::affine::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,
/*addMemRefDimBounds=*/false)))
return success();
- LLVM_DEBUG(llvm::dbgs() << "Memory region");
- LLVM_DEBUG(region.getConstraints()->dump());
+ LDBG() << "Memory region: " << region.getConstraints();
bool outOfBounds = false;
unsigned rank = loadOrStoreOp.getMemRefType().getRank();
@@ -1558,7 +1554,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
// Check if 'loopDepth' exceeds nesting depth of src/dst ops.
if ((!isBackwardSlice && loopDepth > getNestingDepth(a)) ||
(isBackwardSlice && loopDepth > getNestingDepth(b))) {
- LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n");
+ LDBG() << "Invalid loop depth";
return SliceComputationResult::GenericFailure;
}
@@ -1571,7 +1567,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
&dependenceConstraints, /*dependenceComponents=*/nullptr,
/*allowRAR=*/readReadAccesses);
if (result.value == DependenceResult::Failure) {
- LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n");
+ LDBG() << "Dependence check failed";
return SliceComputationResult::GenericFailure;
}
if (result.value == DependenceResult::NoDependence)
@@ -1586,8 +1582,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
if (sliceUnionCst.getNumDimAndSymbolVars() == 0) {
// Initialize 'sliceUnionCst' with the bounds computed in previous step.
if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Unable to compute slice bound constraints\n");
+ LDBG() << "Unable to compute slice bound constraints";
return SliceComputationResult::GenericFailure;
}
assert(sliceUnionCst.getNumDimAndSymbolVars() > 0);
@@ -1597,8 +1592,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
// Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
FlatAffineValueConstraints tmpSliceCst;
if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Unable to compute slice bound constraints\n");
+ LDBG() << "Unable to compute slice bound constraints";
return SliceComputationResult::GenericFailure;
}
@@ -1630,8 +1624,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
if (sliceUnionCst.getNumLocalVars() > 0 ||
tmpSliceCst.getNumLocalVars() > 0 ||
failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Unable to compute union bounding box of slice bounds\n");
+ LDBG() << "Unable to compute union bounding box of slice bounds";
return SliceComputationResult::GenericFailure;
}
}
@@ -1639,7 +1632,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
// Empty union.
if (sliceUnionCst.getNumDimAndSymbolVars() == 0) {
- LLVM_DEBUG(llvm::dbgs() << "empty slice union - unexpected\n");
+ LDBG() << "empty slice union - unexpected";
return SliceComputationResult::GenericFailure;
}
@@ -1652,7 +1645,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
unsigned innermostCommonLoopDepth =
getInnermostCommonLoopDepth(ops, &surroundingLoops);
if (loopDepth > innermostCommonLoopDepth) {
- LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n");
+ LDBG() << "Exceeds max loop depth";
return SliceComputationResult::GenericFailure;
}
@@ -1696,7 +1689,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
// that the slice is valid, otherwise return appropriate failure status.
std::optional<bool> isSliceValid = sliceUnion->isSliceValid();
if (!isSliceValid) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot determine if the slice is valid\n");
+ LDBG() << "Cannot determine if the slice is valid";
return SliceComputationResult::GenericFailure;
}
if (!*isSliceValid)
@@ -2050,7 +2043,8 @@ static std::optional<int64_t> getMemoryFootprintBytes(Block &block,
if (failed(
region->compute(opInst,
/*loopDepth=*/getNestingDepth(&*block.begin())))) {
- LLVM_DEBUG(opInst->emitError("error obtaining memory region"));
+ LDBG() << "Error obtaining memory region";
+ opInst->emitError("error obtaining memory region");
return failure();
}
@@ -2058,9 +2052,11 @@ static std::optional<int64_t> getMemoryFootprintBytes(Block &block,
if (inserted) {
it->second = std::move(region);
} else if (failed(it->second->unionBoundingBox(*region))) {
- LLVM_DEBUG(opInst->emitWarning(
+ LDBG() << "getMemoryFootprintBytes: unable to perform a union on a "
+ "memory region";
+ opInst->emitWarning(
"getMemoryFootprintBytes: unable to perform a union on a memory "
- "region"));
+ "region");
return failure();
}
return WalkResult::advance();
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 22608a1..7e5ce26 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -427,6 +427,21 @@ bool mlir::affine::isValidSymbol(Value value) {
return false;
}
+/// A utility function to check if a value is defined at the top level of
+/// `region` or is an argument of `region` or is defined above the region.
+static bool isTopLevelValueOrAbove(Value value, Region *region) {
+ Region *parentRegion = value.getParentRegion();
+ do {
+ if (parentRegion == region)
+ return true;
+ Operation *regionOp = region->getParentOp();
+ if (regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
+ break;
+ region = region->getParentOp()->getParentRegion();
+ } while (region);
+ return false;
+}
+
/// A value can be used as a symbol for `region` iff it meets one of the
/// following conditions:
/// *) It is a constant.
@@ -445,19 +460,12 @@ bool mlir::affine::isValidSymbol(Value value, Region *region) {
return false;
// A top-level value is a valid symbol.
- if (region && ::isTopLevelValue(value, region))
+ if (region && isTopLevelValueOrAbove(value, region))
return true;
auto *defOp = value.getDefiningOp();
- if (!defOp) {
- // A block argument that is not a top-level value is a valid symbol if it
- // dominates region's parent op.
- Operation *regionOp = region ? region->getParentOp() : nullptr;
- if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
- if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
- return isValidSymbol(value, parentOpRegion);
+ if (!defOp)
return false;
- }
// Constant operation is ok.
Attribute operandCst;
@@ -475,12 +483,6 @@ bool mlir::affine::isValidSymbol(Value value, Region *region) {
if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
return isDimOpValidSymbol(dimOp, region);
- // Check for values dominating `region`'s parent op.
- Operation *regionOp = region ? region->getParentOp() : nullptr;
- if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
- if (auto *parentRegion = region->getParentOp()->getParentRegion())
- return isValidSymbol(value, parentRegion);
-
return false;
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 6c9adff..ff0157e 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -26,6 +26,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#include <iomanip>
#include <optional>
@@ -95,8 +96,8 @@ static bool canRemoveSrcNodeAfterFusion(
// Otherwise, the src loop can't be removed.
if (fusedLoopInsPoint != depNodeOp &&
!fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) {
- LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't "
- "dominate dependence\n");
+ LDBG() << "Src loop can't be removed: dst loop doesn't "
+ << "dominate dependence";
return false;
}
@@ -109,14 +110,13 @@ static bool canRemoveSrcNodeAfterFusion(
if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
std::optional<bool> isMaximal = fusionSlice.isMaximal();
if (!isMaximal) {
- LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine "
- "if fusion is maximal\n");
+ LDBG() << "Src loop can't be removed: can't determine "
+ << "if fusion is maximal";
return false;
}
if (!*isMaximal) {
- LLVM_DEBUG(llvm::dbgs()
- << "Src loop can't be removed: fusion is not maximal\n");
+ LDBG() << "Src loop can't be removed: fusion is not maximal";
return false;
}
}
@@ -190,7 +190,8 @@ static bool isEscapingMemref(Value memref, Block *block) {
// Check if this is defined to be an alias of another memref.
if (auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp))
- if (isEscapingMemref(viewOp.getViewSource(), block))
+ if (memref == viewOp.getViewDest() &&
+ isEscapingMemref(viewOp.getViewSource(), block))
return true;
// Any op besides allocating ops wouldn't guarantee alias freedom
@@ -279,19 +280,19 @@ static std::optional<double> getAdditionalComputeFraction(
AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth,
ArrayRef<ComputationSliceState> depthSliceUnions, int64_t &sliceCost,
int64_t &fusedLoopNestComputeCost) {
- LLVM_DEBUG(llvm::dbgs() << "Determining additional compute fraction...\n";);
+ LDBG() << "Determining additional compute fraction...";
// Compute cost of sliced and unsliced src loop nest.
// Walk src loop nest and collect stats.
LoopNestStats srcLoopNestStats;
if (!getLoopNestStats(srcForOp, &srcLoopNestStats)) {
- LLVM_DEBUG(llvm::dbgs() << "Failed to get source loop nest stats.\n");
+ LDBG() << "Failed to get source loop nest stats.";
return std::nullopt;
}
// Compute cost of dst loop nest.
LoopNestStats dstLoopNestStats;
if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) {
- LLVM_DEBUG(llvm::dbgs() << "Failed to get destination loop nest stats.\n");
+ LDBG() << "Failed to get destination loop nest stats.";
return std::nullopt;
}
@@ -304,14 +305,14 @@ static std::optional<double> getAdditionalComputeFraction(
const ComputationSliceState &slice = depthSliceUnions[depth - 1];
// Skip slice union if it wasn't computed for this depth.
if (slice.isEmpty()) {
- LLVM_DEBUG(llvm::dbgs() << "Slice wasn't computed.\n");
+ LDBG() << "Slice wasn't computed.";
return std::nullopt;
}
if (!getFusionComputeCost(srcForOp, srcLoopNestStats, dstForOp,
dstLoopNestStats, slice,
&fusedLoopNestComputeCost)) {
- LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n");
+ LDBG() << "Unable to compute fusion compute cost";
return std::nullopt;
}
@@ -348,9 +349,8 @@ static Value createPrivateMemRef(AffineForOp forOp,
MemRefAccess bM(cast<AffineWriteOpInterface>(b));
return aM == bM;
})) {
- LLVM_DEBUG(llvm::dbgs()
- << "Private memref creation unsupported for multiple producer "
- "stores with different access functions.\n");
+ LDBG() << "Private memref creation unsupported for multiple producer "
+ << "stores with different access functions.";
return nullptr;
}
@@ -455,8 +455,7 @@ static Value createPrivateMemRef(AffineForOp forOp,
assert(succeeded(res) &&
"replaceAllMemrefUsesWith should always succeed here");
(void)res;
- LLVM_DEBUG(llvm::dbgs() << "Created private memref of type: " << newMemRefType
- << '\n');
+ LDBG() << "Created private memref of type: " << newMemRefType;
return newMemRef;
}
@@ -505,15 +504,12 @@ static bool isFusionProfitable(AffineForOp srcForOp,
unsigned maxLegalFusionDepth,
unsigned *dstLoopDepth,
double computeToleranceThreshold) {
- LLVM_DEBUG({
- llvm::dbgs()
- << "Checking whether fusion is profitable between source nest:\n";
- llvm::dbgs() << ' ' << srcForOp << " and destination nest:\n";
- llvm::dbgs() << dstForOp << "\n";
- });
+ LDBG() << "Checking whether fusion is profitable between source nest:";
+ LDBG() << ' ' << srcForOp << " and destination nest:";
+ LDBG() << dstForOp;
if (maxLegalFusionDepth == 0) {
- LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth is 0\n");
+ LDBG() << "Can't fuse: maxLegalFusionDepth is 0";
return false;
}
@@ -537,8 +533,8 @@ static bool isFusionProfitable(AffineForOp srcForOp,
// TODO: Suppport multiple producer stores in profitability
// analysis.
if (producerStores.size() > 1) {
- LLVM_DEBUG(llvm::dbgs() << "Limited profitability analysis. Not "
- "supported for multiple producer store case.\n");
+ LDBG() << "Limited profitability analysis. Not "
+ << "supported for multiple producer store case.";
int64_t sliceCost;
int64_t fusedLoopNestComputeCost;
// We will still fuse if fusion obeys the specified compute
@@ -547,12 +543,11 @@ static bool isFusionProfitable(AffineForOp srcForOp,
srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost,
fusedLoopNestComputeCost);
if (!fraction || fraction > computeToleranceThreshold) {
- LLVM_DEBUG(llvm::dbgs() << "Additional computation exceeds "
- "compute tolerance. Not fusing.\n");
+ LDBG() << "Additional computation exceeds "
+ << "compute tolerance. Not fusing.";
return false;
}
- LLVM_DEBUG(llvm::dbgs()
- << "Considering fusion profitable at max legal depth.\n");
+ LDBG() << "Considering fusion profitable at max legal depth.";
return true;
}
@@ -574,8 +569,7 @@ static bool isFusionProfitable(AffineForOp srcForOp,
// Compute src loop nest write region size.
MemRefRegion srcWriteRegion(srcStoreOp->getLoc());
if (failed(srcWriteRegion.compute(srcStoreOp, /*loopDepth=*/0))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Unable to compute MemRefRegion for source operation\n");
+ LDBG() << "Unable to compute MemRefRegion for source operation";
return false;
}
@@ -609,8 +603,7 @@ static bool isFusionProfitable(AffineForOp srcForOp,
getAdditionalComputeFraction(srcForOp, dstForOp, i, depthSliceUnions,
sliceCost, fusedLoopNestComputeCost);
if (!mayAdditionalComputeFraction) {
- LLVM_DEBUG(llvm::dbgs()
- << "Can't determine additional compute fraction.\n");
+ LDBG() << "Can't determine additional compute fraction.";
continue;
}
double additionalComputeFraction = *mayAdditionalComputeFraction;
@@ -620,9 +613,7 @@ static bool isFusionProfitable(AffineForOp srcForOp,
// depth 'i'.
MemRefRegion sliceWriteRegion(srcStoreOp->getLoc());
if (failed(sliceWriteRegion.compute(srcStoreOp, /*loopDepth=*/0, &slice))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Failed to compute slice write region at loopDepth: " << i
- << "\n");
+ LDBG() << "Failed to compute slice write region at loopDepth: " << i;
continue;
}
@@ -630,9 +621,7 @@ static bool isFusionProfitable(AffineForOp srcForOp,
sliceWriteRegion.getRegionSize();
if (!maybeSliceWriteRegionSizeBytes.has_value() ||
*maybeSliceWriteRegionSizeBytes == 0) {
- LLVM_DEBUG(llvm::dbgs()
- << "Failed to get slice write region size at loopDepth: " << i
- << "\n");
+ LDBG() << "Failed to get slice write region size at loopDepth: " << i;
continue;
}
int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
@@ -649,9 +638,8 @@ static bool isFusionProfitable(AffineForOp srcForOp,
<< " storage reduction factor: " << storageReduction << "x\n"
<< " fused nest cost: " << fusedLoopNestComputeCost << "\n"
<< " src write region size: " << srcWriteRegionSizeBytes << "\n"
- << " slice write region size: " << sliceWriteRegionSizeBytes
- << "\n";
- llvm::dbgs() << msg.str();
+ << " slice write region size: " << sliceWriteRegionSizeBytes;
+ LDBG() << msg.str();
});
// TODO: This is a placeholder cost model.
@@ -670,28 +658,24 @@ static bool isFusionProfitable(AffineForOp srcForOp,
// A simple cost model: fuse if it reduces the memory footprint.
if (!bestDstLoopDepth) {
- LLVM_DEBUG(
- llvm::dbgs()
- << "All fusion choices involve more than the threshold amount of "
- "redundant computation; NOT fusing.\n");
+ LDBG() << "All fusion choices involve more than the threshold amount of "
+ << "redundant computation; NOT fusing.";
return false;
}
if (!bestDstLoopDepth) {
- LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n");
+ LDBG() << "no fusion depth could be evaluated.";
return false;
}
// Set dstLoopDepth based on best values from search.
*dstLoopDepth = *bestDstLoopDepth;
- LLVM_DEBUG(
- llvm::dbgs() << " LoopFusion fusion stats:"
- << "\n best loop depth: " << bestDstLoopDepth
- << "\n src loop nest compute cost: " << srcLoopNestCost
- << "\n dst loop nest compute cost: " << dstLoopNestCost
- << "\n fused loop nest compute cost: "
- << minFusedLoopNestComputeCost << "\n");
+ LDBG() << " LoopFusion fusion stats:";
+ LDBG() << " best loop depth: " << bestDstLoopDepth;
+ LDBG() << " src loop nest compute cost: " << srcLoopNestCost;
+ LDBG() << " dst loop nest compute cost: " << dstLoopNestCost;
+ LDBG() << " fused loop nest compute cost: " << minFusedLoopNestComputeCost;
auto dstMemSize = getMemoryFootprintBytes(dstForOp);
auto srcMemSize = getMemoryFootprintBytes(srcForOp);
@@ -699,8 +683,7 @@ static bool isFusionProfitable(AffineForOp srcForOp,
std::optional<double> storageReduction;
if (!dstMemSize || !srcMemSize) {
- LLVM_DEBUG(llvm::dbgs()
- << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
+ LDBG() << " fusion memory benefit cannot be evaluated; NOT fusing.";
return false;
}
@@ -710,13 +693,13 @@ static bool isFusionProfitable(AffineForOp srcForOp,
assert(sliceMemEstimate && "expected value");
auto fusedMem = dstMemSizeVal + *sliceMemEstimate;
- LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
- << " dst mem: " << dstMemSizeVal << "\n"
- << " fused mem: " << fusedMem << "\n"
- << " slice mem: " << sliceMemEstimate << "\n");
+ LDBG() << " src mem: " << srcMemSizeVal;
+ LDBG() << " dst mem: " << dstMemSizeVal;
+ LDBG() << " fused mem: " << fusedMem;
+ LDBG() << " slice mem: " << sliceMemEstimate;
if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
- LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
+ LDBG() << "Fusion is not profitable; NOT fusing.";
return false;
}
storageReduction =
@@ -734,8 +717,8 @@ static bool isFusionProfitable(AffineForOp srcForOp,
<< std::setprecision(2) << additionalComputeFraction
<< "% redundant computation and a ";
msg << (storageReduction ? std::to_string(*storageReduction) : "<unknown>");
- msg << "% storage reduction.\n";
- llvm::dbgs() << msg.str();
+ msg << "% storage reduction.";
+ LDBG() << msg.str();
});
return true;
@@ -895,7 +878,7 @@ public:
/// No fusion is performed when producers with a user count greater than
/// `maxSrcUserCount` for any of the memrefs involved.
void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) {
- LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
+ LDBG() << "Evaluating dst loop " << dstId;
// Skip if this node was removed (fused into another node).
if (mdg->nodes.count(dstId) == 0)
return;
@@ -909,7 +892,7 @@ public:
if (dstNode->op->getNumResults() > 0)
return;
- LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
+ LDBG() << "Evaluating dst loop " << dstId;
// Sink sequential loops in 'dstNode' (and thus raise parallel loops)
// while preserving relative order. This can increase the maximum loop
@@ -936,18 +919,14 @@ public:
auto *srcNode = mdg->getNode(srcId);
auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
- LLVM_DEBUG(llvm::dbgs()
- << "Trying to fuse producer loop nest " << srcId
- << " with consumer loop nest " << dstId << "\n");
- LLVM_DEBUG(llvm::dbgs() << "Compute tolerance threshold: "
- << computeToleranceThreshold << '\n');
- LLVM_DEBUG(llvm::dbgs()
- << "Producer loop nest:\n"
- << *srcNode->op << "\n and consumer loop nest:\n"
- << *dstNode->op << '\n');
+ LDBG() << "Trying to fuse producer loop nest " << srcId
+ << " with consumer loop nest " << dstId;
+ LDBG() << "Compute tolerance threshold: " << computeToleranceThreshold;
+ LDBG() << "Producer loop nest:";
+ LDBG() << *srcNode->op << " and consumer loop nest:";
+ LDBG() << *dstNode->op;
- LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId
- << " for dst loop " << dstId << "\n");
+ LDBG() << "Evaluating src loop " << srcId << " for dst loop " << dstId;
// Skip if 'srcNode' is a loop nest returning values.
// TODO: support loop nests that return values.
@@ -1018,19 +997,16 @@ public:
&depthSliceUnions[i - 1], strategy);
if (result.value == FusionResult::Success) {
maxLegalFusionDepth = i;
- LLVM_DEBUG(llvm::dbgs()
- << "Found valid slice for depth: " << i << '\n');
+ LDBG() << "Found valid slice for depth: " << i;
}
}
if (maxLegalFusionDepth == 0) {
- LLVM_DEBUG(llvm::dbgs()
- << "Can't fuse: fusion is not legal at any depth\n");
+ LDBG() << "Can't fuse: fusion is not legal at any depth";
continue;
}
- LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
- << maxLegalFusionDepth << '\n');
+ LDBG() << "Max legal depth for fusion: " << maxLegalFusionDepth;
double computeToleranceThresholdToUse = computeToleranceThreshold;
@@ -1040,7 +1016,7 @@ public:
// producer-consumer memref access for example). Check this and allow
// fusion accordingly.
if (hasCyclicDependence(srcAffineForOp)) {
- LLVM_DEBUG(llvm::dbgs() << "Source nest has a cyclic dependence.\n");
+ LDBG() << "Source nest has a cyclic dependence.";
// Maximal fusion does not check for compute tolerance threshold; so
// perform the maximal fusion only when the redundanation computation
// is zero.
@@ -1053,18 +1029,15 @@ public:
srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
sliceCost, fusedLoopNestComputeCost);
if (!fraction || fraction > 0) {
- LLVM_DEBUG(
- llvm::dbgs()
- << "Can't perform maximal fusion with a cyclic dependence "
- "and non-zero additional compute.\n");
+ LDBG() << "Can't perform maximal fusion with a cyclic dependence "
+ << "and non-zero additional compute.";
return;
}
} else {
// Set redundant computation tolerance to zero regardless of what
// the user specified. Without this, fusion would be invalid.
- LLVM_DEBUG(llvm::dbgs()
- << "Setting compute tolerance to zero since "
- "source has a cylic dependence.\n");
+ LDBG() << "Setting compute tolerance to zero since "
+ << "source has a cylic dependence.";
computeToleranceThresholdToUse = 0;
}
}
@@ -1107,8 +1080,7 @@ public:
if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
removeSrcNode)) {
// Create a private version of this memref.
- LLVM_DEBUG(llvm::dbgs()
- << "Creating private memref for " << memref << '\n');
+ LDBG() << "Creating private memref for " << memref;
// Create a private version of this memref.
privateMemrefs.insert(memref);
}
@@ -1118,10 +1090,9 @@ public:
fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
dstNodeChanged = true;
- LLVM_DEBUG(llvm::dbgs()
- << "Fused src loop " << srcId << " into dst loop " << dstId
- << " at depth " << bestDstLoopDepth << ":\n"
- << dstAffineForOp << "\n");
+ LDBG() << "Fused src loop " << srcId << " into dst loop " << dstId
+ << " at depth " << bestDstLoopDepth << ":";
+ LDBG() << dstAffineForOp;
// Move 'dstAffineForOp' before 'insertPointInst' if needed.
if (fusedLoopInsPoint != dstAffineForOp)
@@ -1179,8 +1150,7 @@ public:
dstLoopCollector.memrefFrees);
if (removeSrcNode) {
- LLVM_DEBUG(llvm::dbgs()
- << "Removing src loop " << srcId << " after fusion\n");
+ LDBG() << "Removing src loop " << srcId << " after fusion";
// srcNode is no longer valid after it is removed from mdg.
srcAffineForOp.erase();
mdg->removeNode(srcId);
@@ -1195,7 +1165,7 @@ public:
/// user count greater than `maxSrcUserCount` for any of the memrefs involved
/// are encountered.
void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
- LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
+ LDBG() << "--- Producer/Consumer Fusion ---";
init();
while (!worklist.empty()) {
unsigned dstId = worklist.back();
@@ -1207,7 +1177,7 @@ public:
// Visits each node in the graph, and for each node, attempts to fuse it with
// its sibling nodes (nodes which share a parent, but no dependence edges).
void fuseSiblingNodes() {
- LLVM_DEBUG(llvm::dbgs() << "--- Sibling Fusion ---\n");
+ LDBG() << "--- Sibling Fusion ---";
init();
while (!worklist.empty()) {
unsigned dstId = worklist.back();
@@ -1289,8 +1259,7 @@ public:
maxLegalFusionDepth = i;
}
- LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
- << maxLegalFusionDepth << '\n');
+ LDBG() << "Max legal depth for fusion: " << maxLegalFusionDepth;
// Skip if fusion is not feasible at any loop depths.
if (maxLegalFusionDepth == 0)
@@ -1304,7 +1273,7 @@ public:
// producer-consumer memref access for example). Check this and allow
// fusion accordingly.
if (hasCyclicDependence(sibAffineForOp)) {
- LLVM_DEBUG(llvm::dbgs() << "Source nest has a cyclic dependence.\n");
+ LDBG() << "Source nest has a cyclic dependence.";
// Maximal fusion does not check for compute tolerance threshold; so
// perform the maximal fusion only when the redundanation computation is
// zero.
@@ -1316,17 +1285,15 @@ public:
sibAffineForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
sliceCost, fusedLoopNestComputeCost);
if (!fraction || fraction > 0) {
- LLVM_DEBUG(
- llvm::dbgs()
- << "Can't perform maximal fusion with a cyclic dependence "
- "and non-zero additional compute.\n");
+ LDBG() << "Can't perform maximal fusion with a cyclic dependence "
+ << "and non-zero additional compute.";
return;
}
} else {
// Set redundant computation tolerance to zero regardless of what the
// user specified. Without this, fusion would be invalid.
- LLVM_DEBUG(llvm::dbgs() << "Setting compute tolerance to zero since "
- "source has a cyclic dependence.\n");
+ LDBG() << "Setting compute tolerance to zero since "
+ << "source has a cyclic dependence.";
computeToleranceThresholdToUse = 0.0;
}
}
@@ -1356,8 +1323,7 @@ public:
// slice is used in the destination.
auto isMaximal = bestSlice.isMaximal();
if (!isMaximal.value_or(false)) {
- LLVM_DEBUG(llvm::dbgs()
- << "Slice isn't maximal; not performing sibling fusion.\n");
+ LDBG() << "Slice isn't maximal; not performing sibling fusion.";
continue;
}
@@ -1374,10 +1340,9 @@ public:
if (insertPointInst != dstForInst)
dstForInst->moveBefore(insertPointInst);
- LLVM_DEBUG(llvm::dbgs()
- << "Fused sibling nest " << sibId << " into destination nest "
- << dstNode->id << " at depth " << bestDstLoopDepth << ":\n"
- << dstAffineForOp << "\n");
+ LDBG() << "Fused sibling nest " << sibId << " into destination nest "
+ << dstNode->id << " at depth " << bestDstLoopDepth << ":";
+ LDBG() << dstAffineForOp;
// Update data dependence graph state post fusion.
updateStateAfterSiblingFusion(sibNode, dstNode);
@@ -1555,7 +1520,7 @@ public:
void LoopFusion::runOnBlock(Block *block) {
MemRefDependenceGraph g(*block);
if (!g.init()) {
- LLVM_DEBUG(llvm::dbgs() << "MDG init failed\n");
+ LDBG() << "MDG init failed";
return;
}
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
index 41cd739..c6abb0d 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
@@ -251,20 +252,20 @@ FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp,
FusionStrategy fusionStrategy) {
// Return 'failure' if 'dstLoopDepth == 0'.
if (dstLoopDepth == 0) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n");
+ LDBG() << "Cannot fuse loop nests at depth 0";
return FusionResult::FailPrecondition;
}
// Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block.
auto *block = srcForOp->getBlock();
if (block != dstForOp->getBlock()) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n");
+ LDBG() << "Cannot fuse loop nests in different blocks";
return FusionResult::FailPrecondition;
}
// Return 'failure' if no valid insertion point for fused loop nest in 'block'
// exists which would preserve dependences.
if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) {
- LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n");
+ LDBG() << "Fusion would violate dependences in block";
return FusionResult::FailBlockDependence;
}
@@ -277,14 +278,14 @@ FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp,
// Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'.
SmallVector<Operation *, 4> opsA;
if (!gatherLoadsAndStores(forOpA, opsA)) {
- LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
+ LDBG() << "Fusing loops with affine.if unsupported";
return FusionResult::FailPrecondition;
}
// Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'.
SmallVector<Operation *, 4> opsB;
if (!gatherLoadsAndStores(forOpB, opsB)) {
- LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
+ LDBG() << "Fusing loops with affine.if unsupported";
return FusionResult::FailPrecondition;
}
@@ -296,7 +297,7 @@ FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp,
// TODO: 'getMaxLoopDepth' does not support forward slice fusion.
assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion");
if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) {
- LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n");
+ LDBG() << "Fusion would violate loop dependences";
return FusionResult::FailFusionDependence;
}
}
@@ -339,12 +340,12 @@ FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp,
strategyOpsA, opsB, dstLoopDepth, numCommonLoops,
isSrcForOpBeforeDstForOp, srcSlice);
if (sliceComputationResult.value == SliceComputationResult::GenericFailure) {
- LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
+ LDBG() << "computeSliceUnion failed";
return FusionResult::FailPrecondition;
}
if (sliceComputationResult.value ==
SliceComputationResult::IncorrectSliceFailure) {
- LLVM_DEBUG(llvm::dbgs() << "Incorrect slice computation\n");
+ LDBG() << "Incorrect slice computation";
return FusionResult::FailIncorrectSlice;
}
@@ -477,7 +478,7 @@ bool mlir::affine::getLoopNestStats(AffineForOp forOpRoot,
auto *parentForOp = forOp->getParentOp();
if (forOp != forOpRoot) {
if (!isa<AffineForOp>(parentForOp)) {
- LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n");
+ LDBG() << "Expected parent AffineForOp";
return WalkResult::interrupt();
}
// Add mapping to 'forOp' from its parent AffineForOp.
@@ -498,7 +499,7 @@ bool mlir::affine::getLoopNestStats(AffineForOp forOpRoot,
std::optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
if (!maybeConstTripCount) {
// Currently only constant trip count loop nests are supported.
- LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n");
+ LDBG() << "Non-constant trip count unsupported";
return WalkResult::interrupt();
}
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 2de057d..cd216ef 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -21,9 +21,11 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
@@ -365,12 +367,11 @@ checkIfHyperRectangular(MutableArrayRef<AffineForOp> input) {
if (input.size() <= 1)
return success();
if (failed(getIndexSet(ops, &cst))) {
- LLVM_DEBUG(llvm::dbgs() << "Index set computation failed!\n");
+ LDBG() << "Index set computation failed!";
return failure();
}
if (!cst.isHyperRectangular(0, input.size())) {
- LLVM_DEBUG(llvm::dbgs()
- << "Non-hyperrectangular nests not supported for tiling!\n");
+ LDBG() << "Non-hyperrectangular nests not supported for tiling!";
return failure();
}
return success();
@@ -385,14 +386,13 @@ static LogicalResult performPreTilingChecks(MutableArrayRef<AffineForOp> input,
if (llvm::any_of(input,
[](AffineForOp op) { return op.getNumResults() > 0; })) {
- LLVM_DEBUG(llvm::dbgs()
- << "Cannot tile nest where a loop has yield values\n");
+ LDBG() << "Cannot tile nest where a loop has yield values";
return failure();
}
// Check if the supplied `for` ops are all successively nested.
if (!isPerfectlyNested(input)) {
- LLVM_DEBUG(llvm::dbgs() << "input loops not perfectly nested");
+ LDBG() << "input loops not perfectly nested";
return failure();
}
@@ -1098,7 +1098,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
// If the trip count is lower than the unroll jam factor, no unroll jam.
if (mayBeConstantTripCount && *mayBeConstantTripCount < unrollJamFactor) {
- LLVM_DEBUG(llvm::dbgs() << "[failed] trip count < unroll-jam factor\n");
+ LDBG() << "[failed] trip count < unroll-jam factor";
return failure();
}
@@ -1339,6 +1339,15 @@ bool mlir::affine::isValidLoopInterchangePermutation(
unsigned maxLoopDepth = loops.size();
if (maxLoopDepth == 1)
return true;
+
+ // We cannot guarantee the validity of the interchange if the loops have
+ // iter_args, since the dependence analysis does not take them into account.
+ // Conservatively return false in such cases.
+ if (llvm::any_of(loops, [](AffineForOp loop) {
+ return loop.getNumIterOperands() > 0;
+ }))
+ return false;
+
// Gather dependence components for dependences between all ops in loop nest
// rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth].
std::vector<SmallVector<DependenceComponent, 2>> depCompsVec;
@@ -1766,9 +1775,7 @@ findHighestBlockForPlacement(const MemRefRegion &region, Block &block,
// We can't hoist past the definition of the memref being copied.
Value memref = region.memref;
if (!memref.getParentRegion()->isAncestor(enclosingOp->getParentRegion())) {
- LLVM_DEBUG(
- llvm::dbgs()
- << "memref definition will end up not dominating hoist location\n");
+ LDBG() << "memref definition will end up not dominating hoist location";
break;
}
@@ -1977,7 +1984,7 @@ static LogicalResult generateCopy(
auto memRefType = cast<MemRefType>(memref.getType());
if (!memRefType.getLayout().isIdentity()) {
- LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
+ LDBG() << "Non-identity layout map not yet supported";
return failure();
}
@@ -1989,7 +1996,7 @@ static LogicalResult generateCopy(
unsigned rank = memRefType.getRank();
if (rank == 0) {
- LLVM_DEBUG(llvm::dbgs() << "Non-zero ranked memrefs supported\n");
+ LDBG() << "Non-zero ranked memrefs supported";
return failure();
}
@@ -2001,19 +2008,18 @@ static LogicalResult generateCopy(
std::optional<int64_t> numElements =
region.getConstantBoundingSizeAndShape(&fastBufferShape, &lbs);
if (!numElements) {
- LLVM_DEBUG(llvm::dbgs() << "Non-constant region size not supported\n");
+ LDBG() << "Non-constant region size not supported";
return failure();
}
if (llvm::any_of(lbs, [](AffineMap lb) { return lb.getNumResults() > 1; })) {
// This can be supported in the future if needed.
- LLVM_DEBUG(llvm::dbgs()
- << "Max lower bound for memref region start not supported\n");
+ LDBG() << "Max lower bound for memref region start not supported";
return failure();
}
if (*numElements == 0) {
- LLVM_DEBUG(llvm::dbgs() << "Nothing to copy\n");
+ LDBG() << "Nothing to copy";
return success();
}
@@ -2021,9 +2027,8 @@ static LogicalResult generateCopy(
for (unsigned i = 0; i < rank; ++i) {
region.getLowerAndUpperBound(i, lbMaps[i], ubMaps[i]);
if (lbMaps[i].getNumResults() == 0 || ubMaps[i].getNumResults() == 0) {
- LLVM_DEBUG(llvm::dbgs()
- << "Missing lower or upper bound for region along dimension: "
- << i << '\n');
+ LDBG() << "Missing lower or upper bound for region along dimension: "
+ << i;
return failure();
}
}
@@ -2122,7 +2127,7 @@ static LogicalResult generateCopy(
// TODO: use all stride levels once DmaStartOp is extended for
// multi-level strides.
if (dmaStrideInfos.size() > 1) {
- LLVM_DEBUG(llvm::dbgs() << "Only up to one level of stride supported\n");
+ LDBG() << "Only up to one level of stride supported";
return failure();
}
@@ -2309,10 +2314,11 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
// surrounding the this block range.
unsigned copyDepth = getNestingDepth(&*begin);
- LLVM_DEBUG(llvm::dbgs() << "Generating copies at depth " << copyDepth
- << "\n");
- LLVM_DEBUG(llvm::dbgs() << "from begin: " << *begin << "\n");
- LLVM_DEBUG(llvm::dbgs() << "to inclusive end: " << *std::prev(end) << "\n");
+ LDBG() << "Generating copies at depth " << copyDepth;
+ LDBG() << "from begin: "
+ << OpWithFlags(&*begin, OpPrintingFlags().skipRegions());
+ LDBG() << "to inclusive end: "
+ << OpWithFlags(&*std::prev(end), OpPrintingFlags().skipRegions());
// List of memory regions to copy for. We need a map vector to have a
// guaranteed iteration order to write test cases. CHECK-DAG doesn't help here
@@ -2349,8 +2355,8 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
return;
if (!memref.getParentRegion()->isAncestor(block->getParent())) {
- LLVM_DEBUG(llvm::dbgs() << "memref definition is inside of the depth at "
- "which copy-in/copy-out would happen\n");
+ LDBG() << "memref definition is inside of the depth at "
+ << "which copy-in/copy-out would happen";
return;
}
@@ -2358,12 +2364,10 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
auto region = std::make_unique<MemRefRegion>(opInst->getLoc());
if (failed(region->compute(opInst, copyDepth, /*sliceState=*/nullptr,
/*addMemRefDimBounds=*/false))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Error obtaining memory region: semi-affine maps?\n");
- LLVM_DEBUG(llvm::dbgs() << "over-approximating to the entire memref\n");
+ LDBG() << "Error obtaining memory region: semi-affine maps?";
+ LDBG() << "over-approximating to the entire memref";
if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) {
- LLVM_DEBUG(
- opInst->emitError("non-constant memref sizes not yet supported"));
+ LDBG() << "non-constant memref sizes not yet supported";
error = true;
return;
}
@@ -2392,13 +2396,11 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
// Perform a union with the existing region.
if (failed(it->second->unionBoundingBox(*region))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Memory region bounding box failed; "
- "over-approximating to the entire memref\n");
+ LDBG() << "Memory region bounding box failed; "
+ << "over-approximating to the entire memref";
// If the union fails, we will overapproximate.
if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) {
- LLVM_DEBUG(opInst->emitError(
- "non-constant memref sizes not yet supported"));
+ LDBG() << "non-constant memref sizes not yet supported";
error = true;
return true;
}
@@ -2428,8 +2430,7 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
});
if (error) {
- LLVM_DEBUG(begin->emitError(
- "copy generation failed for one or more memref's in this block\n"));
+ LDBG() << "copy generation failed for one or more memref's in this block";
return failure();
}
@@ -2466,8 +2467,7 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
processRegions(writeRegions);
if (!ret) {
- LLVM_DEBUG(begin->emitError(
- "copy generation failed for one or more memref's in this block\n"));
+ LDBG() << "copy generation failed for one or more memref's in this block";
return failure();
}
@@ -2608,7 +2608,7 @@ static AffineIfOp createSeparationCondition(MutableArrayRef<AffineForOp> loops,
/*boundFloorDivisor=*/nullptr,
/*ub=*/nullptr, &fullTileLbPos,
&fullTileUbPos)) {
- LLVM_DEBUG(llvm::dbgs() << "Can't get constant diff pair for a loop\n");
+ LDBG() << "Can't get constant diff pair for a loop";
return nullptr;
}
@@ -2667,8 +2667,7 @@ createFullTiles(MutableArrayRef<AffineForOp> inputNest,
for (auto loop : inputNest) {
// TODO: straightforward to generalize to a non-unit stride.
if (loop.getStepAsInt() != 1) {
- LLVM_DEBUG(llvm::dbgs()
- << "[tile separation] non-unit stride not implemented\n");
+ LDBG() << "[tile separation] non-unit stride not implemented";
return failure();
}
SmallVector<Operation *, 1> loopOp{loop.getOperation()};
@@ -2682,8 +2681,8 @@ createFullTiles(MutableArrayRef<AffineForOp> inputNest,
/*boundFloorDivisor=*/nullptr,
/*ub=*/nullptr, &lbPos, &ubPos) ||
lbPos == ubPos) {
- LLVM_DEBUG(llvm::dbgs() << "[tile separation] Can't get constant diff / "
- "equalities not yet handled\n");
+ LDBG() << "[tile separation] Can't get constant diff / "
+ << "equalities not yet handled";
return failure();
}
@@ -2741,8 +2740,8 @@ mlir::affine::separateFullTiles(MutableArrayRef<AffineForOp> inputNest,
AffineIfOp ifOp = createSeparationCondition(inputNest, b);
if (!ifOp) {
fullTileLoops.front().erase();
- LLVM_DEBUG(llvm::dbgs() << "All tiles are full tiles, or failure creating "
- "separation condition\n");
+ LDBG() << "All tiles are full tiles, or failure creating "
+ << "separation condition";
return failure();
}
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 488c3c3..7d4d818 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2678,6 +2678,7 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
case AtomicRMWKind::addi:
case AtomicRMWKind::maxu:
case AtomicRMWKind::ori:
+ case AtomicRMWKind::xori:
return builder.getZeroAttr(resultType);
case AtomicRMWKind::andi:
return builder.getIntegerAttr(
@@ -2736,7 +2737,7 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
// Integer operations.
.Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
.Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
- .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
+ .Case([](arith::XOrIOp op) { return AtomicRMWKind::xori; })
.Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
.Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
.Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
@@ -2806,6 +2807,8 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
return arith::OrIOp::create(builder, loc, lhs, rhs);
case AtomicRMWKind::andi:
return arith::AndIOp::create(builder, loc, lhs, rhs);
+ case AtomicRMWKind::xori:
+ return arith::XOrIOp::create(builder, loc, lhs, rhs);
// TODO: Add remaining reduction operations.
default:
(void)emitOptionalError(loc, "Reduction operation type not supported");
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 93682a9..4780dbb 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -12,7 +12,7 @@ add_mlir_dialect_library(MLIRArithTransforms
UnsignedWhenEquivalent.cpp
ADDITIONAL_HEADER_DIRS
- {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arith/Transforms
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arith/Transforms
DEPENDS
MLIRArithTransformsIncGen
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
index 1aa8064..35365f2 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
@@ -158,13 +158,11 @@ protected:
PatternRewriter &rewriter) {
// Check iterator types for matrix multiplication.
SmallVector<vector::IteratorType> itTypes = op.getIteratorTypesArray();
- if (!((itTypes.size() == 3 &&
- (itTypes[0] == vector::IteratorType::parallel &&
- itTypes[1] == vector::IteratorType::parallel &&
- itTypes[2] == vector::IteratorType::reduction)) ||
- (itTypes.size() == 2 &&
- (itTypes[0] == vector::IteratorType::parallel &&
- itTypes[1] == vector::IteratorType::reduction))))
+ if ((itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+ itTypes[1] != vector::IteratorType::parallel ||
+ itTypes[2] != vector::IteratorType::reduction) &&
+ (itTypes.size() != 2 || itTypes[0] != vector::IteratorType::parallel ||
+ itTypes[1] != vector::IteratorType::reduction))
return rewriter.notifyMatchFailure(
op, "iterator types do not correspond to matrix multiplication");
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
index 35b0bd1..6cb2a56 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
@@ -183,9 +183,9 @@ protected:
Value acc;
// Conventional names for matrix dimensions.
- int64_t M = 0;
- int64_t N = 0;
- int64_t K = 0;
+ int64_t m = 0;
+ int64_t n = 0;
+ int64_t k = 0;
// Create the matrix mulitply and accumulate operation according to
// `mmlaOp`.
@@ -286,41 +286,41 @@ Value VectorContractRewriter::lower(vector::ContractionOp op,
// Single-dimension vector type for the entire RHS tile.
- auto flatRhsTileType = VectorType::get(/*shape=*/K * N, operandEltType,
+ auto flatRhsTileType = VectorType::get(/*shape=*/k * n, operandEltType,
/*scalableDims=*/{true});
// Vector type having the same number of elements as a row in the
// accumulator/output tile and the same element type.
- auto accRowTy = VectorType::get(/*shape=*/N, resultEltType,
+ auto accRowTy = VectorType::get(/*shape=*/n, resultEltType,
/*scalableDims=*/{true});
// Vector type having twice the number of elements as a row in the
// accumulator/output tile the same element type.
- auto accRowX2Ty = VectorType::get(/*shape=*/2 * N, resultEltType,
+ auto accRowX2Ty = VectorType::get(/*shape=*/2 * n, resultEltType,
/*scalableDims=*/{true});
// Vector type having half the number of elements as a row in the
// accumulator/output tile and an integer element type with twice the bit
// width.
- auto accRow64Ty = VectorType::get(/*shape=*/N / 2, rewriter.getI64Type(),
+ auto accRow64Ty = VectorType::get(/*shape=*/n / 2, rewriter.getI64Type(),
/*scalableDims=*/{true});
// Vector type having the same the number of elements as a row in the
// accumulator/output tile and an integer element type with twice the bit
// width.
- auto accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(),
+ auto accRowX264Ty = VectorType::get(/*shape=*/n, rewriter.getI64Type(),
/*scalableDims=*/{true});
Location loc = op.getLoc();
// Extract LHS sub-tiles with logical shape <2xK>.
SmallVector<Value> lhsTile;
- for (int64_t i = 0; i < M; i += 2) {
+ for (int64_t i = 0; i < m; i += 2) {
// Extract two consecutive rows of the LHS tile.
auto r0 =
vector::ExtractOp::create(rewriter, loc, lhs, ArrayRef<int64_t>{i});
auto r1 =
vector::ExtractOp::create(rewriter, loc, lhs, ArrayRef<int64_t>{i + 1});
// Concatenate to obtain a 2 x K x <input-type> flattened sub-tile.
- SmallVector<int64_t> shuffleIdx(2 * K);
+ SmallVector<int64_t> shuffleIdx(2 * k);
std::iota(shuffleIdx.begin(), shuffleIdx.end(), 0);
auto t = vector::ShuffleOp::create(rewriter, loc, r0, r1, shuffleIdx);
// Turn it into a scalable vector.
@@ -337,13 +337,13 @@ Value VectorContractRewriter::lower(vector::ContractionOp op,
// Extract the RHS sub-tiles with logical shape <Kx[2]>.
SmallVector<Value> rhsTile;
- for (int64_t j = 0; j < N; j += 2)
+ for (int64_t j = 0; j < n; j += 2)
rhsTile.push_back(vector::ScalableExtractOp::create(
- rewriter, loc, flatRhsType, rhs, j * K));
+ rewriter, loc, flatRhsType, rhs, j * k));
// Extract and pack the ACC sub-tiles.
SmallVector<Value> accTile;
- for (int64_t i = 0; i < M; i += 2) {
+ for (int64_t i = 0; i < m; i += 2) {
// Extract two consecutive rows of the accumulator tile.
auto r0 = vector::ExtractOp::create(rewriter, loc, op.getAcc(),
ArrayRef<int64_t>{i});
@@ -370,28 +370,28 @@ Value VectorContractRewriter::lower(vector::ContractionOp op,
vector::BitCastOp::create(rewriter, loc, accRowX2Ty, intrI64);
}
// Extract ACC sub-tiles.
- for (int64_t j = 0; j < N; j += 2)
+ for (int64_t j = 0; j < n; j += 2)
accTile.push_back(vector::ScalableExtractOp::create(
rewriter, loc, flatAccType, accTileVec, j * 2));
}
// Emit sub-tile matrix multiplications.
SmallVector<Value> outTile;
- for (int64_t i = 0; i < M / 2; ++i)
- for (int64_t j = 0; j < N / 2; ++j) {
- Value mmla = createMMLA(rewriter, loc, accTile[i * N / 2 + j], lhsTile[i],
+ for (int64_t i = 0; i < m / 2; ++i)
+ for (int64_t j = 0; j < n / 2; ++j) {
+ Value mmla = createMMLA(rewriter, loc, accTile[i * n / 2 + j], lhsTile[i],
rhsTile[j]);
outTile.push_back(mmla);
}
// Unpack the OUT sub-tiles and insert into the result.
Value result = ub::PoisonOp::create(rewriter, loc, op.getResultType());
- for (int64_t i = 0; i < M / 2; ++i) {
+ for (int64_t i = 0; i < m / 2; ++i) {
// Collect a number of sub-tiles in a row.
Value row = ub::PoisonOp::create(rewriter, loc, accRowX2Ty);
- for (int64_t j = 0; j < N / 2; ++j)
+ for (int64_t j = 0; j < n / 2; ++j)
row = vector::ScalableInsertOp::create(
- rewriter, loc, outTile[i * N / 2 + j], row, j * 4);
+ rewriter, loc, outTile[i * n / 2 + j], row, j * 4);
// Unpack the row to obtain two rows of the output. If we have the out
// sub-tiles transposed we obtain two consecutive output rows by
@@ -432,9 +432,9 @@ public:
VectorType lhsType = op.getLhsType();
VectorType rhsType = op.getRhsType();
- M = lhsType.getDimSize(0);
- N = rhsType.getDimSize(0);
- K = rhsType.getDimSize(1);
+ m = lhsType.getDimSize(0);
+ n = rhsType.getDimSize(0);
+ k = rhsType.getDimSize(1);
// Check the operands have the expected shape:
// * for LHS: fixed vector MxK
@@ -442,8 +442,8 @@ public:
// * K == 8
// * M and N even and at least 2
if (lhsType.isScalable() || !rhsType.getScalableDims()[0] ||
- rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != K || K != 8 ||
- M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 ||
+ rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != k || k != 8 ||
+ m < 2 || m % 2 != 0 || n < 2 || n % 2 != 0 ||
!rhsType.getScalableDims()[0])
return rewriter.notifyMatchFailure(op, "non-matching operand shape");
@@ -504,9 +504,9 @@ public:
VectorType lhsType = op.getLhsType();
VectorType rhsType = op.getRhsType();
- M = lhsType.getDimSize(0);
- N = rhsType.getDimSize(0);
- K = rhsType.getDimSize(1);
+ m = lhsType.getDimSize(0);
+ n = rhsType.getDimSize(0);
+ k = rhsType.getDimSize(1);
// Check the operands have the expected shape:
// * for LHS: fixed vector MxK
@@ -514,8 +514,8 @@ public:
// * K == 4
// * M and N even and at least 2
if (lhsType.isScalable() || !rhsType.getScalableDims()[0] ||
- rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != K || K != 4 ||
- M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 ||
+ rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != k || k != 4 ||
+ m < 2 || m % 2 != 0 || n < 2 || n % 2 != 0 ||
!rhsType.getScalableDims()[0])
return rewriter.notifyMatchFailure(op, "non-matching operand shape");
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
index ddc64ea..91e37dd 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
@@ -248,7 +248,7 @@ LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) {
Region *definingRegion = value.getParentRegion();
// Last users of the `value` inside all blocks where the value dies.
- llvm::SmallSet<Operation *, 4> lastUsers;
+ llvm::SmallPtrSet<Operation *, 4> lastUsers;
// Find blocks in the `definingRegion` that have users of the `value` (if
// there are multiple users in the block, which one will be selected is
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index f1f12f4..56ff212 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -463,8 +463,12 @@ struct SimplifyClones : public OpRewritePattern<CloneOp> {
// which otherwise could prevent removal of unnecessary allocs.
Value canonicalSource = source;
while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
- canonicalSource.getDefiningOp()))
+ canonicalSource.getDefiningOp())) {
+ if (canonicalSource != iface.getViewDest()) {
+ break;
+ }
canonicalSource = iface.getViewSource();
+ }
std::optional<Operation *> maybeCloneDeallocOp =
memref::findDealloc(cloneOp.getOutput());
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 8916526..a465c95 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -37,8 +37,12 @@ using namespace mlir::bufferization;
/// Given a memref value, return the "base" value by skipping over all
/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
static Value getViewBase(Value value) {
- while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
+ while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>()) {
+ if (value != viewLikeOp.getViewDest()) {
+ break;
+ }
value = viewLikeOp.getViewSource();
+ }
return value;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 8f983ab..0b2e080 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -121,7 +121,7 @@ void BufferViewFlowAnalysis::build(Operation *op) {
// Add additional dependencies created by view changes to the alias list.
if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
registerDependencies(viewInterface.getViewSource(),
- viewInterface->getResult(0));
+ viewInterface.getViewDest());
return WalkResult::advance();
}
@@ -231,8 +231,12 @@ static bool isFunctionArgument(Value v) {
/// Given a memref value, return the "base" value by skipping over all
/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
static Value getViewBase(Value value) {
- while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
+ while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>()) {
+ if (value != viewLikeOp.getViewDest()) {
+ break;
+ }
value = viewLikeOp.getViewSource();
+ }
return value;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 91f6f25..68ef519 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -20,6 +20,7 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/PassManager.h"
+#include "llvm/Support/DebugLog.h"
#include <optional>
namespace mlir {
@@ -328,20 +329,16 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
"blocks");
// Bufferize the op.
- LLVM_DEBUG(llvm::dbgs()
- << "//===-------------------------------------------===//\n"
- << "IR after bufferizing: " << nextOp->getName() << "\n");
+ LDBG(3) << "//===-------------------------------------------===//\n"
+ << "IR after bufferizing: " << nextOp->getName();
rewriter.setInsertionPoint(nextOp);
if (failed(
bufferizableOp.bufferize(rewriter, options, bufferizationState))) {
- LLVM_DEBUG(llvm::dbgs()
- << "failed to bufferize\n"
- << "//===-------------------------------------------===//\n");
+ LDBG(2) << "failed to bufferize\n"
+ << "//===-------------------------------------------===//";
return nextOp->emitError("failed to bufferize op");
}
- LLVM_DEBUG(llvm::dbgs()
- << *op
- << "\n//===-------------------------------------------===//\n");
+ LDBG(3) << *op << "\n//===-------------------------------------------===//";
}
// Return early if the top-level op is entirely gone.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index a8e8353..fb7f2bb 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -56,6 +56,7 @@
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/DebugLog.h"
MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::OneShotAnalysisState)
@@ -616,13 +617,10 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
if (getParallelRegion(def.getParentRegion(), options) !=
getParallelRegion(uConflictingWrite->getOwner()->getParentRegion(),
options)) {
- LLVM_DEBUG(
- llvm::dbgs()
- << "\n- bufferizes out-of-place due to parallel region:\n");
- LLVM_DEBUG(llvm::dbgs()
- << " unConflictingWrite = operand "
- << uConflictingWrite->getOperandNumber() << " of "
- << *uConflictingWrite->getOwner() << "\n");
+ LDBG() << "\n- bufferizes out-of-place due to parallel region:\n"
+ << " unConflictingWrite = operand "
+ << uConflictingWrite->getOperandNumber() << " of "
+ << *uConflictingWrite->getOwner();
return true;
}
}
@@ -631,9 +629,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
for (OpOperand *uRead : usesRead) {
Operation *readingOp = uRead->getOwner();
- LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n");
- LLVM_DEBUG(llvm::dbgs() << " uRead = operand " << uRead->getOperandNumber()
- << " of " << *readingOp << "\n");
+ LDBG() << "\n- check conflict:\n"
+ << " uRead = operand " << uRead->getOperandNumber() << " of "
+ << *readingOp;
// Find the definition of uRead by following the SSA use-def chain.
// E.g.:
@@ -648,23 +646,22 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
const SetVector<Value> &definitions = state.findDefinitionsCached(uRead);
if (definitions.empty()) {
// Fast path: No conflict if there are no definitions.
- LLVM_DEBUG(llvm::dbgs()
- << " no conflict: read value has no definitions\n");
+ LDBG() << " no conflict: read value has no definitions";
continue;
}
// Look for conflicting memory writes. Potential conflicts are writes to an
// alias that have been decided to bufferize inplace.
for (OpOperand *uConflictingWrite : usesWrite) {
- LLVM_DEBUG(llvm::dbgs() << " unConflictingWrite = operand "
- << uConflictingWrite->getOperandNumber() << " of "
- << *uConflictingWrite->getOwner() << "\n");
+ LDBG() << " unConflictingWrite = operand "
+ << uConflictingWrite->getOperandNumber() << " of "
+ << *uConflictingWrite->getOwner();
// Check if op dominance can be used to rule out read-after-write
// conflicts.
bool useDominance =
canUseOpDominance(uRead, uConflictingWrite, definitions, state);
- LLVM_DEBUG(llvm::dbgs() << "\n- useDominance = " << useDominance << "\n");
+ LDBG() << "\n- useDominance = " << useDominance;
// Throughout this loop, check for multiple requirements that have to be
// met for uConflictingWrite to be an actual conflict.
@@ -680,8 +677,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// inside a loop), there may be no meaningful `happensBefore`
// relationship.
if (happensBefore(readingOp, conflictingWritingOp, domInfo)) {
- LLVM_DEBUG(llvm::dbgs()
- << " no conflict: read happens before write\n");
+ LDBG() << " no conflict: read happens before write";
continue;
}
@@ -693,8 +689,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// Note: If the op is executed multiple times (e.g., because it is
// inside a loop), it may be conflicting with itself.
if (uConflictingWrite == uRead) {
- LLVM_DEBUG(llvm::dbgs()
- << " no conflict: read and write are same use\n");
+ LDBG() << " no conflict: read and write are same use";
continue;
}
@@ -705,8 +700,8 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// multiple times.
if (state.insideMutuallyExclusiveRegions(readingOp,
conflictingWritingOp)) {
- LLVM_DEBUG(llvm::dbgs() << " no conflict: read and write are in "
- "mutually exclusive regions\n");
+ LDBG() << " no conflict: read and write are in "
+ "mutually exclusive regions";
continue;
}
@@ -721,9 +716,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
state, uRead, uConflictingWrite->get()) ||
hasEquivalentValueInReverseUseDefChain(
state, uConflictingWrite, uRead->get())) {
- LLVM_DEBUG(
- llvm::dbgs()
- << " no conflict: op bufferizes to element-wise access\n");
+ LDBG() << " no conflict: op bufferizes to element-wise access";
continue;
}
}
@@ -733,15 +726,14 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// No conflict if the operands are non-conflicting subsets.
if (areNonConflictingSubsets(uRead, uConflictingWrite, state)) {
- LLVM_DEBUG(llvm::dbgs() << " no conflict: non-conflicting subsets\n");
+ LDBG() << " no conflict: non-conflicting subsets";
continue;
}
// No conflict if the op interface says so.
if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) {
- LLVM_DEBUG(llvm::dbgs()
- << " no conflict: op interace of reading op says 'no'\n");
+ LDBG() << " no conflict: op interace of reading op says 'no'";
continue;
}
}
@@ -751,9 +743,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
options.dynCastBufferizableOp(conflictingWritingOp)) {
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
state)) {
- LLVM_DEBUG(
- llvm::dbgs()
- << " no conflict: op interace of writing op says 'no'\n");
+ LDBG() << " no conflict: op interace of writing op says 'no'";
continue;
}
}
@@ -761,29 +751,26 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// Check all possible definitions.
for (Value definition : definitions) {
- LLVM_DEBUG(llvm::dbgs() << " * definition = " << definition << "\n");
+ LDBG() << " * definition = " << definition;
// No conflict if the conflicting write happens before the definition.
if (Operation *defOp = definition.getDefiningOp()) {
if (happensBefore(conflictingWritingOp, defOp, domInfo)) {
// conflictingWritingOp happens before defOp. No conflict.
- LLVM_DEBUG(llvm::dbgs()
- << " no conflict: write happens before definition\n");
+ LDBG() << " no conflict: write happens before definition";
continue;
}
// No conflict if conflictingWritingOp is contained in defOp.
if (defOp->isProperAncestor(conflictingWritingOp)) {
- LLVM_DEBUG(
- llvm::dbgs()
- << " no conflict: write is contained in definition\n");
+ LDBG() << " no conflict: write is contained in definition";
continue;
}
} else {
auto bbArg = cast<BlockArgument>(definition);
Block *block = bbArg.getOwner();
if (!block->findAncestorOpInBlock(*conflictingWritingOp)) {
- LLVM_DEBUG(llvm::dbgs() << " no conflict: definition is bbArg "
- "and write happens outside of block\n");
+ LDBG() << " no conflict: definition is bbArg "
+ "and write happens outside of block";
// conflictingWritingOp happens outside of the block. No
// conflict.
continue;
@@ -795,8 +782,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
AliasingValueList aliases = state.getAliasingValues(*uConflictingWrite);
if (aliases.getNumAliases() == 1 &&
aliases.getAliases()[0].value == definition) {
- LLVM_DEBUG(llvm::dbgs()
- << " no conflict: definition and write are same\n");
+ LDBG() << " no conflict: definition and write are same";
continue;
}
@@ -804,7 +790,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
if (options.printConflicts)
annotateConflict(uRead, uConflictingWrite, definition);
- LLVM_DEBUG(llvm::dbgs() << " => RaW CONFLICT FOUND\n");
+ LDBG() << " => RaW CONFLICT FOUND";
return true;
}
}
@@ -958,7 +944,7 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
for (AliasingValue alias : state.getAliasingValues(operand))
state.applyOnAliases(alias.value, checkReadOnly);
if (foundReadOnly) {
- LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n");
+ LDBG() << "=> NOT WRITABLE";
return true;
}
@@ -987,10 +973,9 @@ void OneShotAnalysisState::resetCache() {
static LogicalResult
bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state,
const DominanceInfo &domInfo) {
- LLVM_DEBUG(
- llvm::dbgs() << "//===-------------------------------------------===//\n"
- << "Analyzing operand #" << operand.getOperandNumber()
- << " of " << *operand.getOwner() << "\n");
+ LDBG() << "//===-------------------------------------------===//\n"
+ << "Analyzing operand #" << operand.getOperandNumber() << " of "
+ << *operand.getOwner();
bool foundInterference =
wouldCreateWriteToNonWritableBuffer(operand, state) ||
@@ -1001,8 +986,7 @@ bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state,
else
state.bufferizeInPlace(operand);
- LLVM_DEBUG(llvm::dbgs()
- << "//===-------------------------------------------===//\n");
+ LDBG() << "//===-------------------------------------------===//";
return success();
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index 725fa24..b593cca 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -51,14 +51,8 @@ static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
/// Return "true" if the given op is guaranteed to have neither "Allocate" nor
/// "Free" side effects.
static bool hasNeitherAllocateNorFreeSideEffect(Operation *op) {
- if (isa<MemoryEffectOpInterface>(op))
- return !hasEffect<MemoryEffects::Allocate>(op) &&
- !hasEffect<MemoryEffects::Free>(op);
- // If the op does not implement the MemoryEffectOpInterface but has has
- // recursive memory effects, then this op in isolation (without its body) does
- // not have any side effects. All the ops inside the regions of this op will
- // be processed separately.
- return op->hasTrait<OpTrait::HasRecursiveMemoryEffects>();
+ return !mightHaveEffect<MemoryEffects::Allocate>(op) &&
+ !mightHaveEffect<MemoryEffects::Free>(op);
}
/// Return "true" if the given op has buffer semantics. I.e., it has buffer
@@ -517,9 +511,7 @@ LogicalResult BufferDeallocation::verifyOperationPreconditions(Operation *op) {
// MemoryEffectOpInterface. They usually do not have side effects apart
// from the callee, which will be analyzed separately. (This is similar to
// "recursive memory effects".)
- if (!isa<MemoryEffectOpInterface>(op) &&
- !op->hasTrait<OpTrait::HasRecursiveMemoryEffects>() &&
- !isa<CallOpInterface>(op))
+ if (hasUnknownEffects(op) && !isa<CallOpInterface>(op))
return op->emitError(
"ops with unknown memory side effects are not supported");
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
index 37b4cfc..47740d3 100644
--- a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
@@ -3,7 +3,7 @@ add_mlir_dialect_library(MLIRControlFlowTransforms
BufferizableOpInterfaceImpl.cpp
ADDITIONAL_HEADER_DIRS
- {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms
LINK_LIBS PUBLIC
MLIRBufferizationDialect
diff --git a/mlir/lib/Dialect/DLTI/Traits.cpp b/mlir/lib/Dialect/DLTI/Traits.cpp
index 34f2dd5..3f6dd29 100644
--- a/mlir/lib/Dialect/DLTI/Traits.cpp
+++ b/mlir/lib/Dialect/DLTI/Traits.cpp
@@ -24,7 +24,7 @@ LogicalResult mlir::impl::verifyHasDefaultDLTIDataLayoutTrait(Operation *op) {
}
DataLayoutSpecInterface mlir::impl::getDataLayoutSpec(Operation *op) {
- return op->getAttrOfType<DataLayoutSpecAttr>(
+ return op->getAttrOfType<DataLayoutSpecInterface>(
DLTIDialect::kDataLayoutAttrName);
}
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index e6a3154..00ce3b5 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -114,11 +114,8 @@ bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) {
bool mlir::emitc::isSupportedFloatType(Type type) {
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
switch (floatType.getWidth()) {
- case 16: {
- if (llvm::isa<Float16Type, BFloat16Type>(type))
- return true;
- return false;
- }
+ case 16:
+ return llvm::isa<Float16Type, BFloat16Type>(type);
case 32:
case 64:
return true;
@@ -134,6 +131,12 @@ bool mlir::emitc::isPointerWideType(Type type) {
type);
}
+bool mlir::emitc::isFundamentalType(Type type) {
+ return llvm::isa<IndexType>(type) || isPointerWideType(type) ||
+ isSupportedIntegerType(type) || isSupportedFloatType(type) ||
+ isa<emitc::PointerType>(type);
+}
+
/// Check that the type of the initial value is compatible with the operations
/// result type.
static LogicalResult verifyInitializationAttribute(Operation *op,
@@ -378,6 +381,52 @@ OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
// ExpressionOp
//===----------------------------------------------------------------------===//
+ParseResult ExpressionOp::parse(OpAsmParser &parser, OperationState &result) {
+ SmallVector<OpAsmParser::UnresolvedOperand> operands;
+ if (parser.parseOperandList(operands))
+ return parser.emitError(parser.getCurrentLocation()) << "expected operands";
+ if (succeeded(parser.parseOptionalKeyword("noinline")))
+ result.addAttribute(ExpressionOp::getDoNotInlineAttrName(result.name),
+ parser.getBuilder().getUnitAttr());
+ Type type;
+ if (parser.parseColonType(type))
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected function type");
+ auto fnType = llvm::dyn_cast<FunctionType>(type);
+ if (!fnType)
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected function type");
+ if (parser.resolveOperands(operands, fnType.getInputs(),
+ parser.getCurrentLocation(), result.operands))
+ return failure();
+ if (fnType.getNumResults() != 1)
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected single return type");
+ result.addTypes(fnType.getResults());
+ Region *body = result.addRegion();
+ SmallVector<OpAsmParser::Argument> argsInfo;
+ for (auto [unresolvedOperand, operandType] :
+ llvm::zip(operands, fnType.getInputs())) {
+ OpAsmParser::Argument argInfo;
+ argInfo.ssaName = unresolvedOperand;
+ argInfo.type = operandType;
+ argsInfo.push_back(argInfo);
+ }
+ if (parser.parseRegion(*body, argsInfo, /*enableNameShadowing=*/true))
+ return failure();
+ return success();
+}
+
+void emitc::ExpressionOp::print(OpAsmPrinter &p) {
+ p << ' ';
+ p.printOperands(getDefs());
+ p << " : ";
+ p.printFunctionalType(getOperation());
+ p.shadowRegionArgs(getRegion(), getDefs());
+ p << ' ';
+ p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
+}
+
Operation *ExpressionOp::getRootOp() {
auto yieldOp = cast<YieldOp>(getBody()->getTerminator());
Value yieldedValue = yieldOp.getResult();
@@ -1398,6 +1447,7 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
//===----------------------------------------------------------------------===//
// FieldOp
//===----------------------------------------------------------------------===//
+
static void printEmitCFieldOpTypeAndInitialValue(OpAsmPrinter &p, FieldOp op,
TypeAttr type,
Attribute initialValue) {
@@ -1455,6 +1505,15 @@ LogicalResult FieldOp::verify() {
//===----------------------------------------------------------------------===//
// GetFieldOp
//===----------------------------------------------------------------------===//
+
+LogicalResult GetFieldOp::verify() {
+ auto parentClassOp = getOperation()->getParentOfType<emitc::ClassOp>();
+ if (!parentClassOp.getOperation())
+ return emitOpError(" must be nested within an emitc.class operation");
+
+ return success();
+}
+
LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
mlir::FlatSymbolRefAttr fieldNameAttr = getFieldNameAttr();
FieldOp fieldOp =
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index 3f0690c..f8469b8 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -9,7 +9,9 @@
#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/STLExtras.h"
namespace mlir {
namespace emitc {
@@ -24,20 +26,24 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
Location loc = op->getLoc();
builder.setInsertionPointAfter(op);
- auto expressionOp = emitc::ExpressionOp::create(builder, loc, resultType);
+ auto expressionOp =
+ emitc::ExpressionOp::create(builder, loc, resultType, op->getOperands());
// Replace all op's uses with the new expression's result.
result.replaceAllUsesWith(expressionOp.getResult());
- // Create an op to yield op's value.
- Region &region = expressionOp.getRegion();
- Block &block = region.emplaceBlock();
+ Block &block = expressionOp.createBody();
+ IRMapping mapper;
+ for (auto [operand, arg] :
+ llvm::zip(expressionOp.getOperands(), block.getArguments()))
+ mapper.map(operand, arg);
builder.setInsertionPointToEnd(&block);
- auto yieldOp = emitc::YieldOp::create(builder, loc, result);
- // Move op into the new expression.
- op->moveBefore(yieldOp);
+ Operation *rootOp = builder.clone(*op, mapper);
+ op->erase();
+ // Create an op to yield op's value.
+ emitc::YieldOp::create(builder, loc, rootOp->getResults()[0]);
return expressionOp;
}
@@ -53,51 +59,93 @@ struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
using OpRewritePattern<ExpressionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExpressionOp expressionOp,
PatternRewriter &rewriter) const override {
- bool anythingFolded = false;
- for (Operation &op : llvm::make_early_inc_range(
- expressionOp.getBody()->without_terminator())) {
- // Don't fold expressions whose result value has its address taken.
- auto applyOp = dyn_cast<emitc::ApplyOp>(op);
- if (applyOp && applyOp.getApplicableOperator() == "&")
- continue;
-
- for (Value operand : op.getOperands()) {
- auto usedExpression = operand.getDefiningOp<ExpressionOp>();
- if (!usedExpression)
- continue;
-
- // Don't fold expressions with multiple users: assume any
- // re-materialization was done separately.
- if (!usedExpression.getResult().hasOneUse())
- continue;
-
- // Don't fold expressions with side effects.
- if (usedExpression.hasSideEffects())
- continue;
-
- // Fold the used expression into this expression by cloning all
- // instructions in the used expression just before the operation using
- // its value.
- rewriter.setInsertionPoint(&op);
- IRMapping mapper;
- for (Operation &opToClone :
- usedExpression.getBody()->without_terminator()) {
- Operation *clone = rewriter.clone(opToClone, mapper);
- mapper.map(&opToClone, clone);
- }
-
- Operation *expressionRoot = usedExpression.getRootOp();
- Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot);
- assert(clonedExpressionRootOp &&
- "Expected cloned expression root to be in mapper");
- assert(clonedExpressionRootOp->getNumResults() == 1 &&
- "Expected cloned root to have a single result");
-
- rewriter.replaceOp(usedExpression, clonedExpressionRootOp);
- anythingFolded = true;
- }
+ Block *expressionBody = expressionOp.getBody();
+ ExpressionOp usedExpression;
+ SetVector<Value> foldedOperands;
+
+ auto takesItsOperandsAddress = [](Operation *user) {
+ auto applyOp = dyn_cast<emitc::ApplyOp>(user);
+ return applyOp && applyOp.getApplicableOperator() == "&";
+ };
+
+ // Select as expression to fold the first operand expression that
+ // - doesn't have its result value's address taken,
+ // - has a single user: assume any re-materialization was done separately,
+ // - has no side effects,
+ // and save all other operands to be used later as operands in the folded
+ // expression.
+ for (auto [operand, arg] : llvm::zip(expressionOp.getOperands(),
+ expressionBody->getArguments())) {
+ ExpressionOp operandExpression = operand.getDefiningOp<ExpressionOp>();
+ if (usedExpression || !operandExpression ||
+ llvm::any_of(arg.getUsers(), takesItsOperandsAddress) ||
+ !operandExpression.getResult().hasOneUse() ||
+ operandExpression.hasSideEffects())
+ foldedOperands.insert(operand);
+ else
+ usedExpression = operandExpression;
}
- return anythingFolded ? success() : failure();
+
+ // If no operand expression was selected, bail out.
+ if (!usedExpression)
+ return failure();
+
+ // Collect additional operands from the folded expression.
+ for (Value operand : usedExpression.getOperands())
+ foldedOperands.insert(operand);
+
+ // Create a new expression to hold the folding result.
+ rewriter.setInsertionPointAfter(expressionOp);
+ auto foldedExpression = emitc::ExpressionOp::create(
+ rewriter, expressionOp.getLoc(), expressionOp.getResult().getType(),
+ foldedOperands.getArrayRef(), expressionOp.getDoNotInline());
+ Block &foldedExpressionBody = foldedExpression.createBody();
+
+ // Map each operand of the new expression to its matching block argument.
+ IRMapping mapper;
+ for (auto [operand, arg] : llvm::zip(foldedExpression.getOperands(),
+ foldedExpressionBody.getArguments()))
+ mapper.map(operand, arg);
+
+ // Prepare to fold the used expression and the matched expression into the
+ // newly created folded expression.
+ auto foldExpression = [&rewriter, &mapper](ExpressionOp expressionToFold,
+ bool withTerminator) {
+ Block *expressionToFoldBody = expressionToFold.getBody();
+ for (auto [operand, arg] :
+ llvm::zip(expressionToFold.getOperands(),
+ expressionToFoldBody->getArguments())) {
+ mapper.map(arg, mapper.lookup(operand));
+ }
+
+ for (Operation &opToClone : expressionToFoldBody->without_terminator())
+ rewriter.clone(opToClone, mapper);
+
+ if (withTerminator)
+ rewriter.clone(*expressionToFoldBody->getTerminator(), mapper);
+ };
+ rewriter.setInsertionPointToStart(&foldedExpressionBody);
+
+ // First, fold the used expression into the new expression and map its
+ // result to the clone of its root operation within the new expression.
+ foldExpression(usedExpression, /*withTerminator=*/false);
+ Operation *expressionRoot = usedExpression.getRootOp();
+ Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot);
+ assert(clonedExpressionRootOp &&
+ "Expected cloned expression root to be in mapper");
+ assert(clonedExpressionRootOp->getNumResults() == 1 &&
+ "Expected cloned root to have a single result");
+ mapper.map(usedExpression.getResult(),
+ clonedExpressionRootOp->getResults()[0]);
+
+ // Now fold the matched expression into the new expression.
+ foldExpression(expressionOp, /*withTerminator=*/true);
+
+ // Complete the rewrite.
+ rewriter.replaceOp(expressionOp, foldedExpression);
+ rewriter.eraseOp(usedExpression);
+
+ return success();
}
};
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index c55e26e..06d7e07 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -64,8 +64,8 @@ public:
TypeAttr typeAttr = TypeAttr::get(val.getType());
fields.push_back({fieldName, typeAttr});
- FieldOp fieldop = rewriter.create<emitc::FieldOp>(
- funcOp->getLoc(), fieldName, typeAttr, nullptr);
+ FieldOp fieldop = emitc::FieldOp::create(rewriter, funcOp->getLoc(),
+ fieldName, typeAttr, nullptr);
if (argAttrs && idx < argAttrs->size()) {
fieldop->setDiscardableAttrs(funcOp.getArgAttrDict(idx));
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 5a72ef1..b87b4f4 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -756,7 +756,8 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
Type asyncTokenType, ValueRange asyncDependencies,
TypeRange workgroupAttributions,
TypeRange privateAttributions, Value clusterSizeX,
- Value clusterSizeY, Value clusterSizeZ) {
+ Value clusterSizeY, Value clusterSizeZ,
+ FlatSymbolRefAttr module, FlatSymbolRefAttr function) {
OpBuilder::InsertionGuard g(builder);
// Add a WorkGroup attribution attribute. This attribute is required to
@@ -781,6 +782,12 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
if (dynamicSharedMemorySize)
result.addOperands(dynamicSharedMemorySize);
+ // Add optional module and function attributes.
+ if (module)
+ result.addAttribute(getModuleAttrName(result.name), module);
+ if (function)
+ result.addAttribute(getFunctionAttrName(result.name), function);
+
// Create a kernel body region with kNumConfigRegionAttributes + N memory
// attributions, where the first kNumConfigRegionAttributes arguments have
// `index` type and the rest have the same types as the data operands.
@@ -944,6 +951,21 @@ void LaunchOp::print(OpAsmPrinter &p) {
p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
<< getDynamicSharedMemorySize();
+ // Print optional module attribute.
+ StringRef moduleAttrName = getModuleAttrName();
+ if (auto module = getModule()) {
+ p << ' ' << moduleAttrName << '(';
+ p.printSymbolName(*module);
+ p << ')';
+ }
+ // Print optional function attribute.
+ StringRef functionAttrName = getFunctionAttrName();
+ if (auto function = getFunction()) {
+ p << ' ' << functionAttrName << '(';
+ p.printSymbolName(*function);
+ p << ')';
+ }
+
printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
@@ -952,7 +974,8 @@ void LaunchOp::print(OpAsmPrinter &p) {
p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
LaunchOp::getOperandSegmentSizeAttr(),
- getNumWorkgroupAttributionsAttrName()});
+ getNumWorkgroupAttributionsAttrName(),
+ moduleAttrName, functionAttrName});
}
// Parse the size assignment blocks for blocks and threads. These have the form
@@ -990,6 +1013,9 @@ parseSizeAssignment(OpAsmParser &parser,
/// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
/// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
/// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
+/// (`dynamic_shared_memory_size` ssa-use)?
+/// (`module(` symbol-ref-id `)`)?
+/// (`function(` symbol-ref-id `)`)?
/// memory-attribution
/// region attr-dict?
/// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
@@ -1060,6 +1086,27 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
}
+ // Parse optional module attribute.
+ StringRef moduleAttrName = getModuleAttrName(result.name);
+ if (succeeded(parser.parseOptionalKeyword(moduleAttrName))) {
+ FlatSymbolRefAttr moduleSymbol;
+ if (parser.parseLParen() ||
+ parser.parseAttribute(moduleSymbol, Type(), moduleAttrName,
+ result.attributes) ||
+ parser.parseRParen())
+ return failure();
+ }
+ // Parse optional function attribute.
+ StringRef functionAttrName = getFunctionAttrName(result.name);
+ if (succeeded(parser.parseOptionalKeyword(functionAttrName))) {
+ FlatSymbolRefAttr funcSymbol;
+ if (parser.parseLParen() ||
+ parser.parseAttribute(funcSymbol, Type(), functionAttrName,
+ result.attributes) ||
+ parser.parseRParen())
+ return failure();
+ }
+
// Create the region arguments, it has kNumConfigRegionAttributes arguments
// that correspond to block/thread identifiers and grid/block sizes, all
// having `index` type, a variadic number of WorkGroup Attributions and
@@ -2439,8 +2486,7 @@ LogicalResult WarpExecuteOnLane0Op::verify() {
if (getArgs().size() != getWarpRegion().getNumArguments())
return emitOpError(
"expected same number op arguments and block arguments.");
- auto yield =
- cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = getTerminator();
if (yield.getNumOperands() != getNumResults())
return emitOpError(
"expected same number of yield operands and return values.");
@@ -2464,6 +2510,50 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
}
+gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
+ return cast<gpu::YieldOp>(getBody()->getTerminator());
+}
+
+//===----------------------------------------------------------------------===//
+// GPU_SubgroupBroadcastOp
+//===----------------------------------------------------------------------===//
+
+void gpu::SubgroupBroadcastOp::inferResultRanges(
+ ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
+ setResultRange(getResult(), argRanges.front());
+}
+
+Speculation::Speculatability gpu::SubgroupBroadcastOp::getSpeculatability() {
+ switch (getBroadcastType()) {
+ case BroadcastType::first_active_lane:
+ // Cannot speculate first_lane broadcast, because speculating it across
+ // control flow can change the active lanes.
+ return Speculation::NotSpeculatable;
+ case BroadcastType::any_lane:
+ LLVM_FALLTHROUGH;
+ case BroadcastType::specific_lane:
+ // Speculation should be safe as long as we inside structured control flow.
+ return Speculation::Speculatable;
+ }
+}
+
+LogicalResult gpu::SubgroupBroadcastOp::verify() {
+ switch (getBroadcastType()) {
+ case BroadcastType::first_active_lane:
+ LLVM_FALLTHROUGH;
+ case BroadcastType::any_lane:
+ if (getLane())
+ return emitOpError()
+ << "lane can only be specified for `specific_lane` broadcast";
+ return success();
+ case BroadcastType::specific_lane:
+ if (!getLane())
+ return emitOpError()
+ << "lane must be specified for `specific_lane` broadcast";
+ return success();
+ }
+}
+
//===----------------------------------------------------------------------===//
// GPU KernelMetadataAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 21cb2f6..c766539 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -13,6 +13,7 @@
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/TransformOps/Utils.h"
@@ -43,6 +44,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/InterleavedRange.h"
#include "llvm/Support/LogicalResult.h"
+#include <optional>
#include <type_traits>
using namespace mlir;
@@ -170,7 +172,16 @@ void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) {
void transform::ApplyGPUPromoteShuffleToAMDGPUPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- populateGpuPromoteShuffleToAMDGPUPatterns(patterns);
+ std::optional<StringRef> chipsetName = getChipset();
+ std::optional<amdgpu::Chipset> maybeChipset;
+ if (chipsetName) {
+ FailureOr<amdgpu::Chipset> parsedChipset =
+ amdgpu::Chipset::parse(*chipsetName);
+ assert(llvm::succeeded(parsedChipset) && "expected valid chipset");
+ maybeChipset = parsedChipset;
+ }
+
+ populateGpuPromoteShuffleToAMDGPUPatterns(patterns, maybeChipset);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
index 9bf11c7..d2c2138 100644
--- a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
@@ -25,6 +25,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
namespace mlir {
#define GEN_PASS_DEF_GPUELIMINATEBARRIERS
@@ -37,9 +38,6 @@ using namespace mlir::gpu;
#define DEBUG_TYPE "gpu-erase-barriers"
#define DEBUG_TYPE_ALIAS "gpu-erase-barries-alias"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
-
// The functions below provide interface-like verification, but are too specific
// to barrier elimination to become interfaces.
@@ -424,27 +422,18 @@ static bool maybeCaptured(Value v) {
/// everything. This seems sufficient to achieve barrier removal in structured
/// control flow, more complex cases would require a proper dataflow analysis.
static bool mayAlias(Value first, Value second) {
- DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, {
- DBGS_ALIAS() << "checking aliasing between ";
- DBGS_ALIAS() << first << "\n";
- DBGS_ALIAS() << " and ";
- DBGS_ALIAS() << second << "\n";
- });
+ LDBG(DEBUG_TYPE_ALIAS, 1)
+ << "checking aliasing between " << first << " and " << second;
first = getBase(first);
second = getBase(second);
- DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, {
- DBGS_ALIAS() << "base ";
- DBGS_ALIAS() << first << "\n";
- DBGS_ALIAS() << " and ";
- DBGS_ALIAS() << second << "\n";
- });
+ LDBG(DEBUG_TYPE_ALIAS, 1) << "base " << first << " and " << second;
// Values derived from the same base memref do alias (unless we do a more
// advanced analysis to prove non-overlapping accesses).
if (first == second) {
- DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, DBGS_ALIAS() << "-> do alias!\n");
+ LDBG(DEBUG_TYPE_ALIAS, 1) << "-> do alias!";
return true;
}
@@ -493,7 +482,7 @@ static bool mayAlias(Value first, Value second) {
return false;
// Otherwise, conservatively assume aliasing.
- DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, DBGS_ALIAS() << "-> may alias!\n");
+ LDBG(DEBUG_TYPE_ALIAS, 1) << "-> may alias!";
return true;
}
@@ -567,20 +556,16 @@ haveConflictingEffects(ArrayRef<MemoryEffects::EffectInstance> beforeEffects,
continue;
// Other kinds of effects create a conflict, e.g. read-after-write.
- LLVM_DEBUG(
- DBGS() << "found a conflict between (before): " << before.getValue()
- << " read:" << isa<MemoryEffects::Read>(before.getEffect())
- << " write:" << isa<MemoryEffects::Write>(before.getEffect())
- << " alloc:"
- << isa<MemoryEffects::Allocate>(before.getEffect()) << " free:"
- << isa<MemoryEffects::Free>(before.getEffect()) << "\n");
- LLVM_DEBUG(
- DBGS() << "and (after): " << after.getValue()
- << " read:" << isa<MemoryEffects::Read>(after.getEffect())
- << " write:" << isa<MemoryEffects::Write>(after.getEffect())
- << " alloc:" << isa<MemoryEffects::Allocate>(after.getEffect())
- << " free:" << isa<MemoryEffects::Free>(after.getEffect())
- << "\n");
+ LDBG() << "found a conflict between (before): " << before.getValue()
+ << " read:" << isa<MemoryEffects::Read>(before.getEffect())
+ << " write:" << isa<MemoryEffects::Write>(before.getEffect())
+ << " alloc:" << isa<MemoryEffects::Allocate>(before.getEffect())
+ << " free:" << isa<MemoryEffects::Free>(before.getEffect());
+ LDBG() << "and (after): " << after.getValue()
+ << " read:" << isa<MemoryEffects::Read>(after.getEffect())
+ << " write:" << isa<MemoryEffects::Write>(after.getEffect())
+ << " alloc:" << isa<MemoryEffects::Allocate>(after.getEffect())
+ << " free:" << isa<MemoryEffects::Free>(after.getEffect());
return true;
}
}
@@ -595,8 +580,8 @@ public:
LogicalResult matchAndRewrite(BarrierOp barrier,
PatternRewriter &rewriter) const override {
- LLVM_DEBUG(DBGS() << "checking the necessity of: " << barrier << " "
- << barrier.getLoc() << "\n");
+ LDBG() << "checking the necessity of: " << barrier << " "
+ << barrier.getLoc();
SmallVector<MemoryEffects::EffectInstance> beforeEffects;
getEffectsBefore(barrier, beforeEffects, /*stopAtBarrier=*/true);
@@ -605,14 +590,12 @@ public:
getEffectsAfter(barrier, afterEffects, /*stopAtBarrier=*/true);
if (!haveConflictingEffects(beforeEffects, afterEffects)) {
- LLVM_DEBUG(DBGS() << "the surrounding barriers are sufficient, removing "
- << barrier << "\n");
+ LDBG() << "the surrounding barriers are sufficient, removing " << barrier;
rewriter.eraseOp(barrier);
return success();
}
- LLVM_DEBUG(DBGS() << "barrier is necessary: " << barrier << " "
- << barrier.getLoc() << "\n");
+ LDBG() << "barrier is necessary: " << barrier << " " << barrier.getLoc();
return failure();
}
};
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index 99f5c5b..97adad6 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -356,8 +356,8 @@ public:
auto funcWalkResult = func.walk([&](gpu::LaunchOp op) {
SetVector<Value> operands;
std::string kernelFnName;
- if (op.getKernelFunc()) {
- kernelFnName = op.getKernelFunc()->getRootReference().str();
+ if (op.getFunction()) {
+ kernelFnName = op.getFunction()->str();
} else {
kernelFnName =
Twine(op->getParentOfType<SymbolOpInterface>().getName(),
@@ -403,9 +403,8 @@ private:
OpBuilder builder(context);
std::string kernelModuleName;
gpu::GPUModuleOp kernelModule;
- if (gpuLaunchOp.getKernelModule()) {
- kernelModuleName =
- gpuLaunchOp.getKernelModule()->getRootReference().str();
+ if (gpuLaunchOp.getModule()) {
+ kernelModuleName = gpuLaunchOp.getModule()->str();
kernelModule =
parentSymbolTable.lookup<gpu::GPUModuleOp>(kernelModuleName);
} else {
@@ -432,8 +431,7 @@ private:
if (std::optional<SymbolTable::UseRange> symbolUses =
SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) {
for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
- StringRef symbolName =
- cast<FlatSymbolRefAttr>(symbolUse.getSymbolRef()).getValue();
+ StringAttr symbolName = symbolUse.getSymbolRef().getLeafReference();
if (symbolTable.lookup(symbolName))
continue;
diff --git a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
index 18c69f5..67cef8a 100644
--- a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
@@ -11,16 +11,21 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/PatternMatch.h"
+#include <optional>
using namespace mlir;
namespace {
+
+constexpr amdgpu::Chipset kGfx950 = amdgpu::Chipset(9, 5, 0);
+
/// Try to promote `gpu.shuffle` to `amdgpu.swizzle_bitmode`, width must be 64
/// and offset must be a constant integer in the range [0, 31].
struct PromoteShuffleToSwizzlePattern
@@ -56,9 +61,48 @@ struct PromoteShuffleToSwizzlePattern
return success();
}
};
+
+/// Try to promote `gpu.shuffle` to `amdgpu.permlane_swap`, width must be 64
+/// and offset must be a constant integer in the set {16, 32}.
+struct PromoteShuffleToPermlanePattern
+ : public OpRewritePattern<gpu::ShuffleOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(gpu::ShuffleOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getMode() != gpu::ShuffleMode::XOR)
+ return rewriter.notifyMatchFailure(op,
+ "only xor shuffle mode is supported");
+
+ if (!isConstantIntValue(op.getWidth(), 64))
+ return rewriter.notifyMatchFailure(op,
+ "only 64 width shuffle is supported");
+
+ std::optional<int64_t> offset = getConstantIntValue(op.getOffset());
+ if (!offset)
+ return rewriter.notifyMatchFailure(op,
+ "offset must be a constant integer");
+
+ int64_t offsetValue = *offset;
+ if (offsetValue != 16 && offsetValue != 32)
+ return rewriter.notifyMatchFailure(op, "offset must be either 15 or 31");
+
+ Location loc = op.getLoc();
+ Value res = amdgpu::PermlaneSwapOp::create(
+ rewriter, loc, op.getResult(0).getType(), op.getValue(), offsetValue);
+ Value valid = arith::ConstantIntOp::create(rewriter, loc, 1, /*width*/ 1);
+ rewriter.replaceOp(op, {res, valid});
+ return success();
+ }
+};
+
} // namespace
void mlir::populateGpuPromoteShuffleToAMDGPUPatterns(
- RewritePatternSet &patterns) {
- patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext());
+ RewritePatternSet &patterns, std::optional<amdgpu::Chipset> maybeChipset) {
+ patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext(),
+ /*benefit*/ 1);
+ if (maybeChipset && *maybeChipset >= kGfx950)
+ patterns.add<PromoteShuffleToPermlanePattern>(patterns.getContext(),
+ /*benefit*/ 2);
}
diff --git a/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp b/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp
index e9cf493..6da76e9 100644
--- a/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Target/LLVM/XeVM/Target.h"
#include "llvm/Support/Regex.h"
namespace mlir {
diff --git a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
index 384d1a0..88f531f 100644
--- a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
+++ b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Value.h"
+#include "llvm/ADT/DenseMap.h"
#include <numeric>
@@ -55,28 +56,30 @@ WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns(
SmallVector<size_t> &indices) const {
SmallVector<Type> types(warpOp.getResultTypes().begin(),
warpOp.getResultTypes().end());
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
- llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
- yield.getOperands().end());
+ gpu::YieldOp yield = warpOp.getTerminator();
+ SmallVector<Value> yieldValues(yield.getOperands().begin(),
+ yield.getOperands().end());
+ llvm::SmallDenseMap<Value, unsigned> indexLookup;
+ // Record the value -> first index mapping for faster lookup.
+ for (auto [i, v] : llvm::enumerate(yieldValues)) {
+ if (!indexLookup.count(v))
+ indexLookup[v] = i;
+ }
+
for (auto [value, type] : llvm::zip_equal(newYieldedValues, newReturnTypes)) {
- if (yieldValues.insert(value)) {
+ // If the value already exists in the yield, don't create a new output.
+ if (indexLookup.count(value)) {
+ indices.push_back(indexLookup[value]);
+ } else {
+ // If the value is new, add it to the yield and to the types.
+ yieldValues.push_back(value);
types.push_back(type);
indices.push_back(yieldValues.size() - 1);
- } else {
- // If the value already exit the region don't create a new output.
- for (auto [idx, yieldOperand] :
- llvm::enumerate(yieldValues.getArrayRef())) {
- if (yieldOperand == value) {
- indices.push_back(idx);
- break;
- }
- }
}
}
- yieldValues.insert_range(newYieldedValues);
+
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
- rewriter, warpOp, yieldValues.getArrayRef(), types);
+ rewriter, warpOp, yieldValues, types);
rewriter.replaceOp(warpOp,
newWarpOp.getResults().take_front(warpOp.getNumResults()));
return newWarpOp;
@@ -85,8 +88,7 @@ WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns(
OpOperand *WarpDistributionPattern::getWarpResult(
WarpExecuteOnLane0Op warpOp,
llvm::function_ref<bool(Operation *)> fn) const {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
for (OpOperand &yieldOperand : yield->getOpOperands()) {
Value yieldValues = yieldOperand.get();
Operation *definedOp = yieldValues.getDefiningOp();
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index ff55f17..ec581ac 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -32,6 +32,7 @@ add_mlir_dialect_library(MLIRLLVMDialect
MLIRInferTypeOpInterface
MLIRIR
MLIRMemorySlotInterfaces
+ MLIRPtrMemorySpaceInterfaces
MLIRSideEffectInterfaces
MLIRSupport
)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index 894de44..7220e10 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -12,10 +12,20 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/LogicalResult.h"
+#include "llvm/Support/Regex.h"
#define DEBUG_TYPE "ptx-builder"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
//===----------------------------------------------------------------------===//
// BasicPtxBuilderInterface
@@ -28,50 +38,122 @@ using namespace NVVM;
static constexpr int64_t kSharedMemorySpace = 3;
-static char getRegisterType(Type type) {
- if (type.isInteger(1))
- return 'b';
- if (type.isInteger(16))
- return 'h';
- if (type.isInteger(32))
- return 'r';
- if (type.isInteger(64))
- return 'l';
- if (type.isF32())
- return 'f';
- if (type.isF64())
- return 'd';
- if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
- // Shared address spaces is addressed with 32-bit pointers.
- if (ptr.getAddressSpace() == kSharedMemorySpace) {
+static FailureOr<char> getRegisterType(Type type, Location loc) {
+ MLIRContext *ctx = type.getContext();
+ auto i16 = IntegerType::get(ctx, 16);
+ auto i32 = IntegerType::get(ctx, 32);
+ auto f32 = Float32Type::get(ctx);
+
+ auto getRegisterTypeForScalar = [&](Type type) -> FailureOr<char> {
+ if (type.isInteger(1))
+ return 'b';
+ if (type.isInteger(16))
+ return 'h';
+ if (type.isInteger(32))
return 'r';
+ if (type.isInteger(64))
+ return 'l';
+ if (type.isF32())
+ return 'f';
+ if (type.isF64())
+ return 'd';
+ if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
+ // Shared address spaces is addressed with 32-bit pointers.
+ if (ptr.getAddressSpace() == kSharedMemorySpace) {
+ return 'r';
+ }
+ return 'l';
}
- return 'l';
+ // register type for struct is not supported.
+ mlir::emitError(
+ loc, "The register type could not be deduced from MLIR type. The ")
+ << type
+ << " is not supported. Supported types are:"
+ "i1, i16, i32, i64, f32, f64,"
+ "pointers.\nPlease use llvm.bitcast if you have different type. "
+ "\nSee the constraints from here: "
+ "https://docs.nvidia.com/cuda/inline-ptx-assembly/"
+ "index.html#constraints";
+ return failure();
+ };
+
+ // Packed registers
+ if (auto v = dyn_cast<VectorType>(type)) {
+ assert(v.getNumDynamicDims() == 0 && "Dynamic vectors are not supported");
+
+ int64_t lanes = v.getNumElements();
+ Type elem = v.getElementType();
+
+ // Case 1. Single vector
+ if (lanes <= 1)
+ return getRegisterTypeForScalar(elem);
+
+ // Case 2. Packed registers
+ Type widened = elem;
+ switch (lanes) {
+
+ case 2:
+ if (elem.isF16() || elem.isBF16()) // vector<2xf16>
+ widened = f32;
+ else if (elem.isFloat(8)) // vector<2xf8>
+ widened = i16;
+ break;
+ case 4:
+ if (elem.isInteger(8)) // vector<i8x4>
+ widened = i32;
+ else if (elem.isFloat(8)) // vector<f8x4>
+ widened = f32;
+ else if (elem.isFloat(4)) // vector<f4x4>
+ widened = i16;
+ break;
+ // Other packing is not supported
+ default:
+ break;
+ }
+ return getRegisterTypeForScalar(widened);
}
- // register type for struct is not supported.
- llvm_unreachable("The register type could not deduced from MLIR type");
- return '?';
+
+ return getRegisterTypeForScalar(type);
}
-static char getRegisterType(Value v) {
+static FailureOr<char> getRegisterType(Value v, Location loc) {
if (v.getDefiningOp<LLVM::ConstantOp>())
return 'n';
- return getRegisterType(v.getType());
+ return getRegisterType(v.getType(), loc);
}
-void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
- LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n");
+/// Extract every element of a struct value.
+static SmallVector<Value> extractStructElements(PatternRewriter &rewriter,
+ Location loc, Value structVal) {
+ auto structTy = dyn_cast<LLVM::LLVMStructType>(structVal.getType());
+ assert(structTy && "expected LLVM struct");
+
+ SmallVector<Value> elems;
+ for (unsigned i : llvm::seq<unsigned>(0, structTy.getBody().size()))
+ elems.push_back(rewriter.create<LLVM::ExtractValueOp>(loc, structVal, i));
+
+ return elems;
+}
+
+LogicalResult PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
+ LDBG() << v << "\t Modifier : " << itype << "\n";
+ registerModifiers.push_back(itype);
+
+ Location loc = interfaceOp->getLoc();
auto getModifier = [&]() -> const char * {
- if (itype == PTXRegisterMod::ReadWrite) {
- assert(false && "Read-Write modifier is not supported. Try setting the "
- "same value as Write and Read separately.");
- return "+";
- }
- if (itype == PTXRegisterMod::Write) {
+ switch (itype) {
+ case PTXRegisterMod::Read:
+ return "";
+ case PTXRegisterMod::Write:
return "=";
+ case PTXRegisterMod::ReadWrite:
+ // "Read-Write modifier is not actually supported
+ // Interface will change it to "=" later and add integer mapping
+ return "+";
}
- return "";
+ llvm_unreachable("Unknown PTX register modifier");
};
+
auto addValue = [&](Value v) {
if (itype == PTXRegisterMod::Read) {
ptxOperands.push_back(v);
@@ -90,35 +172,273 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
}
for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
if (itype != PTXRegisterMod::Write) {
- Value extractValue = LLVM::ExtractValueOp::create(
- rewriter, interfaceOp->getLoc(), v, idx);
+ Value extractValue =
+ LLVM::ExtractValueOp::create(rewriter, loc, v, idx);
addValue(extractValue);
}
if (itype == PTXRegisterMod::ReadWrite) {
ss << idx << ",";
} else {
- ss << getModifier() << getRegisterType(t) << ",";
+ FailureOr<char> regType = getRegisterType(t, loc);
+ if (failed(regType))
+ return rewriter.notifyMatchFailure(loc,
+ "failed to get register type");
+ ss << getModifier() << regType.value() << ",";
}
}
- return;
+ return success();
}
// Handle Scalars
addValue(v);
- ss << getModifier() << getRegisterType(v) << ",";
+ FailureOr<char> regType = getRegisterType(v, loc);
+ if (failed(regType))
+ return rewriter.notifyMatchFailure(loc, "failed to get register type");
+ ss << getModifier() << regType.value() << ",";
+ return success();
+}
+
+/// Check if the operation needs to pack and unpack results.
+static bool
+needsPackUnpack(BasicPtxBuilderInterface interfaceOp,
+ bool needsManualRegisterMapping,
+ SmallVectorImpl<PTXRegisterMod> &registerModifiers) {
+ if (needsManualRegisterMapping)
+ return false;
+ const unsigned writeOnlyVals = interfaceOp->getNumResults();
+ const unsigned readWriteVals =
+ llvm::count_if(registerModifiers, [](PTXRegisterMod m) {
+ return m == PTXRegisterMod::ReadWrite;
+ });
+ return (writeOnlyVals + readWriteVals) > 1;
+}
+
+/// Pack the result types of the interface operation.
+/// If the operation has multiple results, it packs them into a struct
+/// type. Otherwise, it returns the original result types.
+static SmallVector<Type>
+packResultTypes(BasicPtxBuilderInterface interfaceOp,
+ bool needsManualRegisterMapping,
+ SmallVectorImpl<PTXRegisterMod> &registerModifiers,
+ SmallVectorImpl<Value> &ptxOperands) {
+ MLIRContext *ctx = interfaceOp->getContext();
+ TypeRange resultRange = interfaceOp->getResultTypes();
+
+ if (!needsPackUnpack(interfaceOp, needsManualRegisterMapping,
+ registerModifiers)) {
+ // Single value path:
+ if (interfaceOp->getResults().size() == 1)
+ return SmallVector<Type>{resultRange.front()};
+
+ // No declared results: if there is an RW, forward its type.
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+ if (m == PTXRegisterMod::ReadWrite)
+ return SmallVector<Type>{v.getType()};
+ }
+
+ SmallVector<Type> packed;
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+ if (m == PTXRegisterMod::ReadWrite)
+ packed.push_back(v.getType());
+ for (Type t : resultRange)
+ packed.push_back(t);
+
+ if (packed.empty())
+ return {};
+
+ auto sTy = LLVM::LLVMStructType::getLiteral(ctx, packed, /*isPacked=*/false);
+ return SmallVector<Type>{sTy};
+}
+
+/// Canonicalize the register constraints:
+/// - Turn every "+X" into "=X"
+/// - Append (at the very end) the 0-based indices of tokens that were "+X"
+/// Examples:
+/// "+f,+f,+r,=r,=r,r,r" -> "=f,=f,=r,=r,=r,r,r,0,1,2"
+/// "+f,+f,+r,=r,=r" -> "=f,=f,=r,=r,=r,0,1,2"
+static std::string canonicalizeRegisterConstraints(llvm::StringRef csv) {
+ SmallVector<llvm::StringRef> toks;
+ SmallVector<std::string> out;
+ SmallVector<unsigned> plusIdx;
+
+ csv.split(toks, ',');
+ out.reserve(toks.size() + 8);
+
+ for (unsigned i = 0, e = toks.size(); i < e; ++i) {
+ StringRef t = toks[i].trim();
+ if (t.consume_front("+")) {
+ plusIdx.push_back(i);
+ out.push_back(("=" + t).str());
+ } else {
+ out.push_back(t.str());
+ }
+ }
+
+ // Append indices of original "+X" tokens.
+ for (unsigned idx : plusIdx)
+ out.push_back(std::to_string(idx));
+
+ // Join back to CSV.
+ std::string result;
+ result.reserve(csv.size() + plusIdx.size() * 2);
+ llvm::raw_string_ostream os(result);
+ for (size_t i = 0; i < out.size(); ++i) {
+ if (i)
+ os << ',';
+ os << out[i];
+ }
+ return os.str();
+}
+
+constexpr llvm::StringLiteral kReadWritePrefix{"rw"};
+constexpr llvm::StringLiteral kWriteOnlyPrefix{"w"};
+constexpr llvm::StringLiteral kReadOnlyPrefix{"r"};
+
+/// Returns a regex that matches {$rwN}, {$wN}, {$rN}
+static llvm::Regex getPredicateMappingRegex() {
+ llvm::Regex rx(llvm::formatv(R"(\{\$({0}|{1}|{2})([0-9]+)\})",
+ kReadWritePrefix, kWriteOnlyPrefix,
+ kReadOnlyPrefix)
+ .str());
+ return rx;
+}
+
+void mlir::NVVM::countPlaceholderNumbers(
+ StringRef ptxCode, llvm::SmallDenseSet<unsigned int> &seenRW,
+ llvm::SmallDenseSet<unsigned int> &seenW,
+ llvm::SmallDenseSet<unsigned int> &seenR,
+ llvm::SmallVectorImpl<unsigned int> &rwNums,
+ llvm::SmallVectorImpl<unsigned int> &wNums,
+ llvm::SmallVectorImpl<unsigned int> &rNums) {
+
+ llvm::Regex rx = getPredicateMappingRegex();
+ StringRef rest = ptxCode;
+
+ SmallVector<StringRef, 3> m; // 0: full, 1: kind, 2: number
+ while (!rest.empty() && rx.match(rest, &m)) {
+ unsigned num = 0;
+ (void)m[2].getAsInteger(10, num);
+ // Insert it into the vector only the first time we see this number
+ if (m[1].equals_insensitive(kReadWritePrefix)) {
+ if (seenRW.insert(num).second)
+ rwNums.push_back(num);
+ } else if (m[1].equals_insensitive(kWriteOnlyPrefix)) {
+ if (seenW.insert(num).second)
+ wNums.push_back(num);
+ } else {
+ if (seenR.insert(num).second)
+ rNums.push_back(num);
+ }
+
+ const size_t advance = (size_t)(m[0].data() - rest.data()) + m[0].size();
+ rest = rest.drop_front(advance);
+ }
+}
+
+/// Rewrites `{$rwN}`, `{$wN}`, and `{$rN}` placeholders in `ptxCode` into
+/// compact `$K` indices:
+/// - All `rw*` first (sorted by N),
+/// - Then `w*`,
+/// - Then `r*`.
+/// If there a predicate, it comes always in the end.
+/// Each number is assigned once; duplicates are ignored.
+///
+/// Example Input:
+/// "{
+/// reg .pred p;
+/// setp.ge.s32 p, {$r0}, {$r1};"
+/// selp.s32 {$rw0}, {$r0}, {$r1}, p;
+/// selp.s32 {$rw1}, {$r0}, {$r1}, p;
+/// selp.s32 {$w0}, {$r0}, {$r1}, p;
+/// selp.s32 {$w1}, {$r0}, {$r1}, p;
+/// }\n"
+/// Example Output:
+/// "{
+/// reg .pred p;
+/// setp.ge.s32 p, $4, $5;"
+/// selp.s32 $0, $4, $5, p;
+/// selp.s32 $1, $4, $5, p;
+/// selp.s32 $2, $4, $5, p;
+/// selp.s32 $3, $4, $5, p;
+/// }\n"
+static std::string rewriteAsmPlaceholders(llvm::StringRef ptxCode) {
+ llvm::SmallDenseSet<unsigned> seenRW, seenW, seenR;
+ llvm::SmallVector<unsigned> rwNums, wNums, rNums;
+
+ // Step 1. Count Register Placeholder numbers
+ countPlaceholderNumbers(ptxCode, seenRW, seenW, seenR, rwNums, wNums, rNums);
+
+ // Step 2. Sort the Register Placeholder numbers
+ llvm::sort(rwNums);
+ llvm::sort(wNums);
+ llvm::sort(rNums);
+
+ // Step 3. Create mapping from original to new IDs
+ llvm::DenseMap<unsigned, unsigned> rwMap, wMap, rMap;
+ unsigned nextId = 0;
+ for (unsigned n : rwNums)
+ rwMap[n] = nextId++;
+ for (unsigned n : wNums)
+ wMap[n] = nextId++;
+ for (unsigned n : rNums)
+ rMap[n] = nextId++;
+
+ // Step 4. Rewrite the PTX code with new IDs
+ std::string out;
+ out.reserve(ptxCode.size());
+ size_t prev = 0;
+ StringRef rest = ptxCode;
+ SmallVector<StringRef, 3> matches;
+ llvm::Regex rx = getPredicateMappingRegex();
+ while (!rest.empty() && rx.match(rest, &matches)) {
+ // Compute absolute match bounds in the original buffer.
+ size_t absStart = (size_t)(matches[0].data() - ptxCode.data());
+ size_t absEnd = absStart + matches[0].size();
+
+ // Emit text before the match.
+ out.append(ptxCode.data() + prev, ptxCode.data() + absStart);
+
+ // Emit compact $K
+ unsigned num = 0;
+ (void)matches[2].getAsInteger(10, num);
+ unsigned id = 0;
+ if (matches[1].equals_insensitive(kReadWritePrefix))
+ id = rwMap.lookup(num);
+ else if (matches[1].equals_insensitive(kWriteOnlyPrefix))
+ id = wMap.lookup(num);
+ else
+ id = rMap.lookup(num);
+
+ out.push_back('$');
+ out += std::to_string(id);
+
+ prev = absEnd;
+
+ const size_t advance =
+ (size_t)(matches[0].data() - rest.data()) + matches[0].size();
+ rest = rest.drop_front(advance);
+ }
+
+ // Step 5. Tail.
+ out.append(ptxCode.data() + prev, ptxCode.data() + ptxCode.size());
+ return out;
}
LLVM::InlineAsmOp PtxBuilder::build() {
auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
LLVM::AsmDialect::AD_ATT);
- auto resultTypes = interfaceOp->getResultTypes();
+ SmallVector<Type> resultTypes = packResultTypes(
+ interfaceOp, needsManualRegisterMapping, registerModifiers, ptxOperands);
// Remove the last comma from the constraints string.
if (!registerConstraints.empty() &&
registerConstraints[registerConstraints.size() - 1] == ',')
registerConstraints.pop_back();
+ registerConstraints = canonicalizeRegisterConstraints(registerConstraints);
std::string ptxInstruction = interfaceOp.getPtx();
+ if (!needsManualRegisterMapping)
+ ptxInstruction = rewriteAsmPlaceholders(ptxInstruction);
// Add the predicate to the asm string.
if (interfaceOp.getPredicate().has_value() &&
@@ -136,7 +456,7 @@ LLVM::InlineAsmOp PtxBuilder::build() {
rewriter, interfaceOp->getLoc(),
/*result types=*/resultTypes,
/*operands=*/ptxOperands,
- /*asm_string=*/llvm::StringRef(ptxInstruction),
+ /*asm_string=*/ptxInstruction,
/*constraints=*/registerConstraints.data(),
/*has_side_effects=*/interfaceOp.hasSideEffect(),
/*is_align_stack=*/false, LLVM::TailCallKind::None,
@@ -146,10 +466,89 @@ LLVM::InlineAsmOp PtxBuilder::build() {
void PtxBuilder::buildAndReplaceOp() {
LLVM::InlineAsmOp inlineAsmOp = build();
- LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
- if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) {
- rewriter.replaceOp(interfaceOp, inlineAsmOp);
- } else {
+ LDBG() << "\n Generated PTX \n\t" << inlineAsmOp;
+
+ // Case 0: no result at all → just erase wrapper op.
+ if (!hasResult) {
rewriter.eraseOp(interfaceOp);
+ return;
+ }
+
+ if (needsManualRegisterMapping) {
+ rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
+ return;
+ }
+
+ // Case 1: Simple path, return single scalar
+ if (!needsPackUnpack(interfaceOp, needsManualRegisterMapping,
+ registerModifiers)) {
+ if (inlineAsmOp->getNumResults() > 0) {
+ rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
+ } else {
+ // RW-only case with no declared results: forward the RW value.
+ SmallVector<Value> results;
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+ if (m == PTXRegisterMod::ReadWrite) {
+ results.push_back(v);
+ break;
+ }
+ rewriter.replaceOp(interfaceOp, results);
+ }
+ return;
+ }
+
+ const bool hasRW = llvm::any_of(registerModifiers, [](PTXRegisterMod m) {
+ return m == PTXRegisterMod::ReadWrite;
+ });
+
+ // All multi-value paths produce a single struct result we need to unpack.
+ assert(LLVM::LLVMStructType::classof(inlineAsmOp.getResultTypes().front()) &&
+ "expected struct return for multi-result inline asm");
+ Value structVal = inlineAsmOp.getResult(0);
+ SmallVector<Value> unpacked =
+ extractStructElements(rewriter, interfaceOp->getLoc(), structVal);
+
+ // Case 2: only declared results (no RW): replace the op with all unpacked.
+ if (!hasRW && interfaceOp->getResults().size() > 0) {
+ rewriter.replaceOp(interfaceOp, unpacked);
+ return;
+ }
+
+ // Case 3: RW-only (no declared results): update RW uses and erase wrapper.
+ if (hasRW && interfaceOp->getResults().size() == 0) {
+ unsigned idx = 0;
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) {
+ if (m != PTXRegisterMod::ReadWrite)
+ continue;
+ Value repl = unpacked[idx++];
+ v.replaceUsesWithIf(repl, [&](OpOperand &use) {
+ Operation *owner = use.getOwner();
+ return owner != interfaceOp && owner != inlineAsmOp;
+ });
+ }
+ rewriter.eraseOp(interfaceOp);
+ return;
+ }
+
+ // Case 4: mixed (RW + declared results).
+ {
+ // First rewrite RW operands in place.
+ unsigned idx = 0;
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) {
+ if (m != PTXRegisterMod::ReadWrite)
+ continue;
+ Value repl = unpacked[idx++];
+ v.replaceUsesWithIf(repl, [&](OpOperand &use) {
+ Operation *owner = use.getOwner();
+ return owner != interfaceOp && owner != inlineAsmOp;
+ });
+ }
+ // The remaining unpacked values correspond to the declared results.
+ SmallVector<Value> tail;
+ tail.reserve(unpacked.size() - idx);
+ for (unsigned i = idx, e = unpacked.size(); i < e; ++i)
+ tail.push_back(unpacked[i]);
+
+ rewriter.replaceOp(interfaceOp, tail);
}
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index 1e02bfe..e268e8f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -12,6 +12,8 @@
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/Ptr/IR/PtrEnums.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
@@ -51,6 +53,87 @@ void LLVMDialect::registerAttributes() {
}
//===----------------------------------------------------------------------===//
+// AddressSpaceAttr
+//===----------------------------------------------------------------------===//
+
+/// Checks whether the given type is an LLVM type that can be loaded or stored.
+static bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering,
+ std::optional<int64_t> alignment,
+ const ::mlir::DataLayout *dataLayout,
+ function_ref<InFlightDiagnostic()> emitError) {
+ if (!isLoadableType(type)) {
+ if (emitError)
+ emitError() << "type must be LLVM type with size, but got " << type;
+ return false;
+ }
+ if (ordering == ptr::AtomicOrdering::not_atomic)
+ return true;
+
+ // To check atomic validity we need a datalayout.
+ if (!dataLayout) {
+ if (emitError)
+ emitError() << "expected a valid data layout";
+ return false;
+ }
+ if (!isTypeCompatibleWithAtomicOp(type, *dataLayout)) {
+ if (emitError)
+ emitError() << "unsupported type " << type << " for atomic access";
+ return false;
+ }
+ return true;
+}
+
+bool AddressSpaceAttr::isValidLoad(
+ Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
+ const ::mlir::DataLayout *dataLayout,
+ function_ref<InFlightDiagnostic()> emitError) const {
+ return isValidLoadStoreImpl(type, ordering, alignment, dataLayout, emitError);
+}
+
+bool AddressSpaceAttr::isValidStore(
+ Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
+ const ::mlir::DataLayout *dataLayout,
+ function_ref<InFlightDiagnostic()> emitError) const {
+ return isValidLoadStoreImpl(type, ordering, alignment, dataLayout, emitError);
+}
+
+bool AddressSpaceAttr::isValidAtomicOp(
+ ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering,
+ std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
+ function_ref<InFlightDiagnostic()> emitError) const {
+ // TODO: update this method once `ptr.atomic_rmw` is implemented.
+ assert(false && "unimplemented, see TODO in the source.");
+ return false;
+}
+
+bool AddressSpaceAttr::isValidAtomicXchg(
+ Type type, ptr::AtomicOrdering successOrdering,
+ ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
+ const ::mlir::DataLayout *dataLayout,
+ function_ref<InFlightDiagnostic()> emitError) const {
+ // TODO: update this method once `ptr.atomic_cmpxchg` is implemented.
+ assert(false && "unimplemented, see TODO in the source.");
+ return false;
+}
+
+bool AddressSpaceAttr::isValidAddrSpaceCast(
+ Type tgt, Type src, function_ref<InFlightDiagnostic()> emitError) const {
+ // TODO: update this method once the `ptr.addrspace_cast` op is added to the
+ // dialect.
+ assert(false && "unimplemented, see TODO in the source.");
+ return false;
+}
+
+bool AddressSpaceAttr::isValidPtrIntCast(
+ Type intLikeTy, Type ptrLikeTy,
+ function_ref<InFlightDiagnostic()> emitError) const {
+ // TODO: update this method once the int-cast ops are added to the `ptr`
+ // dialect.
+ assert(false && "unimplemented, see TODO in the source.");
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
// AliasScopeAttr
//===----------------------------------------------------------------------===//
@@ -374,6 +457,43 @@ TargetFeaturesAttr TargetFeaturesAttr::featuresAt(Operation *op) {
getAttributeName());
}
+FailureOr<Attribute> TargetFeaturesAttr::query(DataLayoutEntryKey key) {
+ auto stringKey = dyn_cast<StringAttr>(key);
+ if (!stringKey)
+ return failure();
+
+ if (contains(stringKey))
+ return UnitAttr::get(getContext());
+
+ if (contains((std::string("+") + stringKey.strref()).str()))
+ return BoolAttr::get(getContext(), true);
+
+ if (contains((std::string("-") + stringKey.strref()).str()))
+ return BoolAttr::get(getContext(), false);
+
+ return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// TargetAttr
+//===----------------------------------------------------------------------===//
+
+FailureOr<::mlir::Attribute> TargetAttr::query(DataLayoutEntryKey key) {
+ if (auto stringAttrKey = dyn_cast<StringAttr>(key)) {
+ if (stringAttrKey.getValue() == "triple")
+ return getTriple();
+ if (stringAttrKey.getValue() == "chip")
+ return getChip();
+ if (stringAttrKey.getValue() == "features" && getFeatures())
+ return getFeatures();
+ }
+ return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// ModuleFlagAttr
+//===----------------------------------------------------------------------===//
+
LogicalResult
ModuleFlagAttr::verify(function_ref<InFlightDiagnostic()> emitError,
LLVM::ModFlagBehavior flagBehavior, StringAttr key,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 422039f..ef27070 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -141,6 +141,38 @@ static ParseResult parseLLVMLinkage(OpAsmParser &p, LinkageAttr &val) {
return success();
}
+static ArrayAttr getLLVMAlignParamForCompressExpand(OpBuilder &builder,
+ bool isExpandLoad,
+ uint64_t alignment = 1) {
+ // From
+ // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
+ // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
+ //
+ // The pointer alignment defaults to 1.
+ if (alignment == 1) {
+ return nullptr;
+ }
+
+ auto emptyDictAttr = builder.getDictionaryAttr({});
+ auto alignmentAttr = builder.getI64IntegerAttr(alignment);
+ auto namedAttr =
+ builder.getNamedAttr(LLVMDialect::getAlignAttrName(), alignmentAttr);
+ SmallVector<mlir::NamedAttribute> attrs = {namedAttr};
+ auto alignDictAttr = builder.getDictionaryAttr(attrs);
+ // From
+ // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
+ // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
+ //
+ // The align parameter attribute can be provided for [expandload]'s first
+ // argument. The align parameter attribute can be provided for
+ // [compressstore]'s second argument.
+ int pos = isExpandLoad ? 0 : 1;
+ return pos == 0 ? builder.getArrayAttr(
+ {alignDictAttr, emptyDictAttr, emptyDictAttr})
+ : builder.getArrayAttr(
+ {emptyDictAttr, alignDictAttr, emptyDictAttr});
+}
+
//===----------------------------------------------------------------------===//
// Operand bundle helpers.
//===----------------------------------------------------------------------===//
@@ -821,8 +853,8 @@ void LoadOp::getEffects(
/// Returns true if the given type is supported by atomic operations. All
/// integer, float, and pointer types with a power-of-two bitsize and a minimal
/// size of 8 bits are supported.
-static bool isTypeCompatibleWithAtomicOp(Type type,
- const DataLayout &dataLayout) {
+bool LLVM::isTypeCompatibleWithAtomicOp(Type type,
+ const DataLayout &dataLayout) {
if (!isa<IntegerType, LLVMPointerType>(type))
if (!isCompatibleFloatingPointType(type))
return false;
@@ -836,8 +868,9 @@ static bool isTypeCompatibleWithAtomicOp(Type type,
/// Verifies the attributes and the type of atomic memory access operations.
template <typename OpTy>
-LogicalResult verifyAtomicMemOp(OpTy memOp, Type valueType,
- ArrayRef<AtomicOrdering> unsupportedOrderings) {
+static LogicalResult
+verifyAtomicMemOp(OpTy memOp, Type valueType,
+ ArrayRef<AtomicOrdering> unsupportedOrderings) {
if (memOp.getOrdering() != AtomicOrdering::not_atomic) {
DataLayout dataLayout = DataLayout::closest(memOp);
if (!isTypeCompatibleWithAtomicOp(valueType, dataLayout))
@@ -1087,7 +1120,7 @@ static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee) {
/// Verify that the parameter and return types of the variadic callee type match
/// the `callOp` argument and result types.
template <typename OpTy>
-LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
+static LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
std::optional<LLVMFunctionType> varCalleeType = callOp.getVarCalleeType();
if (!varCalleeType)
return success();
@@ -2500,7 +2533,7 @@ LogicalResult GlobalOp::verifyRegions() {
// LLVM::GlobalCtorsOp
//===----------------------------------------------------------------------===//
-LogicalResult checkGlobalXtorData(Operation *op, ArrayAttr data) {
+static LogicalResult checkGlobalXtorData(Operation *op, ArrayAttr data) {
if (data.empty())
return success();
@@ -4117,6 +4150,32 @@ LogicalResult LLVM::masked_scatter::verify() {
}
//===----------------------------------------------------------------------===//
+// masked_expandload (intrinsic)
+//===----------------------------------------------------------------------===//
+
+void LLVM::masked_expandload::build(OpBuilder &builder, OperationState &state,
+ mlir::TypeRange resTys, Value ptr,
+ Value mask, Value passthru,
+ uint64_t align) {
+ ArrayAttr argAttrs = getLLVMAlignParamForCompressExpand(builder, true, align);
+ build(builder, state, resTys, ptr, mask, passthru, /*arg_attrs=*/argAttrs,
+ /*res_attrs=*/nullptr);
+}
+
+//===----------------------------------------------------------------------===//
+// masked_compressstore (intrinsic)
+//===----------------------------------------------------------------------===//
+
+void LLVM::masked_compressstore::build(OpBuilder &builder,
+ OperationState &state, Value value,
+ Value ptr, Value mask, uint64_t align) {
+ ArrayAttr argAttrs =
+ getLLVMAlignParamForCompressExpand(builder, false, align);
+ build(builder, state, value, ptr, mask, /*arg_attrs=*/argAttrs,
+ /*res_attrs=*/nullptr);
+}
+
+//===----------------------------------------------------------------------===//
// InlineAsmOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index e7d5dad..ef38027 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -19,6 +19,7 @@
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "sroa"
@@ -734,9 +735,8 @@ static std::optional<uint64_t> gepToByteOffset(const DataLayout &dataLayout,
return false;
})
.Default([&](Type type) {
- LLVM_DEBUG(llvm::dbgs()
- << "[sroa] Unsupported type for offset computations"
- << type << "\n");
+ LDBG() << "[sroa] Unsupported type for offset computations"
+ << type;
return true;
});
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index 78b4411..297640c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -24,7 +24,9 @@ using namespace mlir::LLVM;
/// prints it as usual.
static void dispatchPrint(AsmPrinter &printer, Type type) {
if (isCompatibleType(type) &&
- !llvm::isa<IntegerType, FloatType, VectorType>(type))
+ !(llvm::isa<IntegerType, FloatType, VectorType>(type) ||
+ (llvm::isa<PtrLikeTypeInterface>(type) &&
+ !llvm::isa<LLVMPointerType>(type))))
return mlir::LLVM::detail::printType(type, printer);
printer.printType(type);
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index fee2d3e..2dd0132 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -13,6 +13,7 @@
#include "TypeDetail.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -701,6 +702,17 @@ const llvm::fltSemantics &LLVMPPCFP128Type::getFloatSemantics() const {
// Utility functions.
//===----------------------------------------------------------------------===//
+/// Check whether type is a compatible ptr type. These are pointer-like types
+/// with no element type, no metadata, and using the LLVM AddressSpaceAttr
+/// memory space.
+static bool isCompatiblePtrType(Type type) {
+ auto ptrTy = dyn_cast<PtrLikeTypeInterface>(type);
+ if (!ptrTy)
+ return false;
+ return !ptrTy.hasPtrMetadata() && ptrTy.getElementType() == nullptr &&
+ isa<AddressSpaceAttr>(ptrTy.getMemorySpace());
+}
+
bool mlir::LLVM::isCompatibleOuterType(Type type) {
// clang-format off
if (llvm::isa<
@@ -734,7 +746,7 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) {
if (auto vecType = llvm::dyn_cast<VectorType>(type))
return vecType.getRank() == 1;
- return false;
+ return isCompatiblePtrType(type);
}
static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
@@ -784,6 +796,8 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
LLVMX86AMXType
>([](Type) { return true; })
// clang-format on
+ .Case<PtrLikeTypeInterface>(
+ [](Type type) { return isCompatiblePtrType(type); })
.Default([](Type) { return false; });
if (!result)
@@ -805,6 +819,18 @@ bool mlir::LLVM::isCompatibleType(Type type) {
return LLVMDialect::isCompatibleType(type);
}
+bool mlir::LLVM::isLoadableType(Type type) {
+ return /*LLVM_PrimitiveType*/ (
+ LLVM::isCompatibleOuterType(type) &&
+ !isa<LLVM::LLVMVoidType, LLVM::LLVMFunctionType>(type)) &&
+ /*LLVM_OpaqueStruct*/
+ !(isa<LLVM::LLVMStructType>(type) &&
+ cast<LLVM::LLVMStructType>(type).isOpaque()) &&
+ /*LLVM_AnyTargetExt*/
+ !(isa<LLVM::LLVMTargetExtType>(type) &&
+ !cast<LLVM::LLVMTargetExtType>(type).supportsMemOps());
+}
+
bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {
return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
Float80Type, Float128Type, LLVMPPCFP128Type>(type);
@@ -818,7 +844,8 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) {
if (auto intType = llvm::dyn_cast<IntegerType>(elementType))
return intType.isSignless();
return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
- Float80Type, Float128Type, LLVMPointerType>(elementType);
+ Float80Type, Float128Type, LLVMPointerType>(elementType) ||
+ isCompatiblePtrType(elementType);
}
return false;
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index e0977f5..77ec1eb 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -33,6 +33,7 @@
#include "llvm/IR/IRBuilder.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/NVPTXAddrSpace.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <optional>
@@ -50,7 +51,6 @@ using namespace NVVM;
// This verifier is shared among the following Ops:
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
-// CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
bool isIm2Col,
@@ -82,8 +82,27 @@ LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
}
LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
- if (getCoordinates().size() > 5)
- return emitError("Maximum 5 coordinates and dimension is supported.");
+ TMAStoreMode mode = getMode();
+ // We lower through inline-ptx when getPredicate() is true.
+ // a) Only TILE mode is supported
+ // b) Cache-hint is not supported
+ if (getPredicate()) {
+ if (mode != TMAStoreMode::TILE)
+ return emitError("Inline-ptx lowering supported only for Tile mode.");
+ if (getL2CacheHint())
+ return emitError("Inline-ptx lowering unsupported with L2 cache-hint.");
+ }
+
+ size_t dims = getCoordinates().size();
+ switch (mode) {
+ case TMAStoreMode::TILE:
+ return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
+ case TMAStoreMode::IM2COL:
+ return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
+ case TMAStoreMode::TILE_SCATTER4:
+ if (dims != 5)
+ return emitError("Scatter4 mode expects 5 coordinates");
+ }
return success();
}
@@ -98,17 +117,59 @@ LogicalResult CpAsyncOp::verify() {
return success();
}
+// This verify params can be shared across TMA Load and Prefetch Ops.
+static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff,
+ TMALoadMode mode, Location loc) {
+ if (tensorDims < 1 || tensorDims > 5)
+ return emitError(loc, "expects coordinates between 1 to 5 dimension");
+
+ auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col,
+ size_t expectedIm2colOff) -> LogicalResult {
+ if (isIm2col && (tensorDims < 3))
+ return emitError(loc)
+ << "to use " << stringifyEnum(mode)
+ << " mode, the tensor has to be at least 3-dimensional";
+
+ if (numIm2colOff != expectedIm2colOff)
+ return emitError(loc) << " im2col offsets expected " << expectedIm2colOff
+ << " (provided " << numIm2colOff << ")";
+
+ return success();
+ };
+
+ switch (mode) {
+ case TMALoadMode::TILE:
+ return checkTMALoadParams(mode, false, 0);
+ case TMALoadMode::IM2COL:
+ return checkTMALoadParams(mode, true, tensorDims - 2);
+ case TMALoadMode::IM2COL_W:
+ case TMALoadMode::IM2COL_W_128:
+ return checkTMALoadParams(mode, true, 2);
+ case TMALoadMode::TILE_GATHER4:
+ return (tensorDims == 5)
+ ? checkTMALoadParams(mode, false, 0)
+ : emitError(loc, "Gather4 mode expects 5 coordinates");
+ }
+ return success();
+}
+
LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
- size_t numIm2ColOffsets = getIm2colOffsets().size();
- bool isIm2Col = numIm2ColOffsets > 0;
- return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
- numIm2ColOffsets, getLoc());
+ return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
+ getMode(), getLoc());
}
LogicalResult CpAsyncBulkTensorReduceOp::verify() {
- bool isIm2Col = (getMode() == TMAStoreMode::IM2COL);
- return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 0,
- getLoc());
+ TMAStoreMode mode = getMode();
+ size_t dims = getCoordinates().size();
+ switch (mode) {
+ case TMAStoreMode::TILE:
+ return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
+ case TMAStoreMode::IM2COL:
+ return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
+ case TMAStoreMode::TILE_SCATTER4:
+ return emitError("Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
+ }
+ return success();
}
LogicalResult ConvertFloatToTF32Op::verify() {
@@ -189,6 +250,26 @@ LogicalResult BulkStoreOp::verify() {
return success();
}
+LogicalResult PMEventOp::verify() {
+ auto eventId = getEventId();
+ auto maskedEventId = getMaskedEventId();
+ if (!maskedEventId && !eventId) {
+ return emitOpError() << "either `id` or `mask` must be set";
+ }
+
+ if (maskedEventId && eventId) {
+ return emitOpError() << "`id` and `mask` cannot be set at the same time";
+ }
+
+ if (eventId) {
+ if (eventId < 0 || eventId > 15) {
+ return emitOpError() << "`id` must be between 0 and 15";
+ }
+ }
+
+ return llvm::success();
+}
+
// Given the element type of an operand and whether or not it is an accumulator,
// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
// operand's element type.
@@ -791,24 +872,58 @@ LogicalResult NVVM::WMMAMmaOp::verify() {
}
LogicalResult NVVM::LdMatrixOp::verify() {
- unsigned addressSpace =
- llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
- if (addressSpace != NVVM::kSharedMemorySpace)
- return emitOpError("expected source pointer in memory space 3");
-
- if (getNum() != 1 && getNum() != 2 && getNum() != 4)
- return emitOpError("expected num attribute to be 1, 2 or 4");
+ uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN();
+ if (m == 8 && n == 8) {
+ if (num != 1 && num != 2 && num != 4) {
+ return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 "
+ "matrix");
+ }
+ if (getEltType() != LdStMatrixEltType::B16) {
+ return emitOpError("expected element type to be b16 for 8x8 matrix");
+ }
+ } else if (m == 8 && n == 16) {
+ if (num != 1 && num != 2 && num != 4) {
+ return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 "
+ "matrix");
+ }
+ if (getLayout() != MMALayout::row) {
+ return emitOpError("expected layout to be row for 8x16 matrix");
+ }
+ if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
+ getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
+ return emitOpError("expected element type to be b8x16.b4x16_p64 or "
+ "b8x16.b6x16_p32 for 8x16 matrix");
+ }
+ } else if (m == 16 && n == 16) {
+ if (num != 1 && num != 2) {
+ return emitOpError("expected num attribute to be 1 or 2 for 16x16 "
+ "matrix");
+ }
+ if (getLayout() != MMALayout::col) {
+ return emitOpError("expected layout to be col for 16x16 matrix");
+ }
+ if (getEltType() != LdStMatrixEltType::B8 &&
+ getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
+ getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
+ return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or "
+ "b8x16.b6x16_p32 for 16x16 matrix");
+ }
+ } else {
+ return emitOpError("expected shape to be 8x8, 8x16 or 16x16");
+ }
Type i32 = IntegerType::get(getContext(), 32);
- if (getNum() == 1 && getType() != i32)
+ uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
+ if (numElements == 1 && getType() != i32)
return emitOpError("expected destination type is i32");
- if (getNum() == 2 || getNum() == 4) {
+ if (numElements == 2 || numElements == 4) {
Type dstType = LLVM::LLVMStructType::getLiteral(
- getContext(), SmallVector<Type>(getNum(), i32));
+ getContext(), SmallVector<Type>(numElements, i32));
if (getType() != dstType)
return emitOpError("expected destination type is a structure of ")
- << getNum() << " elements of type i32";
+ << numElements << " elements of type i32";
}
+
return success();
}
@@ -1069,7 +1184,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
return ptx;
}
-void NVVM::WgmmaMmaAsyncOp::getAsmValues(
+bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
RewriterBase &rewriter,
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
&asmValues) {
@@ -1100,7 +1215,9 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
{makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
mlir::NVVM::PTXRegisterMod::Read});
}
+ return true; // Has manual mapping
}
+
LogicalResult NVVM::FenceProxyOp::verify() {
if (getKind() == NVVM::ProxyKind::TENSORMAP)
return emitOpError() << "tensormap proxy is not a supported proxy kind";
@@ -1216,30 +1333,70 @@ LogicalResult NVVM::PrefetchOp::verify() {
unsigned addressSpace =
llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
+ std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
- if (getUniform()) {
- if (getCacheLevel() != CacheLevel::L1)
- return emitOpError("unsupported cache level, the only supported uniform "
- "cache level is L1");
+ if (getTensormap() && cacheLevel)
+ return emitOpError("cannot specify both tensormap and cache level");
- if (addressSpace != MemSpace::kGenericMemorySpace)
+ if (getTensormap()) {
+ if (addressSpace != MemSpace::kGenericMemorySpace &&
+ addressSpace != MemSpace::kConstantMemorySpace) {
return emitOpError(
- "prefetch to uniform cache requires a generic pointer");
- }
+ "prefetch tensormap requires a generic or constant pointer");
+ }
- if (evictPriority) {
- if (getCacheLevel() != CacheLevel::L2)
+ if (evictPriority) {
return emitOpError(
- "cache eviction priority supported only for cache level L2");
-
- if (addressSpace != MemSpace::kGlobalMemorySpace)
- return emitOpError("cache eviction priority requires a global pointer");
+ "prefetch tensormap does not support eviction priority");
+ }
- if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
- *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
+ if (getInParamSpace() && addressSpace != MemSpace::kGenericMemorySpace) {
return emitOpError(
- "unsupported cache eviction priority, only evict_last and "
- "evict_normal are supported");
+ "in_param_space can only be specified for a generic pointer");
+ }
+
+ } else if (cacheLevel) {
+ if (addressSpace != MemSpace::kGenericMemorySpace &&
+ addressSpace != MemSpace::kGlobalMemorySpace &&
+ addressSpace != MemSpace::kLocalMemorySpace) {
+ return emitOpError("prefetch to cache level requires a generic, global, "
+ "or local pointer");
+ }
+
+ if (getUniform()) {
+ if (*cacheLevel != CacheLevel::L1) {
+ return emitOpError(
+ "unsupported cache level, the only supported uniform "
+ "cache level is L1");
+ }
+
+ if (addressSpace != MemSpace::kGenericMemorySpace) {
+ return emitOpError(
+ "prefetch to uniform cache requires a generic pointer");
+ }
+ }
+
+ if (evictPriority) {
+ if (*cacheLevel != CacheLevel::L2)
+ return emitOpError(
+ "cache eviction priority supported only for cache level L2");
+
+ if (addressSpace != MemSpace::kGlobalMemorySpace)
+ return emitOpError("cache eviction priority requires a global pointer");
+
+ if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
+ *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
+ return emitOpError(
+ "unsupported cache eviction priority, only evict_last and "
+ "evict_normal are supported");
+ }
+
+ if (getPredicate())
+ return emitOpError("predicate supported only on prefetch tensormap");
+
+ } else {
+ return emitOpError(
+ "requires specification of either cache level or tensormap");
}
return success();
@@ -1379,28 +1536,102 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
return {id, std::move(args)};
}
-llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
- bool isIm2Col) {
- switch (tensorDims) {
- case 1:
- return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
- case 2:
- return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
- case 3:
- return isIm2Col
- ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
- : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
- case 4:
- return isIm2Col
- ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
- : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
- case 5:
- return isIm2Col
- ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
- : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
- default:
- llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
- }
+mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ // Fill the Intrinsic Args
+ args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
+
+ for (auto v : thisOp.getCoordinates())
+ args.push_back(mt.lookupValue(v));
+ for (auto v : thisOp.getIm2colOffsets())
+ args.push_back(mt.lookupValue(v));
+
+ mlir::Value cacheHint = thisOp.getL2CacheHint();
+ const bool hasCacheHint = static_cast<bool>(cacheHint);
+ llvm::Value *i64Unused =
+ llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
+ args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
+ args.push_back(builder.getInt1(hasCacheHint));
+
+ const unsigned NI = llvm::Intrinsic::not_intrinsic;
+ static constexpr llvm::Intrinsic::ID IDTable[][6] = {
+ {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
+ {NI, NI, NI,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
+ {NI, NI, NI,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
+ {NI, NI, NI,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
+ {NI, NI, NI, NI, NI,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
+
+ static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
+ "TMALoadModes must match number of rows in IDTable");
+ size_t mode = static_cast<size_t>(thisOp.getMode());
+ size_t dim = thisOp.getCoordinates().size();
+ llvm::Intrinsic::ID id = IDTable[mode][dim];
+ if (id == llvm::Intrinsic::not_intrinsic)
+ llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
+
+ return {id, std::move(args)};
+}
+
+mlir::NVVM::IDArgPair
+CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ // Fill the Intrinsic Args
+ args.push_back(mt.lookupValue(thisOp.getSrcMem()));
+ args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
+
+ for (auto v : thisOp.getCoordinates())
+ args.push_back(mt.lookupValue(v));
+
+ mlir::Value cacheHint = thisOp.getL2CacheHint();
+ const bool hasCacheHint = static_cast<bool>(cacheHint);
+ llvm::Value *i64Unused =
+ llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
+ args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
+ args.push_back(builder.getInt1(hasCacheHint));
+
+ const unsigned NI = llvm::Intrinsic::not_intrinsic;
+ static constexpr llvm::Intrinsic::ID IDTable[][6] = {
+ {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
+ {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
+ {NI, NI, NI, NI, NI,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
+
+ static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
+ "TMAStoreModes must match number of rows in IDTable");
+ size_t mode = static_cast<size_t>(thisOp.getMode());
+ size_t dim = thisOp.getCoordinates().size();
+ llvm::Intrinsic::ID id = IDTable[mode][dim];
+ if (id == llvm::Intrinsic::not_intrinsic)
+ llvm_unreachable(
+ "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
+
+ return {id, std::move(args)};
}
#define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
@@ -1774,26 +2005,47 @@ NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
return {ids[type], args};
}
-llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) {
+static llvm::Value *getParamCastedAddr(llvm::Value *addr,
+ llvm::IRBuilderBase &builder) {
+ return builder.CreateAddrSpaceCast(
+ addr,
+ llvm::PointerType::get(builder.getContext(),
+ llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM));
+}
+
+NVVM::IDArgPair
+PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
using MemSpace = NVVM::NVVMMemorySpace;
using CacheLevel = NVVM::PrefetchCacheLevel;
- NVVM::PrefetchCacheLevel cacheLevel = op.getCacheLevel();
+ std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
std::optional<NVVM::CacheEvictionPriority> evictPriority =
op.getEvictPriority();
unsigned addressSpace =
llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
.getAddressSpace();
- if (op.getUniform() && cacheLevel == CacheLevel::L1)
- return llvm::Intrinsic::nvvm_prefetchu_L1;
+ llvm::SmallVector<llvm::Value *> args;
+ llvm::Value *addr = mt.lookupValue(op.getAddr());
+ args.push_back(op.getInParamSpace() ? getParamCastedAddr(addr, builder)
+ : addr);
+
+ if (op.getTensormap())
+ return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
+
+ assert(cacheLevel && "expected cache level for non-tensormap prefetch");
+
+ if (op.getUniform() && *cacheLevel == CacheLevel::L1)
+ return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
- if (evictPriority && cacheLevel == CacheLevel::L2) {
+ if (evictPriority && *cacheLevel == CacheLevel::L2) {
switch (*evictPriority) {
case NVVM::CacheEvictionPriority::EvictLast:
- return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last;
+ return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
case NVVM::CacheEvictionPriority::EvictNormal:
- return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal;
+ return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
default:
llvm_unreachable("Invalid cache eviction priority");
}
@@ -1801,21 +2053,41 @@ llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) {
switch (addressSpace) {
case MemSpace::kGenericMemorySpace:
- return cacheLevel == CacheLevel::L1 ? llvm::Intrinsic::nvvm_prefetch_L1
- : llvm::Intrinsic::nvvm_prefetch_L2;
+ return *cacheLevel == CacheLevel::L1
+ ? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args})
+ : NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
case MemSpace::kGlobalMemorySpace:
- return cacheLevel == CacheLevel::L1
- ? llvm::Intrinsic::nvvm_prefetch_global_L1
- : llvm::Intrinsic::nvvm_prefetch_global_L2;
+ return *cacheLevel == CacheLevel::L1
+ ? NVVM::IDArgPair(
+ {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
+ : NVVM::IDArgPair(
+ {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
case MemSpace::kLocalMemorySpace:
- return cacheLevel == CacheLevel::L1
- ? llvm::Intrinsic::nvvm_prefetch_local_L1
- : llvm::Intrinsic::nvvm_prefetch_local_L2;
+ return *cacheLevel == CacheLevel::L1
+ ? NVVM::IDArgPair(
+ {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
+ : NVVM::IDArgPair(
+ {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
default:
llvm_unreachable("Invalid pointer address space");
}
}
+bool NVVM::InlinePtxOp::getAsmValues(
+ RewriterBase &rewriter,
+ llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
+ &asmValues) {
+ for (auto arg : getReadWriteArgs())
+ asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::ReadWrite});
+ for (auto arg : getResults())
+ asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Write});
+ for (auto arg : getReadOnlyArgs())
+ asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Read});
+ if (getPredicate())
+ asmValues.push_back({getPredicate(), mlir::NVVM::PTXRegisterMod::Read});
+ return false; // No manual mapping needed
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
@@ -1854,19 +2126,31 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
attrName == NVVMDialect::getReqntidAttrName() ||
attrName == NVVMDialect::getClusterDimAttrName()) {
auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
- if (!values || values.empty() || values.size() > 3)
+ if (!values || values.empty() || values.size() > 3) {
return op->emitError()
<< "'" << attrName
<< "' attribute must be integer array with maximum 3 index";
+ }
}
// If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
// attribute
if (attrName == NVVMDialect::getMinctasmAttrName() ||
attrName == NVVMDialect::getMaxnregAttrName() ||
attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
- if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
+ if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) {
return op->emitError()
<< "'" << attrName << "' attribute must be integer constant";
+ }
+ }
+ // blocksareclusters must be used along with reqntid and cluster_dim
+ if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
+ if (!op->hasAttr(NVVMDialect::getReqntidAttrName()) ||
+ !op->hasAttr(NVVMDialect::getClusterDimAttrName())) {
+ return op->emitError()
+ << "'" << attrName << "' attribute must be used along with "
+ << "'" << NVVMDialect::getReqntidAttrName() << "' and "
+ << "'" << NVVMDialect::getClusterDimAttrName() << "'";
+ }
}
return success();
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
index 8317b67..23b4130 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
using namespace mlir;
using namespace LLVM;
@@ -63,9 +63,8 @@ DIExpressionRewriter::simplify(DIExpressionAttr expr,
}
if (maxNumRewrites && numRewrites >= *maxNumRewrites) {
- LLVM_DEBUG(llvm::dbgs()
- << "LLVMDIExpressionSimplifier exceeded max num rewrites ("
- << maxNumRewrites << ")\n");
+ LDBG() << "LLVMDIExpressionSimplifier exceeded max num rewrites ("
+ << maxNumRewrites << ")";
// Skip rewriting the rest.
result.append(inputs.begin(), inputs.end());
}
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
index b951df8..4ea2ac9 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
@@ -129,7 +129,6 @@ handleInlinedAllocas(Operation *call,
OpBuilder::InsertionGuard insertionGuard(builder);
builder.setInsertionPoint(allocaOp);
LLVM::LifetimeStartOp::create(builder, allocaOp.getLoc(),
- arraySize.getValue().getLimitedValue(),
allocaOp.getResult());
}
allocaOp->moveAfter(newConstant);
@@ -147,7 +146,6 @@ handleInlinedAllocas(Operation *call,
for (auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) {
if (shouldInsertLifetime)
LLVM::LifetimeEndOp::create(builder, allocaOp.getLoc(),
- arraySize.getValue().getLimitedValue(),
allocaOp.getResult());
}
}
@@ -237,8 +235,10 @@ getUnderlyingObjectSet(Value pointerValue) {
WalkContinuation walkResult = walkSlice(pointerValue, [&](Value val) {
// Attempt to advance to the source of the underlying view-like operation.
// Examples of view-like operations include GEPOp and AddrSpaceCastOp.
- if (auto viewOp = val.getDefiningOp<ViewLikeOpInterface>())
- return WalkContinuation::advanceTo(viewOp.getViewSource());
+ if (auto viewOp = val.getDefiningOp<ViewLikeOpInterface>()) {
+ if (val == viewOp.getViewDest())
+ return WalkContinuation::advanceTo(viewOp.getViewSource());
+ }
// Attempt to advance to control flow predecessors.
std::optional<SmallVector<Value>> controlFlowPredecessors =
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 34c63d3..578931e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -194,9 +194,10 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
ArrayRef<AffineMap> indexingMaps) {
// Initialize indexingMaps attribute, for MatmulOp.
SmallVector<Attribute, 3> indexingMapsAttrVal;
- indexingMapsAttrVal = llvm::map_to_vector(
- MatmulOp::getDefaultIndexingMaps(b.getContext()),
- [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ indexingMapsAttrVal =
+ llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
+ return AffineMapAttr::get(map);
+ });
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
attributes, regionBuilder);
@@ -1569,40 +1570,50 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}
-// Retrieve the operation from the body, if it is the only one (except
-// yield) and if it gets the same amount of arguments as the body does.
-// If initFirst flag is enabled, we check that init takes the first position in
-// operands of payload.
-static Operation *findPayloadOp(Block *body, bool initFirst = false) {
+static bool canUseShortForm(Block *body, bool initFirst = false) {
+ // Check if the body can be printed in short form. The following 4 conditions
+ // must be satisfied:
+
+ // 1) The body must contain exactly 2 operations: the payload op and a yield.
if (body->getOperations().size() != 2)
- return nullptr;
+ return false;
Operation &payload = body->getOperations().front();
- assert(isa<YieldOp>(body->getOperations().back()));
+ // 2) The payload op must have the same number of operands as the number of
+ // block arguments.
if (payload.getNumOperands() == 0 ||
payload.getNumOperands() != body->getNumArguments())
- return nullptr;
+ return false;
+
+ // 3) If `initFirst` is true (e.g., for reduction ops), the init block
+ // must be the first operand of the payload op, otherwise, the operands
+ // must match the block arguments in order.
if (initFirst) {
// check init
if (payload.getOperands().back() != body->getArgument(0))
- return nullptr;
+ return false;
// check rest
for (const auto &[operand, bbArg] :
llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
if (bbArg != operand)
- return nullptr;
+ return false;
}
} else {
for (const auto &[operand, bbArg] :
llvm::zip(payload.getOperands(), body->getArguments())) {
if (bbArg != operand)
- return nullptr;
+ return false;
}
}
- return &payload;
+
+ // 4) The `yield` operand must be the result of the payload op.
+ auto yieldOp = cast<YieldOp>(body->getTerminator());
+ return yieldOp.getNumOperands() == 1 &&
+ yieldOp.getOperand(0).getDefiningOp() &&
+ yieldOp.getOperand(0).getDefiningOp() == &payload;
}
-void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
+static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
SmallVector<StringRef> elidedAttrs;
std::string attrToElide;
p << " { " << payloadOp->getName().getStringRef();
@@ -1621,15 +1632,15 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
void MapOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
- Operation *payloadOp = findPayloadOp(mapper);
- if (payloadOp) {
- printShortForm(p, payloadOp);
+ bool useShortForm = canUseShortForm(mapper);
+ if (useShortForm) {
+ printShortForm(p, &mapper->getOperations().front());
}
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
p.printOptionalAttrDict((*this)->getAttrs());
- if (!payloadOp) {
+ if (!useShortForm) {
// Print region if the payload op was not detected.
p.increaseIndent();
p.printNewline();
@@ -1828,15 +1839,15 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
void ReduceOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
- Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
- if (payloadOp) {
- printShortForm(p, payloadOp);
+ bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
+ if (useShortForm) {
+ printShortForm(p, &mapper->getOperations().front());
}
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
- if (!payloadOp) {
+ if (!useShortForm) {
// Print region if the payload op was not detected.
p.increaseIndent();
p.printNewline();
@@ -3749,6 +3760,25 @@ std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
// MatMulOp
//===----------------------------------------------------------------------===//
+static FailureOr<SmallVector<SmallVector<int64_t>>>
+getAffineResultPositions(ArrayAttr maps) {
+ SmallVector<SmallVector<int64_t>> positions;
+ for (auto map : maps) {
+ AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
+ if (!attr)
+ return failure();
+ SmallVector<int64_t> pos;
+ for (auto result : attr.getAffineMap().getResults()) {
+ auto dim = dyn_cast<AffineDimExpr>(result);
+ if (!dim)
+ return failure();
+ pos.push_back(dim.getPosition());
+ }
+ positions.push_back(pos);
+ }
+ return positions;
+}
+
/// Returns a list of AffineMap with the typical matmul indexing charactristic.
SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
AffineExpr d0, d1, d2;
@@ -3760,6 +3790,20 @@ SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
return indexingMaps;
}
+bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
+ (*positions)[1] == SmallVector<int64_t>{2, 1} &&
+ (*positions)[2] == SmallVector<int64_t>{0, 1};
+}
+
SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
utils::IteratorType::parallel,
@@ -3836,7 +3880,7 @@ bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
return expr.isFunctionOfDim(bcastMap.getNumDims() - 1);
}
-FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
+static FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
if (parser.parseOptionalKeyword("indexing_maps"))
return ArrayAttr{
nullptr}; // Success in case indexing_maps was not provided.
@@ -3912,6 +3956,380 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+SmallVector<AffineMap>
+MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
+ AffineExpr d0, d1, d2;
+ MLIRContext *context = builder.getContext();
+ bindDims(context, d0, d1, d2);
+ AffineMap mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
+ AffineMap mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
+ AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
+ return {mapLHS, mapRHS, mapOut};
+}
+
+bool MatmulTransposeAOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{2, 0} &&
+ (*positions)[1] == SmallVector<int64_t>{2, 1} &&
+ (*positions)[2] == SmallVector<int64_t>{0, 1};
+}
+
+void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
+ OperationState &result,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
+}
+
+MatmulTransposeAOp
+MatmulTransposeAOp::create(OpBuilder &builder, Location location,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, inputs, outputs, attributes);
+ auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
+ OperationState &result,
+ TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
+}
+
+MatmulTransposeAOp
+MatmulTransposeAOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, attributes);
+ auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
+ OperationState &result,
+ TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ result.addAttribute("cast", cast);
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
+}
+
+MatmulTransposeAOp
+MatmulTransposeAOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
+ auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+bool MatmulTransposeAOp::classof(Operation *op) {
+ return dyn_cast_or_null<linalg::MatmulOp>(op) &&
+ MatmulTransposeAOp::isDefaultIndexingMaps(
+ op->getAttr("indexing_maps"));
+}
+
+SmallVector<AffineMap>
+MatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
+ AffineExpr d0, d1, d2;
+ MLIRContext *context = builder.getContext();
+ bindDims(context, d0, d1, d2);
+ AffineMap mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
+ AffineMap mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
+ AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
+ return {mapLHS, mapRHS, mapOut};
+}
+
+bool MatmulTransposeBOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
+ (*positions)[1] == SmallVector<int64_t>{1, 2} &&
+ (*positions)[2] == SmallVector<int64_t>{0, 1};
+}
+
+void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
+ OperationState &result,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
+}
+
+MatmulTransposeBOp
+MatmulTransposeBOp::create(OpBuilder &builder, Location location,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, inputs, outputs, attributes);
+ auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
+ OperationState &result,
+ TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
+}
+
+MatmulTransposeBOp
+MatmulTransposeBOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, attributes);
+ auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
+ OperationState &result,
+ TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ result.addAttribute("cast", cast);
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
+}
+
+MatmulTransposeBOp
+MatmulTransposeBOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
+ auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+bool MatmulTransposeBOp::classof(Operation *op) {
+ return dyn_cast_or_null<linalg::MatmulOp>(op) &&
+ MatmulTransposeBOp::isDefaultIndexingMaps(
+ op->getAttr("indexing_maps"));
+}
+
+SmallVector<AffineMap>
+BatchMatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
+ AffineExpr d0, d1, d2, d3;
+ MLIRContext *context = builder.getContext();
+ bindDims(context, d0, d1, d2, d3);
+ AffineMap mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
+ AffineMap mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
+ AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
+ return {mapLHS, mapRHS, mapOut};
+}
+
+bool BatchMatmulTransposeAOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{0, 3, 1} &&
+ (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
+ (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
+}
+
+void linalg::BatchMatmulTransposeAOp::build(
+ OpBuilder &builder, OperationState &result, ValueRange inputs,
+ ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(),
+ getDefaultIndexingMaps(builder));
+}
+
+BatchMatmulTransposeAOp
+BatchMatmulTransposeAOp::create(OpBuilder &builder, Location location,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, inputs, outputs, attributes);
+ auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::BatchMatmulTransposeAOp::build(
+ OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(),
+ getDefaultIndexingMaps(builder));
+}
+
+BatchMatmulTransposeAOp
+BatchMatmulTransposeAOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, attributes);
+ auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::BatchMatmulTransposeAOp::build(
+ OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ result.addAttribute("cast", cast);
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(),
+ getDefaultIndexingMaps(builder));
+}
+
+BatchMatmulTransposeAOp
+BatchMatmulTransposeAOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
+ auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+bool BatchMatmulTransposeAOp::classof(Operation *op) {
+ return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
+ BatchMatmulTransposeAOp::isDefaultIndexingMaps(
+ op->getAttr("indexing_maps"));
+}
+
+SmallVector<AffineMap>
+BatchMatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
+ AffineExpr d0, d1, d2, d3;
+ MLIRContext *context = builder.getContext();
+ bindDims(context, d0, d1, d2, d3);
+ AffineMap mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
+ AffineMap mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
+ AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
+ return {mapLHS, mapRHS, mapOut};
+}
+
+bool BatchMatmulTransposeBOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
+ (*positions)[1] == SmallVector<int64_t>{0, 2, 3} &&
+ (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
+}
+
+void linalg::BatchMatmulTransposeBOp::build(
+ OpBuilder &builder, OperationState &result, ValueRange inputs,
+ ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(),
+ getDefaultIndexingMaps(builder));
+}
+
+BatchMatmulTransposeBOp
+BatchMatmulTransposeBOp::create(OpBuilder &builder, Location location,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, inputs, outputs, attributes);
+ auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::BatchMatmulTransposeBOp::build(
+ OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(),
+ getDefaultIndexingMaps(builder));
+}
+
+BatchMatmulTransposeBOp
+BatchMatmulTransposeBOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, attributes);
+ auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::BatchMatmulTransposeBOp::build(
+ OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ result.addAttribute("cast", cast);
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(),
+ getDefaultIndexingMaps(builder));
+}
+
+BatchMatmulTransposeBOp
+BatchMatmulTransposeBOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
+ auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+bool BatchMatmulTransposeBOp::classof(Operation *op) {
+ return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
+ BatchMatmulTransposeBOp::isDefaultIndexingMaps(
+ op->getAttr("indexing_maps"));
+}
+
//===----------------------------------------------------------------------===//
// ContractOp
//===----------------------------------------------------------------------===//
@@ -4120,6 +4538,20 @@ BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
return indexingMaps;
}
+bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
+ (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
+ (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
+}
+
SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
return SmallVector<utils::IteratorType>{
utils::IteratorType::parallel, utils::IteratorType::parallel,
@@ -5042,7 +5474,7 @@ PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
/// Returns true if the tiles and the tiled dims are constant.
template <typename OpTy>
-bool areTilesAndTiledDimsAllConstant(OpTy op) {
+static bool areTilesAndTiledDimsAllConstant(OpTy op) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
"applies to only pack or unpack operations");
ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
@@ -5345,11 +5777,18 @@ ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
auto innerDimsPos = getInnerDimsPos();
- auto packedShape = getSourceType().getShape();
+ SmallVector<int64_t> outerDims(getAllOuterDims());
SmallVector<int64_t> res;
+ // Recover the original order of the outer dims.
+ SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
+ invertPermutationVector(outerDimPermInv);
+ if (!outerDimPermInv.empty())
+ applyPermutationToVector(outerDims, outerDimPermInv);
+
+ // Collect the outer dims corresponding to the tilled inner dims.
for (auto index : innerDimsPos)
- res.push_back(packedShape[index]);
+ res.push_back(outerDims[index]);
return res;
}
@@ -5646,6 +6085,19 @@ BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
return indexingMaps;
}
+bool BatchReduceMatmulOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
+ (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
+ (*positions)[2] == SmallVector<int64_t>{1, 2};
+}
unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
std::string BatchReduceMatmulOp::getLibraryCallName() {
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index bdfc8d0..f0c1f44 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h"
@@ -27,6 +28,7 @@
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Utils/Utils.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
@@ -68,12 +70,7 @@ static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
// We want to discourage direct use of PatternRewriter in APIs but In this
// very specific case, an IRRewriter is not enough.
- struct TrivialPatternRewriter : public PatternRewriter {
- public:
- explicit TrivialPatternRewriter(MLIRContext *context)
- : PatternRewriter(context) {}
- };
- TrivialPatternRewriter rewriter(operation->getContext());
+ PatternRewriter rewriter(operation->getContext());
rewriter.setInsertionPoint(operation);
auto result = pattern.returningMatchAndRewrite(op, rewriter);
if (failed(result))
@@ -1985,14 +1982,19 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
// Convert the padding values to attributes.
SmallVector<Attribute> paddingValues;
- for (auto const &it :
+ for (auto const &[untypedAttr, elementOrTensorType] :
llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
- auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
+
+ if (isa<ub::PoisonAttr>(untypedAttr)) {
+ paddingValues.push_back(untypedAttr);
+ continue;
+ }
+ auto attr = dyn_cast<TypedAttr>(untypedAttr);
if (!attr) {
- emitOpError("expects padding values to be typed attributes");
+ emitOpError("expects padding values to be typed attributes or poison");
return DiagnosedSilenceableFailure::definiteFailure();
}
- Type elementType = getElementTypeOrSelf(std::get<1>(it));
+ Type elementType = getElementTypeOrSelf(elementOrTensorType);
// Try to parse string attributes to obtain an attribute of element type.
if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
@@ -2000,7 +2002,7 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
/*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
if (!parsedAttr || parsedAttr.getType() != elementType) {
auto diag = this->emitOpError("expects a padding that parses to ")
- << elementType << ", got " << std::get<0>(it);
+ << elementType << ", got " << untypedAttr;
diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
return DiagnosedSilenceableFailure::definiteFailure();
}
@@ -2235,8 +2237,13 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
auto attr = dyn_cast<TypedAttr>(untypedAttr);
Type elementType = getElementTypeOrSelf(elementOrTensorType);
+
+ if (isa<ub::PoisonAttr>(untypedAttr)) {
+ paddingValues.push_back(untypedAttr);
+ continue;
+ }
if (!attr) {
- emitOpError("expects padding values to be typed attributes");
+ emitOpError("expects padding values to be typed attributes or poison");
return DiagnosedSilenceableFailure::definiteFailure();
}
// Try to parse string attributes to obtain an attribute of element type.
@@ -3783,8 +3790,15 @@ LogicalResult TileUsingForallOp::verify() {
void transform::VectorizeChildrenAndApplyPatternsOp::build(
OpBuilder &builder, OperationState &result, Value target,
- bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
+ bool foldTypeExtensionsIntoContract, bool vectorizePadding,
+ bool vectorizeExtract, bool flatten1DDepthwiseConv) {
result.addOperands(target);
+ if (foldTypeExtensionsIntoContract) {
+ result.addAttribute(
+ VectorizeChildrenAndApplyPatternsOp::
+ getFoldTypeExtensionsIntoContractAttrName(result.name),
+ builder.getUnitAttr());
+ }
if (vectorizePadding) {
result.addAttribute(
VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
@@ -3875,6 +3889,9 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
patterns.add<CopyVectorizationPattern>(ctx);
+ if (getFoldTypeExtensionsIntoContract())
+ vector::populateFoldArithExtensionPatterns(patterns);
+
if (getVectorizePadding()) {
linalg::populatePadOpVectorizationPatterns(patterns);
// This creates an alternative path for lowering tensor.pad - by
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 3908d73..6912da3f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -55,8 +55,8 @@ static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp,
// Skip the batch dimension if present.
// Offset all dimensions accordingly.
SmallVector<int64_t, 3> offsetDims(dims);
- for (size_t i = 0; i < offsetDims.size(); i++)
- offsetDims[i] += batchDimsOffset;
+ for (int64_t &offsetDim : offsetDims)
+ offsetDim += batchDimsOffset;
auto tileOp = cast<TilingInterface>(linalgOp.getOperation());
OpBuilder builder(tileOp);
@@ -320,10 +320,6 @@ void linalg::populateBlockPackMatmulPatterns(
RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
patterns.add<BlockPackMatmul<linalg::GenericOp>,
BlockPackMatmul<linalg::MatmulOp>,
- BlockPackMatmul<linalg::BatchMatmulOp>,
- BlockPackMatmul<linalg::MatmulTransposeAOp>,
- BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
- BlockPackMatmul<linalg::MatmulTransposeBOp>,
- BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>(
- patterns.getContext(), controlFn);
+ BlockPackMatmul<linalg::BatchMatmulOp>>(patterns.getContext(),
+ controlFn);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 70f846e..fb39e186 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -23,9 +23,11 @@ add_mlir_dialect_library(MLIRLinalgTransforms
InlineScalarOperands.cpp
Interchange.cpp
Loops.cpp
+ MorphOps.cpp
TransposeMatmul.cpp
ShardingInterfaceImpl.cpp
- NamedOpConversions.cpp
+ SimplifyDepthwiseConv.cpp
+ NamedToElementwise.cpp
BlockPackMatmul.cpp
PackAndUnpackPatterns.cpp
Padding.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index d1eb270..108abe8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
@@ -50,28 +51,71 @@ static Value createMul(Location loc, Value x, Value y, Type accType,
return arith::MulFOp::create(builder, loc, xConvert, yConvert);
}
-// Delinearizes the given composite `index` by the basis specified in `factors`.
-static SmallVector<Value> unrollIndex(OpBuilder &b, Location loc, Value index,
- ArrayRef<int64_t> factors) {
- assert(!factors.empty() && "empty factor list");
- SmallVector<Value> basis;
- for (int64_t f : factors)
- basis.push_back(arith::ConstantOp::create(b, loc, b.getIndexAttr(f)));
- FailureOr<SmallVector<Value>> multiIndex =
- affine::delinearizeIndex(b, loc, index, basis);
- assert(!failed(multiIndex) && "Failed to linearize img2col index");
- return *multiIndex;
+// Generate the affine expression to compute the convolved index
+// for the input as `oIndex * stride + fIndex`,
+// where oIndex: output iterator; fIndex: filter iterator.
+static AffineExpr getConvolvedExpr(OpBuilder &b, int64_t stride,
+ bool useSymbols = true) {
+ AffineExpr oExpr, fExpr;
+ if (useSymbols)
+ bindSymbols(b.getContext(), oExpr, fExpr);
+ else
+ bindDims(b.getContext(), oExpr, fExpr);
+ return AffineExpr(stride * oExpr + fExpr);
}
-// Given indices corresponding to iterators in the output (oIndex) and filter
-// (fIndex) for a convolution, compute the convolved index for the
-// input as `oIndex * stride + fIndex`.
-static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex,
- Value fIndex, int64_t stride) {
- AffineExpr oExpr, fExpr;
- bindSymbols(b.getContext(), oExpr, fExpr);
- AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr);
- return affine::makeComposedAffineApply(b, loc, convMap, {oIndex, fIndex});
+// Stores the affine expressions to map the iteration space of the im2col matrix
+// to the corresponding indices of the output and filter matrices
+struct Im2ColToOperandsExprs {
+ AffineExpr fhIndex;
+ AffineExpr fwIndex;
+ AffineExpr icIndex;
+ AffineExpr ohIndex;
+ AffineExpr owIndex;
+};
+
+// Stores the affine expressions to map the iteration space of the im2col matrix
+// to the input matrix indices
+struct Im2ColToInputDimsExprs {
+ AffineExpr bIndex;
+ AffineExpr hIndex;
+ AffineExpr wIndex;
+ AffineExpr cIndex;
+};
+
+/// Construct the affine expressions that map the indices of the im2col matrix
+/// to the corresponding input tensor indices for a 2D convolution with the the
+/// provided strides.
+///
+/// @param exprs Affine expressions for output and filter indices.
+/// @param strides [height, width] stride values for the convolution.
+/// @param rewriter Pattern rewriter.
+/// @return Affine expressions mapping im2col matrix indices to input
+/// offsets.
+static Im2ColToInputDimsExprs
+getIm2ColInputExpressions(Im2ColToOperandsExprs exprs,
+ ArrayRef<int64_t> strides, RewriterBase &rewriter) {
+ // maps the iteration space of the im2col matrix to (output_y, filter_y)
+ auto hIndicesMap = AffineMap::inferFromExprList(
+ {ArrayRef{exprs.ohIndex, exprs.fhIndex}}, rewriter.getContext())[0];
+ // maps the iteration space of the im2col matrix to (output_x, filter_x)
+ auto wIndicesMap = AffineMap::inferFromExprList(
+ {ArrayRef{exprs.owIndex, exprs.fwIndex}}, rewriter.getContext())[0];
+ // Compute the input indexing map, to map the indices of the im2col matrix to
+ // the original input offsets. Each element of the im2col matrix corresponds
+ // to a pair of (out_element, filter_element). First, we build the expressions
+ // to compute the input (ix, iy) indices from [out_x/y, filter_x/y] pairs;
+ // then we compose them with the maps that map the im2col matrix elements to
+ // the (out_element, filter_element) pairs.
+ auto bIndexExpr = rewriter.getAffineDimExpr(0U);
+ auto hIndexExpr = getConvolvedExpr(rewriter, strides[0],
+ /*useSymbols*/ false);
+ hIndexExpr = hIndexExpr.compose(hIndicesMap);
+ auto wIndexExpr = getConvolvedExpr(rewriter, strides[1],
+ /*useSymbols*/ false);
+ wIndexExpr = wIndexExpr.compose(wIndicesMap);
+ auto cIndexExpr = exprs.icIndex;
+ return {bIndexExpr, hIndexExpr, wIndexExpr, cIndexExpr};
}
FailureOr<std::pair<Operation *, Operation *>>
@@ -135,44 +179,37 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
auto reduction = utils::IteratorType::reduction;
SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
+ // Given an index of the im2col matrix, retrieve the corresponding indices of
+ // the output and filter matrices
+ auto mIndicesExprs =
+ delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
+ auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
+ ArrayRef<int64_t>{fw * ic, ic, 1});
+ Im2ColToOperandsExprs i2cToOperExprs;
+ i2cToOperExprs.fhIndex = kIndicesExprs[0];
+ i2cToOperExprs.fwIndex = kIndicesExprs[1];
+ i2cToOperExprs.icIndex = kIndicesExprs[2];
+ i2cToOperExprs.ohIndex = mIndicesExprs[0];
+ i2cToOperExprs.owIndex = mIndicesExprs[1];
+
+ // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
+ Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions(
+ i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
+ rewriter);
+ auto inMap =
+ AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.hIndex,
+ inExprs.wIndex, inExprs.cIndex}},
+ rewriter.getContext())[0];
+
SmallVector<AffineMap> img2colIndexingMaps = {
- AffineMap::getMultiDimIdentityMap(nloops, context)};
+ inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
auto img2ColTensor = linalg::GenericOp::create(
rewriter, loc, colTensor.getType(),
- /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
+ /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
img2colIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- // Get the iterators named based on the matmul (batch, m, k).
- Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0);
- Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1);
- Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2);
-
- // Recover the original iteration indices from the problem/input sizes.
- SmallVector<Value> mIndices = unrollIndex(
- nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
- auto ohIndex = mIndices[0];
- auto owIndex = mIndices[1];
-
- SmallVector<Value> kIndices = unrollIndex(
- nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
- auto fhIndex = kIndices[0];
- auto fwIndex = kIndices[1];
- auto icIndex = kIndices[2];
-
- // Extract the input element corresponding to the expanded indices.
- Value hIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
- convOp.getStrides().getValues<int64_t>()[0]);
- Value wIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
- convOp.getStrides().getValues<int64_t>()[1]);
-
- // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
- SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
- Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input,
- extractionIndices);
- linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
});
// Because the filter does not share the same batch dimension,
@@ -421,44 +458,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
auto reduction = utils::IteratorType::reduction;
SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);
- SmallVector<AffineMap, 4> img2colIndexingMaps = {
- AffineMap::getMultiDimIdentityMap(nloops, context)};
+ // Recover the original iteration indices from the problem/input sizes:
+ // given an index of the im2col matrix, retrieve the corresponding indices of
+ // the output and filter matrices
+ auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(1U),
+ ArrayRef<int64_t>{fh * fw, fw, 1});
+ auto mIndicesExprs =
+ delinearize(rewriter.getAffineDimExpr(2U), ArrayRef<int64_t>{ow, 1});
+ Im2ColToOperandsExprs i2cToOperExprs;
+ i2cToOperExprs.icIndex = kIndicesExprs[0];
+ i2cToOperExprs.fhIndex = kIndicesExprs[1];
+ i2cToOperExprs.fwIndex = kIndicesExprs[2];
+ i2cToOperExprs.ohIndex = mIndicesExprs[0];
+ i2cToOperExprs.owIndex = mIndicesExprs[1];
+ Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions(
+ i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
+ rewriter);
+ auto inMap =
+ AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.cIndex,
+ inExprs.hIndex, inExprs.wIndex}},
+ rewriter.getContext())[0];
+ // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
+ SmallVector<AffineMap> img2colIndexingMaps = {
+ inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
auto img2ColTensor = linalg::GenericOp::create(
rewriter, loc, colTensor.getType(),
- /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
+ /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
img2colIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- // Get the iterators named based on the matmul (batch, m, k).
- Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0);
- Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 1);
- Value nIndex = linalg::IndexOp::create(nestedBuilder, loc, 2);
-
- // Recover the original iteration indices from the problem/input sizes.
- SmallVector<Value> kIndices = unrollIndex(
- nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw});
- auto icIndex = kIndices[0];
- auto fhIndex = kIndices[1];
- auto fwIndex = kIndices[2];
-
- SmallVector<Value> nIndices = unrollIndex(
- nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t>{oh, ow});
- auto ohIndex = nIndices[0];
- auto owIndex = nIndices[1];
-
- // Extract the input element corresponding to the expanded indices.
- Value hIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
- convOp.getStrides().getValues<int64_t>()[0]);
- Value wIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
- convOp.getStrides().getValues<int64_t>()[1]);
-
- // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
- SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex};
- Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input,
- extractionIndices);
- linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
});
// Because the filter does not share the same batch dimension,
@@ -545,6 +574,7 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
Value reshapedOutput = tensor::CollapseShapeOp::create(
rewriter, loc, reshapedOutputType, output, outputReassocIndices);
+ // Shape of the Toeplitz matrix produced by Im2col.
SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
inputType.getElementType());
@@ -556,44 +586,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
auto reduction = utils::IteratorType::reduction;
SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
+ // Given an index of the im2col matrix, retrieve the corresponding indices of
+ // the output and filter matrices
+ auto mIndicesExprs =
+ delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
+ auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
+ ArrayRef<int64_t>{fw * ic, ic, 1});
+ Im2ColToOperandsExprs i2cToOperExprs;
+ i2cToOperExprs.fhIndex = kIndicesExprs[0];
+ i2cToOperExprs.fwIndex = kIndicesExprs[1];
+ i2cToOperExprs.icIndex = kIndicesExprs[2];
+ i2cToOperExprs.ohIndex = mIndicesExprs[0];
+ i2cToOperExprs.owIndex = mIndicesExprs[1];
+
+ // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
+ Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions(
+ i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
+ rewriter);
+ auto inMap =
+ AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.hIndex,
+ inExprs.wIndex, inExprs.cIndex}},
+ rewriter.getContext())[0];
SmallVector<AffineMap> img2colIndexingMaps = {
- AffineMap::getMultiDimIdentityMap(nloops, context)};
+ inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
auto img2ColTensor = linalg::GenericOp::create(
rewriter, loc, colTensor.getType(),
- /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
+ /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
img2colIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- // Get the iterators named based on the matmul (batch, m, k).
- Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0);
- Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1);
- Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2);
-
- // Recover the original iteration indices from the problem/input sizes.
- SmallVector<Value> mIndices = unrollIndex(
- nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
- auto ohIndex = mIndices[0];
- auto owIndex = mIndices[1];
-
- SmallVector<Value> kIndices = unrollIndex(
- nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
- auto fhIndex = kIndices[0];
- auto fwIndex = kIndices[1];
- auto icIndex = kIndices[2];
-
- // Extract the input element corresponding to the expanded indices.
- Value hIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
- convOp.getStrides().getValues<int64_t>()[0]);
- Value wIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
- convOp.getStrides().getValues<int64_t>()[1]);
-
- // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
- SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
- Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input,
- extractionIndices);
- linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
});
// Because we didn't transpose the filters we don't actually have a batched
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 76ddee4..2ff7f46 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -75,7 +75,7 @@ static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource,
// layout for best compatibility.
Value toBuffer = bufferization::ToBufferOp::create(
b, loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType),
- tensorSource, /*readOnly=*/true);
+ tensorSource, /*read_only=*/true);
memref::CopyOp::create(b, loc, toBuffer, memrefDest);
} break;
case linalg::BufferizeToAllocationOptions::MemcpyOp::LinalgCopy: {
@@ -84,7 +84,7 @@ static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource,
// layout for best compatibility.
Value toBuffer = bufferization::ToBufferOp::create(
b, loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType),
- tensorSource, /*readOnly=*/true);
+ tensorSource, /*read_only=*/true);
linalg::CopyOp::create(b, loc, toBuffer, memrefDest);
} break;
};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 0a9c176..40085a2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -6,10 +6,12 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Dominance.h"
#include "llvm/ADT/SetOperations.h"
@@ -1236,6 +1238,272 @@ private:
ControlPropagationFn controlFn;
};
+// This struct contains infomation about extract_slice dims.
+struct SliceDimInfo {
+ OpFoldResult offset;
+ OpFoldResult sliceSize;
+ OpFoldResult outputSize;
+};
+
+/// Return the first input extract slice operand, if present, for the current
+/// generic op.
+static FailureOr<OpOperand *> getSliceOperand(GenericOp genericOp) {
+ OpOperand *sliceOperand = nullptr;
+ for (auto operand : genericOp.getDpsInputOperands()) {
+ auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
+ if (!extractOp)
+ continue;
+ sliceOperand = operand;
+ break;
+ }
+ if (!sliceOperand) {
+ return failure();
+ }
+ return sliceOperand;
+}
+
+// Return a map of dims that have partial slices on them so that other operands
+// can use this information. Also return a bool mentioning if a reduction dim
+// has a non full slice as that can be used to fold the original extract slice.
+static FailureOr<llvm::DenseMap<int64_t, SliceDimInfo>>
+getPartialSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand) {
+ tensor::ExtractSliceOp producerSliceOp =
+ sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
+ assert(producerSliceOp && "expect a valid ExtractSliceOp");
+ llvm::DenseMap<int64_t, SliceDimInfo> partialSliceDimMap;
+ SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+
+ SmallVector<OpFoldResult> shape = getAsIndexOpFoldResult(
+ genericOp.getContext(), producerSliceOp.getSourceType().getShape());
+
+ for (auto [idx, expr] : llvm::enumerate(
+ genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
+ // If we have a full slice in a dimension then we dont need to add it to
+ // the partial slice map.
+ if (isConstantIntValue(offsets[idx], 0) &&
+ isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
+ continue;
+ }
+ // We only support partial slices of AffineDimExprs so bail-out if thats not
+ // the case.
+ if (!isa<AffineDimExpr>(expr)) {
+ return failure();
+ }
+ SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
+ int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
+ partialSliceDimMap[dimPos] = sliceDimInfo;
+ }
+ // Next check if the dims with partial slice info are used in non
+ // AffineDimExpr in other operands and if they are then bail-out.
+ for (OpOperand &operand : genericOp->getOpOperands()) {
+ if (operand == *sliceOperand) {
+ continue;
+ }
+ AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
+ if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
+ if (isa<AffineDimExpr>(expr)) {
+ return false;
+ }
+ WalkResult status = expr.walk([&](AffineExpr expr) {
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+ if (partialSliceDimMap.contains(dimExpr.getPosition())) {
+ return WalkResult::interrupt();
+ }
+ }
+ return WalkResult::advance();
+ });
+ if (status.wasInterrupted()) {
+ return true;
+ }
+ return false;
+ })) {
+ return failure();
+ }
+ }
+ return partialSliceDimMap;
+}
+
+static FailureOr<std::tuple<GenericOp, Value>>
+pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
+ GenericOp genericOp,
+ ControlPropagationFn controlFn) {
+ if (genericOp.getNumResults() != 1)
+ return rewriter.notifyMatchFailure(
+ genericOp, "propagation through multi-result generic is unsupported.");
+ if (hasGatherSemantics(genericOp))
+ return rewriter.notifyMatchFailure(
+ genericOp,
+ "propagation through generic with gather semantics is unsupported.");
+ // Collect the sliced operand, if present.
+ auto maybeSliceOperand = getSliceOperand(genericOp);
+ if (failed(maybeSliceOperand))
+ return failure();
+ OpOperand *sliceOperand = *maybeSliceOperand;
+ unsigned OperandIndex = sliceOperand->getOperandNumber();
+
+ if (!controlFn(sliceOperand))
+ return failure();
+
+ tensor::ExtractSliceOp producerSliceOp =
+ sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
+ assert(producerSliceOp && "expect a valid ExtractSliceOp");
+
+ if (producerSliceOp.getSource().getType().getRank() !=
+ producerSliceOp.getResult().getType().getRank()) {
+ return rewriter.notifyMatchFailure(
+ genericOp,
+ "propagation of rank-reducing extract slice is unsupported.");
+ }
+
+ SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
+ if (!areAllConstantIntValue(strides, 1))
+ return rewriter.notifyMatchFailure(
+ genericOp, "propagation of strided extract slice is unsupported.");
+
+ // check if we can support the propagation of this extractSlice
+ // through the generic op and if so return the dimensions that
+
+ auto maybePartialSliceDimMap =
+ getPartialSliceDimInfo(genericOp, sliceOperand);
+
+ if (failed(maybePartialSliceDimMap)) {
+ return failure();
+ }
+
+ auto partialSliceDimMap = *maybePartialSliceDimMap;
+
+ SmallVector<utils::IteratorType> iterators =
+ genericOp.getIteratorTypesArray();
+ bool hasPartialReductionDimSlice =
+ llvm::any_of(partialSliceDimMap, [&](const auto &slice) {
+ int64_t sliceDim = slice.first;
+ return iterators[sliceDim] == utils::IteratorType::reduction;
+ });
+
+ // Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
+ Location loc = genericOp->getLoc();
+ AffineExpr dim0, dim1;
+ bindDims(rewriter.getContext(), dim0, dim1);
+ auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
+ auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
+ return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap,
+ {v1, v2});
+ };
+
+ MLIRContext *ctx = genericOp.getContext();
+ SmallVector<Value> paddedInputs;
+ for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+ if (idx == OperandIndex && !hasPartialReductionDimSlice) {
+ paddedInputs.push_back(producerSliceOp.getSource());
+ continue;
+ }
+ AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
+ SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 0));
+ SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 0));
+ for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) {
+ if (!isa<AffineDimExpr>(expr)) {
+ continue;
+ }
+ AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+ if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
+ continue;
+ }
+ SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
+ operandLowPads[idx] = sliceDimInfo.offset;
+ operandHighPads[idx] =
+ sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+ sliceDimInfo.sliceSize);
+ }
+ auto paddingValue = ub::PoisonOp::create(
+ rewriter, loc, getElementTypeOrSelf(operand->get().getType()));
+ auto paddedOperand = tensor::PadOp::create(
+ rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads,
+ paddingValue, /*nofold=*/false);
+ paddedInputs.push_back(paddedOperand);
+ }
+ AffineMap outputIndexingMap =
+ genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
+
+ auto outputShapeType =
+ llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
+ SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector(
+ outputShapeType.getShape(),
+ [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); });
+ SmallVector<OpFoldResult> newSizes = OutputShape;
+ SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 0));
+ SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 0));
+ SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(),
+ getAsIndexOpFoldResult(ctx, 1));
+ for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) {
+ if (!isa<AffineDimExpr>(expr)) {
+ continue;
+ }
+ AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+ if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
+ continue;
+ }
+ SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
+ outputLowPads[idx] = sliceDimInfo.offset;
+ outputHighPads[idx] = sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+ sliceDimInfo.sliceSize);
+ OutputShape[idx] = sliceDimInfo.outputSize;
+ newSizes[idx] = sliceDimInfo.sliceSize;
+ }
+ Value newPadOutput;
+ auto outputElType =
+ getElementTypeOrSelf(genericOp.getDpsInits()[0].getType());
+ if (isGenericOutsNotUsed(genericOp)) {
+ newPadOutput =
+ tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
+ } else {
+ auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
+ newPadOutput = tensor::PadOp::create(
+ rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads,
+ outputHighPads, paddingValue, /*nofold=*/false);
+ }
+
+ auto newGenericOp = linalg::GenericOp::create(
+ rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput},
+ genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
+ /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
+ rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
+ newGenericOp.getRegion().begin());
+
+ auto extractOp = tensor::ExtractSliceOp::create(
+ rewriter, loc,
+ newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
+ outputLowPads, newSizes, newStrides);
+ Value extractRes = extractOp.getResult();
+
+ return std::make_tuple(newGenericOp, extractRes);
+}
+
+class PushDownExtractSliceOpThroughGenericOp final
+ : public OpRewritePattern<GenericOp> {
+public:
+ PushDownExtractSliceOpThroughGenericOp(MLIRContext *context,
+ ControlPropagationFn fun)
+ : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
+
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ auto genericAndRepl =
+ pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
+ if (failed(genericAndRepl))
+ return failure();
+ rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
+ return success();
+ }
+
+private:
+ ControlPropagationFn controlFn;
+};
+
} // namespace
void mlir::linalg::populateDataLayoutPropagationPatterns(
@@ -1247,3 +1515,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
patterns.getContext(), controlPackUnPackPropagation);
}
+
+void mlir::linalg::populateExtractSliceSinkingPatterns(
+ RewritePatternSet &patterns,
+ const ControlPropagationFn &controlPackUnPackPropagation) {
+ patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
+ patterns.getContext(), controlPackUnPackPropagation);
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index bf66ed0..22690da 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -691,9 +691,9 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
auto newResultType = RankedTensorType::get(
newResultShape, padOp.getResultType().getElementType());
- auto newPadOp = rewriter.create<tensor::PadOp>(
- padOp.getLoc(), /*result=*/newResultType, collapsedSource, newLowPad,
- newHighPad, paddingVal, padOp.getNofold());
+ auto newPadOp = tensor::PadOp::create(
+ rewriter, padOp.getLoc(), /*result=*/newResultType, collapsedSource,
+ newLowPad, newHighPad, paddingVal, padOp.getNofold());
Value dest = padOp.getResult();
if (options.rankReductionStrategy ==
@@ -1052,12 +1052,8 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
static bool constexpr reduceLeft =
(std::is_same_v<FromOpTy, BatchMatmulOp> &&
std::is_same_v<ToOpTy, BatchVecmatOp>) ||
- (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
- std::is_same_v<ToOpTy, BatchVecmatOp>) ||
(std::is_same_v<FromOpTy, MatmulOp> &&
std::is_same_v<ToOpTy, VecmatOp>) ||
- (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
- std::is_same_v<ToOpTy, VecmatOp>) ||
(std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
/// Look for non-batch spatial dims to collapse.
@@ -1113,27 +1109,15 @@ void mlir::linalg::populateContractionOpRankReducingPatterns(
MLIRContext *context = patterns.getContext();
// Unbatching patterns for unit batch size
patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
- patterns
- .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
- context);
- patterns
- .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
- context);
patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
// Non-batch rank 1 reducing patterns
patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
- patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
- patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
// Batch rank 1 reducing patterns
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
- patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
- context);
- patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
- context);
// Non-batch rank 0 reducing patterns
patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index c523153..baf4083 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -20,13 +20,26 @@ namespace mlir {
using namespace mlir;
+static inline bool isScalarLike(Type t) {
+ return isa<IntegerType, FloatType, IndexType, ComplexType>(t);
+}
+
static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
if (!OpTrait::hasElementwiseMappableTraits(op))
return false;
- // TODO: The conversion pattern can be made to work for `any_of` here, but
- // it's more complex as it requires tracking which operands are scalars.
- return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
+ auto types = op->getOperandTypes();
+
+ // We want at least one ranked tensor.
+ bool anyRankedTensor = llvm::any_of(types, llvm::IsaPred<RankedTensorType>);
+
+ // No invalid operands (i.e., every operand is a ranked tensor or
+ // scalar-like).
+ bool noneInvalid = llvm::none_of(types, [](Type t) {
+ return !(isa<RankedTensorType>(t) || isScalarLike(t));
+ });
+
+ return anyRankedTensor && noneInvalid;
}
/// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
@@ -81,13 +94,41 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
return rewriter.notifyMatchFailure(
op, "requires elementwise op on ranked tensors");
- auto rank = cast<RankedTensorType>(op->getResult(0).getType()).getRank();
- SmallVector<AffineMap, 3> indexingMaps(
- op->getNumResults() + op->getNumOperands(),
- rewriter.getMultiDimIdentityMap(rank));
- SmallVector<utils::IteratorType, 6> iteratorTypes(
+ auto resTy = cast<RankedTensorType>(op->getResult(0).getType());
+ auto rank = resTy.getRank();
+
+ // Maps: identity for tensors (rank > 0), scalar map for scalars.
+ AffineMap scalarMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0,
+ /*results=*/{}, rewriter.getContext());
+ AffineMap idMap = rewriter.getMultiDimIdentityMap(rank);
+
+ // Match phase.
+ SmallVector<bool> isScalarOperand;
+ isScalarOperand.reserve(op->getNumOperands());
+ for (Type ty : op->getOperandTypes()) {
+ if (isScalarLike(ty))
+ isScalarOperand.push_back(true);
+ else if (auto rt = dyn_cast<RankedTensorType>(ty))
+ isScalarOperand.push_back(false);
+ else
+ return rewriter.notifyMatchFailure(
+ op,
+ "unsupported operand type (expected scalar-like or ranked tensor)");
+ }
+
+ // Create indexing maps.
+ SmallVector<AffineMap> indexingMaps;
+ indexingMaps.reserve(op->getNumOperands() + op->getNumResults());
+
+ for (bool isScalar : isScalarOperand)
+ indexingMaps.push_back(isScalar ? scalarMap : idMap);
+
+ indexingMaps.append(op->getNumResults(), idMap);
+
+ SmallVector<utils::IteratorType> iteratorTypes(
rank, utils::IteratorType::parallel);
- auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op);
+ SmallVector<Value> outputs =
+ getOrCreateOperandsMatchingResultTypes(rewriter, op);
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, /*resultTensorTypes=*/op->getResultTypes(),
/*inputs=*/op->getOperands(),
@@ -96,14 +137,14 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
/*iteratorTypes=*/iteratorTypes,
/*bodyBuilder=*/
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
- auto resultTypes = llvm::to_vector<6>(
+ SmallVector<Type> resultEltTys = llvm::to_vector<6>(
llvm::map_range(op->getResultTypes(), [](Type type) {
return cast<TensorType>(type).getElementType();
}));
- auto *scalarOp =
+ Operation *scalarOp =
builder.create(loc, op->getName().getIdentifier(),
regionArgs.take_front(op->getNumOperands()),
- resultTypes, op->getAttrs());
+ resultEltTys, op->getAttrs());
linalg::YieldOp::create(builder, loc, scalarOp->getResults());
});
return success();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index fd530f2..9436f1c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -594,7 +594,8 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl(
auto clonedForOp = scf::ForOp::create(
rewriter, loc, bvm.lookupOrDefault(forOp.getLowerBound()),
bvm.lookupOrDefault(forOp.getUpperBound()),
- bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor);
+ bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor,
+ /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
// Map the induction var, region args and results to the `clonedForOp`.
bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 58986a6..36434cf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -55,7 +55,8 @@ static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter,
scf::ForOp newLoop = scf::ForOp::create(
rewriter, loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(),
- loop.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {});
+ loop.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {},
+ loop.getUnsignedCmp());
// Generate the new yield with the replaced operand.
auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
@@ -165,8 +166,12 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
Value source = transferRead.getBase();
// Skip view-like Ops and retrive the actual soruce Operation
- while (auto srcOp = source.getDefiningOp<ViewLikeOpInterface>())
- source = srcOp.getViewSource();
+ while (auto viewLike = source.getDefiningOp<ViewLikeOpInterface>()) {
+ if (viewLike.getViewDest() != source) {
+ break;
+ }
+ source = viewLike.getViewSource();
+ }
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
source.getUsers().end());
@@ -177,7 +182,8 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
if (!processed.insert(user).second)
continue;
if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
- users.append(viewLike->getUsers().begin(), viewLike->getUsers().end());
+ Value viewDest = viewLike.getViewDest();
+ users.append(viewDest.getUsers().begin(), viewDest.getUsers().end());
continue;
}
if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp
new file mode 100644
index 0000000..f261ccb
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp
@@ -0,0 +1,62 @@
+//===- MorphOps.cpp - conversion between named,category and generic ops ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements conversions between linalg ops:
+// named <--> category (elementwise, contraction, ..) <--> generic.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGMORPHOPSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "linalg-morphism"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+struct LinalgMorphOpsPass
+ : public impl::LinalgMorphOpsPassBase<LinalgMorphOpsPass> {
+
+ using impl::LinalgMorphOpsPassBase<
+ LinalgMorphOpsPass>::LinalgMorphOpsPassBase;
+
+ void runOnOperation() override;
+};
+
+void LinalgMorphOpsPass::runOnOperation() {
+
+ RewritePatternSet patterns(&getContext());
+
+ // Lowering paths (named -> category -> generic)
+ if (namedToCategory) {
+ populateLinalgNamedToElementwisePatterns(patterns);
+ }
+ if (namedToGeneric || categoryToGeneric) {
+ populateLinalgNamedOpsGeneralizationPatterns(patterns);
+ }
+
+ // Lifting paths (named <- category <- generic)
+ if (genericToNamed) {
+ populateLinalgGenericOpsSpecializationPatterns(patterns);
+ }
+
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ signalPassFailure();
+}
+} // namespace
diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
new file mode 100644
index 0000000..00a076b
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
@@ -0,0 +1,98 @@
+//===- NamedToElementwise.cpp - convert linalg named op into elementwise --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewriting those linalg named ops that are essentially
+// elementwise e.g. `linalg.exp`, to `linalg.elementwise`. This allows further
+// optimization on `linalg.elementwise` such as folding transpose, broadcast.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+#define DEBUG_TYPE "linalg-named-to-elementwise"
+
+namespace {
+ElementwiseKind getKind(Operation *op) {
+ return llvm::TypeSwitch<Operation *, ElementwiseKind>(op)
+ .Case([](SelectOp) { return ElementwiseKind::select; })
+ .Case([](AddOp) { return ElementwiseKind::add; })
+ .Case([](SubOp) { return ElementwiseKind::sub; })
+ .Case([](MulOp) { return ElementwiseKind::mul; })
+ .Case([](DivOp) { return ElementwiseKind::div; })
+ .Case([](DivUnsignedOp) { return ElementwiseKind::div_unsigned; })
+ .Case([](PowFOp) { return ElementwiseKind::powf; })
+ .Case([](ExpOp) { return ElementwiseKind::exp; })
+ .Case([](LogOp) { return ElementwiseKind::log; })
+ .Case([](AbsOp) { return ElementwiseKind::abs; })
+ .Case([](CeilOp) { return ElementwiseKind::ceil; })
+ .Case([](FloorOp) { return ElementwiseKind::floor; })
+ .Case([](NegFOp) { return ElementwiseKind::negf; })
+ .Case([](ReciprocalOp) { return ElementwiseKind::reciprocal; })
+ .Case([](RoundOp) { return ElementwiseKind::round; })
+ .Case([](SqrtOp) { return ElementwiseKind::sqrt; })
+ .Case([](RsqrtOp) { return ElementwiseKind::rsqrt; })
+ .Case([](SquareOp) { return ElementwiseKind::square; })
+ .Case([](TanhOp) { return ElementwiseKind::tanh; })
+ .Case([](ErfOp) { return ElementwiseKind::erf; })
+ .Default([&](Operation *op) {
+ llvm_unreachable("unhandled case in named to elementwise");
+ return ElementwiseKind::sub;
+ });
+}
+
+template <typename NamedOpTy>
+struct NamedToElementwisePattern : public OpRewritePattern<NamedOpTy> {
+ using OpRewritePattern<NamedOpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(NamedOpTy op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<NamedAttribute> attrs;
+ auto kindAttr = ElementwiseKindAttr::get(op.getContext(), getKind(op));
+ attrs.push_back(rewriter.getNamedAttr("kind", kindAttr));
+ attrs.push_back(
+ rewriter.getNamedAttr("indexing_maps", op.getIndexingMaps()));
+
+ rewriter.replaceOpWithNewOp<ElementwiseOp>(op, op.getDpsInputs(),
+ op.getDpsInits(), attrs);
+ return success();
+ }
+};
+} // namespace
+
+void mlir::linalg::populateLinalgNamedToElementwisePatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<NamedToElementwisePattern<SelectOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<AddOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<SubOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<MulOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<DivOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<DivUnsignedOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<PowFOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<ExpOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<LogOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<AbsOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<CeilOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<FloorOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<NegFOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<ReciprocalOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<RoundOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<SqrtOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<RsqrtOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<SquareOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<TanhOp>>(patterns.getContext());
+ patterns.add<NamedToElementwisePattern<ErfOp>>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 2e62523..8942670 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -230,13 +231,18 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
Value paddingValue;
if (auto complexTy =
dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) {
- auto complexAttr = cast<ArrayAttr>(paddingValueAttr);
- paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
- complexTy, complexAttr);
- } else {
- paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(),
- cast<TypedAttr>(paddingValueAttr));
+ if (auto complexAttr = dyn_cast<ArrayAttr>(paddingValueAttr)) {
+ paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
+ complexTy, complexAttr);
+ }
+ } else if (isa<ub::PoisonAttr>(paddingValueAttr)) {
+ paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(),
+ getElementTypeOrSelf(v.getType()));
+ } else if (auto typedAttr = dyn_cast<TypedAttr>(paddingValueAttr)) {
+ paddingValue =
+ arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr);
}
+ assert(paddingValue && "failed to create value from padding attribute");
// Pad the operand to the bounding box defined by `paddedShape`.
SmallVector<int64_t> tensorShape;
@@ -257,11 +263,11 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
paddingValue, /*nofold=*/false, dynDims);
}
-FailureOr<TilingInterface>
-linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
- const PadTilingInterfaceOptions &constOptions,
- SmallVector<tensor::PadOp> &padOps,
- PadSizeComputationFunction computePaddingSizeFun) {
+FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
+ RewriterBase &rewriter, TilingInterface opToPad,
+ const PadTilingInterfaceOptions &constOptions,
+ SmallVector<tensor::PadOp> &padOps,
+ const PadSizeComputationFunction &computePaddingSizeFun) {
LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n");
Location loc = opToPad.getLoc();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp
index a2bd9d9..27ccf3c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp
@@ -21,7 +21,7 @@
#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
-#define GEN_PASS_DEF_LINALGNAMEDOPCONVERSIONPASS
+#define GEN_PASS_DEF_SIMPLIFYDEPTHWISECONVPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir
@@ -143,23 +143,22 @@ struct SimplifyDepthwiseConvQOp
}
};
-struct LinalgNamedOpConversionPass
- : public impl::LinalgNamedOpConversionPassBase<
- LinalgNamedOpConversionPass> {
- using impl::LinalgNamedOpConversionPassBase<
- LinalgNamedOpConversionPass>::LinalgNamedOpConversionPassBase;
+struct SimplifyDepthwiseConvPass
+ : public impl::SimplifyDepthwiseConvPassBase<SimplifyDepthwiseConvPass> {
+ using impl::SimplifyDepthwiseConvPassBase<
+ SimplifyDepthwiseConvPass>::SimplifyDepthwiseConvPassBase;
void runOnOperation() override {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
- populateLinalgNamedOpConversionPatterns(patterns);
+ populateSimplifyDepthwiseConvPatterns(patterns);
if (failed(applyPatternsGreedily(op, std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
-void mlir::linalg::populateLinalgNamedOpConversionPatterns(
+void mlir::linalg::populateSimplifyDepthwiseConvPatterns(
RewritePatternSet &patterns) {
patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>(
patterns.getContext());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 455e1a6..35ba4f15 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -234,19 +234,8 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
/// Codegen the different matmul variants.
if (numOfBatchDims) {
- if (a == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
- genericOp);
- if (b == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
- genericOp);
return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
}
-
- if (a == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
- if (b == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index bb725f2..e9a8b25 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -29,6 +29,7 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/InterleavedRange.h"
#include "llvm/Support/raw_ostream.h"
#include <utility>
@@ -38,9 +39,6 @@
using namespace mlir;
using namespace mlir::linalg;
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
-
//===----------------------------------------------------------------------===//
// Transformations exposed as functional-style API calls.
//===----------------------------------------------------------------------===//
@@ -91,11 +89,11 @@ static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) {
}
return true;
}
+#endif // NDEBUG
static std::string stringifyReassocIndices(ReassociationIndicesRef ri) {
return llvm::interleaved(ri, ", ", /*Prefix=*/"|", /*Suffix=*/"");
}
-#endif // NDEBUG
/// Return the index of the first result of `map` that is a function of
/// AffineDimExpr(dim), std::nullopt otherwise.
@@ -276,23 +274,18 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
tensor::PadOp::create(rewriter, loc, collapsed, packOp.getSource(), lows,
highs, paddingValue, /*nofold=*/false);
- LLVM_DEBUG(
- DBGSNL(); DBGSNL();
- DBGS() << "insertPositions: "
- << llvm::interleaved(packingMetadata.insertPositions);
- DBGSNL(); DBGS() << "outerPositions: "
- << llvm::interleaved(packingMetadata.outerPositions);
- DBGSNL(); DBGS() << "packedShape: "
- << llvm::interleaved(packedTensorType.getShape());
- DBGSNL(); DBGS() << "packedToStripMinedShapePerm: "
- << llvm::interleaved(packedToStripMinedShapePerm);
- DBGSNL();
- DBGS() << "reassociations: "
- << llvm::interleaved(llvm::map_range(
- packingMetadata.reassociations, stringifyReassocIndices));
- DBGSNL();
- DBGS() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
- DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
+ LDBG() << "insertPositions: "
+ << llvm::interleaved(packingMetadata.insertPositions);
+ LDBG() << "outerPositions: "
+ << llvm::interleaved(packingMetadata.outerPositions);
+ LDBG() << "packedShape: " << llvm::interleaved(packedTensorType.getShape());
+ LDBG() << "packedToStripMinedShapePerm: "
+ << llvm::interleaved(packedToStripMinedShapePerm);
+ LDBG() << "reassociations: "
+ << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
+ stringifyReassocIndices));
+ LDBG() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
+ LDBG() << "collapsed type: " << collapsed;
if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
// Pack ops which operate as simple pads may not produce legal
@@ -317,7 +310,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
rewriter, loc, /*source=*/padOp, /*dest=*/packOp.getDest(),
/*offsets=*/zeros, sizes, /*strides=*/ones);
- LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
+ LDBG() << "insert_slice op: " << insertSliceOp;
rewriter.replaceOp(packOp, insertSliceOp->getResults());
@@ -339,10 +332,9 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
auto transposeOp = linalg::TransposeOp::create(
rewriter, loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
- LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
- DBGS() << "reshape op: " << reshapeOp; DBGSNL();
- DBGS() << "transpPerm: " << llvm::interleaved(transpPerm);
- DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
+ LDBG() << "reshape op: " << reshapeOp;
+ LDBG() << "transpPerm: " << llvm::interleaved(transpPerm);
+ LDBG() << "transpose op: " << transposeOp;
// 7. Replace packOp by transposeOp.
rewriter.replaceOp(packOp, transposeOp->getResults());
@@ -410,21 +402,16 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
linalg::TransposeOp::create(rewriter, loc, unPackOp.getSource(), emptyOp,
packedToStripMinedShapePerm);
- LLVM_DEBUG(
- DBGSNL(); DBGSNL();
- DBGS() << "insertPositions: "
- << llvm::interleaved(packingMetadata.insertPositions);
- DBGSNL(); DBGS() << "packedShape: "
- << llvm::interleaved(packedTensorType.getShape());
- DBGSNL(); DBGS() << "packedToStripMinedShapePerm: "
- << llvm::interleaved(packedToStripMinedShapePerm);
- DBGSNL();
- DBGS() << "reassociations: "
- << llvm::interleaved(llvm::map_range(
- packingMetadata.reassociations, stringifyReassocIndices));
- DBGSNL();
- DBGS() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
- DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););
+ LDBG() << "insertPositions: "
+ << llvm::interleaved(packingMetadata.insertPositions);
+ LDBG() << "packedShape: " << llvm::interleaved(packedTensorType.getShape());
+ LDBG() << "packedToStripMinedShapePerm: "
+ << llvm::interleaved(packedToStripMinedShapePerm);
+ LDBG() << "reassociations: "
+ << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
+ stringifyReassocIndices));
+ LDBG() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
+ LDBG() << "collapsed type: " << collapsedType;
// 4. Collapse from the stripMinedShape to the padded result.
auto reshapeOp = tensor::CollapseShapeOp::create(
@@ -486,10 +473,9 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
SmallVector<utils::IteratorType> iteratorTypes =
linalgOp.getIteratorTypesArray();
- LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n"
- << "maps: " << llvm::interleaved(indexingMaps) << "\n"
- << "iterators: " << llvm::interleaved(iteratorTypes)
- << "\n");
+ LDBG() << "Start packing: " << linalgOp;
+ LDBG() << "maps: " << llvm::interleaved(indexingMaps);
+ LDBG() << "iterators: " << llvm::interleaved(iteratorTypes);
SmallVector<linalg::PackOp> packOps;
SmallVector<linalg::UnPackOp> unPackOps;
@@ -511,14 +497,11 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
- LLVM_DEBUG(
- DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i]
- << "\n"
- << "maps: " << llvm::interleaved(indexingMaps) << "\n"
- << "iterators: " << llvm::interleaved(iteratorTypes) << "\n"
- << "packedDimForEachOperand: "
- << llvm::interleaved(packedOperandsDims.packedDimForEachOperand)
- << "\n");
+ LDBG() << "++++ After pack size #" << i << ": " << packedSizes[i];
+ LDBG() << "maps: " << llvm::interleaved(indexingMaps);
+ LDBG() << "iterators: " << llvm::interleaved(iteratorTypes);
+ LDBG() << "packedDimForEachOperand: "
+ << llvm::interleaved(packedOperandsDims.packedDimForEachOperand);
}
// Step 2. Propagate packing to all LinalgOp operands.
@@ -534,10 +517,9 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
SmallVector<OpFoldResult> innerPackSizes =
listOfPackedOperandsDim.extractPackSizesForOperand(pos);
- LLVM_DEBUG(DBGS() << "operand: " << operand << "\n"
- << "innerPos: " << llvm::interleaved(innerPos) << "\n"
- << "innerPackSizes: "
- << llvm::interleaved(innerPackSizes) << "\n");
+ LDBG() << "operand: " << operand;
+ LDBG() << "innerPos: " << llvm::interleaved(innerPos);
+ LDBG() << "innerPackSizes: " << llvm::interleaved(innerPackSizes);
if (innerPackSizes.empty()) {
inputsAndInits.push_back(operand);
continue;
@@ -776,8 +758,8 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
int64_t numLoops = linalgOp.getNumLoops();
if (numLoops <= 2) {
- LLVM_DEBUG(DBGS() << "need 3+ loops to find a matmul to pack, got "
- << numLoops << "\nin: " << linalgOp << "\n");
+ LDBG() << "need 3+ loops to find a matmul to pack, got " << numLoops
+ << " in: " << linalgOp;
return rewriter.notifyMatchFailure(
linalgOp, "need 3+ loops to find a matmul to pack");
}
@@ -801,8 +783,7 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
FailureOr<ContractionDimensions> maybeDimensions =
inferContractionDims(linalgOp);
if (failed(maybeDimensions)) {
- LLVM_DEBUG(DBGS() << "couldn't infer matmul iterators in: " << linalgOp
- << "\n");
+ LDBG() << "couldn't infer matmul iterators in: " << linalgOp;
return rewriter.notifyMatchFailure(linalgOp,
"couldn't infer matmul iterators");
}
@@ -814,10 +795,8 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
// to plug a heuristic.
int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
kPos = maybeDimensions->k.back();
- LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
- DBGS() << "Start packing generic op greedily with (m@" << mPos
- << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp
- << "\n";);
+ LDBG() << "Start packing generic op greedily with (m@" << mPos << ", n@"
+ << nPos << ", k@" << kPos << "): " << linalgOp;
// 2.a. Rewrite as a generic.
auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
@@ -833,14 +812,14 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
// not change the indexings of any operand.
SmallVector<int64_t> permutation =
computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos);
- LLVM_DEBUG(DBGS() << "perm: " << llvm::interleaved(permutation) << "\n");
+ LDBG() << "perm: " << llvm::interleaved(permutation);
// Sign .. unsigned pollution.
SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end());
FailureOr<GenericOp> interchangeResult =
interchangeGenericOp(rewriter, genericOp, unsignedPerm);
assert(succeeded(interchangeResult) && "unexpected failure interchanging op");
genericOp = *interchangeResult;
- LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n";);
+ LDBG() << "Generalized Op to pack: " << genericOp;
// At this point, the op iterators are normalized to {leading, k, m, n}.
// The layouts induced by packing will always be:
@@ -862,12 +841,11 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
// Add leading zeros to match numLoops, we only pack the last 3 dimensions
// post interchange.
- LLVM_DEBUG(DBGS() << "paddedSizesNextMultipleOf: "
- << llvm::interleaved(paddedSizesNextMultipleOf) << "\n"
- << "loopRanges: "
- << llvm::interleaved(llvm::map_range(
- loopRanges, [](Range r) { return r.size; }))
- << "\n");
+ LDBG() << "paddedSizesNextMultipleOf: "
+ << llvm::interleaved(paddedSizesNextMultipleOf);
+ LDBG() << "loopRanges: "
+ << llvm::interleaved(
+ llvm::map_range(loopRanges, [](Range r) { return r.size; }));
SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(),
rewriter.getIndexAttr(0));
for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
@@ -883,8 +861,7 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
{loopRanges[adjustedPackedSizes.size()].size,
rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
}
- LLVM_DEBUG(DBGS() << "adjustedPackedSizes: "
- << llvm::interleaved(adjustedPackedSizes) << "\n");
+ LDBG() << "adjustedPackedSizes: " << llvm::interleaved(adjustedPackedSizes);
// TODO: If we wanted to give the genericOp a name after packing, after
// calling `pack` would be a good time. One would still need to check that
@@ -1214,9 +1191,8 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
}
srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end());
- LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"
- << "perm: " << llvm::interleaved(srcPermForTranspose)
- << "\n");
+ LDBG() << "Pack permutation: " << packOp;
+ LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose);
// 2.1 Create tensor.empty (init value for TransposeOp)
SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index a2a4335..2650488 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -59,12 +59,12 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
ArrayRef<int64_t>{1, 0});
Operation *newMatmulOp;
if (transposeLHS) {
- newMatmulOp = linalg::MatmulTransposeAOp::create(
+ newMatmulOp = MatmulTransposeAOp::create(
rewriter, loc, matmulOp.getResultTypes(),
ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
matmulOp.getOutputs());
} else {
- newMatmulOp = linalg::MatmulTransposeBOp::create(
+ newMatmulOp = MatmulTransposeBOp::create(
rewriter, loc, matmulOp.getResultTypes(),
ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
matmulOp.getOutputs());
@@ -116,12 +116,12 @@ mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
ArrayRef<int64_t>{0, 2, 1});
Operation *newMatmulOp;
if (transposeLHS) {
- newMatmulOp = linalg::BatchMatmulTransposeAOp::create(
+ newMatmulOp = BatchMatmulTransposeAOp::create(
rewriter, loc, batchMatmulOp.getResultTypes(),
ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
batchMatmulOp.getOutputs());
} else {
- newMatmulOp = linalg::BatchMatmulTransposeBOp::create(
+ newMatmulOp = BatchMatmulTransposeBOp::create(
rewriter, loc, batchMatmulOp.getResultTypes(),
ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
batchMatmulOp.getOutputs());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index cf65e67..406f05c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2563,7 +2563,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
"vectorization";
return failure();
}
- if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
+ if (isa<linalg::MatmulOp>(op)) {
LDBG()
<< "Scalable vectorization of the reduction dim in Matmul-like ops "
"is not supported";
@@ -2604,17 +2604,12 @@ vectorizeScalableVectorPrecondition(Operation *op,
return failure();
}
- // Check to not let go the matmul with extended semantic, through this
- // transform.
- if (linalgOp.hasUserDefinedMaps())
- return failure();
-
// Cond 4: Only the following ops are supported in the
// presence of scalable vectors
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
- isa<linalg::MatmulTransposeAOp>(op) ||
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
+ isa<linalg::BatchMmt4DOp>(op) ||
hasReductionIterator(linalgOp));
}
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index e1c0c24..d37a056 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -1,6 +1,6 @@
add_mlir_dialect_library(MLIRMathTransforms
AlgebraicSimplification.cpp
- ExpandPatterns.cpp
+ ExpandOps.cpp
ExtendToSupportedTypes.cpp
PolynomialApproximation.cpp
UpliftToFMA.cpp
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index 4a40a30..cd68039 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -13,14 +13,18 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHEXPANDOPSPASS
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
/// Create a float constant.
static Value createFloatConst(Location loc, Type type, APFloat value,
OpBuilder &b) {
@@ -661,66 +665,77 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op,
return success();
}
-void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
- patterns.add(convertCtlzOp);
-}
-
-void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) {
- patterns.add(convertSinhOp);
-}
-
-void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) {
- patterns.add(convertCoshOp);
-}
-
-void mlir::populateExpandTanPattern(RewritePatternSet &patterns) {
- patterns.add(convertTanOp);
-}
-
-void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
- patterns.add(convertTanhOp);
-}
-
-void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) {
- patterns.add(convertAsinhOp);
-}
-
-void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) {
- patterns.add(convertAcoshOp);
-}
-
-void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) {
- patterns.add(convertAtanhOp);
-}
-
-void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
- patterns.add(convertFmaFOp);
-}
-
-void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) {
- patterns.add(convertCeilOp);
-}
-
-void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
- patterns.add(convertExp2fOp);
-}
-
-void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
- patterns.add(convertPowfOp);
-}
-
-void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
- patterns.add(convertFPowIOp);
-}
-
-void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
- patterns.add(convertRoundOp);
+// Convert `math.clampf` into `arith.minimumf` + `arith.maximumf`
+static LogicalResult convertClampfOp(math::ClampFOp op,
+ PatternRewriter &rewriter) {
+ auto minOp = arith::MinimumFOp::create(rewriter, op.getLoc(), op.getValue(),
+ op.getMin(), op.getFastmath());
+ rewriter.replaceOpWithNewOp<arith::MaximumFOp>(op, minOp, op.getMax(),
+ op.getFastmath());
+ return success();
}
-void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) {
- patterns.add(convertRoundEvenOp);
+void mlir::math::populateExpansionPatterns(RewritePatternSet &patterns,
+ ArrayRef<StringRef> opMnemonics) {
+ auto filter = [&](StringRef name) {
+ // This should be a static assert and `consume_front` take a twine, but none
+ // is currently possible. TODO: augment `StringRef::consume_front` and make
+ // `getDialectNamespace` use `std::string_view`.
+ assert("math" == MathDialect::getDialectNamespace());
+ name.consume_front("math.");
+ return opMnemonics.empty() || (llvm::count(opMnemonics, name) > 0);
+ };
+ if (filter(CountLeadingZerosOp::getOperationName()))
+ patterns.add(convertCtlzOp);
+ if (filter(SinhOp::getOperationName()))
+ patterns.add(convertSinhOp);
+ if (filter(CoshOp::getOperationName()))
+ patterns.add(convertCoshOp);
+ if (filter(TanOp::getOperationName()))
+ patterns.add(convertTanOp);
+ if (filter(TanhOp::getOperationName()))
+ patterns.add(convertTanhOp);
+ if (filter(AsinhOp::getOperationName()))
+ patterns.add(convertAsinhOp);
+ if (filter(AcoshOp::getOperationName()))
+ patterns.add(convertAcoshOp);
+ if (filter(AtanhOp::getOperationName()))
+ patterns.add(convertAtanhOp);
+ if (filter(FmaOp::getOperationName()))
+ patterns.add(convertFmaFOp);
+ if (filter(CeilOp::getOperationName()))
+ patterns.add(convertCeilOp);
+ if (filter(Exp2Op::getOperationName()))
+ patterns.add(convertExp2fOp);
+ if (filter(PowFOp::getOperationName()))
+ patterns.add(convertPowfOp);
+ if (filter(FPowIOp::getOperationName()))
+ patterns.add(convertFPowIOp);
+ if (filter(RoundOp::getOperationName()))
+ patterns.add(convertRoundOp);
+ if (filter(RoundEvenOp::getOperationName()))
+ patterns.add(convertRoundEvenOp);
+ if (filter(RsqrtOp::getOperationName()))
+ patterns.add(convertRsqrtOp);
+ if (filter(ClampFOp::getOperationName()))
+ patterns.add(convertClampfOp);
}
-void mlir::populateExpandRsqrtPattern(RewritePatternSet &patterns) {
- patterns.add(convertRsqrtOp);
-}
+//===----------------------------------------------------------------------===//
+// MathExpandOpsPass pass
+//===----------------------------------------------------------------------===//
+namespace {
+struct MathExpandOpsPass final
+ : math::impl::MathExpandOpsPassBase<MathExpandOpsPass> {
+ using MathExpandOpsPassBase::MathExpandOpsPassBase;
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ SmallVector<StringRef> mnemonics =
+ llvm::to_vector_of<StringRef>(opMnemonics);
+ math::populateExpansionPatterns(patterns, mnemonics);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 74b968c..b59d73d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3558,6 +3558,7 @@ LogicalResult AtomicRMWOp::verify() {
case arith::AtomicRMWKind::minu:
case arith::AtomicRMWKind::muli:
case arith::AtomicRMWKind::ori:
+ case arith::AtomicRMWKind::xori:
case arith::AtomicRMWKind::andi:
if (!llvm::isa<IntegerType>(getValue().getType()))
return emitOpError() << "with kind '"
diff --git a/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp
index bbb269b..1939195 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp
@@ -21,9 +21,9 @@ namespace {
struct ReallocOpInterface
: public BufferViewFlowOpInterface::ExternalModel<ReallocOpInterface,
ReallocOp> {
- void
- populateDependencies(Operation *op,
- RegisterDependenciesFn registerDependenciesFn) const {
+ void populateDependencies(
+ Operation *op,
+ const RegisterDependenciesFn &registerDependenciesFn) const {
auto reallocOp = cast<ReallocOp>(op);
// memref.realloc may return the source operand.
registerDependenciesFn(reallocOp.getSource(), reallocOp.getResult());
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 9771bd2..d35566a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -959,7 +959,7 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
PatternRewriter &rewriter) const override {
auto viewLikeOp =
extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
- if (!viewLikeOp)
+ if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest())
return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source");
rewriter.modifyOpInPlace(extractOp, [&]() {
extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index 5d3cec4..860384f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -43,50 +43,34 @@ static bool overrideBuffer(Operation *op, Value buffer) {
/// propagate the type change and erase old subview ops.
static void replaceUsesAndPropagateType(RewriterBase &rewriter,
Operation *oldOp, Value val) {
- SmallVector<Operation *> opsToDelete;
- SmallVector<OpOperand *> operandsToReplace;
-
- // Save the operand to replace / delete later (avoid iterator invalidation).
- // TODO: can we use an early_inc iterator?
- for (OpOperand &use : oldOp->getUses()) {
- // Non-subview ops will be replaced by `val`.
- auto subviewUse = dyn_cast<memref::SubViewOp>(use.getOwner());
- if (!subviewUse) {
- operandsToReplace.push_back(&use);
+ // Iterate with early_inc to erase current user inside the loop.
+ for (OpOperand &use : llvm::make_early_inc_range(oldOp->getUses())) {
+ Operation *user = use.getOwner();
+ if (auto subviewUse = dyn_cast<memref::SubViewOp>(user)) {
+ // `subview(old_op)` is replaced by a new `subview(val)`.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(subviewUse);
+ MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
+ subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
+ subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
+ subviewUse.getStaticStrides());
+ Value newSubview = memref::SubViewOp::create(
+ rewriter, subviewUse->getLoc(), newType, val,
+ subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
+ subviewUse.getMixedStrides());
+
+ // Ouch recursion ... is this really necessary?
+ replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
+
+ // Safe to erase.
+ rewriter.eraseOp(subviewUse);
continue;
}
-
- // `subview(old_op)` is replaced by a new `subview(val)`.
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(subviewUse);
- MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
- subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
- subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
- subviewUse.getStaticStrides());
- Value newSubview = memref::SubViewOp::create(
- rewriter, subviewUse->getLoc(), newType, val,
- subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
- subviewUse.getMixedStrides());
-
- // Ouch recursion ... is this really necessary?
- replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
-
- opsToDelete.push_back(use.getOwner());
+ // Non-subview: replace with new value.
+ rewriter.startOpModification(user);
+ use.set(val);
+ rewriter.finalizeOpModification(user);
}
-
- // Perform late replacement.
- // TODO: can we use an early_inc iterator?
- for (OpOperand *operand : operandsToReplace) {
- Operation *op = operand->getOwner();
- rewriter.startOpModification(op);
- operand->set(val);
- rewriter.finalizeOpModification(op);
- }
-
- // Perform late op erasure.
- // TODO: can we use an early_inc iterator?
- for (Operation *op : opsToDelete)
- rewriter.eraseOp(op);
}
// Transformation to do multi-buffering/array expansion to remove dependencies
@@ -216,8 +200,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
offsets, sizes, strides);
LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");
- // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to
- // handle dealloc uses separately..
+ // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need
+ // to handle dealloc uses separately..
for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) {
auto deallocOp = dyn_cast<memref::DeallocOp>(use.getOwner());
if (!deallocOp)
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index 5af46a4..3de9c38 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -210,8 +210,10 @@ MemrefValue skipFullyAliasingOperations(MemrefValue source) {
MemrefValue skipViewLikeOps(MemrefValue source) {
while (auto op = source.getDefiningOp()) {
if (auto viewLike = dyn_cast<ViewLikeOpInterface>(op)) {
- source = cast<MemrefValue>(viewLike.getViewSource());
- continue;
+ if (source == viewLike.getViewDest()) {
+ source = cast<MemrefValue>(viewLike.getViewSource());
+ continue;
+ }
}
return source;
}
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 34c95e3..8474244 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -422,6 +422,12 @@ std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
<< descMemref << " != " << dstMemref;
}
+ int lastDimBytes =
+ descMemref.getShape().back() * descMemref.getElementTypeBitWidth() / 8;
+ if (lastDimBytes % 16 != 0) {
+ return op->emitError() << "the bytes in the last dimension of the tensor "
+ "map must be a multiple of 16";
+ }
return std::nullopt;
}
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 485bb73..ded4c7a 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -173,9 +173,7 @@ void OpenACCDialect::initialize() {
//===----------------------------------------------------------------------===//
static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
- if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
- return true;
- return false;
+ return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
}
static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
@@ -1390,6 +1388,36 @@ void acc::ParallelOp::addPrivatization(MLIRContext *context,
setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
+void acc::ParallelOp::addFirstPrivatization(
+ MLIRContext *context, mlir::acc::FirstprivateOp op,
+ mlir::acc::FirstprivateRecipeOp recipe) {
+ getFirstprivateOperandsMutable().append(op.getResult());
+
+ llvm::SmallVector<mlir::Attribute> recipes;
+
+ if (getFirstprivatizationRecipesAttr())
+ llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
+
+ recipes.push_back(
+ mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
+ setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
+}
+
+void acc::ParallelOp::addReduction(MLIRContext *context,
+ mlir::acc::ReductionOp op,
+ mlir::acc::ReductionRecipeOp recipe) {
+ getReductionOperandsMutable().append(op.getResult());
+
+ llvm::SmallVector<mlir::Attribute> recipes;
+
+ if (getReductionRecipesAttr())
+ llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
+
+ recipes.push_back(
+ mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
+ setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
+}
+
static ParseResult parseNumGangs(
mlir::OpAsmParser &parser,
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
@@ -2041,6 +2069,36 @@ void acc::SerialOp::addPrivatization(MLIRContext *context,
setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
+void acc::SerialOp::addFirstPrivatization(
+ MLIRContext *context, mlir::acc::FirstprivateOp op,
+ mlir::acc::FirstprivateRecipeOp recipe) {
+ getFirstprivateOperandsMutable().append(op.getResult());
+
+ llvm::SmallVector<mlir::Attribute> recipes;
+
+ if (getFirstprivatizationRecipesAttr())
+ llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
+
+ recipes.push_back(
+ mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
+ setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
+}
+
+void acc::SerialOp::addReduction(MLIRContext *context,
+ mlir::acc::ReductionOp op,
+ mlir::acc::ReductionRecipeOp recipe) {
+ getReductionOperandsMutable().append(op.getResult());
+
+ llvm::SmallVector<mlir::Attribute> recipes;
+
+ if (getReductionRecipesAttr())
+ llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
+
+ recipes.push_back(
+ mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
+ setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
+}
+
//===----------------------------------------------------------------------===//
// KernelsOp
//===----------------------------------------------------------------------===//
@@ -3059,6 +3117,20 @@ void acc::LoopOp::addPrivatization(MLIRContext *context,
setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
+void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op,
+ mlir::acc::ReductionRecipeOp recipe) {
+ getReductionOperandsMutable().append(op.getResult());
+
+ llvm::SmallVector<mlir::Attribute> recipes;
+
+ if (getReductionRecipesAttr())
+ llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
+
+ recipes.push_back(
+ mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
+ setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
+}
+
//===----------------------------------------------------------------------===//
// DataOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index c1c1767..6e43f28 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -3874,6 +3874,159 @@ LogicalResult AllocateDirOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// TargetAllocMemOp
+//===----------------------------------------------------------------------===//
+
+mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
+ return getInTypeAttr().getValue();
+}
+
+/// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype,
+/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
+/// attr-dict-without-keyword
+static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser,
+ mlir::OperationState &result) {
+ auto &builder = parser.getBuilder();
+ bool hasOperands = false;
+ std::int32_t typeparamsSize = 0;
+
+ // Parse device number as a new operand
+ mlir::OpAsmParser::UnresolvedOperand deviceOperand;
+ mlir::Type deviceType;
+ if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType))
+ return mlir::failure();
+ if (parser.resolveOperand(deviceOperand, deviceType, result.operands))
+ return mlir::failure();
+ if (parser.parseComma())
+ return mlir::failure();
+
+ mlir::Type intype;
+ if (parser.parseType(intype))
+ return mlir::failure();
+ result.addAttribute("in_type", mlir::TypeAttr::get(intype));
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands;
+ llvm::SmallVector<mlir::Type> typeVec;
+ if (!parser.parseOptionalLParen()) {
+ // parse the LEN params of the derived type. (<params> : <types>)
+ if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) ||
+ parser.parseColonTypeList(typeVec) || parser.parseRParen())
+ return mlir::failure();
+ typeparamsSize = operands.size();
+ hasOperands = true;
+ }
+ std::int32_t shapeSize = 0;
+ if (!parser.parseOptionalComma()) {
+ // parse size to scale by, vector of n dimensions of type index
+ if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None))
+ return mlir::failure();
+ shapeSize = operands.size() - typeparamsSize;
+ auto idxTy = builder.getIndexType();
+ for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
+ typeVec.push_back(idxTy);
+ hasOperands = true;
+ }
+ if (hasOperands &&
+ parser.resolveOperands(operands, typeVec, parser.getNameLoc(),
+ result.operands))
+ return mlir::failure();
+
+ mlir::Type restype = builder.getIntegerType(64);
+ if (!restype) {
+ parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype;
+ return mlir::failure();
+ }
+ llvm::SmallVector<std::int32_t> segmentSizes{1, typeparamsSize, shapeSize};
+ result.addAttribute("operandSegmentSizes",
+ builder.getDenseI32ArrayAttr(segmentSizes));
+ if (parser.parseOptionalAttrDict(result.attributes) ||
+ parser.addTypeToList(restype, result.types))
+ return mlir::failure();
+ return mlir::success();
+}
+
+mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser,
+ mlir::OperationState &result) {
+ return parseTargetAllocMemOp(parser, result);
+}
+
+void omp::TargetAllocMemOp::print(mlir::OpAsmPrinter &p) {
+ p << " ";
+ p.printOperand(getDevice());
+ p << " : ";
+ p << getDevice().getType();
+ p << ", ";
+ p << getInType();
+ if (!getTypeparams().empty()) {
+ p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')';
+ }
+ for (auto sh : getShape()) {
+ p << ", ";
+ p.printOperand(sh);
+ }
+ p.printOptionalAttrDict((*this)->getAttrs(),
+ {"in_type", "operandSegmentSizes"});
+}
+
+llvm::LogicalResult omp::TargetAllocMemOp::verify() {
+ mlir::Type outType = getType();
+ if (!mlir::dyn_cast<IntegerType>(outType))
+ return emitOpError("must be a integer type");
+ return mlir::success();
+}
+
+//===----------------------------------------------------------------------===//
+// WorkdistributeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WorkdistributeOp::verify() {
+ // Check that region exists and is not empty
+ Region &region = getRegion();
+ if (region.empty())
+ return emitOpError("region cannot be empty");
+ // Verify single entry point.
+ Block &entryBlock = region.front();
+ if (entryBlock.empty())
+ return emitOpError("region must contain a structured block");
+ // Verify single exit point.
+ bool hasTerminator = false;
+ for (Block &block : region) {
+ if (isa<TerminatorOp>(block.back())) {
+ if (hasTerminator) {
+ return emitOpError("region must have exactly one terminator");
+ }
+ hasTerminator = true;
+ }
+ }
+ if (!hasTerminator) {
+ return emitOpError("region must be terminated with omp.terminator");
+ }
+ auto walkResult = region.walk([&](Operation *op) -> WalkResult {
+ // No implicit barrier at end
+ if (isa<BarrierOp>(op)) {
+ return emitOpError(
+ "explicit barriers are not allowed in workdistribute region");
+ }
+ // Check for invalid nested constructs
+ if (isa<ParallelOp>(op)) {
+ return emitOpError(
+ "nested parallel constructs not allowed in workdistribute");
+ }
+ if (isa<TeamsOp>(op)) {
+ return emitOpError(
+ "nested teams constructs not allowed in workdistribute");
+ }
+ return WalkResult::advance();
+ });
+ if (walkResult.wasInterrupted())
+ return failure();
+
+ Operation *parentOp = (*this)->getParentOp();
+ if (!llvm::dyn_cast<TeamsOp>(parentOp))
+ return emitOpError("workdistribute must be nested under teams");
+ return success();
+}
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
index 497468b..bd1e655 100644
--- a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
@@ -1,3 +1,22 @@
+set(LLVM_OPTIONAL_SOURCES
+ MemorySpaceInterfaces.cpp
+ PtrAttrs.cpp
+ PtrTypes.cpp
+ PtrDialect.cpp
+)
+
+add_mlir_dialect_library(
+ MLIRPtrMemorySpaceInterfaces
+ MemorySpaceInterfaces.cpp
+
+ DEPENDS
+ MLIRPtrOpsEnumsGen
+ MLIRPtrMemorySpaceInterfacesIncGen
+ LINK_LIBS
+ PUBLIC
+ MLIRIR
+)
+
add_mlir_dialect_library(
MLIRPtrDialect
PtrAttrs.cpp
@@ -15,4 +34,5 @@ add_mlir_dialect_library(
MLIRDataLayoutInterfaces
MLIRMemorySlotInterfaces
MLIRViewLikeInterface
+ MLIRPtrMemorySpaceInterfaces
)
diff --git a/mlir/lib/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp b/mlir/lib/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp
new file mode 100644
index 0000000..059e67f
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp
@@ -0,0 +1,15 @@
+//===-- MemorySpaceInterfaces.cpp - ptr memory space interfaces -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the ptr dialect memory space interfaces.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
+
+#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.cpp.inc"
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp
index 772d25d..ac3bcd6 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp
@@ -22,26 +22,30 @@ constexpr const static unsigned kBitsInByte = 8;
//===----------------------------------------------------------------------===//
bool GenericSpaceAttr::isValidLoad(
- Type type, ptr::AtomicOrdering ordering, IntegerAttr alignment,
+ Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
+ const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError) const {
return true;
}
bool GenericSpaceAttr::isValidStore(
- Type type, ptr::AtomicOrdering ordering, IntegerAttr alignment,
+ Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
+ const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError) const {
return true;
}
bool GenericSpaceAttr::isValidAtomicOp(
ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering,
- IntegerAttr alignment, function_ref<InFlightDiagnostic()> emitError) const {
+ std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
+ function_ref<InFlightDiagnostic()> emitError) const {
return true;
}
bool GenericSpaceAttr::isValidAtomicXchg(
Type type, ptr::AtomicOrdering successOrdering,
- ptr::AtomicOrdering failureOrdering, IntegerAttr alignment,
+ ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
+ const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError) const {
return true;
}
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index c5ec0ca..d5976b9 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -85,6 +85,124 @@ LogicalResult FromPtrOp::verify() {
}
//===----------------------------------------------------------------------===//
+// LoadOp
+//===----------------------------------------------------------------------===//
+
+/// Verifies the attributes and the type of atomic memory access operations.
+template <typename OpTy>
+static LogicalResult
+verifyAtomicMemOp(OpTy memOp, ArrayRef<AtomicOrdering> unsupportedOrderings) {
+ if (memOp.getOrdering() != AtomicOrdering::not_atomic) {
+ if (llvm::is_contained(unsupportedOrderings, memOp.getOrdering()))
+ return memOp.emitOpError("unsupported ordering '")
+ << stringifyAtomicOrdering(memOp.getOrdering()) << "'";
+ if (!memOp.getAlignment())
+ return memOp.emitOpError("expected alignment for atomic access");
+ return success();
+ }
+ if (memOp.getSyncscope()) {
+ return memOp.emitOpError(
+ "expected syncscope to be null for non-atomic access");
+ }
+ return success();
+}
+
+/// Verifies that the alignment attribute is a power of 2 if present.
+static LogicalResult
+verifyAlignment(std::optional<int64_t> alignment,
+ function_ref<InFlightDiagnostic()> emitError) {
+ if (!alignment)
+ return success();
+ if (alignment.value() <= 0)
+ return emitError() << "alignment must be positive";
+ if (!llvm::isPowerOf2_64(alignment.value()))
+ return emitError() << "alignment must be a power of 2";
+ return success();
+}
+
+void LoadOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable());
+ // Volatile operations can have target-specific read-write effects on
+ // memory besides the one referred to by the pointer operand.
+ // Similarly, atomic operations that are monotonic or stricter cause
+ // synchronization that from a language point-of-view, are arbitrary
+ // read-writes into memory.
+ if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
+ getOrdering() != AtomicOrdering::unordered)) {
+ effects.emplace_back(MemoryEffects::Write::get());
+ effects.emplace_back(MemoryEffects::Read::get());
+ }
+}
+
+LogicalResult LoadOp::verify() {
+ auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
+ MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
+ DataLayout dataLayout = DataLayout::closest(*this);
+ if (!ms.isValidLoad(getResult().getType(), getOrdering(), getAlignment(),
+ &dataLayout, emitDiag))
+ return failure();
+ if (failed(verifyAlignment(getAlignment(), emitDiag)))
+ return failure();
+ return verifyAtomicMemOp(*this,
+ {AtomicOrdering::release, AtomicOrdering::acq_rel});
+}
+
+void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
+ Value addr, unsigned alignment, bool isVolatile,
+ bool isNonTemporal, bool isInvariant, bool isInvariantGroup,
+ AtomicOrdering ordering, StringRef syncscope) {
+ build(builder, state, type, addr,
+ alignment ? std::optional<int64_t>(alignment) : std::nullopt,
+ isVolatile, isNonTemporal, isInvariant, isInvariantGroup, ordering,
+ syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
+}
+
+//===----------------------------------------------------------------------===//
+// StoreOp
+//===----------------------------------------------------------------------===//
+
+void StoreOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ effects.emplace_back(MemoryEffects::Write::get(), &getPtrMutable());
+ // Volatile operations can have target-specific read-write effects on
+ // memory besides the one referred to by the pointer operand.
+ // Similarly, atomic operations that are monotonic or stricter cause
+ // synchronization that from a language point-of-view, are arbitrary
+ // read-writes into memory.
+ if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
+ getOrdering() != AtomicOrdering::unordered)) {
+ effects.emplace_back(MemoryEffects::Write::get());
+ effects.emplace_back(MemoryEffects::Read::get());
+ }
+}
+
+LogicalResult StoreOp::verify() {
+ auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
+ MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
+ DataLayout dataLayout = DataLayout::closest(*this);
+ if (!ms.isValidStore(getValue().getType(), getOrdering(), getAlignment(),
+ &dataLayout, emitDiag))
+ return failure();
+ if (failed(verifyAlignment(getAlignment(), emitDiag)))
+ return failure();
+ return verifyAtomicMemOp(*this,
+ {AtomicOrdering::acquire, AtomicOrdering::acq_rel});
+}
+
+void StoreOp::build(OpBuilder &builder, OperationState &state, Value value,
+ Value addr, unsigned alignment, bool isVolatile,
+ bool isNonTemporal, bool isInvariantGroup,
+ AtomicOrdering ordering, StringRef syncscope) {
+ build(builder, state, value, addr,
+ alignment ? std::optional<int64_t>(alignment) : std::nullopt,
+ isVolatile, isNonTemporal, isInvariantGroup, ordering,
+ syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
+}
+
+//===----------------------------------------------------------------------===//
// PtrAddOp
//===----------------------------------------------------------------------===//
@@ -152,10 +270,6 @@ llvm::TypeSize TypeOffsetOp::getTypeSize(std::optional<DataLayout> layout) {
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc"
-#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp.inc"
-
-#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.cpp.inc"
-
#include "mlir/Dialect/Ptr/IR/PtrOpsEnums.cpp.inc"
#define GET_TYPEDEF_CLASSES
diff --git a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
index 825d119..deb7109 100644
--- a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt
@@ -4,7 +4,7 @@ add_mlir_dialect_library(MLIRQuantTransforms
StripFuncQuantTypes.cpp
ADDITIONAL_HEADER_DIRS
- {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Quant/Transforms
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Quant/Transforms
DEPENDS
MLIRQuantTransformsIncGen
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 0262a1b..84f9777 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -157,8 +157,7 @@ void ExecuteRegionOp::print(OpAsmPrinter &p) {
p.printRegion(getRegion(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
-
- p.printOptionalAttrDict((*this)->getAttrs());
+ p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"no_inline"});
}
LogicalResult ExecuteRegionOp::verify() {
@@ -318,9 +317,12 @@ void ConditionOp::getSuccessorRegions(
void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
Value ub, Value step, ValueRange initArgs,
- BodyBuilderFn bodyBuilder) {
+ BodyBuilderFn bodyBuilder, bool unsignedCmp) {
OpBuilder::InsertionGuard guard(builder);
+ if (unsignedCmp)
+ result.addAttribute(getUnsignedCmpAttrName(result.name),
+ builder.getUnitAttr());
result.addOperands({lb, ub, step});
result.addOperands(initArgs);
for (Value v : initArgs)
@@ -450,6 +452,9 @@ static void printInitializationList(OpAsmPrinter &p,
}
void ForOp::print(OpAsmPrinter &p) {
+ if (getUnsignedCmp())
+ p << " unsigned";
+
p << " " << getInductionVar() << " = " << getLowerBound() << " to "
<< getUpperBound() << " step " << getStep();
@@ -462,7 +467,8 @@ void ForOp::print(OpAsmPrinter &p) {
p.printRegion(getRegion(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/!getInitArgs().empty());
- p.printOptionalAttrDict((*this)->getAttrs());
+ p.printOptionalAttrDict((*this)->getAttrs(),
+ /*elidedAttrs=*/getUnsignedCmpAttrName().strref());
}
ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -472,6 +478,10 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::Argument inductionVariable;
OpAsmParser::UnresolvedOperand lb, ub, step;
+ if (succeeded(parser.parseOptionalKeyword("unsigned")))
+ result.addAttribute(getUnsignedCmpAttrName(result.name),
+ builder.getUnitAttr());
+
// Parse the induction variable followed by '='.
if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
// Parse loop bounds.
@@ -562,7 +572,7 @@ ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
inits.append(newInitOperands.begin(), newInitOperands.end());
scf::ForOp newLoop = scf::ForOp::create(
rewriter, getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
- [](OpBuilder &, Location, Value, ValueRange) {});
+ [](OpBuilder &, Location, Value, ValueRange) {}, getUnsignedCmp());
newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
// Generate the new yield values and append them to the scf.yield operation.
@@ -806,7 +816,8 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
// 2. Create the new forOp shell.
scf::ForOp newForOp = scf::ForOp::create(
rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), newIterOperands);
+ forOp.getStep(), newIterOperands, /*bodyBuilder=*/nullptr,
+ forOp.getUnsignedCmp());
newForOp->setAttrs(forOp->getAttrs());
Block &newBlock = newForOp.getRegion().front();
SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
@@ -931,7 +942,8 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
scf::ForOp newForOp =
scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
- forOp.getUpperBound(), forOp.getStep(), newIterArgs);
+ forOp.getUpperBound(), forOp.getStep(), newIterArgs,
+ /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
newForOp->setAttrs(forOp->getAttrs());
Block &newBlock = newForOp.getRegion().front();
@@ -989,12 +1001,12 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
/// Util function that tries to compute a constant diff between u and l.
/// Returns std::nullopt when the difference between two AffineValueMap is
/// dynamic.
-static std::optional<int64_t> computeConstDiff(Value l, Value u) {
+static std::optional<APInt> computeConstDiff(Value l, Value u) {
IntegerAttr clb, cub;
if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
llvm::APInt lbValue = clb.getValue();
llvm::APInt ubValue = cub.getValue();
- return (ubValue - lbValue).getSExtValue();
+ return ubValue - lbValue;
}
// Else a simple pattern match for x + c or c + x
@@ -1003,7 +1015,7 @@ static std::optional<int64_t> computeConstDiff(Value l, Value u) {
u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
matchPattern(
u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
- return diff.getSExtValue();
+ return diff;
return std::nullopt;
}
@@ -1022,13 +1034,15 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
return success();
}
- std::optional<int64_t> diff =
+ std::optional<APInt> diff =
computeConstDiff(op.getLowerBound(), op.getUpperBound());
if (!diff)
return failure();
// If the loop is known to have 0 iterations, remove it.
- if (*diff <= 0) {
+ bool zeroOrLessIterations =
+ diff->isZero() || (!op.getUnsignedCmp() && diff->isNegative());
+ if (zeroOrLessIterations) {
rewriter.replaceOp(op, op.getInitArgs());
return success();
}
@@ -3384,9 +3398,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
if (functionType.getNumInputs() != operands.size()) {
return parser.emitError(typeLoc)
- << "expected as many input types as operands "
- << "(expected " << operands.size() << " got "
- << functionType.getNumInputs() << ")";
+ << "expected as many input types as operands " << "(expected "
+ << operands.size() << " got " << functionType.getNumInputs() << ")";
}
// Resolve input operands.
@@ -4222,14 +4235,15 @@ LogicalResult scf::IndexSwitchOp::verify() {
<< "see yield operation here";
}
for (auto [idx, result, operand] :
- llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
- yield.getOperandTypes())) {
- if (result == operand)
+ llvm::enumerate(getResultTypes(), yield.getOperands())) {
+ if (!operand)
+ return yield.emitOpError() << "operand " << idx << " is null\n";
+ if (result == operand.getType())
continue;
return (emitOpError("expected result #")
<< idx << " of each region to be " << result)
.attachNote(yield.getLoc())
- << name << " returns " << operand << " here";
+ << name << " returns " << operand.getType() << " here";
}
return success();
};
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index f8799c5..fb179e6 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -769,7 +769,8 @@ struct ForOpInterface
// Construct a new scf.for op with memref instead of tensor values.
auto newForOp = scf::ForOp::create(
rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), castedInitArgs);
+ forOp.getStep(), castedInitArgs, /*bodyBuilder=*/nullptr,
+ forOp.getUnsignedCmp());
newForOp->setAttrs(forOp->getAttrs());
Block *loopBody = newForOp.getBody();
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index bee7780..ae52af5 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -58,9 +58,12 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
auto *beforeBlock = rewriter.createBlock(
&whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs);
rewriter.setInsertionPointToStart(whileOp.getBeforeBody());
- auto cmpOp = arith::CmpIOp::create(
- rewriter, whileOp.getLoc(), arith::CmpIPredicate::slt,
- beforeBlock->getArgument(0), forOp.getUpperBound());
+ arith::CmpIPredicate predicate = forOp.getUnsignedCmp()
+ ? arith::CmpIPredicate::ult
+ : arith::CmpIPredicate::slt;
+ auto cmpOp = arith::CmpIOp::create(rewriter, whileOp.getLoc(), predicate,
+ beforeBlock->getArgument(0),
+ forOp.getUpperBound());
scf::ConditionOp::create(rewriter, whileOp.getLoc(), cmpOp.getResult(),
beforeBlock->getArguments());
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 1130538..7e7fba4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -791,6 +791,11 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
bool *modifiedIR) {
if (modifiedIR)
*modifiedIR = false;
+
+ // TODO: Add support for unsigned loops.
+ if (forOp.getUnsignedCmp())
+ return failure();
+
LoopPipelinerInternal pipeliner;
if (!pipeliner.initializeLoopInfo(forOp, options))
return failure();
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index 4752c08..f1203b2 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -256,6 +256,10 @@ struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
LogicalResult matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const override {
+ if (forOp.getUnsignedCmp())
+ return rewriter.notifyMatchFailure(forOp,
+ "unsigned loops are not supported");
+
// Do not peel already peeled loops.
if (forOp->hasAttr(kPeeledLoopLabel))
return failure();
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index 694cd85..4ea8321 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -269,10 +269,10 @@ namespace {
struct ParallelLoopFusion
: public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
void runOnOperation() override {
- auto &AA = getAnalysis<AliasAnalysis>();
+ auto &aa = getAnalysis<AliasAnalysis>();
auto mayAlias = [&](Value val1, Value val2) -> bool {
- return !AA.alias(val1, val2).isNo();
+ return !aa.alias(val1, val2).isNo();
};
getOperation()->walk([&](Operation *child) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 1b07b77..072bc50 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -52,8 +52,8 @@ public:
SmallVector<unsigned> offsets;
offsets.push_back(0);
// Do the type conversion and record the offsets.
- for (Type type : op.getResultTypes()) {
- if (failed(typeConverter->convertTypes(type, dstTypes)))
+ for (Value v : op.getResults()) {
+ if (failed(typeConverter->convertType(v, dstTypes)))
return rewriter.notifyMatchFailure(op, "could not convert result type");
offsets.push_back(dstTypes.size());
}
@@ -116,7 +116,8 @@ public:
llvm::getSingleElement(adaptor.getLowerBound()),
llvm::getSingleElement(adaptor.getUpperBound()),
llvm::getSingleElement(adaptor.getStep()),
- flattenValues(adaptor.getInitArgs()));
+ flattenValues(adaptor.getInitArgs()),
+ /*bodyBuilder=*/nullptr, op.getUnsignedCmp());
// Reserve whatever attributes in the original op.
newOp->setAttrs(op->getAttrs());
@@ -126,7 +127,6 @@ public:
// Inline the type converted region from the original operation.
rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
newOp.getRegion().end());
-
return newOp;
}
};
@@ -225,15 +225,14 @@ void mlir::scf::populateSCFStructuralTypeConversions(
void mlir::scf::populateSCFStructuralTypeConversionTarget(
const TypeConverter &typeConverter, ConversionTarget &target) {
- target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
- return typeConverter.isLegal(op->getResultTypes());
- });
+ target.addDynamicallyLegalOp<ForOp, IfOp>(
+ [&](Operation *op) { return typeConverter.isLegal(op->getResults()); });
target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
// We only have conversions for a subset of ops that use scf.yield
// terminators.
if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp()))
return true;
- return typeConverter.isLegal(op.getOperandTypes());
+ return typeConverter.isLegal(op.getOperands());
});
target.addDynamicallyLegalOp<WhileOp, ConditionOp>(
[&](Operation *op) { return typeConverter.isLegal(op); });
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index c0e47ee..834c021 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -797,7 +797,8 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
inits.append(newInitOperands.begin(), newInitOperands.end());
auto newLoop = scf::ForOp::create(
rewriter, loc, loopOp.getLowerBound(), loopOp.getUpperBound(),
- loopOp.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {});
+ loopOp.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {},
+ loopOp.getUnsignedCmp());
// Move the loop body to the new op.
Block *loopBody = loopOp.getBody();
@@ -935,7 +936,8 @@ static LogicalResult addInitOperandsToLoopNest(
auto newLoop = scf::ForOp::create(
rewriter, forLoop.getLoc(), forLoop.getLowerBound(),
forLoop.getUpperBound(), forLoop.getStep(), newInits,
- [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {},
+ forLoop.getUnsignedCmp());
// Merge the body of the new loop with the body of the old loops.
SmallVector<Value> sourceBlockArgs;
@@ -1914,63 +1916,6 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
return failure();
}
-/// Check that the loop is perfectly nested.
-/// The loops are expected to be ordered from outer most to inner most.
-/// For example:
-/// ```
-/// %0 = scf.for()
-/// %1 = scf.for()
-/// %2 = scf.for()
-/// %3 = ...
-/// yield %3
-/// yield %2
-/// yield %1
-/// ```
-/// Here loops should be [%0, %1].
-static bool
-isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) {
- assert(!loops.empty() && "unexpected empty loop nest");
- if (loops.size() == 1) {
- return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
- }
- for (auto [outerLoop, innerLoop] :
- llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
- auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
- auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
- if (!outerFor || !innerFor) {
- return false;
- }
- auto outerBBArgs = outerFor.getRegionIterArgs();
- auto innerIterArgs = innerFor.getInitArgs();
- if (outerBBArgs.size() != innerIterArgs.size()) {
- return false;
- }
-
- for (auto [outerBBArg, innerIterArg] :
- llvm::zip_equal(outerBBArgs, innerIterArgs)) {
- if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
- innerIterArg != outerBBArg) {
- return false;
- }
- }
-
- ValueRange outerYields =
- cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
- ValueRange innerResults = innerFor.getResults();
- if (outerYields.size() != innerResults.size()) {
- return false;
- }
- for (auto [outerYield, innerResult] :
- llvm::zip_equal(outerYields, innerResults)) {
- if (!llvm::hasSingleElement(innerResult.getUses()) ||
- outerYield != innerResult) {
- return false;
- }
- }
- }
- return true;
-}
-
/// Fetch the untiled consumer of the outermost scf.for's result which is
/// yielded by a tensor.insert_slice from the innermost scf.for. This function
/// makes the following assumptions :
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 5731795..684dff8 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1233,6 +1233,7 @@ static void getPerfectlyNestedLoopsImpl(
static Loops stripmineSink(scf::ForOp forOp, Value factor,
ArrayRef<scf::ForOp> targets) {
+ assert(!forOp.getUnsignedCmp() && "unsigned loops are not supported");
auto originalStep = forOp.getStep();
auto iv = forOp.getInductionVar();
@@ -1241,6 +1242,8 @@ static Loops stripmineSink(scf::ForOp forOp, Value factor,
Loops innerLoops;
for (auto t : targets) {
+ assert(!t.getUnsignedCmp() && "unsigned loops are not supported");
+
// Save information for splicing ops out of t when done
auto begin = t.getBody()->begin();
auto nOps = t.getBody()->getOperations().size();
@@ -1415,6 +1418,8 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
scf::ForOp source,
RewriterBase &rewriter) {
+ assert(source.getUnsignedCmp() == target.getUnsignedCmp() &&
+ "incompatible signedness");
unsigned numTargetOuts = target.getNumResults();
unsigned numSourceOuts = source.getNumResults();
@@ -1428,7 +1433,8 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
rewriter.setInsertionPointAfter(source);
scf::ForOp fusedLoop = scf::ForOp::create(
rewriter, source.getLoc(), source.getLowerBound(), source.getUpperBound(),
- source.getStep(), fusedInitArgs);
+ source.getStep(), fusedInitArgs, /*bodyBuilder=*/nullptr,
+ source.getUnsignedCmp());
// Map original induction variables and operands to those of the fused loop.
IRMapping mapping;
@@ -1506,3 +1512,41 @@ FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
rewriter.replaceOp(forallOp, normalizedForallOp);
return normalizedForallOp;
}
+
+bool mlir::isPerfectlyNestedForLoops(
+ MutableArrayRef<LoopLikeOpInterface> loops) {
+ assert(!loops.empty() && "unexpected empty loop nest");
+ if (loops.size() == 1)
+ return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
+ for (auto [outerLoop, innerLoop] :
+ llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
+ auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
+ auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
+ if (!outerFor || !innerFor)
+ return false;
+ auto outerBBArgs = outerFor.getRegionIterArgs();
+ auto innerIterArgs = innerFor.getInitArgs();
+ if (outerBBArgs.size() != innerIterArgs.size())
+ return false;
+
+ for (auto [outerBBArg, innerIterArg] :
+ llvm::zip_equal(outerBBArgs, innerIterArgs)) {
+ if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
+ innerIterArg != outerBBArg)
+ return false;
+ }
+
+ ValueRange outerYields =
+ cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
+ ValueRange innerResults = innerFor.getResults();
+ if (outerYields.size() != innerResults.size())
+ return false;
+ for (auto [outerYield, innerResult] :
+ llvm::zip_equal(outerYields, innerResults)) {
+ if (!llvm::hasSingleElement(innerResult.getUses()) ||
+ outerYield != innerResult)
+ return false;
+ }
+ }
+ return true;
+}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index ddb3426..369b953 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -1322,7 +1322,7 @@ struct spirv::detail::TensorArmTypeStorage final : TypeStorage {
}
TensorArmTypeStorage(ArrayRef<int64_t> shape, Type elementType)
- : shape(std::move(shape)), elementType(std::move(elementType)) {}
+ : shape(shape), elementType(elementType) {}
ArrayRef<int64_t> shape;
Type elementType;
diff --git a/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp
index d4e7618..7a05dfe 100644
--- a/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp
@@ -513,8 +513,9 @@ LogicalResult shard::detail::defaultAddShardingAnnotations(
}
#ifndef NDEBUG
-static bool isValueCompatibleWithFullReplicationSharding(Value value,
- Sharding sharding) {
+static bool
+isValueCompatibleWithFullReplicationSharding(Value value,
+ const Sharding &sharding) {
if (isa<RankedTensorType>(value.getType())) {
return isFullReplication(sharding);
}
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 3e3d476..5dc61a2 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -477,10 +477,10 @@ reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid,
return targetShard;
}
-TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, GridOp grid,
- Sharding sourceSharding, Sharding targetSharding,
- TypedValue<ShapedType> sourceUnshardedValue,
- TypedValue<ShapedType> sourceShard) {
+static TypedValue<ShapedType>
+reshard(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding,
+ Sharding targetSharding, TypedValue<ShapedType> sourceUnshardedValue,
+ TypedValue<ShapedType> sourceShard) {
// If source and destination sharding are the same, no need to do anything.
if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) &&
isFullReplication(targetSharding))) {
@@ -535,7 +535,7 @@ using UnshardedToShardedValueMap = DenseMap<Value, Value>;
// Get the types of block arguments for an partitioned block.
// Reads the sharding annotations of the arguments to deduce the sharded types.
// Types that are not ranked tensors are left unchanged.
-SmallVector<Type>
+static SmallVector<Type>
shardedBlockArgumentTypes(Block &block,
SymbolTableCollection &symbolTableCollection) {
SmallVector<Type> res;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
index 56b435c..9694a40 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
@@ -231,7 +231,9 @@ ParseResult DimLvlMapParser::parseLvlSpecList() {
const auto loc = parser.getCurrentLocation();
const auto res = parser.parseCommaSeparatedList(
mlir::OpAsmParser::Delimiter::Paren,
- [=]() -> ParseResult { return parseLvlSpec(requireLvlVarBinding); },
+ [this, requireLvlVarBinding]() -> ParseResult {
+ return parseLvlSpec(requireLvlVarBinding);
+ },
" in level-specifier list");
FAILURE_IF_FAILED(res)
const auto specLvlRank = lvlSpecs.size();
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index 9e2e6ab..a1711a6 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -156,13 +156,14 @@ minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) {
return pair1 <= pair2 ? sm1 : sm2;
}
-bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, StringRef name) {
+static bool isInternalConsistent(VarEnv const &env, VarInfo::ID id,
+ StringRef name) {
const auto &var = env.access(id);
return (var.getName() == name && var.getID() == id);
}
-bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, llvm::SMLoc loc,
- VarKind vk) {
+static bool isUsageConsistent(VarEnv const &env, VarInfo::ID id,
+ llvm::SMLoc loc, VarKind vk) {
const auto &var = env.access(id);
return var.getKind() == vk;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
index 3b97786..dabbea1 100644
--- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
@@ -71,7 +71,6 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm,
pm.addPass(createLowerAffinePass());
pm.addPass(
createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
- pm.addPass(createFinalizeMemRefToLLVMConversionPass());
pm.addNestedPass<func::FuncOp>(createConvertComplexToStandardPass());
pm.addNestedPass<func::FuncOp>(arith::createArithExpandOpsPass());
pm.addNestedPass<func::FuncOp>(createConvertMathToLLVMPass());
@@ -79,12 +78,6 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm,
pm.addPass(createConvertComplexToLibm());
pm.addPass(
createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
- pm.addPass(createConvertComplexToLLVMPass());
- pm.addPass(
- createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
- pm.addPass(createConvertFuncToLLVMPass());
- pm.addPass(createArithToLLVMConversionPass());
- pm.addPass(createConvertControlFlowToLLVMPass());
// Finalize GPU code generation.
if (gpuCodegen) {
@@ -99,8 +92,8 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm,
pm.addPass(createGpuModuleToBinaryPass(gpuModuleToBinaryPassOptions));
}
- // Convert poison values.
- pm.addPass(createUBToLLVMConversionPass());
+ // Convert to LLVM.
+ pm.addPass(createConvertToLLVMPass());
// Ensure all casts are realized.
pm.addPass(createReconcileUnrealizedCastsPass());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 3b4140e..ae7eef2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -1219,8 +1219,9 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
/// Implements the rewriting for operator sort and sort_coo.
template <typename OpTy>
-LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm,
- uint64_t ny, PatternRewriter &rewriter) {
+static LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys,
+ AffineMap xPerm, uint64_t ny,
+ PatternRewriter &rewriter) {
Location loc = op.getLoc();
SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 134aef3..0e88d31d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -730,9 +730,9 @@ public:
{tensor, lvlCoords, values, filled, added, count},
EmitCInterface::On);
Operation *parent = getTop(op);
+ rewriter.setInsertionPointAfter(parent);
rewriter.replaceOp(op, adaptor.getTensor());
// Deallocate the buffers on exit of the loop nest.
- rewriter.setInsertionPointAfter(parent);
memref::DeallocOp::create(rewriter, loc, values);
memref::DeallocOp::create(rewriter, loc, filled);
memref::DeallocOp::create(rewriter, loc, added);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index 4464450..febec6d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -533,8 +533,10 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
VectorType vtp = vectorType(vl, init.getType());
Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
forOp.getRegionIterArg(0), init, vtp);
- forOpNew = scf::ForOp::create(rewriter, loc, forOp.getLowerBound(),
- forOp.getUpperBound(), step, vinit);
+ forOpNew =
+ scf::ForOp::create(rewriter, loc, forOp.getLowerBound(),
+ forOp.getUpperBound(), step, vinit,
+ /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
forOpNew->setAttr(
LoopEmitter::getLoopEmitterLoopAttrName(),
forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
@@ -605,8 +607,8 @@ public:
ForOpRewriter(MLIRContext *context, unsigned vectorLength,
bool enableVLAVectorization, bool enableSIMDIndex32)
- : OpRewritePattern(context), vl{vectorLength, enableVLAVectorization,
- enableSIMDIndex32} {}
+ : OpRewritePattern(context),
+ vl{vectorLength, enableVLAVectorization, enableSIMDIndex32} {}
LogicalResult matchAndRewrite(scf::ForOp op,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 7d4b112..68584ec 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3200,20 +3200,6 @@ void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "padded");
}
-// TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
-// supports optional types.
-void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
- Type typeToInfer, Type typeToInferFrom) {}
-
-ParseResult
-parseInferType(OpAsmParser &parser,
- std::optional<OpAsmParser::UnresolvedOperand> optOperand,
- Type &typeToInfer, Type typeToInferFrom) {
- if (optOperand)
- typeToInfer = typeToInferFrom;
- return success();
-}
-
LogicalResult PadOp::verify() {
auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
@@ -4059,7 +4045,7 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
// Common Canonicalizers and Folders.
//===----------------------------------------------------------------------===//
-bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
+static bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
// 1. InsertSliceOp has its own logic about folding tensor.cast ops.
// 2. Exclude DPS ops that are also LoopLike from this interface as they
// might need special handling of attached regions.
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index 2ec23e1..dfce835 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -327,172 +327,31 @@ struct BubbleUpExpandShapeThroughExtractSlice
PatternRewriter &rewriter) const override {
auto expandShapeOp =
sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
+ if (!expandShapeOp) {
+ return rewriter.notifyMatchFailure(
+ sliceOp, "tensor.extract_slice source not produced by expand_shape");
+ }
+ SmallVector<ReassociationIndices> reassociation =
+ expandShapeOp.getReassociationIndices();
- if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
- rewriter)
- .failed())
+ SmallVector<OpFoldResult> offsets, sizes, strides;
+ if (failed(getCollapsedExtractSliceInfo(rewriter, sliceOp, reassociation,
+ offsets, sizes, strides)))
return failure();
- // The tensor.extract_slice before applying the pattern works on the result
- // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
- // referring to the state before applying the pattern are named with the
- // prefix "expanded", and ones referring to the state after applying the
- // pattern are named with the prefix "collapsed".
- SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
- SmallVector<OpFoldResult> expandedShape =
- getMixedValues(expandShapeOp.getStaticOutputShape(),
- expandShapeOp.getOutputShape(), rewriter);
-
- // Helper variables and function for accumulating the size values.
- Location loc = expandShapeOp->getLoc();
- AffineExpr d0, d1, d2;
- bindDims(rewriter.getContext(), d0, d1, d2);
- // Multiply two integers.
- auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
- auto mulMap = AffineMap::get(2, 0, {d0 * d1});
- return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
- {v1, v2});
- };
-
- // Compute new offsets, sizes, and strides for tensor.extract_slice.
- // The new tensor.extract_slice will work on a tensor that has has a rank of
- // ReassociationIndices.size(). In the loop a single offset, size, and
- // stride value is computed per reassociation group.
- SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes,
- collapsedStrides;
- for (const ReassociationIndices &indices :
- expandShapeOp.getReassociationIndices()) {
- // collapsedSize will hold the size of the single dim that represents the
- // reassociation group in the non expanded tensor.
- OpFoldResult collapsedSize = rewriter.getIndexAttr(1);
- // The reassocGroupSizes and reassocGroupOffsets are used to create an
- // affine.linearize_index op to linearize the single offset value required
- // for this reassociation group.
- SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
-
- for (long expandedDim : indices) {
- // reassocGroupSizes and reassocGroupOffsets can be obtained directly
- // from the expanded state, but the collapsed size requires calculation
- // as it did not previously exist.
- reassocGroupSizes.push_back(expandedShape[expandedDim]);
- reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
- collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
- }
-
- SmallVector<Value> offsetVals =
- llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
- return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
- });
- OpFoldResult collapsedOffset =
- affine::AffineLinearizeIndexOp::create(rewriter, loc, offsetVals,
- reassocGroupSizes,
- /*disjoint=*/true)
- .getResult();
- collapsedOffsets.push_back(collapsedOffset);
- collapsedSizes.push_back(collapsedSize);
-
- // Only unit stride is supported.
- collapsedStrides.push_back(rewriter.getIndexAttr(1));
- }
-
// The shape of the result can be obtained from the sizes passed in.
- SmallVector<Value> dynDims;
- SmallVector<int64_t> shape;
- dispatchIndexOpFoldResults(expandedSizes, dynDims, shape);
- RankedTensorType resultType = RankedTensorType::get(
- shape, expandShapeOp.getResultType().getElementType());
+ SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
+ RankedTensorType resultType = sliceOp.getResultType();
// Create a new ExtractSliceOp and ExpandShapeOp.
+ Location loc = sliceOp.getLoc();
Value newSliceOp = tensor::ExtractSliceOp::create(
- rewriter, loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes,
- collapsedStrides);
+ rewriter, loc, expandShapeOp.getSrc(), offsets, sizes, strides);
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
sliceOp, resultType, newSliceOp,
expandShapeOp.getReassociationIndices(), expandedSizes);
return success();
}
-
- // Helper function to check if all the required conditions for the
- // tensor.extract_slice to be bubbled up through the tensor.expand_shape are
- // met.
- LogicalResult
- checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp,
- tensor::ExpandShapeOp expandShapeOp,
- PatternRewriter &rewriter) const {
-
- if (!expandShapeOp) {
- return rewriter.notifyMatchFailure(
- sliceOp, "tensor.extract_slice source not produced by expand_shape");
- }
-
- if (!sliceOp.hasUnitStride()) {
- return rewriter.notifyMatchFailure(
- sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
- "be supported in this transformation.");
- }
-
- SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
-
- if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
- sizes.size()) {
- return rewriter.notifyMatchFailure(sliceOp,
- "unimplemented: rank reducing slice");
- }
-
- SmallVector<OpFoldResult> outputShape =
- getMixedValues(expandShapeOp.getStaticOutputShape(),
- expandShapeOp.getOutputShape(), rewriter);
-
- std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
- isZeroOffsetAndFullSize =
- [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
- if (!isZeroInteger(offset))
- return false;
- FailureOr<bool> maybeEqual =
- ValueBoundsConstraintSet::areEqual(sliceSize, size);
- return llvm::succeeded(maybeEqual) && maybeEqual.value();
- };
-
- // Check that the slice is contiguous within each reassociation group.
- // The slice is contiguous only if after the first dimension where a non
- // unit slice is taken, the slice size on all subsequent dimensions of the
- // group is equal to the entire size of the dimension.
- // Examples of contiguous slices:
- // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
- // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
- // Examples of non contiguous slices:
- // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
- // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
- for (const ReassociationIndices &indices :
- expandShapeOp.getReassociationIndices()) {
- int64_t i = 0;
- int64_t e = indices.size();
- // Find the first expanded dim after the first dim with non-unit extracted
- // size.
- for (; i < e; ++i) {
- if (!isOneInteger(sizes[indices[i]])) {
- // +1 to skip the first non-unit size dim.
- i++;
- break;
- }
- }
-
- // Verify that all subsequent dimensions extract the full size of the
- // source tensor.
- for (; i < e; ++i) {
- int64_t expandedDim = indices[i];
- if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
- outputShape[expandedDim])) {
- return rewriter.notifyMatchFailure(
- sliceOp, "Not a contiguous slice of the expanded tensor.");
- }
- }
- }
-
- return success();
- }
};
/// Converts `tensor.extract_slice(tensor.collapse_shape)` to
@@ -582,170 +441,281 @@ struct BubbleUpCollapseShapeThroughExtractSlice
"tensor.extract_slice source not produced by tensor.collapse_shape");
}
- if (!sliceOp.hasUnitStride()) {
- return rewriter.notifyMatchFailure(
- sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
- "be supported in this transformation.");
- }
+ SmallVector<OpFoldResult> offsets, sizes, strides;
+ if (failed(getExpandedExtractSliceInfo(
+ rewriter, sliceOp, collapseShapeOp.getReassociationIndices(),
+ collapseShapeOp.getSrcType().getShape(), offsets, sizes, strides)))
+ return failure();
- // The tensor.extract_slice before applying the pattern works on the result
- // of the tensor.collapse_shape, so variables (i.e. inputs for
- // ExtractSliceOp) referring to the state before applying the pattern are
- // named with the prefix "collapsed", and ones referring to the state after
- // applying the pattern are named with the prefix "expanded".
- SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
-
- if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
- collapsedSizes.size()) {
- return rewriter.notifyMatchFailure(sliceOp,
- "unimplemented: rank reducing slice");
- }
+ Value newSliceOp = tensor::ExtractSliceOp::create(
+ rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), offsets,
+ sizes, strides);
+ rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+ sliceOp, sliceOp.getResultType(), newSliceOp,
+ collapseShapeOp.getReassociationIndices());
- ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape();
- SmallVector<ReassociationIndices, 4> reassociationIndices =
- collapseShapeOp.getReassociationIndices();
-
- // Compute new offsets, sizes, and strides for tensor.extract_slice.
- // The new tensor.extract_slice will work on a tensor that has has a rank
- // equal to the rank of the src of the collapse_shape. In each iteration of
- // the loop, the offsets and sizes will be computed per reassociation group.
- SmallVector<OpFoldResult> expandedOffsets, expandedSizes;
- SmallVector<OpFoldResult> expandedStrides(srcShape.size(),
- rewriter.getIndexAttr(1));
-
- for (auto [collapsedSize, collapsedOffset, reassocIndices] :
- llvm::zip_equal(collapsedSizes, collapsedOffsets,
- collapseShapeOp.getReassociationIndices())) {
- // CASE #1 - size and/or offset are dynamic.
- // In this case, the slice can be represented as a contiguous slice only
- // if there is a single dimension in the reassociation group that has a
- // size not equal to 1.
- if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
- int nonUnitSizeCount = 0;
- for (int64_t expandedShapeIdx : reassocIndices) {
- if (srcShape[expandedShapeIdx] != 1) {
- nonUnitSizeCount++;
- expandedSizes.push_back(collapsedSize);
- expandedOffsets.push_back(collapsedOffset);
- continue;
- }
-
- expandedSizes.push_back(rewriter.getIndexAttr(1));
- expandedOffsets.push_back(rewriter.getIndexAttr(0));
- }
+ return success();
+ }
+};
- if (nonUnitSizeCount != 1) {
- return rewriter.notifyMatchFailure(
- sliceOp,
- "unsupported: slice cannot be verified to be contiguous");
- }
- continue;
- }
+} // namespace
- // CASE #2 = size and offset are static.
- // Verify that the slice can be represented as a contiguous slice of the
- // src of the collapse_shape.
- // Checking this is done on order of most internal dimensions first,
- // so traversal is done in reverse order of the reassociation group.
- // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
- // ...,An] then we first find the size and offset for n...k+1 then for k
- // and then for k-1...0.
-
- // currentCollapsedsize and currentCollapsedOffset are initialized with
- // the original collapsed size and offset and divided by the expanded
- // shape size in each dimension as we go along the reassociation group.
- // In essence we are spreading the original collapsed size and offset over
- // the various expanded slice dimensions.
- // The variables are used both to check the validity of the slice and to
- // compute the expanded sizes and offsets.
- int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
- int64_t currentCollapsedOffset =
- getConstantIntValue(collapsedOffset).value();
-
- SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
-
- ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
- reassocIndices.rend());
- int64_t idx = 0;
- int64_t reassocGroupSize = reassocIndices.size();
-
- // First handle the trailing dimensions where the slice size should be
- // equal to the tensor shape and the offset should be 0 (n...k+1).
- for (; idx < reassocGroupSize; ++idx) {
- int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
-
- if (currentCollapsedsize < expandedShapeSize)
- break;
-
- // We need to make sure that the slice size can be set to the shape size
- // and the offset to 0.
- if ((currentCollapsedsize % expandedShapeSize) != 0 ||
- (currentCollapsedOffset % expandedShapeSize) != 0) {
- return rewriter.notifyMatchFailure(
- sliceOp, "unsupported: cannot be extracted as a contiguous slice "
- "of the src of the collapse_shape");
- }
+LogicalResult mlir::tensor::getCollapsedExtractSliceInfo(
+ OpBuilder &b, tensor::ExtractSliceOp sliceOp,
+ ArrayRef<ReassociationIndices> reassociation,
+ SmallVectorImpl<OpFoldResult> &collapsedOffsets,
+ SmallVectorImpl<OpFoldResult> &collapsedSizes,
+ SmallVectorImpl<OpFoldResult> &collapsedStrides) {
+ if (!sliceOp.hasUnitStride()) {
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
- groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize));
- groupExpandedOffsets.push_back(rewriter.getIndexAttr(0));
+ if (static_cast<size_t>(sliceOp.getResultType().getRank()) != sizes.size()) {
+ return failure();
+ }
- currentCollapsedsize /= expandedShapeSize;
- currentCollapsedOffset /= expandedShapeSize;
+ auto isZeroOffsetAndFullSize = [&](OpFoldResult offset,
+ OpFoldResult sliceSize, int64_t inputDim) {
+ if (!isZeroInteger(offset))
+ return false;
+ ValueBoundsConstraintSet::Variable inputSize(sliceOp.getSource(), inputDim);
+ FailureOr<bool> maybeEqual =
+ ValueBoundsConstraintSet::areEqual(sliceSize, inputSize);
+ return llvm::succeeded(maybeEqual) && maybeEqual.value();
+ };
+
+ // Check that the slice is contiguous within each reassociation group.
+ // The slice is contiguous only if after the first dimension where a non
+ // unit slice is taken, the slice size on all subsequent dimensions of the
+ // group is equal to the entire size of the dimension.
+ // Examples of contiguous slices:
+ // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
+ // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
+ // Examples of non contiguous slices:
+ // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
+ // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
+ for (const ReassociationIndices &indices : reassociation) {
+ int64_t i = 0;
+ int64_t e = indices.size();
+ // Find the first expanded dim after the first dim with non-unit extracted
+ // size.
+ for (; i < e; ++i) {
+ if (!isOneInteger(sizes[indices[i]])) {
+ // +1 to skip the first non-unit size dim.
+ i++;
+ break;
}
+ }
+
+ // Verify that all subsequent dimensions extract the full size of the
+ // source tensor.
+ for (; i < e; ++i) {
+ int64_t expandedDim = indices[i];
+ if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
+ expandedDim)) {
+ return failure();
+ }
+ }
+ }
+
+ // The tensor.extract_slice before applying the pattern works on the result
+ // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
+ // referring to the state before applying the pattern are named with the
+ // prefix "expanded", and ones referring to the state after applying the
+ // pattern are named with the prefix "collapsed".
+ Location loc = sliceOp.getLoc();
+ SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> expandedShape =
+ getMixedSizes(b, loc, sliceOp.getSource());
+
+ // Helper variables and function for accumulating the size values.
+ AffineExpr d0, d1, d2;
+ bindDims(b.getContext(), d0, d1, d2);
+ // Multiply two integers.
+ auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
+ auto mulMap = AffineMap::get(2, 0, {d0 * d1});
+ return affine::makeComposedFoldedAffineApply(b, loc, mulMap, {v1, v2});
+ };
+
+ // Compute new offsets, sizes, and strides for tensor.extract_slice.
+ // The new tensor.extract_slice will work on a tensor that has has a rank of
+ // ReassociationIndices.size(). In the loop a single offset, size, and
+ // stride value is computed per reassociation group.
+ for (const ReassociationIndices &indices : reassociation) {
+ // collapsedSize will hold the size of the single dim that represents the
+ // reassociation group in the non expanded tensor.
+ OpFoldResult collapsedSize = b.getIndexAttr(1);
+ // The reassocGroupSizes and reassocGroupOffsets are used to create an
+ // affine.linearize_index op to linearize the single offset value required
+ // for this reassociation group.
+ SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
+
+ for (long expandedDim : indices) {
+ // reassocGroupSizes and reassocGroupOffsets can be obtained directly
+ // from the expanded state, but the collapsed size requires calculation
+ // as it did not previously exist.
+ reassocGroupSizes.push_back(expandedShape[expandedDim]);
+ reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
+ collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
+ }
+
+ SmallVector<Value> offsetVals =
+ llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
+ return getValueOrCreateConstantIndexOp(b, loc, ofr);
+ });
+ OpFoldResult collapsedOffset = affine::AffineLinearizeIndexOp::create(
+ b, loc, offsetVals, reassocGroupSizes,
+ /*disjoint=*/true)
+ .getResult();
+ collapsedOffsets.push_back(collapsedOffset);
+ collapsedSizes.push_back(collapsedSize);
+
+ // Only unit stride is supported.
+ collapsedStrides.push_back(b.getIndexAttr(1));
+ }
+ return success();
+}
+
+LogicalResult mlir::tensor::getExpandedExtractSliceInfo(
+ OpBuilder &b, tensor::ExtractSliceOp sliceOp,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<int64_t> expandedShape,
+ SmallVectorImpl<OpFoldResult> &expandedOffsets,
+ SmallVectorImpl<OpFoldResult> &expandedSizes,
+ SmallVectorImpl<OpFoldResult> &expandedStrides) {
+ if (!sliceOp.hasUnitStride()) {
+ return failure();
+ }
+
+ // The tensor.extract_slice before applying the pattern works on the result
+ // of the tensor.collapse_shape, so variables (i.e. inputs for
+ // ExtractSliceOp) referring to the state before applying the pattern are
+ // named with the prefix "collapsed", and ones referring to the state after
+ // applying the pattern are named with the prefix "expanded".
+ SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
+ if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
+ collapsedSizes.size()) {
+ return failure();
+ }
- // Now handle the first dim where slicing occurs on (k).
- if (idx < reassocGroupSize) {
- int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
- int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
- // We need to make sure that the slice size in this dim + offset will
- // not exceed the shape size.
- if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
- return rewriter.notifyMatchFailure(
- sliceOp, "unsupported: slice cannot be extracted as a contiguous "
- "slice of the src of the collapse_shape");
+ // Compute new offsets, sizes, and strides for tensor.extract_slice.
+ // The new tensor.extract_slice will work on a tensor that has has a rank
+ // equal to the rank of the src of the collapse_shape. In each iteration of
+ // the loop, the offsets and sizes will be computed per reassociation group.
+ expandedStrides.resize(expandedShape.size(), b.getIndexAttr(1));
+ for (auto [collapsedSize, collapsedOffset, reassocIndices] :
+ llvm::zip_equal(collapsedSizes, collapsedOffsets, reassociation)) {
+ // CASE #1 - size and/or offset are dynamic.
+ // In this case, the slice can be represented as a contiguous slice only
+ // if there is a single dimension in the reassociation group that has a
+ // size not equal to 1.
+ if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
+ int nonUnitSizeCount = 0;
+ for (int64_t expandedShapeIdx : reassocIndices) {
+ if (expandedShape[expandedShapeIdx] != 1) {
+ nonUnitSizeCount++;
+ expandedSizes.push_back(collapsedSize);
+ expandedOffsets.push_back(collapsedOffset);
+ continue;
}
- groupExpandedSizes.push_back(
- rewriter.getIndexAttr(currentCollapsedsize));
- groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
+ expandedSizes.push_back(b.getIndexAttr(1));
+ expandedOffsets.push_back(b.getIndexAttr(0));
+ }
- currentCollapsedOffset /= expandedShapeSize;
+ if (nonUnitSizeCount != 1) {
+ return failure();
}
+ continue;
+ }
- // Now handle the leading dimensions where the slice size is equal to 1
- // (k-1...0).
- // The size for these dimensions must be 1 because of how we constructed
- // the slice size of the expanded shape. We spread the original collapsed
- // size over the expanded shape sizes until we reached dimension k where
- // the remaining size was smaller than the expanded shape size, and spread
- // the remaining size on it. So, now we are left with only 1s.
- for (idx++; idx < reassocGroupSize; ++idx) {
- int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
- int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
- groupExpandedSizes.push_back(rewriter.getIndexAttr(1));
- groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
- currentCollapsedOffset /= expandedShapeSize;
+ // CASE #2 = size and offset are static.
+ // Verify that the slice can be represented as a contiguous slice of the
+ // src of the collapse_shape.
+ // Checking this is done on order of most internal dimensions first,
+ // so traversal is done in reverse order of the reassociation group.
+ // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
+ // ...,An] then we first find the size and offset for n...k+1 then for k
+ // and then for k-1...0.
+
+ // currentCollapsedsize and currentCollapsedOffset are initialized with
+ // the original collapsed size and offset and divided by the expanded
+ // shape size in each dimension as we go along the reassociation group.
+ // In essence we are spreading the original collapsed size and offset over
+ // the various expanded slice dimensions.
+ // The variables are used both to check the validity of the slice and to
+ // compute the expanded sizes and offsets.
+ int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
+ int64_t currentCollapsedOffset =
+ getConstantIntValue(collapsedOffset).value();
+ SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
+ ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
+ reassocIndices.rend());
+ int64_t idx = 0;
+ int64_t reassocGroupSize = reassocIndices.size();
+
+ // First handle the trailing dimensions where the slice size should be
+ // equal to the tensor shape and the offset should be 0 (n...k+1).
+ for (; idx < reassocGroupSize; ++idx) {
+ int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
+
+ if (currentCollapsedsize < expandedShapeSize)
+ break;
+
+ // We need to make sure that the slice size can be set to the shape size
+ // and the offset to 0.
+ if ((currentCollapsedsize % expandedShapeSize) != 0 ||
+ (currentCollapsedOffset % expandedShapeSize) != 0) {
+ return failure();
}
- expandedSizes.append(groupExpandedSizes.rbegin(),
- groupExpandedSizes.rend());
- expandedOffsets.append(groupExpandedOffsets.rbegin(),
- groupExpandedOffsets.rend());
+ groupExpandedSizes.push_back(b.getIndexAttr(expandedShapeSize));
+ groupExpandedOffsets.push_back(b.getIndexAttr(0));
+
+ currentCollapsedsize /= expandedShapeSize;
+ currentCollapsedOffset /= expandedShapeSize;
}
- Value newSliceOp = tensor::ExtractSliceOp::create(
- rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(),
- expandedOffsets, expandedSizes, expandedStrides);
- rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
- sliceOp, sliceOp.getResultType(), newSliceOp,
- collapseShapeOp.getReassociationIndices());
+ // Now handle the first dim where slicing occurs on (k).
+ if (idx < reassocGroupSize) {
+ int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
+ int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
+ // We need to make sure that the slice size in this dim + offset will
+ // not exceed the shape size.
+ if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
+ return failure();
+ }
+ groupExpandedSizes.push_back(b.getIndexAttr(currentCollapsedsize));
+ groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim));
+ currentCollapsedOffset /= expandedShapeSize;
+ }
- return success();
+ // Now handle the leading dimensions where the slice size is equal to 1
+ // (k-1...0).
+ // The size for these dimensions must be 1 because of how we constructed
+ // the slice size of the expanded shape. We spread the original collapsed
+ // size over the expanded shape sizes until we reached dimension k where
+ // the remaining size was smaller than the expanded shape size, and spread
+ // the remaining size on it. So, now we are left with only 1s.
+ for (idx++; idx < reassocGroupSize; ++idx) {
+ int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
+ int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
+ groupExpandedSizes.push_back(b.getIndexAttr(1));
+ groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim));
+ currentCollapsedOffset /= expandedShapeSize;
+ }
+ expandedSizes.append(groupExpandedSizes.rbegin(),
+ groupExpandedSizes.rend());
+ expandedOffsets.append(groupExpandedOffsets.rbegin(),
+ groupExpandedOffsets.rend());
}
-};
-
-} // namespace
+ return success();
+}
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
RewritePatternSet &patterns) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index e3cba388..8d63646 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -122,8 +122,9 @@ struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
const APFloat lowestVal =
APFloat::getLargest(padConstVal.getSemantics(), true);
return padConstVal == lowestVal;
- } else if (auto padConstIntAttr =
- mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
+ }
+ if (auto padConstIntAttr =
+ mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
const APInt padConstVal = *padConstIntAttr.begin();
const unsigned int bitWidth = padConstVal.getBitWidth();
const APInt lowestVal =
@@ -555,7 +556,8 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
// Check we have a valid NaN propagation combination.
const auto opNanMode = op.getNanMode();
const auto clampNanMode = clampOp.getNanMode();
- if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
+ if (opNanMode == NanPropagationMode::IGNORE &&
+ clampNanMode == NanPropagationMode::PROPAGATE)
return failure();
auto maxValAttr = op.getMaxValAttr();
@@ -636,10 +638,16 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
}
}
+ auto newMode = (opNanMode != clampNanMode)
+ ? tosa::NanPropagationMode::IGNORE
+ : opNanMode;
+
+ auto newModeAttr =
+ NanPropagationModeAttr::get(rewriter.getContext(), newMode);
+
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
- rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
- : opNanMode));
+ newModeAttr);
return success();
}
};
@@ -1120,13 +1128,14 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
}
if (rhsTy == resultTy) {
- if (isSplatZero(resultETy, lhsAttr))
+ if (isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape())
+ // constant values can only be resized if resulting type is static
return lhsAttr.resizeSplat(resultTy);
if (isSplatOne(resultETy, lhsAttr, shift))
return rhs;
}
if (lhsTy == resultTy) {
- if (isSplatZero(resultETy, rhsAttr))
+ if (isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape())
return rhsAttr.resizeSplat(resultTy);
if (isSplatOne(resultETy, rhsAttr, shift))
return lhs;
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 3cafb19..bd7aee5 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -270,6 +270,244 @@ void mlir::tosa::printVariableOpTypeOrInitialValue(
}
}
+namespace {
+
+// parse attributes with special handling for tosa enum attributes
+template <typename EnumType>
+ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser,
+ NamedAttrList &outAttrs) {
+ llvm::StringRef name;
+ if (parser.parseOptionalKeyword(&name) || parser.parseEqual())
+ return failure();
+
+ // special handling: rounding_mode accepts a *bare* RoundingMode enum
+ // keyword.
+ llvm::StringRef kw;
+ if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
+ if (name == "rounding_mode" &&
+ succeeded(parser.parseOptionalKeyword(&kw))) {
+ auto sym = symbolizeRoundingMode(kw);
+ if (!sym)
+ return parser.emitError(parser.getCurrentLocation())
+ << "invalid rounding_mode value: " << kw;
+ auto attr = RoundingModeAttr::get(parser.getContext(), sym.value());
+ outAttrs.push_back(NamedAttribute(name, attr));
+ return success();
+ }
+ }
+ // special handling: mode accepts a *bare* ResizeMode enum keyword.
+ if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
+ if (name == "mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
+ auto sym = symbolizeResizeMode(kw);
+ if (!sym)
+ return parser.emitError(parser.getCurrentLocation())
+ << "invalid resize mode value: " << kw;
+ auto attr = ResizeModeAttr::get(parser.getContext(), sym.value());
+ outAttrs.push_back(NamedAttribute(name, attr));
+ return success();
+ }
+ }
+ // special handling: nan_mode accepts a *bare* NanPropagationMode enum
+ // keyword.
+ if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
+ if (name == "nan_mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
+ auto sym = symbolizeNanPropagationMode(kw);
+ if (!sym)
+ return parser.emitError(parser.getCurrentLocation())
+ << "invalid nan_mode value: " << kw;
+ auto attr = NanPropagationModeAttr::get(parser.getContext(), sym.value());
+ outAttrs.push_back(NamedAttribute(name, attr));
+ return success();
+ }
+ }
+
+ // Default path: parse any normal attribute literal, including fully qualified
+ // enum keyword
+ Attribute attr;
+ return parser.parseAttribute(attr, name, outAttrs);
+}
+
+template <typename EnumType>
+ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
+ // parse operands
+ SmallVector<OpAsmParser::UnresolvedOperand, 5> operands;
+ if (parser.parseCommaSeparatedList(
+ [&]() { return parser.parseOperand(operands.emplace_back()); }))
+ return failure();
+
+ // Parse { attr-dict } with special handling for enum bare token
+ NamedAttrList attrs;
+ if (succeeded(parser.parseOptionalLBrace()) &&
+ failed(parser.parseOptionalRBrace())) {
+ do {
+ if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
+ return failure();
+ } while (succeeded(parser.parseOptionalComma()));
+ if (parser.parseRBrace())
+ return failure();
+ }
+
+ FunctionType fnTy;
+ if (parser.parseColonType(fnTy))
+ return failure();
+
+ // Resolve operands and types
+ if (failed(parser.resolveOperands(operands, fnTy.getInputs(),
+ parser.getCurrentLocation(),
+ result.operands)))
+ return failure();
+
+ result.addTypes(fnTy.getResult(0));
+ result.addAttributes(attrs);
+
+ return success();
+}
+
+void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) {
+ parser << namedAttr.getName().strref() << " = ";
+ auto attr = namedAttr.getValue();
+ if (auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
+ parser << roundingModeAttr.getValue();
+ } else if (auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
+ parser << resizeModeAttr.getValue();
+ } else if (auto nanPropagationModeAttr =
+ dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
+ parser << nanPropagationModeAttr.getValue();
+ } else {
+ parser.printAttribute(attr);
+ }
+}
+
+// print with special handling for default valued NanPropagationMode attribute
+void printWithNanPropagationHandling(OpAsmPrinter &parser, Operation *op) {
+ parser << " ";
+ parser.printOperands(op->getOperands());
+
+ NamedAttrList toPrint(op->getAttrs());
+ // remove default NanPropagate attribute
+ const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
+ for (auto attr : op->getAttrs()) {
+ if (auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
+ if (nanAttr.getValue() == kDefaultNanValue) {
+ // elide from toPrint
+ toPrint.erase(attr.getName());
+ break;
+ }
+ }
+ }
+
+ if (!toPrint.empty()) {
+ parser << " {";
+ llvm::interleaveComma(toPrint, parser, [&](const NamedAttribute namedAttr) {
+ printNamedAttr(parser, namedAttr);
+ });
+ parser << "}";
+ }
+
+ parser << " : ";
+ parser.printFunctionalType(op);
+}
+
+// print with special handling for enums: RoundingMode, ResizeMode
+void printWithEnumHandling(OpAsmPrinter &parser, Operation *op) {
+ parser << " ";
+ parser.printOperands(op->getOperands());
+
+ if (!op->getAttrs().empty()) {
+ parser << " {";
+ llvm::interleaveComma(op->getAttrs(), parser,
+ [&](const NamedAttribute namedAttr) {
+ printNamedAttr(parser, namedAttr);
+ });
+ parser << "}";
+ }
+
+ parser << " : ";
+ parser.printFunctionalType(op);
+}
+
+} // namespace
+
+ParseResult RescaleOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
+}
+
+void RescaleOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
+ParseResult ApplyScaleOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
+}
+
+void ApplyScaleOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
+ParseResult ResizeOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::ResizeMode>(parser, result);
+}
+
+void ResizeOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
+ParseResult ArgMaxOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void ArgMaxOp::print(OpAsmPrinter &parser) {
+ printWithNanPropagationHandling(parser, *this);
+}
+
+ParseResult MaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void MaxPool2dOp::print(OpAsmPrinter &parser) {
+ printWithNanPropagationHandling(parser, *this);
+}
+
+ParseResult ClampOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void ClampOp::print(OpAsmPrinter &parser) {
+ printWithNanPropagationHandling(parser, *this);
+}
+
+ParseResult MaximumOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void MaximumOp::print(OpAsmPrinter &parser) {
+ printWithNanPropagationHandling(parser, *this);
+}
+
+ParseResult MinimumOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void MinimumOp::print(OpAsmPrinter &parser) {
+ printWithNanPropagationHandling(parser, *this);
+}
+
+ParseResult ReduceMaxOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void ReduceMaxOp::print(OpAsmPrinter &parser) {
+ printWithNanPropagationHandling(parser, *this);
+}
+
+ParseResult ReduceMinOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void ReduceMinOp::print(OpAsmPrinter &parser) {
+ printWithNanPropagationHandling(parser, *this);
+}
+
//===----------------------------------------------------------------------===//
// Tosa utilities.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index c7b9534..790bbf7 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -508,14 +508,15 @@ private:
bool attributeCheckRescale(Operation *op) {
if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
- if (rescale.getRoundingMode() == "DOUBLE_ROUND" &&
+ if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
!targetEnv.allows(Extension::doubleround)) {
op->emitOpError()
<< "failed attribute check: rounding_mode = DOUBLE_ROUND "
<< "requires extension [doubleround]";
return false;
- } else if (rescale.getRoundingMode() == "INEXACT_ROUND" &&
- !targetEnv.allows(Extension::inexactround)) {
+ }
+ if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
+ !targetEnv.allows(Extension::inexactround)) {
op->emitOpError()
<< "failed attribute check: rounding_mode = INEXACT_ROUND "
<< "requires extension [inexactround]";
@@ -1122,7 +1123,7 @@ bool checkErrorIfRescale(Operation *op) {
}
// ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
- if (!scale32 && roundingMode == "DOUBLE_ROUND") {
+ if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND) {
op->emitOpError() << "DOUBLE_ROUND is only allowed with scale32=true.";
return false;
}
@@ -1307,7 +1308,8 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
if (isa<FloatType>(type)) {
return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
Float8E5M2Type>(type);
- } else if (auto intTy = dyn_cast<IntegerType>(type)) {
+ }
+ if (auto intTy = dyn_cast<IntegerType>(type)) {
if (intTy.isSignless()) {
switch (intTy.getWidth()) {
case 1:
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 9266a63..48df1a0 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -37,16 +37,13 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/InterleavedRange.h"
#include <optional>
#define DEBUG_TYPE "transform-dialect"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
-
#define DEBUG_TYPE_MATCHER "transform-matcher"
-#define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
-#define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
using namespace mlir;
@@ -182,8 +179,7 @@ transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
DiagnosedSilenceableFailure result =
state.applyTransform(cast<TransformOpInterface>(transform));
if (result.isSilenceableFailure()) {
- LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
- << "\n");
+ LDBG() << "alternative failed: " << result.getMessage();
failed = true;
break;
}
@@ -1155,12 +1151,10 @@ transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
std::optional<DiagnosedSilenceableFailure> maybeFailure;
for (Operation *root : state.getPayloadOps(getRoot())) {
WalkResult walkResult = root->walk([&](Operation *op) {
- DEBUG_MATCHER({
- DBGS_MATCHER() << "matching ";
- op->print(llvm::dbgs(),
- OpPrintingFlags().assumeVerified().skipRegions());
- llvm::dbgs() << " @" << op << "\n";
- });
+ LDBG(1, DEBUG_TYPE_MATCHER)
+ << "matching "
+ << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions())
+ << " @" << op;
// Try matching.
SmallVector<SmallVector<MappedValue>> mappings;
@@ -1172,8 +1166,8 @@ transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
if (diag.isDefiniteFailure())
return WalkResult::interrupt();
if (diag.isSilenceableFailure()) {
- DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
- << " failed: " << diag.getMessage());
+ LDBG(1, DEBUG_TYPE_MATCHER) << "matcher " << matcher.getName()
+ << " failed: " << diag.getMessage();
return WalkResult::advance();
}
@@ -1304,12 +1298,10 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
if (!getRestrictRoot() && op == root)
return WalkResult::advance();
- DEBUG_MATCHER({
- DBGS_MATCHER() << "matching ";
- op->print(llvm::dbgs(),
- OpPrintingFlags().assumeVerified().skipRegions());
- llvm::dbgs() << " @" << op << "\n";
- });
+ LDBG(1, DEBUG_TYPE_MATCHER)
+ << "matching "
+ << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions())
+ << " @" << op;
firstMatchArgument.clear();
firstMatchArgument.push_back(op);
@@ -1322,8 +1314,8 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
if (diag.isDefiniteFailure())
return WalkResult::interrupt();
if (diag.isSilenceableFailure()) {
- DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
- << " failed: " << diag.getMessage());
+ LDBG(1, DEBUG_TYPE_MATCHER) << "matcher " << matcher.getName()
+ << " failed: " << diag.getMessage();
continue;
}
@@ -2173,10 +2165,10 @@ DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
::std::optional<::mlir::Operation *> maybeCurrent,
transform::TransformResults &results, transform::TransformState &state) {
if (!maybeCurrent.has_value()) {
- DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; });
+ LDBG(1, DEBUG_TYPE_MATCHER) << "MatchOperationEmptyOp success";
return DiagnosedSilenceableFailure::success();
}
- DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; });
+ LDBG(1, DEBUG_TYPE_MATCHER) << "MatchOperationEmptyOp failure";
return emitSilenceableError() << "operation is not empty";
}
diff --git a/mlir/lib/Dialect/Transform/IR/Utils.cpp b/mlir/lib/Dialect/Transform/IR/Utils.cpp
index d666390..773eb13 100644
--- a/mlir/lib/Dialect/Transform/IR/Utils.cpp
+++ b/mlir/lib/Dialect/Transform/IR/Utils.cpp
@@ -11,6 +11,7 @@
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
using namespace mlir;
@@ -90,7 +91,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
//
// Rename private symbols in both ops in order to resolve conflicts that can
// be resolved that way.
- LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
+ LDBG() << "renaming private symbols to resolve conflicts:";
// TODO: Do we *actually* need to test in both directions?
for (auto &&[symbolTable, otherSymbolTable] : llvm::zip(
SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable},
@@ -102,7 +103,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
if (!symbolOp)
continue;
StringAttr name = symbolOp.getNameAttr();
- LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n");
+ LDBG() << " found @" << name.getValue();
// Check if there is a colliding op in the other module.
auto collidingOp =
@@ -110,7 +111,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
if (!collidingOp)
continue;
- LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue());
+ LDBG() << " collision found for @" << name.getValue();
// Collisions are fine if both opt are functions and can be merged.
if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
@@ -119,13 +120,12 @@ transform::detail::mergeSymbolsInto(Operation *target,
funcOp && collidingFuncOp) {
if (canMergeInto(funcOp, collidingFuncOp) ||
canMergeInto(collidingFuncOp, funcOp)) {
- LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
- "will be merged\n");
+ LDBG() << " but both ops are functions and will be merged";
continue;
}
// If they can't be merged, proceed like any other collision.
- LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
+ LDBG() << " and both ops are function definitions";
}
// Collision can be resolved by renaming if one of the ops is private.
@@ -133,7 +133,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
[&](SymbolOpInterface op, SymbolOpInterface otherOp,
SymbolTable &symbolTable,
SymbolTable &otherSymbolTable) -> InFlightDiagnostic {
- LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
+ LDBG() << ", renaming";
FailureOr<StringAttr> maybeNewName =
symbolTable.renameToUnique(op, {&otherSymbolTable});
if (failed(maybeNewName)) {
@@ -142,8 +142,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
<< "attempted renaming due to collision with this op";
return diag;
}
- LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue()
- << "\n");
+ LDBG() << " renamed to @" << maybeNewName->getValue();
return InFlightDiagnostic();
};
@@ -161,7 +160,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
return diag;
continue;
}
- LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
+ LDBG() << ", emitting error";
InFlightDiagnostic diag = symbolOp.emitError()
<< "doubly defined symbol @" << name.getValue();
diag.attachNote(collidingOp->getLoc()) << "previously defined here";
@@ -179,7 +178,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
// Step 2:
//
// Move all ops from `other` into target and merge public symbols.
- LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
+ LDBG() << "moving all symbols into target";
{
SmallVector<SymbolOpInterface> opsToMove;
for (Operation &op : other->getRegion(0).front()) {
@@ -193,13 +192,13 @@ transform::detail::mergeSymbolsInto(Operation *target,
targetSymbolTable.lookup(op.getNameAttr()));
// Move op even if we get a collision.
- LLVM_DEBUG(DBGS() << " moving @" << op.getName());
+ LDBG() << " moving @" << op.getName();
op->moveBefore(&target->getRegion(0).front(),
target->getRegion(0).front().end());
// If there is no collision, we are done.
if (!collidingOp) {
- LLVM_DEBUG(llvm::dbgs() << " without collision\n");
+ LDBG() << " without collision";
continue;
}
@@ -217,9 +216,9 @@ transform::detail::mergeSymbolsInto(Operation *target,
}
assert(canMergeInto(funcOp, collidingFuncOp));
- LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
- << collidingFuncOp.getLoc() << ":\n"
- << collidingFuncOp << "\n");
+ LDBG() << " with collision, trying to keep op at "
+ << collidingFuncOp.getLoc() << ":\n"
+ << collidingFuncOp;
// Update symbol table. This works with or without the previous `swap`.
targetSymbolTable.remove(funcOp);
@@ -239,6 +238,6 @@ transform::detail::mergeSymbolsInto(Operation *target,
return target->emitError()
<< "failed to verify target op after merging symbols";
- LLVM_DEBUG(DBGS() << "done merging ops\n");
+ LDBG() << "done merging ops";
return InFlightDiagnostic();
}
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index 14a4fdf..4f4620a 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -312,7 +312,7 @@ LogicalResult transform::TransformState::setParams(Value value,
}
template <typename Mapping, typename Key, typename Mapped>
-void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
+static void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
auto it = mapping.find(key);
if (it == mapping.end())
return;
@@ -771,7 +771,7 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
}
template <typename T>
-DiagnosedSilenceableFailure
+static DiagnosedSilenceableFailure
checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
transform::TransformOpInterface transform,
unsigned operandNumber) {
diff --git a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp
index 41955c8..3ced1a6 100644
--- a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp
@@ -100,12 +100,7 @@ LogicalResult PatternApplicatorExtension::findAllMatches(
PatternApplicator applicator(it->second);
// We want to discourage direct use of PatternRewriter in APIs but In this
// very specific case, an IRRewriter is not enough.
- struct TrivialPatternRewriter : public PatternRewriter {
- public:
- explicit TrivialPatternRewriter(MLIRContext *context)
- : PatternRewriter(context) {}
- };
- TrivialPatternRewriter rewriter(root->getContext());
+ PatternRewriter rewriter(root->getContext());
applicator.applyDefaultCostModel();
root->walk([&](Operation *op) {
if (succeeded(applicator.matchAndRewrite(op, rewriter)))
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index e6ef028..34385d7 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -276,7 +276,7 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
if (!ubConstant)
return std::nullopt;
std::optional<int64_t> stepConstant = getConstantIntValue(step);
- if (!stepConstant)
+ if (!stepConstant || *stepConstant == 0)
return std::nullopt;
return llvm::divideCeilSigned(*ubConstant - *lbConstant, *stepConstant);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a450056..9b2a455 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2402,6 +2402,16 @@ LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
return foldToElementsFromElements(*this, results);
}
+LogicalResult
+ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
+ ToElementsOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ auto vecType = cast<VectorType>(adaptor.getSource().getType());
+ Type elType = vecType.getElementType();
+ inferredReturnTypes.append(vecType.getNumElements(), elType);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// FromElementsOp
//===----------------------------------------------------------------------===//
@@ -2456,8 +2466,12 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
if (llvm::any_of(elements, [](Attribute attr) { return !attr; }))
return {};
+ // DenseElementsAttr only supports int/index/float/complex types.
auto destVecType = fromElementsOp.getDest().getType();
auto destEltType = destVecType.getElementType();
+ if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
+ return {};
+
// Constant attributes might have a different type than the return type.
// Convert them before creating the dense elements attribute.
auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
@@ -2768,8 +2782,8 @@ BroadcastableToResult mlir::vector::isBroadcastableTo(
Type srcType, VectorType dstVectorType,
std::pair<VectorDim, VectorDim> *mismatchingDims) {
// Broadcast scalar to vector of the same element type.
- if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
- getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
+ if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
+ srcType == getElementTypeOrSelf(dstVectorType))
return BroadcastableToResult::Success;
// From now on, only vectors broadcast.
VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
@@ -2841,9 +2855,47 @@ LogicalResult BroadcastOp::verify() {
llvm_unreachable("unexpected vector.broadcast op error");
}
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type and shape_cast only adds or removes ones in the
+// leading dimensions.
+static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
+ auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
+ if (!srcShapeCast)
+ return failure();
+
+ VectorType srcType = srcShapeCast.getSourceVectorType();
+ VectorType destType = broadcastOp.getResultVectorType();
+ // Check type compatibility.
+ if (vector::isBroadcastableTo(srcType, destType) !=
+ BroadcastableToResult::Success)
+ return failure();
+
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ ArrayRef<int64_t> shapecastShape =
+ srcShapeCast.getResultVectorType().getShape();
+ // Trailing dimensions should be the same if shape_cast only alters the
+ // leading dimensions.
+ unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
+ if (!llvm::equal(srcShape.take_back(numTrailingDims),
+ shapecastShape.take_back(numTrailingDims)))
+ return failure();
+
+ assert(all_of(srcShape.drop_back(numTrailingDims),
+ [](int64_t E) { return E == 1; }) &&
+ all_of(shapecastShape.drop_back(numTrailingDims),
+ [](int64_t E) { return E == 1; }) &&
+ "ill-formed shape_cast");
+
+ broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
+ return success();
+}
+
OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
if (getSourceType() == getResultVectorType())
return getSource();
+ if (succeeded(foldBroadcastOfShapeCast(*this)))
+ return getResult();
+
if (!adaptor.getSource())
return {};
auto vectorType = getResultVectorType();
@@ -3238,6 +3290,18 @@ LogicalResult InsertOp::verify() {
return success();
}
+// Calculate the linearized position of the continuous chunk of elements to
+// insert, based on the shape of the value to insert and the positions to insert
+// at.
+static int64_t calculateInsertPosition(VectorType destTy,
+ ArrayRef<int64_t> positions) {
+ llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
+ assert(positions.size() <= completePositions.size() &&
+ "positions size must be less than or equal to destTy rank");
+ copy(positions, completePositions.begin());
+ return linearize(completePositions, computeStrides(destTy.getShape()));
+}
+
namespace {
// If insertOp is only inserting unit dimensions it can be transformed to a
@@ -3275,6 +3339,132 @@ public:
return success();
}
};
+
+/// Pattern to optimize a chain of insertions.
+///
+/// This pattern identifies chains of vector.insert operations that:
+/// 1. Only insert values at static positions.
+/// 2. Completely initialize all elements in the resulting vector.
+/// 3. All intermediate insert operations have only one use.
+///
+/// When these conditions are met, the entire chain can be replaced with a
+/// single vector.from_elements operation.
+///
+/// To keep this pattern simple, and avoid spending too much time on matching
+/// fragmented insert chains, this pattern only considers the last insert op in
+/// the chain.
+///
+/// Example transformation:
+/// %poison = ub.poison : vector<2xi32>
+/// %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
+/// %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
+/// ->
+/// %result = vector.from_elements %c1, %c2 : vector<2xi32>
+class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(InsertOp op,
+ PatternRewriter &rewriter) const override {
+
+ VectorType destTy = op.getDestVectorType();
+ if (destTy.isScalable())
+ return failure();
+ // Ensure this is the trailing vector.insert op in a chain of inserts.
+ for (Operation *user : op.getResult().getUsers())
+ if (auto insertOp = dyn_cast<InsertOp>(user))
+ if (insertOp.getDest() == op.getResult())
+ return failure();
+
+ InsertOp currentOp = op;
+ SmallVector<InsertOp> chainInsertOps;
+ while (currentOp) {
+ // Check cond 1: Dynamic position is not supported.
+ if (currentOp.hasDynamicPosition())
+ return failure();
+
+ chainInsertOps.push_back(currentOp);
+ currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
+ // Check cond 3: Intermediate inserts have only one use to avoid an
+ // explosion of vectors.
+ if (currentOp && !currentOp->hasOneUse())
+ return failure();
+ }
+
+ int64_t vectorSize = destTy.getNumElements();
+ int64_t initializedCount = 0;
+ SmallVector<bool> initializedDestIdxs(vectorSize, false);
+ SmallVector<int64_t> pendingInsertPos;
+ SmallVector<int64_t> pendingInsertSize;
+ SmallVector<Value> pendingInsertValues;
+
+ for (auto insertOp : chainInsertOps) {
+ // This pattern can do nothing with poison index.
+ if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
+ return failure();
+
+ // Calculate the linearized position for inserting elements.
+ int64_t insertBeginPosition =
+ calculateInsertPosition(destTy, insertOp.getStaticPosition());
+
+ // The valueToStore operand may be a vector or a scalar. Need to handle
+ // both cases.
+ int64_t insertSize = 1;
+ if (auto srcVectorType =
+ llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
+ insertSize = srcVectorType.getNumElements();
+
+ assert(insertBeginPosition + insertSize <= vectorSize &&
+ "insert would overflow the vector");
+
+ for (auto index : llvm::seq<int64_t>(insertBeginPosition,
+ insertBeginPosition + insertSize)) {
+ if (initializedDestIdxs[index])
+ continue;
+ initializedDestIdxs[index] = true;
+ ++initializedCount;
+ }
+
+ // Defer the creation of ops before we can make sure the pattern can
+ // succeed.
+ pendingInsertPos.push_back(insertBeginPosition);
+ pendingInsertSize.push_back(insertSize);
+ pendingInsertValues.push_back(insertOp.getValueToStore());
+
+ if (initializedCount == vectorSize)
+ break;
+ }
+
+ // Check cond 2: all positions must be initialized.
+ if (initializedCount != vectorSize)
+ return failure();
+
+ SmallVector<Value> elements(vectorSize);
+ for (auto [insertBeginPosition, insertSize, valueToStore] :
+ llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
+ pendingInsertValues))) {
+ auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
+
+ if (!srcVectorType) {
+ elements[insertBeginPosition] = valueToStore;
+ continue;
+ }
+
+ SmallVector<Type> elementToInsertTypes(insertSize,
+ srcVectorType.getElementType());
+ // Get all elements from the vector in row-major order.
+ auto elementsToInsert = rewriter.create<vector::ToElementsOp>(
+ op.getLoc(), elementToInsertTypes, valueToStore);
+ for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
+ elements[insertBeginPosition + linearIdx] =
+ elementsToInsert.getResult(linearIdx);
+ }
+ }
+
+ rewriter.replaceOpWithNewOp<vector::FromElementsOp>(op, destTy, elements);
+ return success();
+ }
+};
+
} // namespace
static Attribute
@@ -3301,13 +3491,9 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
!insertOp->hasOneUse())
return {};
- // Calculate the linearized position of the continuous chunk of elements to
- // insert.
- llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
- copy(insertOp.getStaticPosition(), completePositions.begin());
+ // Calculate the linearized position for inserting elements.
int64_t insertBeginPosition =
- linearize(completePositions, computeStrides(destTy.getShape()));
-
+ calculateInsertPosition(destTy, insertOp.getStaticPosition());
SmallVector<Attribute> insertedValues;
Type destEltType = destTy.getElementType();
@@ -3343,7 +3529,8 @@ static Value foldInsertUseChain(InsertOp insertOp) {
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
+ results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
+ InsertChainFullyInitialized>(context);
}
OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
@@ -5599,7 +5786,7 @@ LogicalResult GatherOp::verify() {
if (resVType.getElementType() != baseType.getElementType())
return emitOpError("base and result element type should match");
- if (llvm::size(getIndices()) != baseType.getRank())
+ if (llvm::size(getOffsets()) != baseType.getRank())
return emitOpError("requires ") << baseType.getRank() << " indices";
if (resVType.getShape() != indVType.getShape())
return emitOpError("expected result dim to match indices dim");
@@ -5671,11 +5858,11 @@ public:
if (!isa<MemRefType>(op.getBase().getType()))
return rewriter.notifyMatchFailure(op, "base must be of memref type");
- if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
+ if (failed(isZeroBasedContiguousSeq(op.getIndices())))
return failure();
rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
- op.getIndices(), op.getMask(),
+ op.getOffsets(), op.getMask(),
op.getPassThru());
return success();
}
@@ -5699,7 +5886,7 @@ LogicalResult ScatterOp::verify() {
if (valueVType.getElementType() != memType.getElementType())
return emitOpError("base and valueToStore element type should match");
- if (llvm::size(getIndices()) != memType.getRank())
+ if (llvm::size(getOffsets()) != memType.getRank())
return emitOpError("requires ") << memType.getRank() << " indices";
if (valueVType.getShape() != indVType.getShape())
return emitOpError("expected valueToStore dim to match indices dim");
@@ -5734,11 +5921,11 @@ public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ScatterOp op,
PatternRewriter &rewriter) const override {
- if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
+ if (failed(isZeroBasedContiguousSeq(op.getIndices())))
return failure();
rewriter.replaceOpWithNewOp<MaskedStoreOp>(
- op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
+ op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
return success();
}
};
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 2d5cc07..fe066dc 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -139,6 +139,11 @@ void transform::ApplyLowerGatherPatternsOp::populatePatterns(
vector::populateVectorGatherLoweringPatterns(patterns);
}
+void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateVectorFromElementsLoweringPatterns(patterns);
+}
+
void transform::ApplyLowerScanPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorScanLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 6619619..546099c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -162,7 +162,7 @@ struct GatherOpInterface
return failure();
replaceOpWithNewBufferizedOp<vector::GatherOp>(
rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
- gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(),
+ gatherOp.getOffsets(), gatherOp.getIndices(), gatherOp.getMask(),
gatherOp.getPassThru());
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 9e287fc..acbf2b7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorBitCast.cpp
LowerVectorBroadcast.cpp
LowerVectorContract.cpp
+ LowerVectorFromElements.cpp
LowerVectorGather.cpp
LowerVectorInterleave.cpp
LowerVectorMask.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp
new file mode 100644
index 0000000..c22fd54
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp
@@ -0,0 +1,65 @@
+//===- LowerVectorFromElements.cpp - Lower 'vector.from_elements' op -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.from_elements' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+
+#define DEBUG_TYPE "lower-vector-from-elements"
+
+using namespace mlir;
+
+namespace {
+
+/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
+/// outermost dimension. For example:
+/// ```
+/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
+///
+/// ==>
+///
+/// %0 = ub.poison : vector<2x3xf32>
+/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
+/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
+/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
+/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
+/// ```
+///
+/// When applied exhaustively, this will produce a sequence of 1-d from_elements
+/// ops.
+struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::FromElementsOp op,
+ PatternRewriter &rewriter) const override {
+ ValueRange allElements = op.getElements();
+
+ auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
+ VectorType subTy, int64_t index) {
+ size_t subTyNumElements = subTy.getNumElements();
+ assert((index + 1) * subTyNumElements <= allElements.size() &&
+ "out of bounds");
+ ValueRange subElements =
+ allElements.slice(index * subTyNumElements, subTyNumElements);
+ return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
+ };
+
+ return unrollVectorOp(op, rewriter, unrollFromElementsFn);
+ }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorFromElementsLoweringPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<UnrollFromElements>(patterns.getContext(), benefit);
+}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index e062f55..9830189 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -54,27 +54,13 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
LogicalResult matchAndRewrite(vector::GatherOp op,
PatternRewriter &rewriter) const override {
- VectorType resultTy = op.getType();
- if (resultTy.getRank() < 2)
- return rewriter.notifyMatchFailure(op, "already 1-D");
-
- // Unrolling doesn't take vscale into account. Pattern is disabled for
- // vectors with leading scalable dim(s).
- if (resultTy.getScalableDims().front())
- return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
-
- Location loc = op.getLoc();
- Value indexVec = op.getIndexVec();
+ Value indexVec = op.getIndices();
Value maskVec = op.getMask();
Value passThruVec = op.getPassThru();
- Value result = arith::ConstantOp::create(rewriter, loc, resultTy,
- rewriter.getZeroAttr(resultTy));
-
- VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
-
- for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
- int64_t thisIdx[1] = {i};
+ auto unrollGatherFn = [&](PatternRewriter &rewriter, Location loc,
+ VectorType subTy, int64_t index) {
+ int64_t thisIdx[1] = {index};
Value indexSubVec =
vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
@@ -82,15 +68,12 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx);
Value passThruSubVec =
vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx);
- Value subGather = vector::GatherOp::create(
- rewriter, loc, subTy, op.getBase(), op.getIndices(), indexSubVec,
- maskSubVec, passThruSubVec);
- result =
- vector::InsertOp::create(rewriter, loc, subGather, result, thisIdx);
- }
+ return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(),
+ op.getOffsets(), indexSubVec, maskSubVec,
+ passThruSubVec);
+ };
- rewriter.replaceOp(op, result);
- return success();
+ return unrollVectorOp(op, rewriter, unrollGatherFn);
}
};
@@ -158,18 +141,18 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
// 2. Generate new gather indices that will model the
// strided access.
IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
- VectorType vType = op.getIndexVec().getType();
+ VectorType vType = op.getIndices().getType();
Value mulCst = arith::ConstantOp::create(
rewriter, op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
Value newIdxs =
- arith::MulIOp::create(rewriter, op.getLoc(), op.getIndexVec(), mulCst);
+ arith::MulIOp::create(rewriter, op.getLoc(), op.getIndices(), mulCst);
// 3. Create an updated gather op with the collapsed input memref and the
// updated indices.
Value newGather = vector::GatherOp::create(
rewriter, op.getLoc(), op.getResult().getType(), collapsed,
- op.getIndices(), newIdxs, op.getMask(), op.getPassThru());
+ op.getOffsets(), newIdxs, op.getMask(), op.getPassThru());
rewriter.replaceOp(op, newGather);
return success();
@@ -212,8 +195,8 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
- op.getIndexVec());
- auto baseOffsets = llvm::to_vector(op.getIndices());
+ op.getIndices());
+ auto baseOffsets = llvm::to_vector(op.getOffsets());
Value lastBaseOffset = baseOffsets.back();
Value result = op.getPassThru();
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 45ef7f0..5617b06 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -269,7 +269,7 @@ public:
// Replace the `vector.mask` operation.
rewriter.replaceOpWithNewOp<GatherOp>(
maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
- gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(),
+ gatherOp.getOffsets(), gatherOp.getIndices(), maskingOp.getMask(),
passthru);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index bb0f339..c84eb2c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -528,8 +528,7 @@ struct WarpOpTransferWrite : public WarpDistributionPattern {
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
Operation *lastNode = yield->getPrevNode();
auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
if (!writeOp)
@@ -706,6 +705,52 @@ struct WarpOpConstant : public WarpDistributionPattern {
}
};
+/// Sink out step op feeding into a warp op yield.
+/// Vector step op is treated similar to arith.constant, apart from
+/// the result that represents a sequence [0, vec_size).
+/// Due to the to vec_size == warp_size limitation,
+/// we can simply wrap the lane id into a vector (i.e., broadcast).
+/// Supporting vec_size != warp_size may involve preserving the step
+/// result and using additional arith ops (the exact details are TBD).
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xindex>) {
+/// ...
+/// %cst = vector.step : vector<32xindex>
+/// gpu.yield %cst : vector<1xindex>
+/// }
+/// ```
+/// To
+/// ```
+/// gpu.warp_execute_on_lane_0(%arg0) {
+/// ...
+/// }
+/// %lane_id_vec = vector.broadcast %arg0 : index to vector<1xindex>
+struct WarpOpStep final : public WarpDistributionPattern {
+ using Base::Base;
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *yieldOperand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
+ if (!yieldOperand)
+ return failure();
+ const unsigned operandIdx = yieldOperand->getOperandNumber();
+ auto stepOp = yieldOperand->get().getDefiningOp<vector::StepOp>();
+ VectorType resTy = stepOp.getResult().getType();
+ if (resTy.getNumElements() != static_cast<int64_t>(warpOp.getWarpSize()))
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ llvm::formatv("Expected result size ({0}) to be of warp size ({1})",
+ resTy.getNumElements(), warpOp.getWarpSize()));
+ VectorType newVecTy =
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
+ rewriter.setInsertionPointAfter(warpOp);
+ Value laneIdVec = vector::BroadcastOp::create(rewriter, warpOp.getLoc(),
+ newVecTy, warpOp.getLaneid());
+ rewriter.replaceAllUsesWith(warpOp.getResult(operandIdx), laneIdVec);
+ return success();
+ }
+};
+
/// Sink out transfer_read op feeding into a warp op yield.
/// ```
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
@@ -846,8 +891,7 @@ struct WarpOpDeadResult : public WarpDistributionPattern {
newYieldValues.reserve(warpOp->getNumResults());
DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
DenseMap<OpResult, int64_t> dedupResultPositionMap;
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
// Some values may be yielded multiple times and correspond to multiple
// results. Deduplicating occurs by taking each result with its matching
@@ -901,8 +945,7 @@ struct WarpOpForwardOperand : public WarpDistributionPattern {
using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
Value valForwarded;
unsigned resultIndex;
for (OpOperand &operand : yield->getOpOperands()) {
@@ -1708,8 +1751,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto warpOpYield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp warpOpYield = warpOp.getTerminator();
// Only pick up `ForOp` if it is the last op in the region.
Operation *lastNode = warpOpYield->getPrevNode();
auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
@@ -1826,7 +1868,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
rewriter.setInsertionPointAfter(newWarpOp);
auto newForOp = scf::ForOp::create(
rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), newForOpOperands);
+ forOp.getStep(), newForOpOperands, /*bodyBuilder=*/nullptr,
+ forOp.getUnsignedCmp());
// Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
// newly created `ForOp`. This `WarpOp` will contain all ops that were
// contained within the original `ForOp` body.
@@ -2019,7 +2062,7 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
- WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>(
+ WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
patterns.getContext(), benefit);
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
benefit);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 491b448..7dde631 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -762,6 +762,42 @@ struct LinearizeVectorStore final
}
};
+/// This pattern linearizes `vector.from_elements` operations by converting
+/// the result type to a 1-D vector while preserving all element values.
+/// The transformation creates a linearized `vector.from_elements` followed by
+/// a `vector.shape_cast` to restore the original multidimensional shape.
+///
+/// Example:
+///
+/// %0 = vector.from_elements %a, %b, %c, %d : vector<2x2xf32>
+///
+/// is converted to:
+///
+/// %0 = vector.from_elements %a, %b, %c, %d : vector<4xf32>
+/// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32>
+///
+struct LinearizeVectorFromElements final
+ : public OpConversionPattern<vector::FromElementsOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorFromElements(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+ LogicalResult
+ matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType dstTy =
+ getTypeConverter()->convertType<VectorType>(fromElementsOp.getType());
+ assert(dstTy && "vector type destination expected.");
+
+ OperandRange elements = fromElementsOp.getElements();
+ assert(elements.size() == static_cast<size_t>(dstTy.getNumElements()) &&
+ "expected same number of elements");
+ rewriter.replaceOpWithNewOp<vector::FromElementsOp>(fromElementsOp, dstTy,
+ elements);
+ return success();
+ }
+};
+
} // namespace
/// This method defines the set of operations that are linearizable, and hence
@@ -854,7 +890,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
- LinearizeVectorStore>(typeConverter, patterns.getContext());
+ LinearizeVectorStore, LinearizeVectorFromElements>(
+ typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index c707f38..369857f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -98,8 +98,9 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
// If the user has already been processed skip.
if (!processed.insert(user).second)
continue;
- if (isa<ViewLikeOpInterface>(user)) {
- users.append(user->getUsers().begin(), user->getUsers().end());
+ if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
+ Value viewDest = viewLike.getViewDest();
+ users.append(viewDest.getUsers().begin(), viewDest.getUsers().end());
continue;
}
if (isMemoryEffectFree(user))
@@ -182,8 +183,9 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
// If the user has already been processed skip.
if (!processed.insert(user).second)
continue;
- if (isa<ViewLikeOpInterface>(user)) {
- users.append(user->getUsers().begin(), user->getUsers().end());
+ if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
+ Value viewDest = viewLike.getViewDest();
+ users.append(viewDest.getUsers().begin(), viewDest.getUsers().end());
continue;
}
if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 2269a40..dbb5eb3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -600,7 +600,7 @@ struct BubbleDownVectorBitCastForExtract
// Get the first element of the mixed position as integer.
auto mixedPos = extractOp.getMixedPosition();
- if (mixedPos.size() > 0 && !isa<Attribute>(mixedPos[0]))
+ if (!mixedPos.empty() && !isa<Attribute>(mixedPos[0]))
return failure();
uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt();
@@ -2274,7 +2274,7 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
LogicalResult matchAndRewrite(MulOpType mulOp,
PatternRewriter &rewriter) const override {
- auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
+ auto resType = llvm::dyn_cast<VectorType>(mulOp.getResult().getType());
if (!resType)
return failure();
if (resType.getRank() != 2)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 501abec..e8ecb0c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -640,7 +640,7 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
// decomposed shape from each of the index, mask, and pass-through
// vectors.
Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
- loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
+ loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
Value passThruSubVec =
@@ -648,7 +648,7 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
strides);
auto slicedGather = vector::GatherOp::create(
- rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
+ rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
indexSubVec, maskSubVec, passThruSubVec);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 6e2fa35..841e138 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -392,3 +392,29 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
}
return success();
}
+
+LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
+ vector::UnrollVectorOpFn unrollFn) {
+ assert(op->getNumResults() == 1 && "expected single result");
+ assert(isa<VectorType>(op->getResult(0).getType()) && "expected vector type");
+ VectorType resultTy = cast<VectorType>(op->getResult(0).getType());
+ if (resultTy.getRank() < 2)
+ return rewriter.notifyMatchFailure(op, "already 1-D");
+
+ // Unrolling doesn't take vscale into account. Pattern is disabled for
+ // vectors with leading scalable dim(s).
+ if (resultTy.getScalableDims().front())
+ return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
+
+ Location loc = op->getLoc();
+ Value result = ub::PoisonOp::create(rewriter, loc, resultTy);
+ VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
+
+ for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
+ Value subVector = unrollFn(rewriter, loc, subTy, i);
+ result = vector::InsertOp::create(rewriter, loc, subVector, result, i);
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+}
diff --git a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
index 2ce32fe..89b62a2 100644
--- a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
+++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
@@ -22,6 +22,47 @@
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
+using namespace mlir;
+namespace {
+ParseResult parseElseRegion(OpAsmParser &opParser, Region &elseRegion) {
+ std::string keyword;
+ std::ignore = opParser.parseOptionalKeywordOrString(&keyword);
+ if (keyword == "else")
+ return opParser.parseRegion(elseRegion);
+ return ParseResult::success();
+}
+
+void printElseRegion(OpAsmPrinter &opPrinter, Operation *op,
+ Region &elseRegion) {
+ if (elseRegion.empty())
+ return;
+ opPrinter.printKeywordOrString("else ");
+ opPrinter.printRegion(elseRegion);
+}
+
+ParseResult parseWasmVisibility(OpAsmParser &opParser, StringAttr &visibility) {
+ std::string keyword;
+ auto initLocation = opParser.getCurrentLocation();
+ std::ignore = opParser.parseOptionalKeywordOrString(&keyword);
+ if (keyword == "nested" or keyword == "") {
+ visibility = StringAttr::get(opParser.getContext(), "nested");
+ return ParseResult::success();
+ }
+
+ if (keyword == "public" || keyword == "private") {
+ visibility = StringAttr::get(opParser.getContext(), keyword);
+ return ParseResult::success();
+ }
+ opParser.emitError(initLocation, "expecting symbol visibility");
+ return ParseResult::failure();
+}
+
+void printWasmVisibility(OpAsmPrinter &opPrinter, Operation *op,
+ Attribute visibility) {
+ opPrinter.printKeywordOrString(cast<StringAttr>(visibility).strref());
+}
+} // namespace
+
#define GET_OP_CLASSES
#include "mlir/Dialect/WasmSSA/IR/WasmSSAOps.cpp.inc"
@@ -29,7 +70,6 @@
#include "mlir/IR/Types.h"
#include "llvm/Support/LogicalResult.h"
-using namespace mlir;
using namespace wasmssa;
namespace {
diff --git a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
index 242a97c..7869a28 100644
--- a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
@@ -7,13 +7,18 @@ add_mlir_dialect_library(MLIRXeGPUDialect
DEPENDS
MLIRXeGPUIncGen
+ MLIRXeGPUAttrInterfaceIncGen
MLIRXeGPUAttrsIncGen
MLIRXeGPUEnumsIncGen
LINK_LIBS PUBLIC
MLIRArithDialect
+ MLIRIndexDialect
+ MLIRAffineUtils
MLIRArithUtils
MLIRDialectUtils
+ MLIRGPUDialect
+ MLIRXeVMDialect
MLIRIR
MLIRViewLikeInterface
MLIRVectorDialect
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 3c0ca114..7f3be7f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -6,12 +6,16 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
using std::optional;
@@ -33,10 +37,61 @@ void XeGPUDialect::initialize() {
>();
}
+/// Generates instructions to compute offsets for a subgroup identified by
+/// its multidimensional indices (sgId), using the specified subgroup layout
+/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
+/// dimensions (sizePerWg).
+static SmallVector<SmallVector<Value>>
+genOffsetsComputingInsts(OpBuilder &builder, Location loc,
+ SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout,
+ ArrayRef<int64_t> sizePerSg,
+ ArrayRef<int64_t> sizePerWg) {
+
+ SmallVector<SmallVector<Value>> offsets;
+
+ // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i]
+ SmallVector<Value> localOffsets = llvm::map_to_vector(
+ llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value {
+ return builder.createOrFold<index::MulOp>(
+ loc, std::get<0>(t),
+ builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
+ });
+
+ // distUnit[i] is the minimum value between sizePerWg[i] and
+ // sgLayout[i] * sizePerSg[i]
+ SmallVector<int64_t> distUnit = llvm::map_to_vector(
+ llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)),
+ [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
+
+ for (SmallVector<int64_t> unitOffs :
+ StaticTileOffsetRange(sizePerWg, distUnit)) {
+ SmallVector<Value> base =
+ llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
+ return arith::ConstantIndexOp::create(builder, loc, d);
+ });
+
+ SmallVector<Value> adds = llvm::map_to_vector(
+ llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value {
+ return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t),
+ std::get<1>(t));
+ });
+
+ SmallVector<Value> mods = llvm::map_to_vector(
+ llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value {
+ return builder.createOrFold<index::RemUOp>(
+ loc, std::get<0>(t),
+ arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
+ });
+
+ offsets.push_back(mods);
+ }
+ return offsets;
+}
+
// Checks if the given shape can be evenly distributed based on the layout
// and data factors provided by the LayoutAttr.
bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
- xegpu::LayoutAttr attr) {
+ xegpu::DistributeLayoutAttr attr) {
assert(attr && "Layout attribute is missing.");
// Checks whether the given shape can be evenly distributed using the
@@ -49,52 +104,51 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
// smaller than `layout[i] * data[i]`, allowing multiple compute units to
// share the data.
auto tryDistribute = [&](llvm::ArrayRef<int64_t> shape,
- DenseI32ArrayAttr layout, DenseI32ArrayAttr data,
+ SmallVector<int64_t> layout,
+ SmallVector<int64_t> data,
bool rr = true) -> optional<SmallVector<int64_t>> {
llvm::SmallVector<int64_t> newShape(shape);
- if (layout) {
- auto vec = llvm::to_vector_of<int64_t>(layout.asArrayRef());
- if (vec.size() != shape.size())
+ if (layout.size()) {
+ if (layout.size() != shape.size())
return std::nullopt;
- auto ratio = computeShapeRatio(shape, vec);
+ auto ratio = computeShapeRatio(shape, layout);
if (!ratio.has_value())
return std::nullopt;
newShape = ratio.value();
}
- if (data) {
- auto vec = llvm::to_vector_of<int64_t>(data.asArrayRef());
- if (vec.size() != shape.size())
+ if (data.size()) {
+ if (data.size() != shape.size())
return std::nullopt;
- auto ratio = computeShapeRatio(newShape, vec);
+ auto ratio = computeShapeRatio(newShape, data);
if (!ratio.has_value() && rr)
- ratio = computeShapeRatio(vec, newShape);
+ ratio = computeShapeRatio(data, newShape);
if (!ratio.has_value())
return std::nullopt;
// if data is not null, we always return it for next phase.
- newShape = vec;
+ newShape = data;
}
return newShape;
};
// check the sgLayout and sgData
auto maybeSgShape =
- tryDistribute(shape, attr.getSgLayout(), attr.getSgData());
+ tryDistribute(shape, attr.getSgLayoutAsInt(), attr.getSgDataAsInt());
if (!maybeSgShape)
return false;
auto sgShape = maybeSgShape.value();
// check InstData, it neither have layout nor need round-robin
auto maybeInstShape =
- tryDistribute(sgShape, nullptr, attr.getInstData(), false);
+ tryDistribute(sgShape, {}, attr.getInstDataAsInt(), false);
if (!maybeInstShape)
return false;
auto instShape = maybeInstShape.value();
// check LaneLayout and LaneData
- auto maybeLaneShape =
- tryDistribute(instShape, attr.getLaneLayout(), attr.getLaneData(), false);
+ auto maybeLaneShape = tryDistribute(instShape, attr.getLaneLayoutAsInt(),
+ attr.getLaneDataAsInt(), false);
return maybeLaneShape.has_value();
}
@@ -211,6 +265,150 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
return success();
}
+FailureOr<SmallVector<Value>>
+LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
+ Value linearId) {
+ // delinearizeSubgroupId is only available for
+ // workgroup-level layout attribute
+ if (!isForWorkgroup())
+ return failure();
+
+ // TODO: handle order attribute
+ auto hasDefaultOrder = [&]() {
+ DenseI32ArrayAttr order = getOrder();
+ return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>(
+ llvm::reverse(order.asArrayRef())));
+ };
+ if (!hasDefaultOrder())
+ return mlir::emitError(loc, "order attribute is currently not supported.");
+
+ auto dims = llvm::map_to_vector(getSgLayoutAsInt(), [&](int64_t d) -> Value {
+ return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
+ });
+
+ return affine::delinearizeIndex(builder, loc, linearId, dims);
+}
+
+/// Implements DistributeLayoutAttr::getOffsets to generate
+/// instructions for computing multi-dimensional offsets when distributed by
+/// LayoutAttr.
+FailureOr<SmallVector<SmallVector<Value>>>
+LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
+ ArrayRef<int64_t> shape) {
+ if (!isForWorkgroup())
+ return failure();
+
+ SmallVector<int64_t> sgLayout = getSgLayoutAsInt();
+ SmallVector<int64_t> sgShape = getSgDataAsInt();
+ if (sgShape.empty()) {
+ if (auto derivedShape = computeShapeRatio(shape, sgLayout))
+ sgShape = derivedShape.value();
+ else
+ return failure();
+ }
+
+ // delinearize Ids
+ auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+ if (failed(maybeIds))
+ return failure();
+ SmallVector<Value> sgIds = *maybeIds;
+
+ return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
+ shape);
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_SliceAttr
+//===----------------------------------------------------------------------===//
+LogicalResult
+SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
+ xegpu::DistributeLayoutAttr parent, DenseI64ArrayAttr dims) {
+ if (!parent || !dims)
+ return emitError() << "expected parent layout and dims attribute";
+
+ int64_t rank = parent.getRank();
+
+ // check every element in dims is unique and smaller than rank
+ llvm::SmallDenseSet<int64_t> seen;
+ for (int64_t dim : dims.asArrayRef()) {
+ if (dim < 0 || dim >= rank)
+ return emitError() << "invalid dim (" << dim << ") in slice attribute.";
+ if (!seen.insert(dim).second)
+ return emitError() << "repeated dim (" << dim << ") in slice attribute.";
+ }
+ return success();
+}
+
+SliceAttr SliceAttr::flatten() const {
+ xegpu::DistributeLayoutAttr parent = getParent();
+ SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
+
+ while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
+ parent = sliceAttr.getParent();
+ slicedDims.push_back(sliceAttr.getDims());
+ }
+
+ auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(parent);
+ SmallVector<int64_t> indices =
+ llvm::to_vector(llvm::seq<int64_t>(0, layoutAttr.getRank()));
+
+ // get remaining dims (flattend) by applying slice ops with all slicedDims
+ SmallVector<int64_t> remainingDims(indices);
+ for (auto dim : llvm::reverse(slicedDims))
+ remainingDims = XeGPUDialect::slice(llvm::ArrayRef<int64_t>(remainingDims),
+ dim.asArrayRef());
+
+ // get flattend sliced dims by applying slice ops with the remaining dims
+ SmallVector<int64_t> flattendDims = XeGPUDialect::slice(
+ llvm::ArrayRef<int64_t>(indices), llvm::ArrayRef<int64_t>(remainingDims));
+
+ return xegpu::SliceAttr::get(
+ getContext(), layoutAttr,
+ DenseI64ArrayAttr::get(getContext(), flattendDims));
+}
+
+FailureOr<SmallVector<Value>>
+SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
+ Value linearId) {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ return parent.delinearizeSubgroupId(builder, loc, linearId);
+}
+
+/// Implements DistributeLayoutAttr::getOffsets to generate
+/// instructions for computing multi-dimensional offsets when distributed by
+/// SliceAttr.
+FailureOr<SmallVector<SmallVector<Value>>>
+SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
+ ArrayRef<int64_t> shape) {
+ assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
+ if (!isForWorkgroup())
+ return failure();
+
+ SmallVector<int64_t> sgLayout = getSgLayoutAsInt();
+ SmallVector<int64_t> sgShape = getSgDataAsInt();
+ if (sgShape.empty()) {
+ if (auto derivedShape = computeShapeRatio(shape, sgLayout))
+ sgShape = derivedShape.value();
+ else
+ return failure();
+ }
+
+ // delinearize Ids
+ auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+ if (failed(maybeIds))
+ return failure();
+
+ // The effective sgIds for offsets computing correspond
+ // to the dims that are not sliced.
+ ArrayRef<int64_t> dims = flatten().getDims().asArrayRef();
+ SmallVector<Value> sgIds =
+ XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
+
+ return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
+ shape);
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_RangeAttr
//===----------------------------------------------------------------------===//
@@ -230,7 +428,7 @@ RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
// XeGPU_TensorDescType
//===----------------------------------------------------------------------===//
-mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
+mlir::Type TensorDescType::parse(AsmParser &parser) {
llvm::SmallVector<int64_t> shape;
mlir::Type elementType;
mlir::FailureOr<mlir::Attribute> encoding;
@@ -280,7 +478,7 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
layout.value_or(mlir::Attribute()));
}
-void TensorDescType::print(::mlir::AsmPrinter &printer) const {
+void TensorDescType::print(AsmPrinter &printer) const {
printer << "<";
auto shape = getShape();
@@ -325,10 +523,10 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
return Base::get(context, shape, elementType, attr, layout);
}
-LogicalResult TensorDescType::verify(
- llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
- llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
- mlir::Attribute encoding, mlir::Attribute layout) {
+LogicalResult
+TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
+ llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
+ mlir::Attribute encoding, mlir::Attribute layout) {
size_t rank = shape.size();
if (rank == 0)
@@ -394,6 +592,119 @@ LogicalResult TensorDescType::verify(
return success();
}
+//===----------------------------------------------------------------------===//
+// XeGPU_MemDescType
+//===----------------------------------------------------------------------===//
+mlir::Type MemDescType::parse(AsmParser &parser) {
+ llvm::SmallVector<int64_t> shape;
+ mlir::Type elementType;
+ mlir::FailureOr<MemLayoutAttr> layout;
+
+ // Parse literal '<'
+ if (parser.parseLess())
+ return {};
+
+ auto shapeLoc = parser.getCurrentLocation();
+ if (mlir::failed(parser.parseDimensionList(shape, false, true))) {
+ parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
+ return {};
+ }
+
+ auto elemTypeLoc = parser.getCurrentLocation();
+ if (mlir::failed(parser.parseType(elementType))) {
+ parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
+ return {};
+ }
+
+ // parse optional attributes
+ if (mlir::succeeded(parser.parseOptionalComma())) {
+ MemLayoutAttr attr;
+ ParseResult res = parser.parseAttribute(attr);
+ if (mlir::failed(res))
+ return {};
+ layout = attr;
+ }
+
+ // Parse literal '>'
+ if (parser.parseGreater())
+ return {};
+
+ MLIRContext *ctxt = parser.getContext();
+ return MemDescType::getChecked(
+ [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
+ elementType, layout.value_or(MemLayoutAttr()));
+}
+
+void MemDescType::print(AsmPrinter &printer) const {
+ printer << "<";
+
+ printer.printDimensionList(getShape());
+ printer << 'x';
+ printer << getElementType();
+
+ if (auto layout = getMemLayout())
+ printer << ", " << layout;
+
+ printer << ">";
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_MemDescType
+//===----------------------------------------------------------------------===//
+
+Attribute MemLayoutAttr::parse(AsmParser &parser, Type type) {
+
+ auto context = parser.getContext();
+ llvm::SMLoc loc = parser.getCurrentLocation();
+
+ llvm::SmallDenseSet<StringRef> seenKeys;
+ SmallVector<NamedAttribute> attributes;
+
+ auto parseElt = [&]() -> ParseResult {
+ StringRef nameId;
+ if (failed(parser.parseKeyword(&nameId)))
+ return parser.emitError(loc, "expected valid attribute name");
+
+ if (!seenKeys.insert(nameId).second)
+ return parser.emitError(loc, "duplicate key '")
+ << nameId << " in mem layout attribute";
+
+ if (failed(parser.parseEqual()))
+ return failure();
+
+ Attribute attr;
+ if (failed(parser.parseAttribute(attr)))
+ return failure();
+ attributes.emplace_back(nameId, attr);
+ return success();
+ };
+
+ // Parse literal '<'
+ if (parser.parseLess())
+ return {};
+
+ if (failed(parser.parseCommaSeparatedList(parseElt)))
+ return {};
+
+ // Parse literal '>'
+ if (parser.parseGreater())
+ return {};
+
+ return parser.getChecked<MemLayoutAttr>(
+ loc, context, DictionaryAttr::get(context, attributes));
+}
+
+void MemLayoutAttr::print(AsmPrinter &printer) const {
+ printer << "<";
+ ArrayRef<NamedAttribute> attrs = getAttrs().getValue();
+ for (size_t i = 0; i < attrs.size(); i++) {
+ printer << attrs[i].getName().str() << " = " << attrs[i].getValue();
+ if (i < attrs.size() - 1)
+ printer << ", ";
+ }
+ printer << ">";
+}
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 33450f3..aca6654 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
@@ -21,6 +23,17 @@
namespace mlir {
namespace xegpu {
+bool isSharedMemory(const MemRefType &memrefTy) {
+ Attribute attr = memrefTy.getMemorySpace();
+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
+ return intAttr.getInt() == 3;
+ if (auto memrefSpace = llvm::dyn_cast<MemorySpaceAttr>(attr))
+ return memrefSpace.getValue() == MemorySpace::SLM;
+ if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
+ return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
+ return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
+}
+
template <typename T>
static std::string makeString(T array, bool breakline = false) {
std::string buf;
@@ -45,13 +58,6 @@ static SmallVector<int64_t> getShapeOf(Type type) {
return shape;
}
-static int64_t getRankOf(Value val) {
- auto type = val.getType();
- if (auto ty = llvm::dyn_cast<ShapedType>(type))
- return ty.getRank();
- return 0;
-}
-
static bool isReadHintOrNone(const CachePolicyAttr &attr) {
if (!attr)
return true;
@@ -76,13 +82,18 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
if (!tdescTy.isScattered())
return emitError() << "Expects a scattered TensorDesc.";
- if (!valueTy)
- return emitError() << "Expecting a vector type result.";
+ auto chunkSize = tdescTy.getChunkSizeAsInt();
+ if (!valueTy) {
+ if (chunkSize > 1)
+ return emitError() << "Expecting chunk size == 1 for scalar result";
+ if (dyn_cast<VectorType>(maskTy))
+ return emitError() << "Expecting a vector type result.";
+ return success();
+ }
auto maskShape = getShapeOf(maskTy);
auto valueShape = getShapeOf(valueTy);
auto tdescShape = getShapeOf(tdescTy);
- auto chunkSize = tdescTy.getChunkSizeAsInt();
if (valueTy.getElementType() != tdescTy.getElementType())
return emitError()
@@ -111,25 +122,49 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
}
static LogicalResult
-isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy,
- int64_t chunkSize,
+isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
+ VectorType valueTy, int64_t chunkSize,
function_ref<InFlightDiagnostic()> emitError) {
- if (!valueTy)
- return emitError() << "Expecting a vector type result.";
+ auto maskVecTy = dyn_cast<VectorType>(maskTy);
+ auto offsetsVecTy = dyn_cast<VectorType>(offsetsTy);
+ if (!valueTy) {
+ if (chunkSize > 1)
+ return emitError() << "Expecting chunk size == 1 for scalar result";
+ if (maskVecTy || offsetsVecTy)
+ return emitError() << "Expecting scalar mask and offsets.";
+ else if (maskVecTy && offsetsVecTy)
+ return emitError() << "Expecting a vector type result.";
+ return success();
+ }
+ auto valueSize = valueTy.getNumElements();
+ // SIMT mode with scalar mask and offsets.
+ if (!maskVecTy && !offsetsVecTy) {
+ if (valueSize != chunkSize)
+ return emitError() << "value elements must match chunk size "
+ << chunkSize;
+ return success();
+ }
auto maskShape = getShapeOf(maskTy);
auto valueShape = getShapeOf(valueTy);
- // a valid shape for SIMT case
- if (valueTy.getRank() == 1) {
- if (valueTy.getNumElements() != chunkSize)
- return emitError() << "value elements must match chunk size " << chunkSize
- << " for SIMT code.";
- return success();
+ if (!maskVecTy)
+ return emitError() << "Expecting a vector type mask.";
+ int64_t maskSize = maskVecTy.getNumElements();
+
+ if (chunkSize > 1) {
+ if ((valueTy.getRank() == 1) && (valueSize != chunkSize))
+ return emitError() << "value elements must match chunk size "
+ << chunkSize;
+ } else {
+ if (valueSize != maskSize)
+ return emitError()
+ << "Mask should match value except the chunk size dim.";
}
-
llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
+ if (maskSize == 1)
+ return success();
if (chunkSize > 1)
expectedMaskShape.pop_back();
if (expectedMaskShape != maskShape)
@@ -156,41 +191,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
}
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
- Type tdesc, TypedValue<MemRefType> source,
+ Type tdesc, Value source,
llvm::ArrayRef<OpFoldResult> shape,
llvm::ArrayRef<OpFoldResult> strides) {
- assert(shape.size() && strides.size() && shape.size() == strides.size() &&
- "Shape and strides must be present and of equal size for ui64 "
- "initialization.");
+ Type srcTy = source.getType();
+ assert((isa<IntegerType, MemRefType>(srcTy)) &&
+ "Source has to be either int or memref.");
- llvm::SmallVector<int64_t> staticShape;
- llvm::SmallVector<int64_t> staticStrides;
llvm::SmallVector<Value> dynamicShape;
llvm::SmallVector<Value> dynamicStrides;
- dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
- dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-
- auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
- auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
-
- build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
- dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
- staticStridesAttr);
-}
-
-void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
- Type tdesc, TypedValue<IntegerType> source,
- llvm::ArrayRef<OpFoldResult> shape,
- llvm::ArrayRef<OpFoldResult> strides) {
- assert(shape.size() && strides.size() && shape.size() == strides.size() &&
- "Shape and strides must be present and of equal size for ui64 "
- "initialization.");
-
llvm::SmallVector<int64_t> staticShape;
llvm::SmallVector<int64_t> staticStrides;
- llvm::SmallVector<Value> dynamicShape;
- llvm::SmallVector<Value> dynamicStrides;
dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
@@ -198,6 +210,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
+ if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
+ auto memrefShape = memrefTy.getShape();
+ auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
+
+ // if shape and strides are from Memref, we don't need attributes for them
+ // to keep the IR print clean.
+ if (staticShape == memrefShape && staticStrides == memrefStrides) {
+ staticShapeAttr = DenseI64ArrayAttr();
+ staticStridesAttr = DenseI64ArrayAttr();
+ }
+ }
+
build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
staticStridesAttr);
@@ -265,8 +289,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
}
LogicalResult CreateNdDescOp::verify() {
- auto rank = (int64_t)getMixedOffsets().size();
- bool invalidRank = false;
+ size_t rank = getMixedSizes().size();
+ bool invalidRank = rank != getMixedStrides().size();
bool invalidElemTy = false;
// Memory space of created TensorDesc should match with the source.
@@ -280,31 +304,28 @@ LogicalResult CreateNdDescOp::verify() {
<< " Source: " << srcMemorySpace
<< ", TensorDesc: " << tdescMemorySpace;
+ if (size_t offsetRank = getMixedOffsets().size())
+ invalidRank |= (offsetRank != rank);
+
// check source type matches the rank if it is a memref.
// It also should have the same ElementType as TensorDesc.
- auto memrefTy = dyn_cast<MemRefType>(getSourceType());
- if (memrefTy) {
- invalidRank |= (memrefTy.getRank() != rank);
+ if (auto memrefTy = dyn_cast<MemRefType>(getSourceType()))
invalidElemTy |= memrefTy.getElementType() != getElementType();
- }
if (llvm::isa<IntegerType>(getSourceType())) {
// strides and shape must present for integer source.
if (getMixedStrides().empty() || getMixedSizes().empty())
- return emitOpError("Expecting strides and shape to be present for "
+ return emitOpError("expecting strides and shape to be present for "
"integer source.");
}
- // mismatches among shape, strides, and offsets are
- // already handeled by OffsetSizeAndStrideOpInterface.
- // So they are not check here.
if (invalidRank)
return emitOpError(
"Expecting the rank of shape, strides, offsets, and source (if source "
"is a memref) should match with each other.");
// check result TensorDesc rank
- if (getType().getRank() > rank)
+ if (getType().getRank() > (int64_t)rank)
return emitOpError(
"Expecting the TensorDesc rank is not greater than the "
"ranks of shape, strides, offsets or the memref source.");
@@ -360,13 +381,10 @@ ParseResult parseOptionalDynamicIndexList(
void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op,
OperandRange values,
DenseI64ArrayAttr integers) {
-
- if (!integers)
+ if (!integers || integers.empty())
return;
-
- return printDynamicIndexList(printer, op, values, integers,
- /*scalableFlags=*/{}, {},
- AsmParser::Delimiter::Square);
+ printDynamicIndexList(printer, op, values, integers,
+ /*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square);
}
//===----------------------------------------------------------------------===//
// XeGPU_PrefetchNdOp
@@ -381,6 +399,21 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
l1_hint, l2_hint, l3_hint);
}
+void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
+ Value tensorDesc, ArrayRef<OpFoldResult> offsets,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ SmallVector<Value> dynamicOffsets;
+ SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+
+ build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
+ l2_hint, l3_hint);
+}
+
LogicalResult PrefetchNdOp::verify() {
auto tdescTy = getTensorDescType();
if (tdescTy.isScattered())
@@ -423,6 +456,22 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
l3_hint);
}
+void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
+ Value tensorDesc, ArrayRef<OpFoldResult> offsets,
+ UnitAttr packed, DenseI64ArrayAttr transpose,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ SmallVector<Value> dynamicOffsets;
+ SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+
+ build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
+ packed, transpose, l1_hint, l2_hint, l3_hint);
+}
+
LogicalResult LoadNdOp::verify() {
auto tdescTy = getTensorDescType();
auto valueTy = getType();
@@ -529,6 +578,21 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
}
+void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
+ Value tensorDesc, ArrayRef<OpFoldResult> offsets,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ SmallVector<Value> dynamicOffsets;
+ SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+
+ build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
+ l1_hint, l2_hint, l3_hint);
+}
+
LogicalResult StoreNdOp::verify() {
auto dstTy = getTensorDescType(); // Tile
auto valTy = getValueType(); // Vector
@@ -635,10 +699,6 @@ void CreateDescOp::build(OpBuilder &builder, OperationState &state,
LogicalResult CreateDescOp::verify() {
auto tdescTy = getTensorDescType();
- if (getRankOf(getSource()) > 1)
- return emitOpError(
- "Expecting the source is a 1D memref or pointer (uint64_t).");
-
if (!tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
@@ -673,12 +733,14 @@ LogicalResult CreateDescOp::verify() {
LogicalResult PrefetchOp::verify() {
auto tdescTy = getTensorDescType();
- if (tdescTy && !tdescTy.isScattered())
- return emitOpError("Expects a scattered TensorDesc.\n");
+ if (!tdescTy && !getOffsets())
+ return emitOpError("Expects offsets.");
- if (!tdescTy && getRankOf(getSource()) > 1)
- return emitOpError(
- "Expecting the source is a 1D memref or pointer (uint64_t).");
+ if (tdescTy && getOffsets())
+ return emitOpError("offsets not allowed.");
+
+ if (tdescTy && !tdescTy.isScattered())
+ return emitOpError("Expects a scattered TensorDesc.");
if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -689,6 +751,13 @@ LogicalResult PrefetchOp::verify() {
if (!isReadHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
+ auto srcTy = getSourceType();
+ if (srcTy.isInteger() && !getOffsetAlignByteAttr())
+ return emitOpError("offset_align_byte is required with integer source.");
+
+ if (getOffsetAlignByteAttr() && !srcTy.isInteger())
+ return emitOpError("offset_align_byte only allowed with integer source.");
+
return success();
}
@@ -696,7 +765,8 @@ void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
- build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint);
+ build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint,
+ IntegerAttr{});
}
//===----------------------------------------------------------------------===//
@@ -707,13 +777,15 @@ LogicalResult LoadGatherOp::verify() {
auto maskTy = getMaskType();
auto valueTy = getValueType();
+ if (!tdescTy && !getOffsets())
+ return emitOpError("Expects offsets.");
+
+ if (tdescTy && getOffsets())
+ return emitOpError("offsets not allowed.");
+
if (tdescTy && !tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.");
- if (!tdescTy && getRankOf(getSource()) > 1)
- return emitOpError(
- "Expecting the source is a 1D memref or pointer (uint64_t).");
-
if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -730,10 +802,11 @@ LogicalResult LoadGatherOp::verify() {
uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
auto memTy = dyn_cast<MemRefType>(srcTy);
- if (memTy && (valueTy.getElementType() != memTy.getElementType()))
+ if (memTy && (getElementType() != memTy.getElementType()))
return emitError() << "Value should have the same element type as MemRef.";
- return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
+ auto offsetsTy = getOffsets().getType();
+ return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
[&]() { return emitOpError(); });
}
@@ -746,6 +819,22 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
l1_hint, l2_hint, l3_hint);
}
+void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
+ Type valueType, Value source,
+ ArrayRef<OpFoldResult> offsets, Value mask,
+ IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ auto loc = source.getLoc();
+ int64_t size = static_cast<int64_t>(offsets.size());
+ auto type = VectorType::get(size, builder.getIndexType());
+ auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+ auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+ build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
+ l2_hint, l3_hint);
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_StoreScatterOp
//===----------------------------------------------------------------------===//
@@ -754,12 +843,14 @@ LogicalResult StoreScatterOp::verify() {
auto maskTy = getMaskType();
auto valueTy = getValueType();
- if (tdescTy && !tdescTy.isScattered())
- return emitOpError("Expects a scattered TensorDesc.\n");
+ if (!tdescTy && !getOffsets())
+ return emitOpError("Expects offsets.");
- if (!tdescTy && getRankOf(getDest()) > 1)
- return emitOpError(
- "Expecting the dest is a 1D memref or pointer (uint64_t).");
+ if (tdescTy && getOffsets())
+ return emitOpError("offsets not allowed.");
+
+ if (tdescTy && !tdescTy.isScattered())
+ return emitOpError("Expects a scattered TensorDesc.");
if (!isWriteHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -778,10 +869,11 @@ LogicalResult StoreScatterOp::verify() {
uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
auto memTy = dyn_cast<MemRefType>(destTy);
- if (memTy && (valueTy.getElementType() != memTy.getElementType()))
+ if (memTy && (getElementType() != memTy.getElementType()))
return emitError() << "Value should have the same element type as MemRef.";
- return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
+ auto offsetsTy = getOffsets().getType();
+ return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
[&]() { return emitOpError(); });
}
@@ -794,6 +886,24 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
l2_hint, l3_hint);
}
+void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
+ Value value, Value dest,
+ ArrayRef<OpFoldResult> offsets, Value mask,
+ IntegerAttr chunk_size,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ auto loc = dest.getLoc();
+ int64_t size = static_cast<int64_t>(offsets.size());
+ auto type = VectorType::get(size, builder.getIndexType());
+ auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+ auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+ // Call the correct builder overload that does not expect result types.
+ build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
+ l3_hint);
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_UpdateOffsetOp
//===----------------------------------------------------------------------===//
@@ -888,8 +998,8 @@ LogicalResult ConvertLayoutOp::verify() {
// both input and target layouts should be WgLayout or SgLayout at the same
// time.
- if ((!srcLayout.isWgLayout() || !resLayout.isWgLayout()) &&
- (!srcLayout.isSgLayout() || !resLayout.isSgLayout()))
+ if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) &&
+ (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup()))
return emitOpError("expected input layout and target layout be WgLayout or "
"SgLayout at the same time.");
@@ -928,9 +1038,107 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<FoldConvertLayoutOp>(context);
}
+//===----------------------------------------------------------------------===//
+// XeGPU_LoadMatrixOp
+//===----------------------------------------------------------------------===//
+void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
+ TypedValue<MemDescType> memDesc,
+ llvm::ArrayRef<OpFoldResult> offsets,
+ DistributeLayoutAttr layout) {
+ llvm::SmallVector<Value> dynamicOffsets;
+ llvm::SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+ build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
+ layout);
+}
+
+LogicalResult LoadMatrixOp::verify() {
+ VectorType resTy = getRes().getType();
+ MemDescType mdescTy = getMemDesc().getType();
+
+ if (mdescTy.getRank() != 2)
+ return emitOpError("mem_desc must be 2D.");
+
+ ArrayRef<int64_t> valueShape = resTy.getShape();
+ ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+ if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape),
+ [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+ return emitOpError("result shape must not exceed mem_desc shape.");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_StoreMatrixOp
+//===----------------------------------------------------------------------===//
+void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
+ TypedValue<MemDescType> memDesc,
+ llvm::ArrayRef<OpFoldResult> offsets,
+ DistributeLayoutAttr layout) {
+ llvm::SmallVector<Value> dynamicOffsets;
+ llvm::SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+ build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
+ layout);
+}
+
+LogicalResult StoreMatrixOp::verify() {
+ VectorType dataTy = getData().getType();
+ MemDescType mdescTy = getMemDesc().getType();
+
+ if (mdescTy.getRank() != 2)
+ return emitOpError("mem_desc must be 2D.");
+
+ ArrayRef<int64_t> dataShape = dataTy.getShape();
+ ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+ if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
+ [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+ return emitOpError("data shape must not exceed mem_desc shape.");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_MemDescSubviewOp
+//===----------------------------------------------------------------------===//
+
+void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state,
+ Type resTy, Value src,
+ llvm::ArrayRef<OpFoldResult> offsets) {
+ llvm::SmallVector<Value> dynamicOffsets;
+ llvm::SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+ build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr);
+}
+
+LogicalResult MemDescSubviewOp::verify() {
+ MemDescType srcTy = getSrc().getType();
+ MemDescType resTy = getRes().getType();
+ ArrayRef<int64_t> srcShape = srcTy.getShape();
+ ArrayRef<int64_t> resShape = resTy.getShape();
+
+ if (srcTy.getRank() < resTy.getRank())
+ return emitOpError("result rank must not exceed source rank.");
+
+ if (llvm::any_of(
+ llvm::zip_equal(resShape, srcShape.take_back(resShape.size())),
+ [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+ return emitOpError("result shape must not exceed source shape.");
+
+ if (srcTy.getStrides() != resTy.getStrides())
+ return emitOpError("result must inherit the source strides.");
+
+ return success();
+}
+
} // namespace xegpu
} // namespace mlir
+namespace mlir {
+#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
+} // namespace mlir
#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
#define GET_OP_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index d82c541..9ee002e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -84,9 +84,10 @@ struct ConvertLayoutOpPattern
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
PatternRewriter &rewriter) const override {
- xegpu::LayoutAttr input_layout = op.getInputLayoutAttr();
- xegpu::LayoutAttr target_layout = op.getTargetLayoutAttr();
- if (!input_layout.getInstData() || !target_layout.getInstData())
+ xegpu::DistributeLayoutAttr input_layout = op.getInputLayoutAttr();
+ xegpu::DistributeLayoutAttr target_layout = op.getTargetLayoutAttr();
+ if (input_layout.getInstDataAsInt().empty() ||
+ target_layout.getInstDataAsInt().empty())
return rewriter.notifyMatchFailure(op, "Not a target ConvertLayoutOp.");
input_layout = input_layout.dropInstData();
@@ -140,10 +141,11 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
else
value = (Value)operandOrResult;
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operandOrResult);
- if (layout && layout.isSgLayout()) {
- if (auto inst_data = layout.getInstData())
- return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(operandOrResult);
+ if (layout && layout.isForSubgroup()) {
+ if (!layout.getInstDataAsInt().empty())
+ return layout.getInstDataAsInt();
if (auto type = dyn_cast<ShapedType>(value.getType()))
return llvm::to_vector(type.getShape());
@@ -204,13 +206,15 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
// skip the op if any of its operands or results has workgroup level layouts
bool hasWgLayoutOperands =
llvm::any_of(op->getOpOperands(), [](OpOperand &opr) {
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(opr);
- return layout && layout.isWgLayout();
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(opr);
+ return layout && layout.isForWorkgroup();
});
bool hasWgLayoutResults =
llvm::any_of(op->getOpResults(), [](OpResult result) {
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result);
- return layout && layout.isWgLayout();
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(result);
+ return layout && layout.isForWorkgroup();
});
if (hasWgLayoutOperands || hasWgLayoutResults) {
LDBG() << "skip unrolling for op with workgroup level layout: " << *op;
@@ -220,8 +224,8 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
auto isUnrollable = [](Value value, ArrayRef<int64_t> tileShape) {
Type valTy = value.getType();
if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) {
- xegpu::LayoutAttr layout = tdescTy.getLayoutAttr();
- return layout && layout.getInstData();
+ xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
+ return layout && !layout.getInstDataAsInt().empty();
}
auto shapedType = dyn_cast<ShapedType>(valTy);
return shapedType && !llvm::equal(tileShape, shapedType.getShape());
@@ -247,7 +251,8 @@ void XeGPUBlockingPass::runOnOperation() {
// Preserve the LayoutAttr for each operand to the owner's DictionaryAttr.
// This ensures that the LayoutAttr remains accessible even if the defining
// operation is replaced.
- xegpu::setLayoutAttrs(op, [](Value v) { return xegpu::getLayoutAttr(v); });
+ xegpu::setDistributeLayoutAttrs(
+ op, [](Value v) { return xegpu::getDistributeLayoutAttr(v); });
auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
xegpu::LayoutAttr layout) {
@@ -272,7 +277,7 @@ void XeGPUBlockingPass::runOnOperation() {
auto layout =
llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding());
- if (layout && layout.isWgLayout())
+ if (layout && layout.isForWorkgroup())
return failure();
int count;
@@ -289,7 +294,7 @@ void XeGPUBlockingPass::runOnOperation() {
ArrayRef<int64_t> shape = type.getShape();
xegpu::LayoutAttr layout = type.getLayoutAttr();
- if (layout && layout.isWgLayout())
+ if (layout && layout.isForWorkgroup())
return failure();
int count;
@@ -377,7 +382,7 @@ void XeGPUBlockingPass::runOnOperation() {
if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
op->removeAttr(name);
if (!isa<LoopLikeOpInterface>(op))
- xegpu::setLayoutAttr(result, layout.dropInstData());
+ xegpu::setDistributeLayoutAttr(result, layout.dropInstData());
}
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index bef8804..5cb47b2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -718,7 +718,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
}
// If the result is a vector type, add a temporary layout attribute to the
// op.
- xegpu::setLayoutAttr(result, layout);
+ xegpu::setDistributeLayoutAttr(result, layout);
}
return success();
}
@@ -800,7 +800,7 @@ updateControlFlowOps(mlir::OpBuilder &builder,
// If the type is a vector type and this region argument is an OpResult,
// set the layout attribute on the OpResult.
if (auto result = dyn_cast<OpResult>(successorInput))
- xegpu::setLayoutAttr(result, successorOperandLayout);
+ xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
}
}
return success();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 2088c3c..dddb5ea 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -336,8 +336,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
Operation *lastNode = yield->getPrevNode();
auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
if (!storeOp)
@@ -449,8 +448,7 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
// Make sure the same load op is the last operation in the warp op body.
// This ensure that load op is not sinked earlier violating any barrier
// synchronizations.
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
return yield->getPrevNode() == op;
});
@@ -752,8 +750,7 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
Operation *lastNode = yield->getPrevNode();
auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
if (!prefetchOp)
@@ -794,8 +791,7 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
Operation *lastNode = yield->getPrevNode();
// The last node must be a gpu::BarrierOp.
auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
@@ -841,14 +837,15 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
if (!isa<VectorType>(operand.get().getType()))
continue;
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand);
+ auto layout =
+ xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(operand);
if (!layout) {
op->emitError("Could not find layout attribute for operand ")
<< operand.getOperandNumber() << " of operation " << op->getName();
signalPassFailure();
return;
}
- xegpu::setLayoutAttr(operand, layout);
+ xegpu::setDistributeLayoutAttr(operand, layout);
}
});
// Step 2: Move all operations of a GPU function inside
@@ -882,7 +879,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
if (vecRank == 0)
return AffineMap::get(val.getContext());
// Get the layout of the vector type.
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(val);
+ // TODO: support more layout types
+ auto layout = xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(val);
// If no layout is specified, assume the inner most dimension is distributed
// for now.
if (!layout)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 850f70c..9f627c7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -34,38 +34,29 @@ using namespace mlir;
namespace {
-// Check if there is sg id range attached to the scf.if op.
-static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange,
- int64_t &endOfRange) {
- Operation *parent = op->getParentOp();
- // Find the outermost scf::IfOp with xegpu.sg_id_range.
+// Retrieve the RangeAttr if it is specified.
+static xegpu::RangeAttr getRangeSpecAttr(Operation *op) {
+ Operation *parent = op->getParentOfType<scf::IfOp>();
while (parent) {
- if (auto ifOp = dyn_cast<scf::IfOp>(parent)) {
- if (auto attr = llvm::dyn_cast_or_null<xegpu::RangeAttr>(
- ifOp->getAttr("sg_id_range"))) {
- startOfRange = attr.getStart().getInt();
- endOfRange = attr.getEnd().getInt();
- break;
- }
- }
- parent = parent->getParentOp();
+ if (auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
+ parent->getAttr("sg_id_range")))
+ return attr;
+ parent = parent->getParentOfType<scf::IfOp>();
}
- // Return false if startOfRange is 0
- return (startOfRange > 0 && endOfRange > startOfRange);
+ return {};
}
static std::pair<SmallVector<int64_t>, int>
-getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
+getSgShapeAndCount(ArrayRef<int64_t> shape,
+ xegpu::DistributeLayoutAttr layout) {
int count = 1;
SmallVector<int64_t> sgShape(shape);
-
- if (layout && layout.isWgLayout()) {
- DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout();
- auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
- if (DenseI32ArrayAttr sgDataAttr = layout.getSgData())
- sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
- else
- sgShape = computeShapeRatio(shape, sgLayout).value_or(sgShape);
+ if (layout && layout.isForWorkgroup()) {
+ SmallVector<int64_t> sgLayout = layout.getSgLayoutAsInt();
+ if (!layout.getSgDataAsInt().empty())
+ sgShape = layout.getSgDataAsInt();
+ else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout))
+ sgShape = *maybeDerivedSgData;
SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
// Clamp distUnit to the original shape to handle cases where data is
// shared among subgroups, which may cause distUnit to exceed the original
@@ -77,6 +68,67 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
return std::make_pair(sgShape, count);
}
+/// Utility helper for deriving a list of offsets for each sub-TensorDescs
+/// or sub-MemDescs to be accessed by current subgroup (sgId) based on the
+/// associated distribute layout attribute, the shape, subgroup id and the
+/// original offsets of the op
+template <
+ typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
+ xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
+static LogicalResult
+genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
+ SmallVector<SmallVector<OpFoldResult>> &offsetsList) {
+ Location loc = op.getLoc();
+ SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
+ // not applicable to ops without offsets operands.
+ if (origOffsets.empty())
+ return failure();
+
+ // not applicable to ops without workgroup layout attributes
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ Value sgId = rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr);
+
+ // verify and adjust the sgId if the range specifier is present
+ xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
+ if (sgIdRange) {
+ int64_t startOfRange = sgIdRange.getStart().getInt();
+ int64_t endOfRange = sgIdRange.getEnd().getInt();
+ // verify the RangeAttr against the layout attribute
+ if (layout.getNumSubgroups() != endOfRange - startOfRange)
+ return rewriter.notifyMatchFailure(
+ op, "sg_layout size must match the sg_id_range");
+ // adjust the sgId if necessary
+ if (startOfRange > 0) {
+ Value startOfRangeVal =
+ rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
+ sgId = rewriter.create<index::SubOp>(loc, sgId, startOfRangeVal);
+ }
+ }
+
+ // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
+ // descriptors to be accessed, based on the layout information.
+ ArrayRef<int64_t> wgShape = op.getDataShape();
+ auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ if (failed(maybeDescOffsets))
+ return failure();
+
+ // Compute the final global offsets for each accessed sub-tensor
+ // or sub-memory descriptor.
+ for (const auto &sgOffsets : *maybeDescOffsets) {
+ SmallVector<OpFoldResult> newOffsets = xegpu::addWithRightAligned(
+ rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets);
+ offsetsList.push_back(std::move(newOffsets));
+ }
+
+ // callback(offsetsList);
+ return success();
+}
+
/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
/// from a workgroup descriptor. It replaces the offsets and sizes with
/// appropriate values for the subgroup.
@@ -125,125 +177,74 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
- // Calculate offset for each subgroup
- static SmallVector<OpFoldResult>
- calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
- const SmallVector<OpFoldResult> &originalOffsets,
- const SmallVector<Value> &localOffset,
- const SmallVector<int64_t> &distUnitBaseAddr,
- const SmallVector<int64_t> &distUnitShape) {
- assert(localOffset.size() == distUnitBaseAddr.size() &&
- "localOffset and distUnitBaseAddr must have the same rank");
-
- SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
- originalOffsets.end());
- size_t rank = localOffset.size();
- for (size_t i = 0; i < rank; ++i) {
- size_t dimIdx = originalOffsets.size() - rank + i;
- Value constOffset =
- arith::ConstantIndexOp::create(rewriter, loc, distUnitBaseAddr[i]);
- Value offset =
- rewriter.createOrFold<index::AddOp>(loc, localOffset[i], constOffset);
- Value modValue =
- arith::ConstantIndexOp::create(rewriter, loc, distUnitShape[i]);
- Value offsetMod =
- rewriter.createOrFold<index::RemUOp>(loc, offset, modValue);
- Value origOffset = getValueOrCreateConstantIndexOp(
- rewriter, loc, originalOffsets[dimIdx]);
- Value globalOffset =
- rewriter.createOrFold<index::AddOp>(loc, origOffset, offsetMod);
- globalOffsets[dimIdx] = globalOffset;
- }
-
- return globalOffsets;
- }
-
LogicalResult
matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
+ SmallVector<SmallVector<OpFoldResult>> offsetsList;
+ if (failed(genOffsetsList(rewriter, op, offsetsList)))
+ return failure();
+
MLIRContext *ctx = op.getContext();
xegpu::TensorDescType tdescTy = op.getType();
- auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
- if (!layout)
- return failure();
- Type elemTy = tdescTy.getElementType();
ArrayRef<int64_t> wgShape = tdescTy.getShape();
- // sgLayout must be present for workgroup-level distribution.
- SmallVector<int64_t> sgLayout;
- if (auto sgLayoutAttr = layout.getSgLayout())
- sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
- else
- return rewriter.notifyMatchFailure(
- op, "sgLayout attribute is required in layout");
-
+ Type elemTy = tdescTy.getElementType();
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ auto newTdescTy =
+ xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
+ layout.dropSgLayoutAndData());
- // TODO : Handle order attribute
- // Get the subgroup ID
- auto linearSgId =
- gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-
- // Create constants for layout dimensions
- SmallVector<Value> sgLayoutDim(sgLayout.size());
- SmallVector<Value> sgDataDim(sgShape.size());
+ SmallVector<Value> newOps;
+ for (auto offsets : offsetsList) {
+ auto newOp = xegpu::CreateNdDescOp::create(
+ rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
+ op.getMixedSizes(), op.getMixedStrides());
- for (size_t i = 0; i < sgLayout.size(); i++) {
- sgLayoutDim[i] =
- arith::ConstantIndexOp::create(rewriter, loc, sgLayout[i]);
- sgDataDim[i] = arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
+ newOps.push_back(newOp);
}
+ rewriter.replaceOpWithMultiple(op, {newOps});
- int64_t startOfRange = -1, endOfRange = -1;
- bool sgIdRangeSpecified =
- isSgIdRangeSpecified(op, startOfRange, endOfRange);
-
- Value adjustedSgId = linearSgId;
- if (sgIdRangeSpecified) {
- int64_t sgCount = endOfRange - startOfRange;
- if (computeProduct(sgLayout) != sgCount)
- return rewriter.notifyMatchFailure(
- op, "sg_layout size must match the sg_id_range");
- // Subtract startOfRange from the original subgroup id to get the adjusted
- // sg id
- Value startOfRangeVal =
- arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
- adjustedSgId =
- rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
- }
+ return success();
+ }
+};
+
+// This pattern transforms the CreateNdDescOp without offsets to create a
+// subgroup descriptor from a workgroup descriptor
+struct WgToSgCreateNdOpNoOffset
+ : public OpConversionPattern<xegpu::CreateNdDescOp> {
+ using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
- auto deLinearizeSgId =
- affine::delinearizeIndex(rewriter, loc, adjustedSgId, sgLayoutDim);
- if (failed(deLinearizeSgId))
+ LogicalResult
+ matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ // Check no offsets are specified.
+ if (!op.getMixedOffsets().empty())
return failure();
- SmallVector<Value> sgIds = *deLinearizeSgId;
-
- // Calculate distribution unit shape and local offsets for subgroup
- SmallVector<int64_t> distUnitShape(sgLayout.size());
- SmallVector<Value> localOffset(sgLayout.size());
- for (size_t i = 0; i < sgLayout.size(); i++) {
- distUnitShape[i] = std::min(sgLayout[i] * sgShape[i], wgShape[i]);
- localOffset[i] =
- rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]);
- }
- SmallVector<OpFoldResult> originalOffsets = op.getMixedOffsets();
+ Location loc = op.getLoc();
+ MLIRContext *ctx = op.getContext();
+ xegpu::TensorDescType tdescTy = op.getType();
+ auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+ Type elemTy = tdescTy.getElementType();
+ ArrayRef<int64_t> wgShape = tdescTy.getShape();
+
+ SmallVector<int64_t> sgShape;
+ int count;
+ std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
xegpu::TensorDescType newTdescTy =
xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
layout.dropSgLayoutAndData());
- SmallVector<Value> newCreateNdOps;
- for (SmallVector<int64_t> distUnitBaseAddr :
- StaticTileOffsetRange(wgShape, distUnitShape)) {
- SmallVector<OpFoldResult> globalOffsets =
- calculateGlobalOffsets(rewriter, loc, originalOffsets, localOffset,
- distUnitBaseAddr, distUnitShape);
-
- auto newCreateNdOp = xegpu::CreateNdDescOp::create(
- rewriter, loc, newTdescTy, op.getSource(), globalOffsets,
- op.getMixedSizes(), op.getMixedStrides());
- newCreateNdOps.push_back(newCreateNdOp);
- }
+
+ SmallVector<Value> newCreateNdOps(count);
+ std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
+ return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
+ op.getSource(), op.getMixedSizes(),
+ op.getMixedStrides());
+ });
rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
return success();
@@ -256,12 +257,10 @@ struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
LogicalResult
matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- SmallVector<Value> newLoadOps;
-
- int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
- if ((offsetSize != 0) || op.getConstOffsetsAttr())
+ if (!op.getMixedOffsets().empty())
return failure();
+ SmallVector<Value> newLoadOps;
for (auto src : adaptor.getTensorDesc()) {
xegpu::TensorDescType tdescTy =
dyn_cast<xegpu::TensorDescType>(src.getType());
@@ -284,9 +283,7 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
LogicalResult
matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
-
- int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
- if ((offsetSize != 0) || op.getConstOffsetsAttr())
+ if (!op.getMixedOffsets().empty())
return failure();
for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
@@ -298,6 +295,84 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
}
};
+// This pattern transforms the LoadNdOp with explicit offsets to load
+// subgroup data.
+struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
+ using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ SmallVector<SmallVector<OpFoldResult>> offsetsList;
+ if (failed(genOffsetsList(rewriter, op, offsetsList)))
+ return failure();
+
+ SmallVector<Value> newOps;
+ for (auto [tdesc, offsets] :
+ llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
+ auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
+ VectorType newResTy =
+ VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
+ auto newOp = xegpu::LoadNdOp::create(
+ rewriter, op.getLoc(), newResTy, tdesc, offsets,
+ /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
+ newOps.push_back(newOp);
+ }
+ rewriter.replaceOpWithMultiple(op, {newOps});
+
+ return success();
+ }
+};
+
+// This pattern transforms the StoreNdOp with explicit offsets to store
+// subgroup data.
+struct WgToSgStoreNdOpWithOffset
+ : public OpConversionPattern<xegpu::StoreNdOp> {
+ using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<SmallVector<OpFoldResult>> offsetsList;
+ if (failed(genOffsetsList(rewriter, op, offsetsList)))
+ return failure();
+
+ for (auto [v, tdesc, offsets] :
+ llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
+ rewriter.create<xegpu::StoreNdOp>(op.getLoc(), v, tdesc, offsets,
+ op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr());
+ }
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+};
+
+// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
+// subgroup data.
+struct WgToSgPrefetchNdOpWithOffset
+ : public OpConversionPattern<xegpu::PrefetchNdOp> {
+ using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<SmallVector<OpFoldResult>> offsetsList;
+ if (failed(genOffsetsList(rewriter, op, offsetsList)))
+ return failure();
+
+ for (auto [tdesc, offsets] :
+ llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
+ rewriter.create<xegpu::PrefetchNdOp>(
+ op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr());
+ }
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+};
+
/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
/// offsets of the new subgroup src tensor descriptors.
@@ -331,7 +406,7 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
if (resultTy.getRank() != 2)
return failure();
- auto originalLayout = xegpu::getLayoutAttr(op.getResult());
+ auto originalLayout = xegpu::getDistributeLayoutAttr(op.getResult());
if (!originalLayout)
return failure();
@@ -354,8 +429,8 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
resultTy.getElementType());
tmpC = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
- xegpu::setLayoutAttr(cast<OpResult>(tmpC),
- originalLayout.dropSgLayoutAndData());
+ xegpu::setDistributeLayoutAttr(cast<OpResult>(tmpC),
+ originalLayout.dropSgLayoutAndData());
newDpasOps.push_back(tmpC);
}
@@ -395,8 +470,9 @@ struct WgToSgVectorBroadcastOp
VectorType resultType = op.getResult().getType();
ArrayRef<int64_t> wgShape = resultType.getShape();
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
- if (!layout || !layout.getSgLayout())
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
return failure();
// TODO: Currently only supports cases where the source and result ranks
@@ -411,10 +487,8 @@ struct WgToSgVectorBroadcastOp
VectorType::get(sgShape, resultType.getElementType());
// Check if the output layout is distributable
- SmallVector<int64_t> sgLayout;
- if (auto sgLayoutAttr = layout.getSgLayout())
- sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
- else
+ SmallVector<int64_t> sgLayout = layout.getSgLayoutAsInt();
+ if (sgLayout.empty())
return failure();
if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
@@ -433,8 +507,8 @@ struct WgToSgVectorBroadcastOp
for (auto operand : adaptor.getOperands().front()) {
auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
newResultType, operand);
- xegpu::setLayoutAttr(newBroadcast->getResult(0),
- layout.dropSgLayoutAndData());
+ xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
+ layout.dropSgLayoutAndData());
newBroadcastOps.push_back(newBroadcast.getResult());
}
@@ -460,8 +534,9 @@ struct WgToSgElementwiseOp : public ConversionPattern {
ArrayRef<int64_t> wgShape = resultType.getShape();
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
- if (!layout || !layout.getSgLayout())
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op->getResult(0));
+ if (!layout || !layout.isForWorkgroup())
return failure();
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
@@ -526,8 +601,8 @@ struct WgToSgElementwiseOp : public ConversionPattern {
// is lowered to:
// #a = #xegpu.layout<inst_data = [16, 16]>
// #b = #xegpu.layout<inst_data = [8, 16]>
-// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, matrix_desc<32x64xf32>
-// %d = load_matrix %slm <{layout_result_0 = #a}> : matrix_desc<32x64xf32> -> vector<16x32xf32>
+// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32>
+// %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32>
// xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32>
// clang-format on
struct WgToSgConvertLayoutOp
@@ -536,10 +611,12 @@ struct WgToSgConvertLayoutOp
LogicalResult
matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- xegpu::LayoutAttr input = op.getInputLayout();
- xegpu::LayoutAttr target = op.getTargetLayout();
+ // TODO: currently, we only support LayoutAttr
+ auto input = dyn_cast<xegpu::LayoutAttr>(op.getInputLayout());
+ auto target = dyn_cast<xegpu::LayoutAttr>(op.getTargetLayout());
- if (!input || !target || !input.isWgLayout() || !target.isWgLayout())
+ if (!input || !target || !input.isForWorkgroup() ||
+ !target.isForWorkgroup())
return rewriter.notifyMatchFailure(
op, "Input and target layouts must have subgroup layout");
@@ -649,16 +726,213 @@ struct UnrealizedConversionCastOpPattern
}
};
+// This pattern distributes arith.constant op into subgroup-level constants
+struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
+ using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
+ auto vecType = dyn_cast<VectorType>(op.getType());
+ if (!vecAttr || !vecAttr.isSplat() || !vecType)
+ return failure();
+
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ ArrayRef<int64_t> wgShape = vecType.getShape();
+ SmallVector<int64_t> sgShape;
+ int count;
+ std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+
+ // Current limitation: constant of vector with single value.
+ // TODO: support more complex cases, e.g., vector with multiple values.
+ Attribute singleVal = vecAttr.getSplatValue<Attribute>();
+
+ auto newType = VectorType::get(sgShape, vecType.getElementType());
+ auto sgAttr = DenseElementsAttr::get(newType, singleVal);
+ auto cstOp =
+ arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
+ if (auto newLayout = layout.dropSgLayoutAndData())
+ xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
+ SmallVector<Value> newConsts(count, cstOp);
+
+ rewriter.replaceOpWithMultiple(op, {newConsts});
+ return success();
+ }
+};
+
+// This pattern transforms the LoadGatherOp with explicit offsets to load
+// subgroup data
+struct WgToSgLoadGatherOpWithOffset
+ : public OpConversionPattern<xegpu::LoadGatherOp> {
+ using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ if (!op.getOffsets())
+ return failure();
+
+ Location loc = op.getLoc();
+ VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
+ if (!resultType)
+ return failure();
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+
+ // The offsets need to be distributed
+ auto offsetsVecType =
+ dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
+ auto maskVecType =
+ dyn_cast<VectorType>(adaptor.getMask().front().getType());
+ if (!offsetsVecType || !maskVecType ||
+ offsetsVecType.getShape() != maskVecType.getShape()) {
+ return rewriter.notifyMatchFailure(op,
+ "offsets have not been distributed");
+ }
+
+ SmallVector<Value> newLoadOps;
+ auto chunkSizeAttr =
+ rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
+ VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
+ for (auto [offsets, mask] :
+ llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
+ auto newLoadOp = rewriter.create<xegpu::LoadGatherOp>(
+ loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
+ op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+ xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0),
+ layout.dropSgLayoutAndData());
+ newLoadOps.push_back(newLoadOp);
+ }
+ rewriter.replaceOpWithMultiple(op, {newLoadOps});
+ return success();
+ }
+};
+
+// This pattern transforms the StoreScatterOp with explicit offsets to store
+// subgroup data
+struct WgToSgStoreScatterOpWithOffset
+ : public OpConversionPattern<xegpu::StoreScatterOp> {
+ using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ if (!op.getOffsets())
+ return failure();
+
+ Location loc = op.getLoc();
+ VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
+ if (!valueType)
+ return failure();
+
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getValue());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ // The offsets need to be distributed
+ auto offsetsVecType =
+ dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
+ auto maskVecType =
+ dyn_cast<VectorType>(adaptor.getMask().front().getType());
+ if (!offsetsVecType || !maskVecType ||
+ offsetsVecType.getShape() != maskVecType.getShape()) {
+ return rewriter.notifyMatchFailure(op,
+ "offsets have not been distributed");
+ }
+
+ auto chunkSizeOpt = op.getChunkSize();
+ int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
+ auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
+ for (auto [val, offs, mask] : llvm::zip(
+ adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
+ rewriter.create<xegpu::StoreScatterOp>(
+ loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
+ // Update the layout attribute to drop sg_layout and sg_data.
+ if (auto newLayout = layout.dropSgLayoutAndData())
+ op->setAttr("layout", newLayout);
+ }
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
+ using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ SmallVector<SmallVector<OpFoldResult>> offsetsList;
+ if (failed(genOffsetsList(rewriter, op, offsetsList)))
+ return failure();
+
+ ArrayRef<int64_t> wgShape = op.getDataShape();
+ VectorType valueTy = op.getRes().getType();
+ Type elemTy = valueTy.getElementType();
+
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ VectorType newResTy = VectorType::get(sgShape, elemTy);
+ SmallVector<Value> newOps;
+ for (auto offsets : offsetsList) {
+ auto newOp = rewriter.create<xegpu::LoadMatrixOp>(
+ op.getLoc(), newResTy, op.getMemDesc(), offsets,
+ layout.dropSgLayoutAndData());
+ newOps.push_back(newOp);
+ }
+ rewriter.replaceOpWithMultiple(op, {newOps});
+
+ return success();
+ }
+};
+
+struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
+ using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ SmallVector<SmallVector<OpFoldResult>> offsetsList;
+ if (failed(genOffsetsList(rewriter, op, offsetsList)))
+ return failure();
+
+ xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+ for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
+ rewriter.create<xegpu::StoreMatrixOp>(op.getLoc(), v, op.getMemDesc(),
+ offsets,
+ layout.dropSgLayoutAndData());
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
} // namespace
namespace mlir {
namespace xegpu {
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
- patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
- WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
- UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
- WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
- patterns.getContext());
+ patterns
+ .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
+ WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
+ WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
+ WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
+ WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
+ WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
+ WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
+ WgToSgStoreMatrixOp>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -748,8 +1022,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return xegpu::TensorDescType();
};
- auto isLegal = [&](xegpu::LayoutAttr layout) -> bool {
- return !layout || !layout.isWgLayout();
+ auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool {
+ return !layout || !layout.isForWorkgroup();
};
target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
@@ -761,13 +1035,46 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
});
target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
- auto layout = xegpu::getLayoutAttr(op.getResult());
+ auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
return isLegal(layout);
});
+ target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
+ [=](xegpu::LoadMatrixOp op) -> bool {
+ return isLegal(op.getLayoutAttr());
+ });
+
+ target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
+ [=](xegpu::StoreMatrixOp op) -> bool {
+ return isLegal(op.getLayoutAttr());
+ });
+
+ target.addDynamicallyLegalOp<arith::ConstantOp>(
+ [=](arith::ConstantOp op) -> bool {
+ auto vecType = dyn_cast<VectorType>(op.getType());
+ if (!vecType)
+ return true;
+ return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
+ });
+
+ target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
+ [=](xegpu::LoadGatherOp op) -> bool {
+ auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
+ return isLegal(layout);
+ });
+
+ target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
+ [=](xegpu::StoreScatterOp op) -> bool {
+ // Check if the layout attribute is present on the result.
+ auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout");
+ if (!layout)
+ return true;
+ return isLegal(layout);
+ });
+
target.addDynamicallyLegalOp<vector::BroadcastOp>(
[=](vector::BroadcastOp op) -> bool {
- return isLegal(xegpu::getLayoutAttr(op.getResult()));
+ return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
});
target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
@@ -795,7 +1102,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
}
}
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op->getResult(0));
return isLegal(layout);
});
diff --git a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
index 98e84a4..d9bf4a1 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
@@ -7,5 +7,7 @@ add_mlir_dialect_library(MLIRXeGPUUtils
LINK_LIBS PUBLIC
MLIRIR
MLIRSCFTransforms
+ MLIRGPUDialect
+ MLIRXeVMDialect
MLIRXeGPUDialect
)
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 2cf21fb..cac1ffe 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -11,6 +11,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
@@ -38,7 +41,7 @@ mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) {
auto layout = llvm::dyn_cast_if_present<LayoutAttr>(tdescTy.getLayout());
// It only works for subgroup level layout, which only has lane_layout
// and lane_data, and is to distribute a SIMD code into SIMT code.
- if (!layout || !layout.isSgLayout())
+ if (!layout || !layout.isForSubgroup())
return failure();
SmallVector<int64_t> laneData(layout.getLaneData().asArrayRef());
@@ -111,7 +114,7 @@ std::string xegpu::getLayoutName(const OpResult result) {
return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
}
-xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {
+xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
if (!value)
return nullptr;
@@ -129,11 +132,11 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {
// for LoadNdOp, the layout is stored in the tensor descriptor
if (auto loadNd = dyn_cast<xegpu::LoadNdOp>(defOp))
- return getLayoutAttr(loadNd.getTensorDesc());
+ return getDistributeLayoutAttr(loadNd.getTensorDesc());
std::string layoutName = getLayoutName(result);
if (defOp->hasAttr(layoutName))
- return defOp->getAttrOfType<xegpu::LayoutAttr>(layoutName);
+ return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
}
if (auto arg = dyn_cast<BlockArgument>(value)) {
@@ -141,49 +144,51 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {
if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
OpOperand *tiedInit = loop.getTiedLoopInit(arg);
if (tiedInit)
- return getLayoutAttr(tiedInit->get());
+ return getDistributeLayoutAttr(tiedInit->get());
}
}
return nullptr;
}
-xegpu::LayoutAttr xegpu::getLayoutAttr(const OpOperand &opr) {
+xegpu::DistributeLayoutAttr
+xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
Operation *op = opr.getOwner();
std::string layoutName = xegpu::getLayoutName(opr);
if (op->hasAttr(layoutName))
- return op->getAttrOfType<xegpu::LayoutAttr>(layoutName);
- return getLayoutAttr(opr.get());
+ return op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+ return getDistributeLayoutAttr(opr.get());
}
template <typename T, typename>
-void xegpu::setLayoutAttr(const T &operandOrResult, const LayoutAttr layout) {
+void xegpu::setDistributeLayoutAttr(const T &operandOrResult,
+ const DistributeLayoutAttr layout) {
Operation *owner = operandOrResult.getOwner();
std::string name = xegpu::getLayoutName(operandOrResult);
- if (layout && !owner->hasAttrOfType<LayoutAttr>(name))
+ if (layout && !owner->hasAttrOfType<DistributeLayoutAttr>(name))
owner->setAttr(name, layout);
}
// Explicit instantiation for OpResult
-template void
-xegpu::setLayoutAttr<mlir::OpResult>(const mlir::OpResult &result,
- const mlir::xegpu::LayoutAttr layout);
+template void xegpu::setDistributeLayoutAttr<mlir::OpResult>(
+ const mlir::OpResult &result,
+ const mlir::xegpu::DistributeLayoutAttr layout);
// Explicit instantiation for OpOperand
-template void
-xegpu::setLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand,
- const mlir::xegpu::LayoutAttr layout);
+template void xegpu::setDistributeLayoutAttr<mlir::OpOperand>(
+ const mlir::OpOperand &operand,
+ const mlir::xegpu::DistributeLayoutAttr layout);
-void xegpu::setLayoutAttrs(Operation *op,
- function_ref<LayoutAttr(Value)> getLayoutImpl) {
+void xegpu::setDistributeLayoutAttrs(
+ Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) {
op->walk([&](Operation *nestOp) {
for (OpOperand &opr : nestOp->getOpOperands()) {
auto layout = getLayoutImpl(opr.get());
- setLayoutAttr(opr, layout);
+ setDistributeLayoutAttr(opr, layout);
}
for (OpResult result : nestOp->getOpResults()) {
auto layout = getLayoutImpl(result);
- setLayoutAttr(result, layout);
+ setDistributeLayoutAttr(result, layout);
}
});
}
@@ -192,7 +197,7 @@ template <typename T, typename>
void xegpu::removeLayoutAttr(const T &operandOrResult) {
Operation *owner = operandOrResult.getOwner();
std::string name = xegpu::getLayoutName(operandOrResult);
- if (owner->hasAttrOfType<LayoutAttr>(name))
+ if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
owner->removeAttr(name);
}
@@ -303,7 +308,8 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
if (!inputTy || !resultTy)
return WalkResult::skip();
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(input);
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(input);
if (!layout)
return WalkResult::skip();
@@ -341,7 +347,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
}
{ // perform the conversion from RankedTensorType to VectorType based on the
- // LayoutAttr
+ // DistributeLayoutAttr
// Handle the UnrealizedConversionCastOp introduced by the first step.
// For vector->RankedTensorType, it will simply forward the inputs.
@@ -404,3 +410,49 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
(void)mlir::applyPartialConversion(op, target, std::move(patterns));
}
}
+
+std::optional<std::string> xegpu::getChipStr(Operation *op) {
+ auto gpuModuleOp = op->getParentOfType<gpu::GPUModuleOp>();
+
+ if (!gpuModuleOp)
+ return std::nullopt;
+
+ auto targetAttrs = gpuModuleOp.getTargets();
+ if (targetAttrs) {
+ for (auto &attr : *targetAttrs) {
+ auto xevmAttr = llvm::dyn_cast<xevm::XeVMTargetAttr>(attr);
+ if (xevmAttr)
+ return xevmAttr.getChip().str();
+ }
+ }
+
+ return std::nullopt;
+}
+
+/// Generates element-wise addition ops of two arrays with automatic alignment.
+/// When the input arrays have different sizes, the shorter array is
+/// right-aligned with the longer array, and the unmatched leading elements from
+/// the longer array are preserved unchanged. This is commonly used for offset
+/// computation where higher-dimensional offsets need to be added to
+/// lower-dimensional adjustments.
+///
+/// Example:
+/// lhs = [l1, l2, l3], rhs = [r1, r2]
+/// Result: [11, l2+r1, l3+r2]
+SmallVector<OpFoldResult>
+xegpu::addWithRightAligned(OpBuilder &builder, Location loc,
+ ArrayRef<OpFoldResult> lhs,
+ ArrayRef<OpFoldResult> rhs) {
+ // ensure a is longer than b
+ ArrayRef<OpFoldResult> a = lhs.size() >= rhs.size() ? lhs : rhs;
+ ArrayRef<OpFoldResult> b = lhs.size() >= rhs.size() ? rhs : lhs;
+ SmallVector<OpFoldResult> results(a.take_front(a.size() - b.size()));
+ a = a.slice(a.size() - b.size());
+ for (auto [l, r] : llvm::zip(a, b)) {
+ auto lval = getValueOrCreateConstantIndexOp(builder, loc, l);
+ auto rval = getValueOrCreateConstantIndexOp(builder, loc, r);
+ results.push_back(builder.createOrFold<index::AddOp>(loc, lval, rval));
+ }
+ return results;
+ return {};
+}
diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
index f704fbf..52162a4 100644
--- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
+++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
@@ -106,7 +106,7 @@ void ExecutionEngine::dumpToObjectFile(StringRef filename) {
}
// Compilation is lazy and it doesn't populate object cache unless requested.
// In case object dump is requested before cache is populated, we need to
- // force compilation manually.
+ // force compilation manually.
if (cache->isEmpty()) {
for (std::string &functionName : functionNames) {
auto result = lookupPacked(functionName);
@@ -400,13 +400,6 @@ ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options,
return symbolMap;
};
engine->registerSymbols(runtimeSymbolMap);
-
- // Execute the global constructors from the module being processed.
- // TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a
- // crash for AArch64 see related issue #71963.
- if (!engine->jit->getTargetTriple().isAArch64())
- cantFail(engine->jit->initialize(engine->jit->getMainJITDylib()));
-
return std::move(engine);
}
@@ -442,6 +435,7 @@ Expected<void *> ExecutionEngine::lookup(StringRef name) const {
Error ExecutionEngine::invokePacked(StringRef name,
MutableArrayRef<void *> args) {
+ initialize();
auto expectedFPtr = lookupPacked(name);
if (!expectedFPtr)
return expectedFPtr.takeError();
@@ -451,3 +445,13 @@ Error ExecutionEngine::invokePacked(StringRef name,
return Error::success();
}
+
+void ExecutionEngine::initialize() {
+ if (isInitialized)
+ return;
+ // TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a
+ // crash for AArch64 see related issue #71963.
+ if (!jit->getTargetTriple().isAArch64())
+ cantFail(jit->initialize(jit->getMainJITDylib()));
+ isInitialized = true;
+}
diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp
index 2107df3..0ada4cc 100644
--- a/mlir/lib/ExecutionEngine/JitRunner.cpp
+++ b/mlir/lib/ExecutionEngine/JitRunner.cpp
@@ -202,6 +202,8 @@ compileAndExecute(Options &options, Operation *module, StringRef entryPoint,
auto engine = std::move(*expectedEngine);
+ engine->initialize();
+
auto expectedFPtr = engine->lookupPacked(entryPoint);
if (!expectedFPtr)
return expectedFPtr.takeError();
diff --git a/mlir/lib/ExecutionEngine/VulkanRuntime.cpp b/mlir/lib/ExecutionEngine/VulkanRuntime.cpp
index 9f653b2..9452a56 100644
--- a/mlir/lib/ExecutionEngine/VulkanRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/VulkanRuntime.cpp
@@ -20,7 +20,7 @@
#include <iomanip>
#include <iostream>
-inline void emitVulkanError(const char *api, VkResult error) {
+static inline void emitVulkanError(const char *api, VkResult error) {
std::cerr << " failed with error code " << error << " when executing " << api;
}
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index 57825d9..27b47e2 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -251,6 +251,16 @@ bool Block::mightHaveTerminator() {
return !empty() && back().mightHaveTrait<OpTrait::IsTerminator>();
}
+iterator_range<Block::iterator> Block::without_terminator_impl() {
+ // Note: When the op is unregistered, we do not know for sure if the last
+ // op is a terminator. In that case, we include it in `without_terminator`,
+ // but that decision is somewhat arbitrary.
+ if (!back().hasTrait<OpTrait::IsTerminator>())
+ return {begin(), end()};
+ auto endIt = --end();
+ return {begin(), endIt};
+}
+
// Indexed successor access.
unsigned Block::getNumSuccessors() {
return empty() ? 0 : back().getNumSuccessors();
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index fd898b7..6f880f8 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -19,6 +19,7 @@
#include "mlir/IR/Types.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/Endian.h"
#include <optional>
@@ -1119,9 +1120,8 @@ static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
auto denseEltBitWidth = getDenseElementBitWidth(type);
auto dataSize = static_cast<size_t>(dataEltSize * CHAR_BIT);
if (denseEltBitWidth != dataSize) {
- LLVM_DEBUG(llvm::dbgs() << "expected dense element bit width "
- << denseEltBitWidth << " to match data size "
- << dataSize << " for type " << type << "\n");
+ LDBG() << "expected dense element bit width " << denseEltBitWidth
+ << " to match data size " << dataSize << " for type " << type;
return false;
}
@@ -1129,9 +1129,7 @@ static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
if (!isInt) {
bool valid = llvm::isa<FloatType>(type);
if (!valid)
- LLVM_DEBUG(llvm::dbgs()
- << "expected float type when isInt is false, but found "
- << type << "\n");
+ LDBG() << "expected float type when isInt is false, but found " << type;
return valid;
}
if (type.isIndex())
@@ -1139,9 +1137,7 @@ static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
auto intType = llvm::dyn_cast<IntegerType>(type);
if (!intType) {
- LLVM_DEBUG(llvm::dbgs()
- << "expected integer type when isInt is true, but found " << type
- << "\n");
+ LDBG() << "expected integer type when isInt is true, but found " << type;
return false;
}
@@ -1151,8 +1147,7 @@ static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
bool valid = intType.isSigned() == isSigned;
if (!valid)
- LLVM_DEBUG(llvm::dbgs() << "expected signedness " << isSigned
- << " to match type " << type << "\n");
+ LDBG() << "expected signedness " << isSigned << " to match type " << type;
return valid;
}
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 3ef69ce..d95bdc9 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -33,6 +33,7 @@ add_mlir_library(MLIRIR
PatternMatch.cpp
Region.cpp
RegionKindInterface.cpp
+ Remarks.cpp
SymbolTable.cpp
TensorEncoding.cpp
Types.cpp
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index f84fe89..952619b 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -19,7 +19,7 @@
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/Twine.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/Regex.h"
#include <memory>
@@ -104,14 +104,8 @@ void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
auto it = registeredInterfaces.try_emplace(interface->getID(),
std::move(interface));
- (void)it;
- LLVM_DEBUG({
- if (!it.second) {
- llvm::dbgs() << "[" DEBUG_TYPE
- "] repeated interface registration for dialect "
- << getNamespace();
- }
- });
+ if (!it.second)
+ LDBG() << "repeated interface registration for dialect " << getNamespace();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp
index 23e70c6..662681e 100644
--- a/mlir/lib/IR/Location.cpp
+++ b/mlir/lib/IR/Location.cpp
@@ -52,8 +52,8 @@ struct FileLineColRangeAttrStorage final
FileLineColRangeAttrStorage::totalSizeToAlloc<unsigned>(locEnc - 1);
auto *rawMem =
allocator.allocate(byteSize, alignof(FileLineColRangeAttrStorage));
- auto *result = ::new (rawMem) FileLineColRangeAttrStorage(
- std::move(std::get<0>(tblgenKey)), locEnc - 1);
+ auto *result = ::new (rawMem)
+ FileLineColRangeAttrStorage(std::get<0>(tblgenKey), locEnc - 1);
if (numInArray > 0) {
ArrayRef<unsigned> elements = std::get<1>(tblgenKey);
result->startLine = elements[0];
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 2d5381d..1fa04ed 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -25,12 +25,13 @@
#include "mlir/IR/Location.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/Remarks.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Compiler.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/Mutex.h"
#include "llvm/Support/RWMutex.h"
@@ -134,6 +135,11 @@ public:
DiagnosticEngine diagEngine;
//===--------------------------------------------------------------------===//
+ // Remark
+ //===--------------------------------------------------------------------===//
+ std::unique_ptr<remark::detail::RemarkEngine> remarkEngine;
+
+ //===--------------------------------------------------------------------===//
// Options
//===--------------------------------------------------------------------===//
@@ -388,6 +394,19 @@ bool MLIRContext::hasActionHandler() { return (bool)getImpl().actionHandler; }
DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
//===----------------------------------------------------------------------===//
+// Remark Handlers
+//===----------------------------------------------------------------------===//
+
+void MLIRContext::setRemarkEngine(
+ std::unique_ptr<remark::detail::RemarkEngine> engine) {
+ getImpl().remarkEngine = std::move(engine);
+}
+
+remark::detail::RemarkEngine *MLIRContext::getRemarkEngine() {
+ return getImpl().remarkEngine.get();
+}
+
+//===----------------------------------------------------------------------===//
// Dialect and Operation Registration
//===----------------------------------------------------------------------===//
@@ -455,8 +474,7 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
auto dialectIt = impl.loadedDialects.try_emplace(dialectNamespace, nullptr);
if (dialectIt.second) {
- LLVM_DEBUG(llvm::dbgs()
- << "Load new dialect in Context " << dialectNamespace << "\n");
+ LDBG() << "Load new dialect in Context " << dialectNamespace;
#ifndef NDEBUG
if (impl.multiThreadedExecutionContext != 0)
llvm::report_fatal_error(
@@ -525,8 +543,7 @@ DynamicDialect *MLIRContext::getOrLoadDynamicDialect(
"' has already been registered");
}
- LLVM_DEBUG(llvm::dbgs() << "Load new dynamic dialect in Context "
- << dialectNamespace << "\n");
+ LDBG() << "Load new dynamic dialect in Context " << dialectNamespace;
#ifndef NDEBUG
if (impl.multiThreadedExecutionContext != 0)
llvm::report_fatal_error(
@@ -1192,11 +1209,10 @@ willBeValidAffineMap(unsigned dimCount, unsigned symbolCount,
getMaxDimAndSymbol(ArrayRef<ArrayRef<AffineExpr>>(results), maxDimPosition,
maxSymbolPosition);
if ((maxDimPosition >= dimCount) || (maxSymbolPosition >= symbolCount)) {
- LLVM_DEBUG(
- llvm::dbgs()
+ LDBG()
<< "maximum dimensional identifier position in result expression must "
"be less than `dimCount` and maximum symbolic identifier position "
- "in result expression must be less than `symbolCount`\n");
+ "in result expression must be less than `symbolCount`";
return false;
}
return true;
diff --git a/mlir/lib/IR/ODSSupport.cpp b/mlir/lib/IR/ODSSupport.cpp
index 69b4a56..f2665d2 100644
--- a/mlir/lib/IR/ODSSupport.cpp
+++ b/mlir/lib/IR/ODSSupport.cpp
@@ -112,7 +112,7 @@ Attribute mlir::convertToAttribute(MLIRContext *ctx, bool storage) {
}
template <typename DenseArrayTy, typename T>
-LogicalResult
+static LogicalResult
convertDenseArrayFromAttr(MutableArrayRef<T> storage, Attribute attr,
function_ref<InFlightDiagnostic()> emitError,
StringRef denseArrayTyStr) {
@@ -143,7 +143,7 @@ mlir::convertFromAttribute(MutableArrayRef<int32_t> storage, Attribute attr,
}
template <typename DenseArrayTy, typename T>
-LogicalResult
+static LogicalResult
convertDenseArrayFromAttr(SmallVectorImpl<T> &storage, Attribute attr,
function_ref<InFlightDiagnostic()> emitError,
StringRef denseArrayTyStr) {
diff --git a/mlir/lib/IR/Remarks.cpp b/mlir/lib/IR/Remarks.cpp
new file mode 100644
index 0000000..78c9644
--- /dev/null
+++ b/mlir/lib/IR/Remarks.cpp
@@ -0,0 +1,279 @@
+//===- Remarks.cpp - MLIR Remarks -----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Remarks.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Value.h"
+
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringRef.h"
+
+using namespace mlir::remark::detail;
+
+//------------------------------------------------------------------------------
+// Remark
+//------------------------------------------------------------------------------
+
+Remark::Arg::Arg(llvm::StringRef k, Value v) : key(k) {
+ llvm::raw_string_ostream os(val);
+ os << v;
+}
+
+Remark::Arg::Arg(llvm::StringRef k, Type t) : key(k) {
+ llvm::raw_string_ostream os(val);
+ os << t;
+}
+
+void Remark::insert(llvm::StringRef s) { args.emplace_back(s); }
+void Remark::insert(Arg a) { args.push_back(std::move(a)); }
+
+// Simple helper to print key=val list (sorted).
+static void printArgs(llvm::raw_ostream &os, llvm::ArrayRef<Remark::Arg> args) {
+ if (args.empty())
+ return;
+
+ llvm::SmallVector<Remark::Arg, 8> sorted(args.begin(), args.end());
+ llvm::sort(sorted, [](const Remark::Arg &a, const Remark::Arg &b) {
+ return a.key < b.key;
+ });
+
+ for (size_t i = 0; i < sorted.size(); ++i) {
+ const auto &a = sorted[i];
+ os << a.key << "=";
+
+ llvm::StringRef val(a.val);
+ bool needsQuote = val.contains(' ') || val.contains(',') ||
+ val.contains('{') || val.contains('}');
+ if (needsQuote)
+ os << '"' << val << '"';
+ else
+ os << val;
+
+ if (i + 1 < sorted.size())
+ os << ", ";
+ }
+}
+
+/// Print the remark to the given output stream.
+/// Example output:
+// clang-format off
+/// [Missed] Category: Loop | Pass:Unroller | Function=main | Reason="tripCount=4 < threshold=256"
+/// [Failure] LoopOptimizer | Reason="failed due to unsupported pattern"
+// clang-format on
+void Remark::print(llvm::raw_ostream &os, bool printLocation) const {
+ // Header: [Type] pass:remarkName
+ StringRef type = getRemarkTypeString();
+ StringRef categoryName = getFullCategoryName();
+ StringRef name = remarkName;
+
+ os << '[' << type << "] ";
+ os << name << " | ";
+ if (!categoryName.empty())
+ os << "Category:" << categoryName << " | ";
+ if (!functionName.empty())
+ os << "Function=" << getFunction() << " | ";
+
+ if (printLocation) {
+ if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(getLocation()))
+ os << " @" << flc.getFilename() << ":" << flc.getLine() << ":"
+ << flc.getColumn();
+ }
+
+ printArgs(os, getArgs());
+}
+
+std::string Remark::getMsg() const {
+ std::string s;
+ llvm::raw_string_ostream os(s);
+ print(os);
+ os.flush();
+ return s;
+}
+
+llvm::StringRef Remark::getRemarkTypeString() const {
+ switch (remarkKind) {
+ case RemarkKind::RemarkUnknown:
+ return "Unknown";
+ case RemarkKind::RemarkPassed:
+ return "Passed";
+ case RemarkKind::RemarkMissed:
+ return "Missed";
+ case RemarkKind::RemarkFailure:
+ return "Failure";
+ case RemarkKind::RemarkAnalysis:
+ return "Analysis";
+ }
+ llvm_unreachable("Unknown remark kind");
+}
+
+llvm::remarks::Type Remark::getRemarkType() const {
+ switch (remarkKind) {
+ case RemarkKind::RemarkUnknown:
+ return llvm::remarks::Type::Unknown;
+ case RemarkKind::RemarkPassed:
+ return llvm::remarks::Type::Passed;
+ case RemarkKind::RemarkMissed:
+ return llvm::remarks::Type::Missed;
+ case RemarkKind::RemarkFailure:
+ return llvm::remarks::Type::Failure;
+ case RemarkKind::RemarkAnalysis:
+ return llvm::remarks::Type::Analysis;
+ }
+ llvm_unreachable("Unknown remark kind");
+}
+
+llvm::remarks::Remark Remark::generateRemark() const {
+ auto locLambda = [&]() -> llvm::remarks::RemarkLocation {
+ if (auto flc = dyn_cast<FileLineColLoc>(getLocation()))
+ return {flc.getFilename(), flc.getLine(), flc.getColumn()};
+ return {"<unknown file>", 0, 0};
+ };
+
+ llvm::remarks::Remark r; // The result.
+ r.RemarkType = getRemarkType();
+ r.RemarkName = getRemarkName();
+ // MLIR does not use passes; instead, it has categories and sub-categories.
+ r.PassName = getFullCategoryName();
+ r.FunctionName = getFunction();
+ r.Loc = locLambda();
+ for (const Remark::Arg &arg : getArgs()) {
+ r.Args.emplace_back();
+ r.Args.back().Key = arg.key;
+ r.Args.back().Val = arg.val;
+ }
+ return r;
+}
+
+//===----------------------------------------------------------------------===//
+// InFlightRemark
+//===----------------------------------------------------------------------===//
+
+InFlightRemark::~InFlightRemark() {
+ if (remark && owner)
+ owner->report(std::move(*remark));
+ owner = nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// Remark Engine
+//===----------------------------------------------------------------------===//
+
+template <typename RemarkT, typename... Args>
+InFlightRemark RemarkEngine::makeRemark(Args &&...args) {
+ static_assert(std::is_base_of_v<Remark, RemarkT>,
+ "RemarkT must derive from Remark");
+ return InFlightRemark(*this,
+ std::make_unique<RemarkT>(std::forward<Args>(args)...));
+}
+
+template <typename RemarkT>
+InFlightRemark
+RemarkEngine::emitIfEnabled(Location loc, RemarkOpts opts,
+ bool (RemarkEngine::*isEnabled)(StringRef) const) {
+ return (this->*isEnabled)(opts.categoryName) ? makeRemark<RemarkT>(loc, opts)
+ : InFlightRemark{};
+}
+
+bool RemarkEngine::isMissedOptRemarkEnabled(StringRef categoryName) const {
+ return missFilter && missFilter->match(categoryName);
+}
+
+bool RemarkEngine::isPassedOptRemarkEnabled(StringRef categoryName) const {
+ return passedFilter && passedFilter->match(categoryName);
+}
+
+bool RemarkEngine::isAnalysisOptRemarkEnabled(StringRef categoryName) const {
+ return analysisFilter && analysisFilter->match(categoryName);
+}
+
+bool RemarkEngine::isFailedOptRemarkEnabled(StringRef categoryName) const {
+ return failedFilter && failedFilter->match(categoryName);
+}
+
+InFlightRemark RemarkEngine::emitOptimizationRemark(Location loc,
+ RemarkOpts opts) {
+ return emitIfEnabled<OptRemarkPass>(loc, opts,
+ &RemarkEngine::isPassedOptRemarkEnabled);
+}
+
+InFlightRemark RemarkEngine::emitOptimizationRemarkMiss(Location loc,
+ RemarkOpts opts) {
+ return emitIfEnabled<OptRemarkMissed>(
+ loc, opts, &RemarkEngine::isMissedOptRemarkEnabled);
+}
+
+InFlightRemark RemarkEngine::emitOptimizationRemarkFailure(Location loc,
+ RemarkOpts opts) {
+ return emitIfEnabled<OptRemarkFailure>(
+ loc, opts, &RemarkEngine::isFailedOptRemarkEnabled);
+}
+
+InFlightRemark RemarkEngine::emitOptimizationRemarkAnalysis(Location loc,
+ RemarkOpts opts) {
+ return emitIfEnabled<OptRemarkAnalysis>(
+ loc, opts, &RemarkEngine::isAnalysisOptRemarkEnabled);
+}
+
+//===----------------------------------------------------------------------===//
+// RemarkEngine
+//===----------------------------------------------------------------------===//
+
+void RemarkEngine::report(const Remark &&remark) {
+ // Stream the remark
+ if (remarkStreamer)
+ remarkStreamer->streamOptimizationRemark(remark);
+
+ // Print using MLIR's diagnostic
+ if (printAsEmitRemarks)
+ emitRemark(remark.getLocation(), remark.getMsg());
+}
+
+RemarkEngine::~RemarkEngine() {
+ if (remarkStreamer)
+ remarkStreamer->finalize();
+}
+
+llvm::LogicalResult
+RemarkEngine::initialize(std::unique_ptr<MLIRRemarkStreamerBase> streamer,
+ std::string *errMsg) {
+ // If you need to validate categories/filters, do so here and set errMsg.
+ remarkStreamer = std::move(streamer);
+ return success();
+}
+
+RemarkEngine::RemarkEngine(bool printAsEmitRemarks,
+ const RemarkCategories &cats)
+ : printAsEmitRemarks(printAsEmitRemarks) {
+ if (cats.passed)
+ passedFilter = llvm::Regex(cats.passed.value());
+ if (cats.missed)
+ missFilter = llvm::Regex(cats.missed.value());
+ if (cats.analysis)
+ analysisFilter = llvm::Regex(cats.analysis.value());
+ if (cats.failed)
+ failedFilter = llvm::Regex(cats.failed.value());
+}
+
+llvm::LogicalResult mlir::remark::enableOptimizationRemarks(
+ MLIRContext &ctx,
+ std::unique_ptr<remark::detail::MLIRRemarkStreamerBase> streamer,
+ const remark::RemarkCategories &cats, bool printAsEmitRemarks) {
+ auto engine =
+ std::make_unique<remark::detail::RemarkEngine>(printAsEmitRemarks, cats);
+
+ std::string errMsg;
+ if (failed(engine->initialize(std::move(streamer), &errMsg))) {
+ llvm::report_fatal_error(
+ llvm::Twine("Failed to initialize remark engine. Error: ") + errMsg);
+ }
+ ctx.setRemarkEngine(std::move(engine));
+
+ return success();
+}
diff --git a/mlir/lib/Interfaces/SideEffectInterfaces.cpp b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
index 266f6db..b5a6888 100644
--- a/mlir/lib/Interfaces/SideEffectInterfaces.cpp
+++ b/mlir/lib/Interfaces/SideEffectInterfaces.cpp
@@ -304,6 +304,11 @@ template bool
mlir::hasEffect<BlockArgument, MemoryEffects::Write, MemoryEffects::Free>(
Operation *, BlockArgument);
+bool mlir::hasUnknownEffects(Operation *op) {
+ return !isa<MemoryEffectOpInterface>(op) &&
+ !op->hasTrait<OpTrait::HasRecursiveMemoryEffects>();
+}
+
bool mlir::wouldOpBeTriviallyDead(Operation *op) {
if (op->mightHaveTrait<OpTrait::IsTerminator>())
return false;
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 2f47939..af4ea5a 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -290,8 +290,7 @@ static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,
DivisionFixupFn fixup) {
const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
&rhsMax = rhs.umax();
-
- if (!rhsMin.isZero()) {
+ if (!rhsMin.isZero() && !rhsMax.isZero()) {
auto udiv = [&fixup](const APInt &a,
const APInt &b) -> std::optional<APInt> {
return fixup(a, b, a.udiv(b));
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index c9481fb..caa9091 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -813,7 +813,7 @@ FailureOr<bool> ValueBoundsConstraintSet::strongCompare(const Variable &lhs,
return false;
// Keep processing as long as the strong relation cannot be proven.
FailureOr<bool> ordered = cstr.strongComparePos(lhsPos, cmp, rhsPos);
- return failed(ordered) ? true : false;
+ return failed(ordered);
};
ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands);
diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp
index 98ed4cc..ec18c48 100644
--- a/mlir/lib/Query/Matcher/VariantValue.cpp
+++ b/mlir/lib/Query/Matcher/VariantValue.cpp
@@ -35,7 +35,7 @@ public:
std::optional<DynMatcher> getDynMatcher() const override {
std::vector<DynMatcher> dynMatchers;
- for (auto variantMatcher : args) {
+ for (const auto &variantMatcher : args) {
std::optional<DynMatcher> dynMatcher = variantMatcher.getDynMatcher();
if (dynMatcher)
dynMatchers.push_back(dynMatcher.value());
@@ -66,8 +66,7 @@ VariantMatcher VariantMatcher::SingleMatcher(DynMatcher matcher) {
VariantMatcher
VariantMatcher::VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp,
ArrayRef<VariantMatcher> args) {
- return VariantMatcher(
- std::make_shared<VariadicOpPayload>(varOp, std::move(args)));
+ return VariantMatcher(std::make_shared<VariadicOpPayload>(varOp, args));
}
std::optional<DynMatcher> VariantMatcher::MatcherOps::constructVariadicOperator(
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index 03e4177..375e820 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -141,7 +141,7 @@ LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
os << "\n";
for (auto &results : matches) {
os << "Match #" << ++matchCount << ":\n\n";
- for (auto op : results.matchedOps) {
+ for (Operation *op : results.matchedOps) {
if (op == results.rootOp) {
finder.printMatch(os, qs, op, "root");
} else {
diff --git a/mlir/lib/RegisterAllDialects.cpp b/mlir/lib/RegisterAllDialects.cpp
index 950b85e2..258fed1 100644
--- a/mlir/lib/RegisterAllDialects.cpp
+++ b/mlir/lib/RegisterAllDialects.cpp
@@ -102,6 +102,7 @@
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Target/LLVM/NVVM/Target.h"
#include "mlir/Target/LLVM/ROCDL/Target.h"
+#include "mlir/Target/LLVM/XeVM/Target.h"
#include "mlir/Target/SPIRV/Target.h"
/// Add all the MLIR dialects to the provided registry.
@@ -199,6 +200,7 @@ void mlir::registerAllDialects(DialectRegistry &registry) {
NVVM::registerNVVMTargetInterfaceExternalModels(registry);
ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
spirv::registerSPIRVTargetInterfaceExternalModels(registry);
+ xevm::registerXeVMTargetInterfaceExternalModels(registry);
}
/// Append all the MLIR dialects to the registry contained in the given context.
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index 8f7c67c..69a85db 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -28,6 +28,7 @@
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
+#include "mlir/Conversion/PtrToLLVM/PtrToLLVM.h"
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
@@ -58,6 +59,7 @@
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h"
/// This function may be called to register all MLIR dialect extensions with the
/// provided registry.
@@ -80,6 +82,7 @@ void mlir::registerAllExtensions(DialectRegistry &registry) {
registerConvertMemRefToEmitCInterface(registry);
registerConvertMemRefToLLVMInterface(registry);
registerConvertNVVMToLLVMInterface(registry);
+ ptr::registerConvertPtrToLLVMInterface(registry);
registerConvertOpenMPToLLVMInterface(registry);
registerConvertSCFToEmitCInterface(registry);
ub::registerConvertUBToLLVMInterface(registry);
diff --git a/mlir/lib/RegisterAllPasses.cpp b/mlir/lib/RegisterAllPasses.cpp
index 1ed3a37..c67b242 100644
--- a/mlir/lib/RegisterAllPasses.cpp
+++ b/mlir/lib/RegisterAllPasses.cpp
@@ -45,6 +45,7 @@
#include "mlir/Dialect/Transform/Transforms/Passes.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Target/LLVMIR/Transforms/Passes.h"
#include "mlir/Transforms/Passes.h"
// This function may be called to register the MLIR passes with the
@@ -74,6 +75,7 @@ void mlir::registerAllPasses() {
registerNVGPUPasses();
registerSparseTensorPasses();
LLVM::registerLLVMPasses();
+ LLVM::registerTargetLLVMIRTransformsPasses();
math::registerMathPasses();
memref::registerMemRefPasses();
shard::registerShardPasses();
diff --git a/mlir/lib/Remark/CMakeLists.txt b/mlir/lib/Remark/CMakeLists.txt
new file mode 100644
index 0000000..920a95d
--- /dev/null
+++ b/mlir/lib/Remark/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_library(MLIRRemarkStreamer
+ RemarkStreamer.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Remark
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+
+ LINK_COMPONENTS
+ Remarks
+ Core
+ BitstreamReader
+ )
diff --git a/mlir/lib/Remark/RemarkStreamer.cpp b/mlir/lib/Remark/RemarkStreamer.cpp
new file mode 100644
index 0000000..8e3544f
--- /dev/null
+++ b/mlir/lib/Remark/RemarkStreamer.cpp
@@ -0,0 +1,69 @@
+#include "mlir/Remark/RemarkStreamer.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Remarks.h"
+
+#include "llvm/Remarks/RemarkSerializer.h"
+#include "llvm/Remarks/RemarkStreamer.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+namespace mlir::remark::detail {
+
+FailureOr<std::unique_ptr<MLIRRemarkStreamerBase>>
+LLVMRemarkStreamer::createToFile(llvm::StringRef path,
+ llvm::remarks::Format fmt) {
+ std::error_code ec;
+ // Use error_code ctor; YAML is text. (Bitstream also works fine here.)
+ auto f =
+ std::make_unique<llvm::ToolOutputFile>(path, ec, llvm::sys::fs::OF_Text);
+ if (ec)
+ return failure();
+
+ auto serOr = llvm::remarks::createRemarkSerializer(
+ fmt, llvm::remarks::SerializerMode::Separate, f->os());
+ if (!serOr) {
+ llvm::consumeError(serOr.takeError());
+ return failure();
+ }
+
+ auto rs =
+ std::make_unique<llvm::remarks::RemarkStreamer>(std::move(*serOr), path);
+
+ auto impl = std::unique_ptr<LLVMRemarkStreamer>(new LLVMRemarkStreamer());
+ impl->remarkStreamer = std::move(rs);
+ impl->file = std::move(f);
+ return std::unique_ptr<MLIRRemarkStreamerBase>(std::move(impl));
+}
+
+void LLVMRemarkStreamer::streamOptimizationRemark(const Remark &remark) {
+ if (!remarkStreamer->matchesFilter(remark.getCategoryName()))
+ return;
+
+ // First, convert the diagnostic to a remark.
+ llvm::remarks::Remark r = remark.generateRemark();
+ // Then, emit the remark through the serializer.
+ remarkStreamer->getSerializer().emit(r);
+}
+
+LLVMRemarkStreamer::~LLVMRemarkStreamer() {
+ if (file && remarkStreamer)
+ file->keep();
+}
+} // namespace mlir::remark::detail
+
+namespace mlir::remark {
+LogicalResult enableOptimizationRemarksWithLLVMStreamer(
+ MLIRContext &ctx, StringRef path, llvm::remarks::Format fmt,
+ const RemarkCategories &cat, bool printAsEmitRemarks) {
+
+ FailureOr<std::unique_ptr<detail::MLIRRemarkStreamerBase>> sOr =
+ detail::LLVMRemarkStreamer::createToFile(path, fmt);
+ if (failed(sOr))
+ return failure();
+
+ return remark::enableOptimizationRemarks(ctx, std::move(*sOr), cat,
+ printAsEmitRemarks);
+}
+
+} // namespace mlir::remark
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index ae3f22d..5cbea5d 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -20,8 +20,10 @@
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/InterleavedRange.h"
#include <numeric>
#include <optional>
@@ -707,10 +709,8 @@ void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
}
// Print the index usage and ensure that we did not run out of index space.
- LLVM_DEBUG({
- llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
- << "(down from initial " << valueDefRanges.size() << ").\n";
- });
+ LDBG() << "Allocated " << allocatedIndices.size() << " indices "
+ << "(down from initial " << valueDefRanges.size() << ").";
assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
"Ran out of memory for allocated indices");
@@ -736,6 +736,7 @@ void Generator::generate(Region *region, ByteCodeWriter &writer) {
}
void Generator::generate(Operation *op, ByteCodeWriter &writer) {
+ LDBG() << "Generating bytecode for operation: " << op->getName();
LLVM_DEBUG({
// The following list must contain all the operations that do not
// produce any bytecode.
@@ -1275,12 +1276,8 @@ private:
/// Handle a switch operation with the provided value and cases.
template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
- LLVM_DEBUG({
- llvm::dbgs() << " * Value: " << value << "\n"
- << " * Cases: ";
- llvm::interleaveComma(cases, llvm::dbgs());
- llvm::dbgs() << "\n";
- });
+ LDBG() << "Switch operation:\n * Value: " << value
+ << "\n * Cases: " << llvm::interleaved(cases);
// Check to see if the attribute value is within the case list. Jump to
// the correct successor index based on the result.
@@ -1424,38 +1421,27 @@ private:
} // namespace
void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
- LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
+ LDBG() << "Executing ApplyConstraint:";
ByteCodeField fun_idx = read();
SmallVector<PDLValue, 16> args;
readList<PDLValue>(args);
- LLVM_DEBUG({
- llvm::dbgs() << " * Arguments: ";
- llvm::interleaveComma(args, llvm::dbgs());
- llvm::dbgs() << "\n";
- });
+ LDBG() << " * Arguments: " << llvm::interleaved(args);
ByteCodeField isNegated = read();
- LLVM_DEBUG({
- llvm::dbgs() << " * isNegated: " << isNegated << "\n";
- llvm::interleaveComma(args, llvm::dbgs());
- });
+ LDBG() << " * isNegated: " << isNegated;
ByteCodeField numResults = read();
const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
ByteCodeRewriteResultList results(numResults);
LogicalResult rewriteResult = constraintFn(rewriter, results, args);
[[maybe_unused]] ArrayRef<PDLValue> constraintResults = results.getResults();
- LLVM_DEBUG({
- if (succeeded(rewriteResult)) {
- llvm::dbgs() << " * Constraint succeeded\n";
- llvm::dbgs() << " * Results: ";
- llvm::interleaveComma(constraintResults, llvm::dbgs());
- llvm::dbgs() << "\n";
- } else {
- llvm::dbgs() << " * Constraint failed\n";
- }
- });
+ if (succeeded(rewriteResult)) {
+ LDBG() << " * Constraint succeeded, results: "
+ << llvm::interleaved(constraintResults);
+ } else {
+ LDBG() << " * Constraint failed";
+ }
assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
"native PDL rewrite function succeeded but returned "
"unexpected number of results");
@@ -1466,15 +1452,12 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
}
LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
- LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
+ LDBG() << "Executing ApplyRewrite:";
const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
SmallVector<PDLValue, 16> args;
readList<PDLValue>(args);
- LLVM_DEBUG({
- llvm::dbgs() << " * Arguments: ";
- llvm::interleaveComma(args, llvm::dbgs());
- });
+ LDBG() << " * Arguments: " << llvm::interleaved(args);
// Execute the rewrite function.
ByteCodeField numResults = read();
@@ -1487,7 +1470,7 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
processNativeFunResults(results, numResults, rewriteResult);
if (failed(rewriteResult)) {
- LLVM_DEBUG(llvm::dbgs() << " - Failed");
+ LDBG() << " - Failed";
return failure();
}
return success();
@@ -1516,7 +1499,7 @@ void ByteCodeExecutor::processNativeFunResults(
PDLValue::Kind resultKind = read<PDLValue::Kind>();
(void)resultKind;
PDLValue result = results.getResults()[resultIdx];
- LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
+ LDBG() << " * Result: " << result;
assert(result.getKind() == resultKind &&
"native PDL rewrite function returned an unexpected type of "
"result");
@@ -1544,16 +1527,16 @@ void ByteCodeExecutor::processNativeFunResults(
}
void ByteCodeExecutor::executeAreEqual() {
- LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
+ LDBG() << "Executing AreEqual:";
const void *lhs = read<const void *>();
const void *rhs = read<const void *>();
- LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n");
+ LDBG() << " * " << lhs << " == " << rhs;
selectJump(lhs == rhs);
}
void ByteCodeExecutor::executeAreRangesEqual() {
- LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
+ LDBG() << "Executing AreRangesEqual:";
PDLValue::Kind valueKind = read<PDLValue::Kind>();
const void *lhs = read<const void *>();
const void *rhs = read<const void *>();
@@ -1562,14 +1545,14 @@ void ByteCodeExecutor::executeAreRangesEqual() {
case PDLValue::Kind::TypeRange: {
const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
- LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
+ LDBG() << " * " << lhs << " == " << rhs;
selectJump(*lhsRange == *rhsRange);
break;
}
case PDLValue::Kind::ValueRange: {
const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
- LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
+ LDBG() << " * " << lhs << " == " << rhs;
selectJump(*lhsRange == *rhsRange);
break;
}
@@ -1579,20 +1562,19 @@ void ByteCodeExecutor::executeAreRangesEqual() {
}
void ByteCodeExecutor::executeBranch() {
- LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
+ LDBG() << "Executing Branch";
curCodeIt = &code[read<ByteCodeAddr>()];
}
void ByteCodeExecutor::executeCheckOperandCount() {
- LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
+ LDBG() << "Executing CheckOperandCount:";
Operation *op = read<Operation *>();
uint32_t expectedCount = read<uint32_t>();
bool compareAtLeast = read();
- LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n"
- << " * Expected: " << expectedCount << "\n"
- << " * Comparator: "
- << (compareAtLeast ? ">=" : "==") << "\n");
+ LDBG() << " * Found: " << op->getNumOperands()
+ << "\n * Expected: " << expectedCount
+ << "\n * Comparator: " << (compareAtLeast ? ">=" : "==");
if (compareAtLeast)
selectJump(op->getNumOperands() >= expectedCount);
else
@@ -1600,25 +1582,24 @@ void ByteCodeExecutor::executeCheckOperandCount() {
}
void ByteCodeExecutor::executeCheckOperationName() {
- LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
+ LDBG() << "Executing CheckOperationName:";
Operation *op = read<Operation *>();
OperationName expectedName = read<OperationName>();
- LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n"
- << " * Expected: \"" << expectedName << "\"\n");
+ LDBG() << " * Found: \"" << op->getName() << "\"\n * Expected: \""
+ << expectedName << "\"";
selectJump(op->getName() == expectedName);
}
void ByteCodeExecutor::executeCheckResultCount() {
- LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
+ LDBG() << "Executing CheckResultCount:";
Operation *op = read<Operation *>();
uint32_t expectedCount = read<uint32_t>();
bool compareAtLeast = read();
- LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n"
- << " * Expected: " << expectedCount << "\n"
- << " * Comparator: "
- << (compareAtLeast ? ">=" : "==") << "\n");
+ LDBG() << " * Found: " << op->getNumResults()
+ << "\n * Expected: " << expectedCount
+ << "\n * Comparator: " << (compareAtLeast ? ">=" : "==");
if (compareAtLeast)
selectJump(op->getNumResults() >= expectedCount);
else
@@ -1626,36 +1607,35 @@ void ByteCodeExecutor::executeCheckResultCount() {
}
void ByteCodeExecutor::executeCheckTypes() {
- LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
+ LDBG() << "Executing AreEqual:";
TypeRange *lhs = read<TypeRange *>();
Attribute rhs = read<Attribute>();
- LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
+ LDBG() << " * " << lhs << " == " << rhs;
selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
}
void ByteCodeExecutor::executeContinue() {
ByteCodeField level = read();
- LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
- << " * Level: " << level << "\n");
+ LDBG() << "Executing Continue\n * Level: " << level;
++loopIndex[level];
popCodeIt();
}
void ByteCodeExecutor::executeCreateConstantTypeRange() {
- LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n");
+ LDBG() << "Executing CreateConstantTypeRange:";
unsigned memIndex = read();
unsigned rangeIndex = read();
ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
- LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n");
+ LDBG() << " * Types: " << typesAttr;
assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
rangeIndex);
}
void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
Location mainRewriteLoc) {
- LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
+ LDBG() << "Executing CreateOperation:";
unsigned memIndex = read();
OperationState state(mainRewriteLoc, read<OperationName>());
@@ -1696,45 +1676,37 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
Operation *resultOp = rewriter.create(state);
memory[memIndex] = resultOp;
- LLVM_DEBUG({
- llvm::dbgs() << " * Attributes: "
- << state.attributes.getDictionary(state.getContext())
- << "\n * Operands: ";
- llvm::interleaveComma(state.operands, llvm::dbgs());
- llvm::dbgs() << "\n * Result Types: ";
- llvm::interleaveComma(state.types, llvm::dbgs());
- llvm::dbgs() << "\n * Result: " << *resultOp << "\n";
- });
+ LDBG() << " * Attributes: "
+ << state.attributes.getDictionary(state.getContext())
+ << "\n * Operands: " << llvm::interleaved(state.operands)
+ << "\n * Result Types: " << llvm::interleaved(state.types)
+ << "\n * Result: " << *resultOp;
}
template <typename T>
void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
- LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n");
+ LDBG() << "Executing CreateDynamic" << type << "Range:";
unsigned memIndex = read();
unsigned rangeIndex = read();
SmallVector<T> values;
readList(values);
- LLVM_DEBUG({
- llvm::dbgs() << "\n * " << type << "s: ";
- llvm::interleaveComma(values, llvm::dbgs());
- llvm::dbgs() << "\n";
- });
+ LDBG() << " * " << type << "s: " << llvm::interleaved(values);
assignRangeToMemory(values, memIndex, rangeIndex);
}
void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
- LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
+ LDBG() << "Executing EraseOp:";
Operation *op = read<Operation *>();
- LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
+ LDBG() << " * Operation: " << *op;
rewriter.eraseOp(op);
}
template <typename T, typename Range, PDLValue::Kind kind>
void ByteCodeExecutor::executeExtract() {
- LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
+ LDBG() << "Executing Extract" << kind << ":";
Range *range = read<Range *>();
unsigned index = read<uint32_t>();
unsigned memIndex = read();
@@ -1745,18 +1717,16 @@ void ByteCodeExecutor::executeExtract() {
}
T result = index < range->size() ? (*range)[index] : T();
- LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n"
- << " * Index: " << index << "\n"
- << " * Result: " << result << "\n");
+ LDBG() << " * " << kind << "s(" << range->size() << ")";
+ LDBG() << " * Index: " << index;
+ LDBG() << " * Result: " << result;
storeToMemory(memIndex, result);
}
-void ByteCodeExecutor::executeFinalize() {
- LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
-}
+void ByteCodeExecutor::executeFinalize() { LDBG() << "Executing Finalize"; }
void ByteCodeExecutor::executeForEach() {
- LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
+ LDBG() << "Executing ForEach:";
const ByteCodeField *prevCodeIt = getPrevCodeIt();
unsigned rangeIndex = read();
unsigned memIndex = read();
@@ -1768,12 +1738,12 @@ void ByteCodeExecutor::executeForEach() {
ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
assert(index <= array.size() && "iterated past the end");
if (index < array.size()) {
- LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n");
+ LDBG() << " * Result: " << array[index];
value = array[index];
break;
}
- LLVM_DEBUG(llvm::dbgs() << " * Done\n");
+ LDBG() << " * Done";
index = 0;
selectJump(size_t(0));
return;
@@ -1791,49 +1761,47 @@ void ByteCodeExecutor::executeForEach() {
}
void ByteCodeExecutor::executeGetAttribute() {
- LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
+ LDBG() << "Executing GetAttribute:";
unsigned memIndex = read();
Operation *op = read<Operation *>();
StringAttr attrName = read<StringAttr>();
Attribute attr = op->getAttr(attrName);
- LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
- << " * Attribute: " << attrName << "\n"
- << " * Result: " << attr << "\n");
+ LDBG() << " * Operation: " << *op << "\n * Attribute: " << attrName
+ << "\n * Result: " << attr;
memory[memIndex] = attr.getAsOpaquePointer();
}
void ByteCodeExecutor::executeGetAttributeType() {
- LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
+ LDBG() << "Executing GetAttributeType:";
unsigned memIndex = read();
Attribute attr = read<Attribute>();
Type type;
if (auto typedAttr = dyn_cast<TypedAttr>(attr))
type = typedAttr.getType();
- LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
- << " * Result: " << type << "\n");
+ LDBG() << " * Attribute: " << attr << "\n * Result: " << type;
memory[memIndex] = type.getAsOpaquePointer();
}
void ByteCodeExecutor::executeGetDefiningOp() {
- LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
+ LDBG() << "Executing GetDefiningOp:";
unsigned memIndex = read();
Operation *op = nullptr;
if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
Value value = read<Value>();
if (value)
op = value.getDefiningOp();
- LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
+ LDBG() << " * Value: " << value;
} else {
ValueRange *values = read<ValueRange *>();
if (values && !values->empty()) {
op = values->front().getDefiningOp();
}
- LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n");
+ LDBG() << " * Values: " << values;
}
- LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n");
+ LDBG() << " * Result: " << op;
memory[memIndex] = op;
}
@@ -1843,9 +1811,8 @@ void ByteCodeExecutor::executeGetOperand(unsigned index) {
Value operand =
index < op->getNumOperands() ? op->getOperand(index) : Value();
- LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
- << " * Index: " << index << "\n"
- << " * Result: " << operand << "\n");
+ LDBG() << " * Operation: " << *op << "\n * Index: " << index
+ << "\n * Result: " << operand;
memory[memIndex] = operand.getAsOpaquePointer();
}
@@ -1860,13 +1827,12 @@ executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
// Check for the sentinel index that signals that all values should be
// returned.
if (index == std::numeric_limits<uint32_t>::max()) {
- LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n");
+ LDBG() << " * Getting all values";
// `values` is already the full value range.
// Otherwise, check to see if this operation uses AttrSizedSegments.
} else if (op->hasTrait<AttrSizedSegmentsT>()) {
- LLVM_DEBUG(llvm::dbgs()
- << " * Extracting values from `" << attrSizedSegments << "`\n");
+ LDBG() << " * Extracting values from `" << attrSizedSegments << "`";
auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments);
if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
@@ -1877,16 +1843,15 @@ executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
std::accumulate(segments.begin(), segments.begin() + index, 0);
values = values.slice(startIndex, *std::next(segments.begin(), index));
- LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", "
- << *std::next(segments.begin(), index) << "]\n");
+ LDBG() << " * Extracting range[" << startIndex << ", "
+ << *std::next(segments.begin(), index) << "]";
// Otherwise, assume this is the last operand group of the operation.
// FIXME: We currently don't support operations with
// SameVariadicOperandSize/SameVariadicResultSize here given that we don't
// have a way to detect it's presence.
} else if (values.size() >= index) {
- LLVM_DEBUG(llvm::dbgs()
- << " * Treating values as trailing variadic range\n");
+ LDBG() << " * Treating values as trailing variadic range";
values = values.drop_front(index);
// If we couldn't detect a way to compute the values, bail out.
@@ -1905,7 +1870,7 @@ executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
}
void ByteCodeExecutor::executeGetOperands() {
- LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
+ LDBG() << "Executing GetOperands:";
unsigned index = read<uint32_t>();
Operation *op = read<Operation *>();
ByteCodeField rangeIndex = read();
@@ -1914,7 +1879,7 @@ void ByteCodeExecutor::executeGetOperands() {
op->getOperands(), op, index, rangeIndex, "operandSegmentSizes",
valueRangeMemory);
if (!result)
- LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n");
+ LDBG() << " * Invalid operand range";
memory[read()] = result;
}
@@ -1924,14 +1889,13 @@ void ByteCodeExecutor::executeGetResult(unsigned index) {
OpResult result =
index < op->getNumResults() ? op->getResult(index) : OpResult();
- LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
- << " * Index: " << index << "\n"
- << " * Result: " << result << "\n");
+ LDBG() << " * Operation: " << *op << "\n * Index: " << index
+ << "\n * Result: " << result;
memory[memIndex] = result.getAsOpaquePointer();
}
void ByteCodeExecutor::executeGetResults() {
- LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
+ LDBG() << "Executing GetResults:";
unsigned index = read<uint32_t>();
Operation *op = read<Operation *>();
ByteCodeField rangeIndex = read();
@@ -1940,12 +1904,12 @@ void ByteCodeExecutor::executeGetResults() {
op->getResults(), op, index, rangeIndex, "resultSegmentSizes",
valueRangeMemory);
if (!result)
- LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n");
+ LDBG() << " * Invalid result range";
memory[read()] = result;
}
void ByteCodeExecutor::executeGetUsers() {
- LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
+ LDBG() << "Executing GetUsers:";
unsigned memIndex = read();
unsigned rangeIndex = read();
OwningOpRange &range = opRangeMemory[rangeIndex];
@@ -1957,7 +1921,7 @@ void ByteCodeExecutor::executeGetUsers() {
Value value = read<Value>();
if (!value)
return;
- LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
+ LDBG() << " * Value: " << value;
// Extract the users of a single value.
range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
@@ -1967,11 +1931,8 @@ void ByteCodeExecutor::executeGetUsers() {
ValueRange *values = read<ValueRange *>();
if (!values)
return;
- LLVM_DEBUG({
- llvm::dbgs() << " * Values (" << values->size() << "): ";
- llvm::interleaveComma(*values, llvm::dbgs());
- llvm::dbgs() << "\n";
- });
+ LDBG() << " * Values (" << values->size()
+ << "): " << llvm::interleaved(*values);
// Extract all the users of a range of values.
SmallVector<Operation *> users;
@@ -1981,54 +1942,49 @@ void ByteCodeExecutor::executeGetUsers() {
llvm::copy(users, range.begin());
}
- LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n");
+ LDBG() << " * Result: " << range.size() << " operations";
}
void ByteCodeExecutor::executeGetValueType() {
- LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
+ LDBG() << "Executing GetValueType:";
unsigned memIndex = read();
Value value = read<Value>();
Type type = value ? value.getType() : Type();
- LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
- << " * Result: " << type << "\n");
+ LDBG() << " * Value: " << value << "\n * Result: " << type;
memory[memIndex] = type.getAsOpaquePointer();
}
void ByteCodeExecutor::executeGetValueRangeTypes() {
- LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
+ LDBG() << "Executing GetValueRangeTypes:";
unsigned memIndex = read();
unsigned rangeIndex = read();
ValueRange *values = read<ValueRange *>();
if (!values) {
- LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n");
+ LDBG() << " * Values: <NULL>";
memory[memIndex] = nullptr;
return;
}
- LLVM_DEBUG({
- llvm::dbgs() << " * Values (" << values->size() << "): ";
- llvm::interleaveComma(*values, llvm::dbgs());
- llvm::dbgs() << "\n * Result: ";
- llvm::interleaveComma(values->getType(), llvm::dbgs());
- llvm::dbgs() << "\n";
- });
+ LDBG() << " * Values (" << values->size()
+ << "): " << llvm::interleaved(*values)
+ << "\n * Result: " << llvm::interleaved(values->getType());
typeRangeMemory[rangeIndex] = values->getType();
memory[memIndex] = &typeRangeMemory[rangeIndex];
}
void ByteCodeExecutor::executeIsNotNull() {
- LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
+ LDBG() << "Executing IsNotNull:";
const void *value = read<const void *>();
- LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
+ LDBG() << " * Value: " << value;
selectJump(value != nullptr);
}
void ByteCodeExecutor::executeRecordMatch(
PatternRewriter &rewriter,
SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
- LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
+ LDBG() << "Executing RecordMatch:";
unsigned patternIndex = read();
PatternBenefit benefit = currentPatternBenefits[patternIndex];
const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
@@ -2036,7 +1992,7 @@ void ByteCodeExecutor::executeRecordMatch(
// If the benefit of the pattern is impossible, skip the processing of the
// rest of the pattern.
if (benefit.isImpossibleToMatch()) {
- LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n");
+ LDBG() << " * Benefit: Impossible To Match";
curCodeIt = dest;
return;
}
@@ -2052,8 +2008,8 @@ void ByteCodeExecutor::executeRecordMatch(
matchLocs.push_back(read<Operation *>()->getLoc());
Location matchLoc = rewriter.getFusedLoc(matchLocs);
- LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n"
- << " * Location: " << matchLoc << "\n");
+ LDBG() << " * Benefit: " << benefit.getBenefit();
+ LDBG() << " * Location: " << matchLoc;
matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
PDLByteCode::MatchResult &match = matches.back();
@@ -2083,38 +2039,34 @@ void ByteCodeExecutor::executeRecordMatch(
}
void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
- LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
+ LDBG() << "Executing ReplaceOp:";
Operation *op = read<Operation *>();
SmallVector<Value, 16> args;
readList(args);
- LLVM_DEBUG({
- llvm::dbgs() << " * Operation: " << *op << "\n"
- << " * Values: ";
- llvm::interleaveComma(args, llvm::dbgs());
- llvm::dbgs() << "\n";
- });
+ LDBG() << " * Operation: " << *op
+ << "\n * Values: " << llvm::interleaved(args);
rewriter.replaceOp(op, args);
}
void ByteCodeExecutor::executeSwitchAttribute() {
- LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
+ LDBG() << "Executing SwitchAttribute:";
Attribute value = read<Attribute>();
ArrayAttr cases = read<ArrayAttr>();
handleSwitch(value, cases);
}
void ByteCodeExecutor::executeSwitchOperandCount() {
- LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
+ LDBG() << "Executing SwitchOperandCount:";
Operation *op = read<Operation *>();
auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
- LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
+ LDBG() << " * Operation: " << *op;
handleSwitch(op->getNumOperands(), cases);
}
void ByteCodeExecutor::executeSwitchOperationName() {
- LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
+ LDBG() << "Executing SwitchOperationName:";
OperationName value = read<Operation *>()->getName();
size_t caseCount = read();
@@ -2123,13 +2075,11 @@ void ByteCodeExecutor::executeSwitchOperationName() {
// switch so that we can display all of the possible values.
LLVM_DEBUG({
const ByteCodeField *prevCodeIt = curCodeIt;
- llvm::dbgs() << " * Value: " << value << "\n"
- << " * Cases: ";
- llvm::interleaveComma(
- llvm::map_range(llvm::seq<size_t>(0, caseCount),
- [&](size_t) { return read<OperationName>(); }),
- llvm::dbgs());
- llvm::dbgs() << "\n";
+ LDBG() << " * Value: " << value << "\n * Cases: "
+ << llvm::interleaved(
+ llvm::map_range(llvm::seq<size_t>(0, caseCount), [&](size_t) {
+ return read<OperationName>();
+ }));
curCodeIt = prevCodeIt;
});
@@ -2144,27 +2094,27 @@ void ByteCodeExecutor::executeSwitchOperationName() {
}
void ByteCodeExecutor::executeSwitchResultCount() {
- LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
+ LDBG() << "Executing SwitchResultCount:";
Operation *op = read<Operation *>();
auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
- LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
+ LDBG() << " * Operation: " << *op;
handleSwitch(op->getNumResults(), cases);
}
void ByteCodeExecutor::executeSwitchType() {
- LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
+ LDBG() << "Executing SwitchType:";
Type value = read<Type>();
auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
handleSwitch(value, cases);
}
void ByteCodeExecutor::executeSwitchTypes() {
- LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
+ LDBG() << "Executing SwitchTypes:";
TypeRange *value = read<TypeRange *>();
auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
if (!value) {
- LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
+ LDBG() << "Types: <NULL>";
return selectJump(size_t(0));
}
handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
@@ -2178,7 +2128,7 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter,
std::optional<Location> mainRewriteLoc) {
while (true) {
// Print the location of the operation being executed.
- LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
+ LDBG() << readInline<Location>();
OpCode opCode = static_cast<OpCode>(read());
switch (opCode) {
@@ -2239,7 +2189,7 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter,
break;
case Finalize:
executeFinalize();
- LLVM_DEBUG(llvm::dbgs() << "\n");
+ LDBG() << "";
return success();
case ForEach:
executeForEach();
@@ -2258,12 +2208,12 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter,
case GetOperand2:
case GetOperand3: {
unsigned index = opCode - GetOperand0;
- LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
+ LDBG() << "Executing GetOperand" << index << ":";
executeGetOperand(index);
break;
}
case GetOperandN:
- LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
+ LDBG() << "Executing GetOperandN:";
executeGetOperand(read<uint32_t>());
break;
case GetOperands:
@@ -2274,12 +2224,12 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter,
case GetResult2:
case GetResult3: {
unsigned index = opCode - GetResult0;
- LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
+ LDBG() << "Executing GetResult" << index << ":";
executeGetResult(index);
break;
}
case GetResultN:
- LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
+ LDBG() << "Executing GetResultN:";
executeGetResult(read<uint32_t>());
break;
case GetResults:
@@ -2324,7 +2274,7 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter,
executeSwitchTypes();
break;
}
- LLVM_DEBUG(llvm::dbgs() << "\n");
+ LDBG() << "";
}
}
@@ -2383,7 +2333,7 @@ LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter,
// bug in the user code (i.e. failable rewrites should not be used with
// pattern rewriters that don't support it).
if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
- LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting");
+ LDBG() << " and rollback is not supported - aborting";
llvm::report_fatal_error(
"Native PDL Rewrite failed, but the pattern "
"rewriter doesn't support recovery. Failable pattern rewrites should "
diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp
index e13bcff..23ae95a 100644
--- a/mlir/lib/Rewrite/PatternApplicator.cpp
+++ b/mlir/lib/Rewrite/PatternApplicator.cpp
@@ -37,9 +37,9 @@ PatternApplicator::~PatternApplicator() = default;
#ifndef NDEBUG
/// Log a message for a pattern that is impossible to match.
static void logImpossibleToMatch(const Pattern &pattern) {
- llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind()
- << "' because it is impossible to match or cannot lead "
- "to legal IR (by cost model)\n";
+ LDBG() << "Ignoring pattern '" << pattern.getRootKind()
+ << "' because it is impossible to match or cannot lead "
+ "to legal IR (by cost model)";
}
/// Log IR after pattern application.
diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt
index 6eb0abc..f0c3ac4 100644
--- a/mlir/lib/Target/CMakeLists.txt
+++ b/mlir/lib/Target/CMakeLists.txt
@@ -4,3 +4,4 @@ add_subdirectory(SPIRV)
add_subdirectory(LLVMIR)
add_subdirectory(LLVM)
add_subdirectory(SMTLIB)
+add_subdirectory(Wasm)
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 8e83e45..570f38c 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -14,6 +14,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
#include "mlir/Support/IndentedOstream.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/Cpp/CppEmitter.h"
@@ -35,7 +36,7 @@ using llvm::formatv;
/// on each element doesn't return a string.
template <typename ForwardIterator, typename UnaryFunctor,
typename NullaryFunctor>
-inline LogicalResult
+static inline LogicalResult
interleaveWithError(ForwardIterator begin, ForwardIterator end,
UnaryFunctor eachFn, NullaryFunctor betweenFn) {
if (begin == end)
@@ -52,16 +53,16 @@ interleaveWithError(ForwardIterator begin, ForwardIterator end,
}
template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
-inline LogicalResult interleaveWithError(const Container &c,
- UnaryFunctor eachFn,
- NullaryFunctor betweenFn) {
+static inline LogicalResult interleaveWithError(const Container &c,
+ UnaryFunctor eachFn,
+ NullaryFunctor betweenFn) {
return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
}
template <typename Container, typename UnaryFunctor>
-inline LogicalResult interleaveCommaWithError(const Container &c,
- raw_ostream &os,
- UnaryFunctor eachFn) {
+static inline LogicalResult interleaveCommaWithError(const Container &c,
+ raw_ostream &os,
+ UnaryFunctor eachFn) {
return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
}
@@ -364,9 +365,10 @@ static bool shouldBeInlined(ExpressionOp expressionOp) {
if (hasDeferredEmission(user))
return false;
- // Do not inline expressions used by ops with the CExpressionInterface. If
- // this was intended, the user could have been merged into the expression op.
- return !isa<emitc::CExpressionInterface>(*user);
+ // Do not inline expressions used by other expressions or by ops with the
+ // CExpressionInterface. If this was intended, the user could have been merged
+ // into the expression op.
+ return !isa<emitc::ExpressionOp, emitc::CExpressionInterface>(*user);
}
static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
@@ -749,11 +751,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
if (t.getType().isIndex()) {
int64_t idx = t.getInt();
Value operand = op.getOperand(idx);
- if (!emitter.hasValueInScope(operand))
- return op.emitOpError("operand ")
- << idx << "'s value not defined in scope";
- os << emitter.getOrCreateName(operand);
- return success();
+ return emitter.emitOperand(operand);
}
}
if (failed(emitter.emitAttribute(op.getLoc(), attr)))
@@ -782,9 +780,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
if (failed(emitter.emitAssignPrefix(op)))
return failure();
os << applyOp.getApplicableOperator();
- os << emitter.getOrCreateName(applyOp.getOperand());
-
- return success();
+ return emitter.emitOperand(applyOp.getOperand());
}
static LogicalResult printOperation(CppEmitter &emitter,
@@ -1447,7 +1443,7 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
}
if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
if (auto iType = dyn_cast<IntegerType>(
- cast<TensorType>(dense.getType()).getElementType())) {
+ cast<ShapedType>(dense.getType()).getElementType())) {
os << '{';
interleaveComma(dense, os, [&](const APInt &val) {
printInt(val, shouldMapToUnsigned(iType.getSignedness()));
@@ -1456,7 +1452,7 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
return success();
}
if (auto iType = dyn_cast<IndexType>(
- cast<TensorType>(dense.getType()).getElementType())) {
+ cast<ShapedType>(dense.getType()).getElementType())) {
os << '{';
interleaveComma(dense, os,
[&](const APInt &val) { printInt(val, false); });
@@ -1538,6 +1534,20 @@ LogicalResult CppEmitter::emitOperand(Value value) {
if (expressionOp && shouldBeInlined(expressionOp))
return emitExpression(expressionOp);
+ if (BlockArgument arg = dyn_cast<BlockArgument>(value)) {
+ // If this operand is a block argument of an expression, emit instead the
+ // matching expression parameter.
+ Operation *argOp = arg.getParentBlock()->getParentOp();
+ if (auto expressionOp = dyn_cast<ExpressionOp>(argOp)) {
+ // This scenario is only expected when one of the operations within the
+ // expression being emitted references one of the expression's block
+ // arguments.
+ assert(expressionOp == emittedExpression &&
+ "Expected expression being emitted");
+ value = expressionOp->getOperand(arg.getArgNumber());
+ }
+ }
+
os << getOrCreateName(value);
return success();
}
@@ -1793,7 +1803,7 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
case 16: {
if (llvm::isa<Float16Type>(type))
return (os << "_Float16"), success();
- else if (llvm::isa<BFloat16Type>(type))
+ if (llvm::isa<BFloat16Type>(type))
return (os << "__bf16"), success();
else
return emitError(loc, "cannot emit float type ") << type;
diff --git a/mlir/lib/Target/LLVM/CMakeLists.txt b/mlir/lib/Target/LLVM/CMakeLists.txt
index f6e44c6..9a0e4d4 100644
--- a/mlir/lib/Target/LLVM/CMakeLists.txt
+++ b/mlir/lib/Target/LLVM/CMakeLists.txt
@@ -210,3 +210,27 @@ if(MLIR_ENABLE_ROCM_CONVERSIONS)
)
endif()
+if ("SPIRV" IN_LIST LLVM_TARGETS_TO_BUILD)
+ set(SPIRV_LIBS
+ SPIRVCodeGen
+ SPIRVDesc
+ SPIRVInfo
+ )
+endif()
+
+add_mlir_dialect_library(MLIRXeVMTarget
+ XeVM/Target.cpp
+
+ OBJECT
+
+ LINK_COMPONENTS
+ ${SPIRV_LIBS}
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRExecutionEngineUtils
+ MLIRSupport
+ MLIRGPUDialect
+ MLIRTargetLLVM
+ MLIRXeVMToLLVMIRTranslation
+)
diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp
index 55c8a64..8760ea8 100644
--- a/mlir/lib/Target/LLVM/NVVM/Target.cpp
+++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp
@@ -24,9 +24,11 @@
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
+#include "llvm/Support/InterleavedRange.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Config/Targets.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/FormatVariadic.h"
@@ -265,6 +267,8 @@ NVPTXSerializer::NVPTXSerializer(Operation &module, NVVMTargetAttr target,
std::optional<NVPTXSerializer::TmpFile>
NVPTXSerializer::createTemp(StringRef name, StringRef suffix) {
llvm::SmallString<128> filename;
+ if (name.size() > 80)
+ name = name.substr(0, 80);
std::error_code ec =
llvm::sys::fs::createTemporaryFile(name, suffix, filename);
if (ec) {
@@ -452,17 +456,11 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
// Dump tool invocation commands.
#define DEBUG_TYPE "serialize-to-binary"
- LLVM_DEBUG({
- llvm::dbgs() << "Tool invocation for module: "
- << getOperation().getNameAttr() << "\n";
- llvm::dbgs() << "ptxas executable:" << ptxasCompiler.value() << "\n";
- llvm::interleave(ptxasArgs, llvm::dbgs(), " ");
- llvm::dbgs() << "\n";
- if (createFatbin) {
- llvm::interleave(fatbinArgs, llvm::dbgs(), " ");
- llvm::dbgs() << "\n";
- }
- });
+ LDBG() << "Tool invocation for module: " << getOperation().getNameAttr()
+ << "\nptxas executable:" << ptxasCompiler.value()
+ << "\nptxas args: " << llvm::interleaved(ptxasArgs, " ");
+ if (createFatbin)
+ LDBG() << "fatbin args: " << llvm::interleaved(fatbinArgs, " ");
#undef DEBUG_TYPE
// Helper function for printing tool error logs.
@@ -507,7 +505,7 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> logBuffer =
llvm::MemoryBuffer::getFile(logFile->first);
if (logBuffer && !(*logBuffer)->getBuffer().empty()) {
- llvm::dbgs() << "Output:\n" << (*logBuffer)->getBuffer() << "\n";
+ LDBG() << "Output:\n" << (*logBuffer)->getBuffer();
llvm::dbgs().flush();
}
});
@@ -529,7 +527,7 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> logBuffer =
llvm::MemoryBuffer::getFile(logFile->first);
if (logBuffer && !(*logBuffer)->getBuffer().empty()) {
- llvm::dbgs() << "Output:\n" << (*logBuffer)->getBuffer() << "\n";
+ LDBG() << "Output:\n" << (*logBuffer)->getBuffer();
llvm::dbgs().flush();
}
});
@@ -629,12 +627,11 @@ NVPTXSerializer::compileToBinaryNVPTX(const std::string &ptxCode) {
SmallVector<char> log(logSize + 1, 0);
RETURN_ON_NVPTXCOMPILER_ERROR(
nvPTXCompilerGetInfoLog(compiler, log.data()));
- llvm::dbgs() << "NVPTX compiler invocation for module: "
- << getOperation().getNameAttr() << "\n";
- llvm::dbgs() << "Arguments: ";
- llvm::interleave(cmdOpts.second, llvm::dbgs(), " ");
- llvm::dbgs() << "\nOutput\n" << log.data() << "\n";
- llvm::dbgs().flush();
+ LDBG() << "NVPTX compiler invocation for module: "
+ << getOperation().getNameAttr()
+ << "\nArguments: " << llvm::interleaved(cmdOpts.second, " ")
+ << "\nOutput\n"
+ << log.data();
}
});
#undef DEBUG_TYPE
@@ -678,10 +675,8 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
// Return LLVM IR if the compilation target is `offload`.
#define DEBUG_TYPE "serialize-to-llvm"
LLVM_DEBUG({
- llvm::dbgs() << "LLVM IR for module: " << getOperation().getNameAttr()
- << "\n";
- llvm::dbgs() << llvmModule << "\n";
- llvm::dbgs().flush();
+ LDBG() << "LLVM IR for module: " << getOperation().getNameAttr();
+ LDBG() << llvmModule;
});
#undef DEBUG_TYPE
if (targetOptions.getCompilationTarget() == gpu::CompilationTarget::Offload)
@@ -716,11 +711,8 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
isaCallback(serializedISA.value());
#define DEBUG_TYPE "serialize-to-isa"
- LLVM_DEBUG({
- llvm::dbgs() << "PTX for module: " << getOperation().getNameAttr() << "\n";
- llvm::dbgs() << *serializedISA << "\n";
- llvm::dbgs().flush();
- });
+ LDBG() << "PTX for module: " << getOperation().getNameAttr() << "\n"
+ << *serializedISA;
#undef DEBUG_TYPE
// Return PTX if the compilation target is `assembly`.
diff --git a/mlir/lib/Target/LLVM/XeVM/Target.cpp b/mlir/lib/Target/LLVM/XeVM/Target.cpp
new file mode 100644
index 0000000..1e6784a2
--- /dev/null
+++ b/mlir/lib/Target/LLVM/XeVM/Target.cpp
@@ -0,0 +1,418 @@
+//===- Target.cpp - MLIR LLVM XeVM target compilation -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This files defines XeVM target related functions including registration
+// calls for the `#xevm.target` compilation attribute.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVM/XeVM/Target.h"
+
+#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
+#include "mlir/Target/LLVM/XeVM/Utils.h"
+#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Export.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/Target/TargetMachine.h"
+
+#include "llvm/Bitcode/BitcodeWriter.h"
+#include "llvm/Config/Targets.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/FileUtilities.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/Path.h"
+#include "llvm/Support/Process.h"
+#include "llvm/Support/Program.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <cstdint>
+#include <cstdlib>
+
+using namespace mlir;
+using namespace mlir::xevm;
+
+namespace {
+// XeVM implementation of the gpu:TargetAttrInterface.
+class XeVMTargetAttrImpl
+ : public gpu::TargetAttrInterface::FallbackModel<XeVMTargetAttrImpl> {
+public:
+ std::optional<SmallVector<char, 0>>
+ serializeToObject(Attribute attribute, Operation *module,
+ const gpu::TargetOptions &options) const;
+
+ Attribute createObject(Attribute attribute, Operation *module,
+ const SmallVector<char, 0> &object,
+ const gpu::TargetOptions &options) const;
+};
+} // namespace
+
+void mlir::xevm::registerXeVMTargetInterfaceExternalModels(
+ DialectRegistry &registry) {
+ registry.addExtension(+[](MLIRContext *ctx, XeVMDialect *dialect) {
+ XeVMTargetAttr::attachInterface<XeVMTargetAttrImpl>(*ctx);
+ });
+}
+
+void mlir::xevm::registerXeVMTargetInterfaceExternalModels(
+ MLIRContext &context) {
+ DialectRegistry registry;
+ registerXeVMTargetInterfaceExternalModels(registry);
+ context.appendDialectRegistry(registry);
+}
+
+SerializeGPUModuleBase::SerializeGPUModuleBase(
+ Operation &module, XeVMTargetAttr xeTarget,
+ const gpu::TargetOptions &targetOptions)
+ : ModuleToObject(module, xeTarget.getTriple(), "", {}, xeTarget.getO()),
+ xeTarget(xeTarget), librariesToLink(targetOptions.getLibrariesToLink()),
+ targetOptions(targetOptions) {
+ if (xeTarget.getLinkFiles())
+ librariesToLink.append(xeTarget.getLinkFiles().begin(),
+ xeTarget.getLinkFiles().end());
+}
+
+XeVMTargetAttr SerializeGPUModuleBase::getTarget() const { return xeTarget; }
+
+std::optional<SmallVector<std::unique_ptr<llvm::Module>>>
+SerializeGPUModuleBase::loadBitcodeFiles(llvm::Module &module) {
+ if (librariesToLink.empty())
+ return SmallVector<std::unique_ptr<llvm::Module>>();
+ SmallVector<std::unique_ptr<llvm::Module>> bcFiles;
+ if (failed(loadBitcodeFilesFromList(module.getContext(), librariesToLink,
+ bcFiles)))
+ return std::nullopt;
+ return std::move(bcFiles);
+}
+
+gpu::GPUModuleOp SerializeGPUModuleBase::getGPUModuleOp() {
+ return dyn_cast<gpu::GPUModuleOp>(&SerializeGPUModuleBase::getOperation());
+}
+
+// There is 1 way to finalize IL to native code: IGC
+// There are 2 ways to access IGC: AOT (ocloc) and JIT (L0 runtime).
+// - L0 runtime consumes IL and is external to MLIR codebase (rt wrappers).
+// - `ocloc` tool can be "queried" from within MLIR.
+std::optional<SmallVector<char, 0>>
+SerializeGPUModuleBase::compileToBinary(const std::string &asmStr,
+ StringRef inputFormat) {
+ using TmpFile = std::pair<llvm::SmallString<128>, llvm::FileRemover>;
+ // Find the `ocloc` tool.
+ std::optional<std::string> oclocCompiler = findTool("ocloc");
+ if (!oclocCompiler)
+ return std::nullopt;
+ Location loc = getGPUModuleOp().getLoc();
+ std::string basename = llvm::formatv(
+ "mlir-{0}-{1}-{2}", getGPUModuleOp().getNameAttr().getValue(),
+ getTarget().getTriple(), getTarget().getChip());
+
+ auto createTemp = [&](StringRef name,
+ StringRef suffix) -> std::optional<TmpFile> {
+ llvm::SmallString<128> filePath;
+ if (auto ec = llvm::sys::fs::createTemporaryFile(name, suffix, filePath)) {
+ getGPUModuleOp().emitError()
+ << "Couldn't create the temp file: `" << filePath
+ << "`, error message: " << ec.message();
+ return std::nullopt;
+ }
+ return TmpFile(filePath, llvm::FileRemover(filePath.c_str()));
+ };
+ // Create temp file
+ std::optional<TmpFile> asmFile = createTemp(basename, "asm");
+ std::optional<TmpFile> binFile = createTemp(basename, "");
+ std::optional<TmpFile> logFile = createTemp(basename, "log");
+ if (!logFile || !asmFile || !binFile)
+ return std::nullopt;
+ // Dump the assembly to a temp file
+ std::error_code ec;
+ {
+ llvm::raw_fd_ostream asmStream(asmFile->first, ec);
+ if (ec) {
+ emitError(loc) << "Couldn't open the file: `" << asmFile->first
+ << "`, error message: " << ec.message();
+ return std::nullopt;
+ }
+ asmStream << asmStr;
+ if (asmStream.has_error()) {
+ emitError(loc) << "An error occurred while writing the assembly to: `"
+ << asmFile->first << "`.";
+ return std::nullopt;
+ }
+ asmStream.flush();
+ }
+ // Set cmd options
+ std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> cmdOpts =
+ targetOptions.tokenizeCmdOptions();
+ // Example: --gpu-module-to-binary="opts='opt1 opt2'"
+ const std::string cmdOptsStr = "\"" + llvm::join(cmdOpts.second, " ") + "\"";
+ SmallVector<StringRef, 12> oclocArgs(
+ {"ocloc", "compile", "-file", asmFile->first, inputFormat, "-device",
+ getTarget().getChip(), "-output", binFile->first, "-output_no_suffix",
+ "-options", cmdOptsStr});
+
+// Dump tool invocation commands.
+#define DEBUG_TYPE "serialize-to-binary"
+ LLVM_DEBUG({
+ llvm::dbgs() << "Tool invocation for module: "
+ << getGPUModuleOp().getNameAttr() << "\n";
+ llvm::interleave(oclocArgs, llvm::dbgs(), " ");
+ llvm::dbgs() << "\n";
+ });
+#undef DEBUG_TYPE
+ // Helper function for printing tool error logs.
+ std::string message;
+ auto emitLogError =
+ [&](StringRef toolName) -> std::optional<SmallVector<char, 0>> {
+ if (message.empty()) {
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> toolStderr =
+ llvm::MemoryBuffer::getFile(logFile->first);
+ if (toolStderr)
+ emitError(loc) << toolName << " invocation failed. Log:\n"
+ << toolStderr->get()->getBuffer();
+ else
+ emitError(loc) << toolName << " invocation failed.";
+ return std::nullopt;
+ }
+ emitError(loc) << toolName
+ << " invocation failed, error message: " << message;
+ return std::nullopt;
+ };
+ std::optional<StringRef> redirects[] = {
+ std::nullopt,
+ logFile->first,
+ logFile->first,
+ };
+ // Invoke ocloc.
+ if (llvm::sys::ExecuteAndWait(oclocCompiler.value(), oclocArgs, std::nullopt,
+ redirects, 0, 0, &message))
+ return emitLogError("`ocloc`");
+ binFile->first.append(".bin");
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> binaryBuffer =
+ llvm::MemoryBuffer::getFile(binFile->first);
+ if (!binaryBuffer) {
+ emitError(loc) << "Couldn't open the file: `" << binFile->first
+ << "`, error message: " << binaryBuffer.getError().message();
+ return std::nullopt;
+ }
+ StringRef bin = (*binaryBuffer)->getBuffer();
+ return SmallVector<char, 0>(bin.begin(), bin.end());
+}
+
+std::optional<std::string> SerializeGPUModuleBase::findTool(StringRef tool) {
+ // 1. Check the toolkit path given in the command line.
+ StringRef pathRef = targetOptions.getToolkitPath();
+ SmallVector<char, 256> path;
+ if (!pathRef.empty()) {
+ path.insert(path.begin(), pathRef.begin(), pathRef.end());
+ llvm::sys::path::append(path, "bin", tool);
+ if (llvm::sys::fs::can_execute(path))
+ return StringRef(path.data(), path.size()).str();
+ }
+ // 2. Check PATH.
+ if (std::optional<std::string> toolPath =
+ llvm::sys::Process::FindInEnvPath("PATH", tool))
+ return *toolPath;
+
+ getGPUModuleOp().emitError()
+ << "Couldn't find the `" << tool
+ << "` binary. Please specify the toolkit "
+ "path via GpuModuleToBinaryPass or add the compiler to $PATH`.";
+ return std::nullopt;
+}
+
+namespace {
+class SPIRVSerializer : public SerializeGPUModuleBase {
+public:
+ SPIRVSerializer(Operation &module, XeVMTargetAttr xeTarget,
+ const gpu::TargetOptions &targetOptions)
+ : SerializeGPUModuleBase(module, xeTarget, targetOptions) {}
+
+ static void init();
+
+ /// Serializes the LLVM module to an object format, depending on the
+ /// compilation target selected in target options.
+ std::optional<SmallVector<char, 0>>
+ moduleToObject(llvm::Module &llvmModule) override;
+
+private:
+ /// Translates the LLVM module to SPIR-V binary using LLVM's
+ /// SPIR-V target.
+ std::optional<std::string>
+ translateToSPIRVBinary(llvm::Module &llvmModule,
+ llvm::TargetMachine &targetMachine);
+};
+} // namespace
+
+void SPIRVSerializer::init() {
+ static llvm::once_flag initializeBackendOnce;
+ llvm::call_once(initializeBackendOnce, []() {
+#if LLVM_HAS_SPIRV_TARGET
+ LLVMInitializeSPIRVTarget();
+ LLVMInitializeSPIRVTargetInfo();
+ LLVMInitializeSPIRVTargetMC();
+ LLVMInitializeSPIRVAsmPrinter();
+#endif
+ });
+}
+
+std::optional<SmallVector<char, 0>>
+SPIRVSerializer::moduleToObject(llvm::Module &llvmModule) {
+#define DEBUG_TYPE "serialize-to-llvm"
+ LLVM_DEBUG({
+ llvm::dbgs() << "LLVM IR for module: " << getGPUModuleOp().getNameAttr()
+ << "\n";
+ llvm::dbgs() << llvmModule << "\n";
+ llvm::dbgs().flush();
+ });
+#undef DEBUG_TYPE
+
+ // Return LLVM IR if the compilation target is `offload`.
+ if (targetOptions.getCompilationTarget() == gpu::CompilationTarget::Offload)
+ return SerializeGPUModuleBase::moduleToObject(llvmModule);
+
+#if !LLVM_HAS_SPIRV_TARGET
+ getGPUModuleOp()->emitError("The `SPIRV` target was not built. Please enable "
+ "it when building LLVM.");
+ return std::nullopt;
+#endif // LLVM_HAS_SPIRV_TARGET
+
+ std::optional<llvm::TargetMachine *> targetMachine =
+ getOrCreateTargetMachine();
+ if (!targetMachine) {
+ getGPUModuleOp().emitError() << "Target Machine unavailable for triple "
+ << triple << ", can't optimize with LLVM\n";
+ return std::nullopt;
+ }
+
+ // Return SPIRV if the compilation target is `assembly`.
+ if (targetOptions.getCompilationTarget() ==
+ gpu::CompilationTarget::Assembly) {
+ std::optional<std::string> serializedISA =
+ translateToISA(llvmModule, **targetMachine);
+ if (!serializedISA) {
+ getGPUModuleOp().emitError() << "Failed translating the module to ISA."
+ << triple << ", can't compile with LLVM\n";
+ return std::nullopt;
+ }
+
+#define DEBUG_TYPE "serialize-to-isa"
+ LLVM_DEBUG({
+ llvm::dbgs() << "SPIR-V for module: " << getGPUModuleOp().getNameAttr()
+ << "\n";
+ llvm::dbgs() << *serializedISA << "\n";
+ llvm::dbgs().flush();
+ });
+#undef DEBUG_TYPE
+
+ // Make sure to include the null terminator.
+ StringRef bin(serializedISA->c_str(), serializedISA->size() + 1);
+ return SmallVector<char, 0>(bin.begin(), bin.end());
+ }
+
+ // Level zero runtime is set up to accept SPIR-V binary
+ // translateToSPIRVBinary translates the LLVM module to SPIR-V binary
+ // using LLVM's SPIRV target.
+ // compileToBinary can be used in the future if level zero runtime
+ // implementation switches to native XeVM binary format.
+ std::optional<std::string> serializedSPIRVBinary =
+ translateToSPIRVBinary(llvmModule, **targetMachine);
+ if (!serializedSPIRVBinary) {
+ getGPUModuleOp().emitError() << "Failed translating the module to Binary.";
+ return std::nullopt;
+ }
+ if (serializedSPIRVBinary->size() % 4) {
+ getGPUModuleOp().emitError() << "SPIRV code size must be a multiple of 4.";
+ return std::nullopt;
+ }
+ StringRef bin(serializedSPIRVBinary->c_str(), serializedSPIRVBinary->size());
+ return SmallVector<char, 0>(bin.begin(), bin.end());
+}
+
+std::optional<std::string>
+SPIRVSerializer::translateToSPIRVBinary(llvm::Module &llvmModule,
+ llvm::TargetMachine &targetMachine) {
+ std::string targetISA;
+ llvm::raw_string_ostream stream(targetISA);
+
+ { // Drop pstream after this to prevent the ISA from being stuck buffering
+ llvm::buffer_ostream pstream(stream);
+ llvm::legacy::PassManager codegenPasses;
+ if (targetMachine.addPassesToEmitFile(codegenPasses, pstream, nullptr,
+ llvm::CodeGenFileType::ObjectFile))
+ return std::nullopt;
+
+ codegenPasses.run(llvmModule);
+ }
+ return targetISA;
+}
+
+std::optional<SmallVector<char, 0>>
+XeVMTargetAttrImpl::serializeToObject(Attribute attribute, Operation *module,
+ const gpu::TargetOptions &options) const {
+ if (!module)
+ return std::nullopt;
+ auto gpuMod = dyn_cast<gpu::GPUModuleOp>(module);
+ if (!gpuMod) {
+ module->emitError("expected to be a gpu.module op");
+ return std::nullopt;
+ }
+ auto xeTarget = cast<XeVMTargetAttr>(attribute);
+ if (xeTarget.getTriple().starts_with("spirv")) {
+ gpuMod.walk([&](LLVM::LLVMFuncOp funcOp) {
+ if (funcOp->hasAttr(gpu::GPUDialect::getKernelFuncAttrName())) {
+ funcOp.setIntelReqdSubGroupSize(16);
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+
+ SPIRVSerializer serializer(*module, cast<XeVMTargetAttr>(attribute),
+ options);
+ serializer.init();
+
+#if !LLVM_HAS_SPIRV_TARGET
+ module->emitError("Cannot run `TargetRegistry::lookupTarget()` for SPIRV "
+ "without having the target built.");
+#endif
+
+ return serializer.run();
+ }
+ module->emitError("Unsupported XeVM target triple: ") << xeTarget.getTriple();
+ return std::nullopt;
+}
+
+Attribute
+XeVMTargetAttrImpl::createObject(Attribute attribute, Operation *module,
+ const SmallVector<char, 0> &object,
+ const gpu::TargetOptions &options) const {
+ Builder builder(attribute.getContext());
+ gpu::CompilationTarget format = options.getCompilationTarget();
+ auto xeTarget = cast<XeVMTargetAttr>(attribute);
+ SmallVector<NamedAttribute, 2> properties;
+ if (format == gpu::CompilationTarget::Assembly)
+ properties.push_back(
+ builder.getNamedAttr("O", builder.getI32IntegerAttr(xeTarget.getO())));
+
+ DictionaryAttr objectProps;
+ if (!properties.empty())
+ objectProps = builder.getDictionaryAttr(properties);
+
+ return builder.getAttr<gpu::ObjectAttr>(
+ attribute, format,
+ builder.getStringAttr(StringRef(object.data(), object.size())),
+ objectProps, /*kernels=*/nullptr);
+}
diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt
index 9ea5c683..a73a78d 100644
--- a/mlir/lib/Target/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt
@@ -1,4 +1,5 @@
add_subdirectory(Dialect)
+add_subdirectory(Transforms)
set(LLVM_OPTIONAL_SOURCES
ConvertFromLLVMIR.cpp
@@ -58,6 +59,7 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration
MLIROpenACCToLLVMIRTranslation
MLIROpenMPToLLVMIRTranslation
MLIRROCDLToLLVMIRTranslation
+ MLIRPtrToLLVMIRTranslation
MLIRSPIRVToLLVMIRTranslation
MLIRVCIXToLLVMIRTranslation
MLIRXeVMToLLVMIRTranslation
diff --git a/mlir/lib/Target/LLVMIR/DataLayoutImporter.cpp b/mlir/lib/Target/LLVMIR/DataLayoutImporter.cpp
index fbad5c2..8bd07cd 100644
--- a/mlir/lib/Target/LLVMIR/DataLayoutImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DataLayoutImporter.cpp
@@ -6,13 +6,14 @@
//
//===----------------------------------------------------------------------===//
-#include "DataLayoutImporter.h"
+#include "mlir/Target/LLVMIR/DataLayoutImporter.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Target/LLVMIR/Import.h"
+
#include "llvm/IR/DataLayout.h"
using namespace mlir;
@@ -274,101 +275,88 @@ DataLayoutImporter::tryToEmplaceLegalIntWidthsEntry(StringRef token) {
return success();
}
-void DataLayoutImporter::translateDataLayout(
- const llvm::DataLayout &llvmDataLayout) {
- dataLayout = {};
-
- // Transform the data layout to its string representation and append the
- // default data layout string specified in the language reference
- // (https://llvm.org/docs/LangRef.html#data-layout). The translation then
- // parses the string and ignores the default value if a specific kind occurs
- // in both strings. Additionally, the following default values exist:
- // - non-default address space pointer specifications default to the default
- // address space pointer specification
- // - the alloca address space defaults to the default address space.
- layoutStr = llvmDataLayout.getStringRepresentation();
- if (!layoutStr.empty())
- layoutStr += "-";
- layoutStr += kDefaultDataLayout;
- StringRef layout(layoutStr);
+DataLayoutSpecInterface DataLayoutImporter::dataLayoutSpecFromDataLayoutStr() {
+ if (!dataLayoutStr.empty())
+ dataLayoutStr += "-";
+ dataLayoutStr += kDefaultDataLayout;
// Split the data layout string into tokens separated by a dash.
SmallVector<StringRef> tokens;
- layout.split(tokens, '-');
+ StringRef(dataLayoutStr).split(tokens, '-');
for (StringRef token : tokens) {
lastToken = token;
FailureOr<StringRef> prefix = tryToParseAlphaPrefix(token);
if (failed(prefix))
- return;
+ return {};
// Parse the endianness.
if (*prefix == "e") {
if (failed(tryToEmplaceEndiannessEntry(
DLTIDialect::kDataLayoutEndiannessLittle, token)))
- return;
+ return {};
continue;
}
if (*prefix == "E") {
if (failed(tryToEmplaceEndiannessEntry(
DLTIDialect::kDataLayoutEndiannessBig, token)))
- return;
+ return {};
continue;
}
// Parse the program address space.
if (*prefix == "P") {
if (failed(tryToEmplaceAddrSpaceEntry(
token, DLTIDialect::kDataLayoutProgramMemorySpaceKey)))
- return;
+ return {};
continue;
}
// Parse the mangling mode.
if (*prefix == "m") {
if (failed(tryToEmplaceManglingModeEntry(
token, DLTIDialect::kDataLayoutManglingModeKey)))
- return;
+ return {};
continue;
}
// Parse the global address space.
if (*prefix == "G") {
if (failed(tryToEmplaceAddrSpaceEntry(
token, DLTIDialect::kDataLayoutGlobalMemorySpaceKey)))
- return;
+ return {};
continue;
}
// Parse the alloca address space.
if (*prefix == "A") {
if (failed(tryToEmplaceAddrSpaceEntry(
token, DLTIDialect::kDataLayoutAllocaMemorySpaceKey)))
- return;
+ return {};
continue;
}
// Parse the stack alignment.
if (*prefix == "S") {
if (failed(tryToEmplaceStackAlignmentEntry(token)))
- return;
+ return {};
continue;
}
// Parse integer alignment specifications.
if (*prefix == "i") {
FailureOr<uint64_t> width = tryToParseInt(token);
if (failed(width))
- return;
+ return {};
Type type = IntegerType::get(context, *width);
if (failed(tryToEmplaceAlignmentEntry(type, token)))
- return;
+ return {};
continue;
}
// Parse float alignment specifications.
if (*prefix == "f") {
FailureOr<uint64_t> width = tryToParseInt(token);
if (failed(width))
- return;
+ return {};
Type type = getFloatType(context, *width);
if (failed(tryToEmplaceAlignmentEntry(type, token)))
- return;
+ return {};
continue;
}
// Parse pointer alignment specifications.
@@ -376,17 +364,17 @@ void DataLayoutImporter::translateDataLayout(
FailureOr<uint64_t> space =
token.starts_with(":") ? 0 : tryToParseInt(token);
if (failed(space))
- return;
+ return {};
auto type = LLVMPointerType::get(context, *space);
if (failed(tryToEmplacePointerAlignmentEntry(type, token)))
- return;
+ return {};
continue;
}
// Parse native integer widths specifications.
if (*prefix == "n") {
if (failed(tryToEmplaceLegalIntWidthsEntry(token)))
- return;
+ return {};
continue;
}
// Parse function pointer alignment specifications.
@@ -394,7 +382,7 @@ void DataLayoutImporter::translateDataLayout(
if (prefix->starts_with("F")) {
StringRef nextPrefix = prefix->drop_front(1);
if (failed(tryToEmplaceFunctionPointerAlignmentEntry(nextPrefix, token)))
- return;
+ return {};
continue;
}
@@ -409,11 +397,12 @@ void DataLayoutImporter::translateDataLayout(
entries.push_back(it.second);
for (const auto &it : keyEntries)
entries.push_back(it.second);
- dataLayout = DataLayoutSpecAttr::get(context, entries);
+ return DataLayoutSpecAttr::get(context, entries);
}
DataLayoutSpecInterface
mlir::translateDataLayout(const llvm::DataLayout &dataLayout,
MLIRContext *context) {
- return DataLayoutImporter(context, dataLayout).getDataLayout();
+ return DataLayoutImporter(context, dataLayout.getStringRepresentation())
+ .getDataLayoutSpec();
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
index 86c731a..a102c43 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
@@ -8,6 +8,7 @@ add_subdirectory(NVVM)
add_subdirectory(OpenACC)
add_subdirectory(OpenMP)
add_subdirectory(ROCDL)
+add_subdirectory(Ptr)
add_subdirectory(SPIRV)
add_subdirectory(VCIX)
add_subdirectory(XeVM)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 0f675a0..fd8463a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -18,6 +18,7 @@
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/DIBuilder.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/Instructions.h"
@@ -358,6 +359,17 @@ static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder,
}
}
+static llvm::DILocalScope *
+getLocalScopeFromLoc(llvm::IRBuilderBase &builder, Location loc,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ if (auto scopeLoc =
+ loc->findInstanceOf<FusedLocWith<LLVM::DILocalScopeAttr>>())
+ if (auto *localScope = llvm::dyn_cast<llvm::DILocalScope>(
+ moduleTranslation.translateDebugInfo(scopeLoc.getMetadata())))
+ return localScope;
+ return builder.GetInsertBlock()->getParent()->getSubprogram();
+}
+
static LogicalResult
convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 90462d1..7f69af14 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -135,33 +135,83 @@ static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
llvm_unreachable("unsupported vote kind");
}
-/// Return the intrinsic ID associated with ldmatrix for the given paramters.
-static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
- int32_t num) {
- if (layout == NVVM::MMALayout::row) {
+static llvm::Intrinsic::ID
+getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
+ NVVM::LdStMatrixShapeAttr shape,
+ NVVM::LdStMatrixEltType eltType) {
+ if (shape.getM() == 8 && shape.getN() == 8) {
switch (num) {
case 1:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
+ return (layout == NVVM::MMALayout::row)
+ ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16
+ : llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
case 2:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
+ return (layout == NVVM::MMALayout::row)
+ ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16
+ : llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
case 4:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
- default:
- llvm_unreachable("unsupported number of matrix");
+ return (layout == NVVM::MMALayout::row)
+ ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16
+ : llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
}
-
- } else {
- switch (num) {
- case 1:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
- case 2:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
- case 4:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
- default:
- llvm_unreachable("unsupported number of matrix");
+ } else if (shape.getM() == 8 && shape.getN() == 16) {
+ if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32;
+ case 2:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32;
+ case 4:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32;
+ }
+ } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64;
+ case 2:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64;
+ case 4:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64;
+ }
+ }
+ } else if (shape.getM() == 16 && shape.getN() == 16) {
+ if (eltType == NVVM::LdStMatrixEltType::B8) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8;
+ case 2:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8;
+ }
+ } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32;
+ case 2:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32;
+ }
+ } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64;
+ case 2:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64;
+ }
}
}
+ llvm_unreachable("unknown ldmatrix kind");
}
/// Return the intrinsic ID associated with stmatrix for the given paramters.
@@ -418,7 +468,11 @@ public:
} else if (attribute.getName() ==
NVVM::NVVMDialect::getKernelFuncAttrName()) {
llvmFunc->setCallingConv(llvm::CallingConv::PTX_Kernel);
+ } else if (attribute.getName() ==
+ NVVM::NVVMDialect::getBlocksAreClustersAttrName()) {
+ llvmFunc->addFnAttr("nvvm.blocksareclusters");
}
+
return success();
}
@@ -429,51 +483,10 @@ public:
llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
llvm::Function *llvmFunc =
moduleTranslation.lookupFunction(funcOp.getName());
- llvm::NamedMDNode *nvvmAnnotations =
- moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations");
if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
- llvm::MDNode *gridConstantMetaData = nullptr;
-
- // Check if a 'grid_constant' metadata node exists for the given function
- for (llvm::MDNode *opnd : llvm::reverse(nvvmAnnotations->operands())) {
- if (opnd->getNumOperands() == 3 &&
- opnd->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc) &&
- opnd->getOperand(1) ==
- llvm::MDString::get(llvmContext, "grid_constant")) {
- gridConstantMetaData = opnd;
- break;
- }
- }
-
- // 'grid_constant' is a function-level meta data node with a list of
- // integers, where each integer n denotes that the nth parameter has the
- // grid_constant annotation (numbering from 1). This requires aggregating
- // the indices of the individual parameters that have this attribute.
- llvm::Type *i32 = llvm::IntegerType::get(llvmContext, 32);
- if (gridConstantMetaData == nullptr) {
- // Create a new 'grid_constant' metadata node
- SmallVector<llvm::Metadata *> gridConstMetadata = {
- llvm::ValueAsMetadata::getConstant(
- llvm::ConstantInt::get(i32, argIdx + 1))};
- llvm::Metadata *llvmMetadata[] = {
- llvm::ValueAsMetadata::get(llvmFunc),
- llvm::MDString::get(llvmContext, "grid_constant"),
- llvm::MDNode::get(llvmContext, gridConstMetadata)};
- llvm::MDNode *llvmMetadataNode =
- llvm::MDNode::get(llvmContext, llvmMetadata);
- nvvmAnnotations->addOperand(llvmMetadataNode);
- } else {
- // Append argIdx + 1 to the 'grid_constant' argument list
- if (auto argList =
- dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) {
- llvm::TempMDTuple clonedArgList = argList->clone();
- clonedArgList->push_back((llvm::ValueAsMetadata::getConstant(
- llvm::ConstantInt::get(i32, argIdx + 1))));
- gridConstantMetaData->replaceOperandWith(
- 2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList)));
- }
- }
+ llvmFunc->addParamAttr(
+ argIdx, llvm::Attribute::get(llvmContext, "nvvm.grid_constant"));
}
return success();
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 762cc88..8a1b554 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2893,6 +2893,12 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
alignment = builder.getInt64(intAttr.getInt());
assert(ty->isPointerTy() && "Invalid type for aligned variable");
assert(alignment && "Invalid alignment value");
+
+ // Check if the alignment value is not a power of 2. If so, skip emitting
+ // alignment.
+ if (!intAttr.getValue().isPowerOf2())
+ continue;
+
auto curInsert = builder.saveIP();
builder.SetInsertPoint(sourceBlock);
llvmVal = builder.CreateLoad(ty, llvmVal);
@@ -3205,6 +3211,23 @@ llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op) {
.Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
}
+void extractAtomicControlFlags(omp::AtomicUpdateOp atomicUpdateOp,
+ bool &isIgnoreDenormalMode,
+ bool &isFineGrainedMemory,
+ bool &isRemoteMemory) {
+ isIgnoreDenormalMode = false;
+ isFineGrainedMemory = false;
+ isRemoteMemory = false;
+ if (atomicUpdateOp &&
+ atomicUpdateOp->hasAttr(atomicUpdateOp.getAtomicControlAttrName())) {
+ mlir::omp::AtomicControlAttr atomicControlAttr =
+ atomicUpdateOp.getAtomicControlAttr();
+ isIgnoreDenormalMode = atomicControlAttr.getIgnoreDenormalMode();
+ isFineGrainedMemory = atomicControlAttr.getFineGrainedMemory();
+ isRemoteMemory = atomicControlAttr.getRemoteMemory();
+ }
+}
+
/// Converts an OpenMP atomic update operation using OpenMPIRBuilder.
static LogicalResult
convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
@@ -3269,13 +3292,19 @@ convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
return moduleTranslation.lookupValue(yieldop.getResults()[0]);
};
+ bool isIgnoreDenormalMode;
+ bool isFineGrainedMemory;
+ bool isRemoteMemory;
+ extractAtomicControlFlags(opInst, isIgnoreDenormalMode, isFineGrainedMemory,
+ isRemoteMemory);
// Handle ambiguous alloca, if any.
auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
ompBuilder->createAtomicUpdate(ompLoc, allocaIP, llvmAtomicX, llvmExpr,
atomicOrdering, binop, updateFn,
- isXBinopExpr);
+ isXBinopExpr, isIgnoreDenormalMode,
+ isFineGrainedMemory, isRemoteMemory);
if (failed(handleError(afterIP, *opInst)))
return failure();
@@ -3364,13 +3393,19 @@ convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
return moduleTranslation.lookupValue(yieldop.getResults()[0]);
};
+ bool isIgnoreDenormalMode;
+ bool isFineGrainedMemory;
+ bool isRemoteMemory;
+ extractAtomicControlFlags(atomicUpdateOp, isIgnoreDenormalMode,
+ isFineGrainedMemory, isRemoteMemory);
// Handle ambiguous alloca, if any.
auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
ompBuilder->createAtomicCapture(
ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
- binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr);
+ binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr,
+ isIgnoreDenormalMode, isFineGrainedMemory, isRemoteMemory);
if (failed(handleError(afterIP, *atomicCaptureOp)))
return failure();
@@ -4327,9 +4362,11 @@ createAlteredByCaptureMap(MapInfoData &mapData,
if (!isPtrTy) {
auto curInsert = builder.saveIP();
+ llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
builder.restoreIP(findAllocaInsertPoint(builder, moduleTranslation));
auto *memTempAlloc =
builder.CreateAlloca(builder.getPtrTy(), nullptr, ".casted");
+ builder.SetCurrentDebugLocation(DbgLoc);
builder.restoreIP(curInsert);
builder.CreateStore(newV, memTempAlloc);
@@ -5836,6 +5873,10 @@ static bool isTargetDeviceOp(Operation *op) {
if (mlir::isa<omp::ThreadprivateOp>(op))
return true;
+ if (mlir::isa<omp::TargetAllocMemOp>(op) ||
+ mlir::isa<omp::TargetFreeMemOp>(op))
+ return true;
+
if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>())
if (auto declareTargetIface =
llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
@@ -5848,6 +5889,85 @@ static bool isTargetDeviceOp(Operation *op) {
return false;
}
+static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
+ llvm::Module *llvmModule) {
+ llvm::Type *i64Ty = builder.getInt64Ty();
+ llvm::Type *i32Ty = builder.getInt32Ty();
+ llvm::Type *returnType = builder.getPtrTy(0);
+ llvm::FunctionType *fnType =
+ llvm::FunctionType::get(returnType, {i64Ty, i32Ty}, false);
+ llvm::Function *func = cast<llvm::Function>(
+ llvmModule->getOrInsertFunction("omp_target_alloc", fnType).getCallee());
+ return func;
+}
+
+static LogicalResult
+convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
+ if (!allocMemOp)
+ return failure();
+
+ // Get "omp_target_alloc" function
+ llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
+ llvm::Function *ompTargetAllocFunc = getOmpTargetAlloc(builder, llvmModule);
+ // Get the corresponding device value in llvm
+ mlir::Value deviceNum = allocMemOp.getDevice();
+ llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
+ // Get the allocation size.
+ llvm::DataLayout dataLayout = llvmModule->getDataLayout();
+ mlir::Type heapTy = allocMemOp.getAllocatedType();
+ llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy);
+ llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
+ llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
+ for (auto typeParam : allocMemOp.getTypeparams())
+ allocSize =
+ builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
+ // Create call to "omp_target_alloc" with the args as translated llvm values.
+ llvm::CallInst *call =
+ builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
+ llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
+
+ // Map the result
+ moduleTranslation.mapValue(allocMemOp.getResult(), resultI64);
+ return success();
+}
+
+static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder,
+ llvm::Module *llvmModule) {
+ llvm::Type *ptrTy = builder.getPtrTy(0);
+ llvm::Type *i32Ty = builder.getInt32Ty();
+ llvm::Type *voidTy = builder.getVoidTy();
+ llvm::FunctionType *fnType =
+ llvm::FunctionType::get(voidTy, {ptrTy, i32Ty}, false);
+ llvm::Function *func = dyn_cast<llvm::Function>(
+ llvmModule->getOrInsertFunction("omp_target_free", fnType).getCallee());
+ return func;
+}
+
+static LogicalResult
+convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
+ if (!freeMemOp)
+ return failure();
+
+ // Get "omp_target_free" function
+ llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
+ llvm::Function *ompTragetFreeFunc = getOmpTargetFree(builder, llvmModule);
+ // Get the corresponding device value in llvm
+ mlir::Value deviceNum = freeMemOp.getDevice();
+ llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
+ // Get the corresponding heapref value in llvm
+ mlir::Value heapref = freeMemOp.getHeapref();
+ llvm::Value *llvmHeapref = moduleTranslation.lookupValue(heapref);
+ // Convert heapref int to ptr and call "omp_target_free"
+ llvm::Value *intToPtr =
+ builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
+ builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
+ return success();
+}
+
/// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
/// OpenMP runtime calls).
static LogicalResult
@@ -6022,6 +6142,12 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
// the omp.canonical_loop.
return applyUnrollHeuristic(op, builder, moduleTranslation);
})
+ .Case([&](omp::TargetAllocMemOp) {
+ return convertTargetAllocMemOp(*op, builder, moduleTranslation);
+ })
+ .Case([&](omp::TargetFreeMemOp) {
+ return convertTargetFreeMemOp(*op, builder, moduleTranslation);
+ })
.Default([&](Operation *inst) {
return inst->emitError()
<< "not yet implemented: " << inst->getName();
@@ -6258,9 +6384,8 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
if (ompBuilder->Config.isTargetDevice()) {
if (isTargetDeviceOp(op)) {
return convertTargetDeviceOp(op, builder, moduleTranslation);
- } else {
- return convertTargetOpsInNest(op, builder, moduleTranslation);
}
+ return convertTargetOpsInNest(op, builder, moduleTranslation);
}
return convertHostOrTargetOperation(op, builder, moduleTranslation);
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/Ptr/CMakeLists.txt
new file mode 100644
index 0000000..f94410d
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_translation_library(MLIRPtrToLLVMIRTranslation
+ PtrToLLVMIRTranslation.cpp
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRPtrDialect
+ MLIRSupport
+ MLIRTargetLLVMIRExport
+ )
diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
new file mode 100644
index 0000000..7b89ec8
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
@@ -0,0 +1,66 @@
+//===- PtrToLLVMIRTranslation.cpp - Translate `ptr` to LLVM IR ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the MLIR `ptr` dialect and
+// LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.h"
+#include "mlir/Dialect/Ptr/IR/PtrOps.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+
+using namespace mlir;
+using namespace mlir::ptr;
+
+namespace {
+/// Implementation of the dialect interface that converts operations belonging
+/// to the `ptr` dialect to LLVM IR.
+class PtrDialectLLVMIRTranslationInterface
+ : public LLVMTranslationDialectInterface {
+public:
+ using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
+
+ /// Translates the given operation to LLVM IR using the provided IR builder
+ /// and saving the state in `moduleTranslation`.
+ LogicalResult
+ convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) const final {
+ // Translation for ptr dialect operations to LLVM IR is currently
+ // unimplemented.
+ return op->emitError("Translation for ptr dialect operations to LLVM IR is "
+ "not implemented.");
+ }
+
+ /// Attaches module-level metadata for functions marked as kernels.
+ LogicalResult
+ amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
+ LLVM::ModuleTranslation &moduleTranslation) const final {
+ // Translation for ptr dialect operations to LLVM IR is currently
+ // unimplemented.
+ return op->emitError("Translation for ptr dialect operations to LLVM IR is "
+ "not implemented.");
+ }
+};
+} // namespace
+
+void mlir::registerPtrDialectTranslation(DialectRegistry &registry) {
+ registry.insert<ptr::PtrDialect>();
+ registry.addExtension(+[](MLIRContext *ctx, ptr::PtrDialect *dialect) {
+ dialect->addInterfaces<PtrDialectLLVMIRTranslationInterface>();
+ });
+}
+
+void mlir::registerPtrDialectTranslation(MLIRContext &context) {
+ DialectRegistry registry;
+ registerPtrDialectTranslation(registry);
+ context.appendDialectRegistry(registry);
+}
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 6325480..7a888bb3 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -16,7 +16,6 @@
#include "mlir/Target/LLVMIR/Import.h"
#include "AttrKindDetail.h"
-#include "DataLayoutImporter.h"
#include "DebugImporter.h"
#include "LoopAnnotationImporter.h"
@@ -25,6 +24,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
+#include "mlir/Target/LLVMIR/DataLayoutImporter.h"
#include "mlir/Tools/mlir-translate/Translation.h"
#include "llvm/ADT/DepthFirstIterator.h"
@@ -1045,8 +1045,9 @@ LogicalResult ModuleImport::convertIFuncs() {
LogicalResult ModuleImport::convertDataLayout() {
Location loc = mlirModule.getLoc();
- DataLayoutImporter dataLayoutImporter(context, llvmModule->getDataLayout());
- if (!dataLayoutImporter.getDataLayout())
+ DataLayoutImporter dataLayoutImporter(
+ context, llvmModule->getDataLayout().getStringRepresentation());
+ if (!dataLayoutImporter.getDataLayoutSpec())
return emitError(loc, "cannot translate data layout: ")
<< dataLayoutImporter.getLastToken();
@@ -1054,7 +1055,7 @@ LogicalResult ModuleImport::convertDataLayout() {
emitWarning(loc, "unhandled data layout token: ") << token;
mlirModule->setAttr(DLTIDialect::kDataLayoutAttrName,
- dataLayoutImporter.getDataLayout());
+ dataLayoutImporter.getDataLayoutSpec());
return success();
}
@@ -1408,6 +1409,67 @@ LogicalResult ModuleImport::convertIFunc(llvm::GlobalIFunc *ifunc) {
return success();
}
+/// Converts LLVM string, integer, and enum attributes into MLIR attributes,
+/// skipping those in `attributesToSkip` and emitting a warning at `loc` for
+/// any other unsupported attributes.
+static ArrayAttr
+convertLLVMAttributesToMLIR(Location loc, MLIRContext *context,
+ llvm::AttributeSet attributes,
+ ArrayRef<StringLiteral> attributesToSkip = {}) {
+ SmallVector<Attribute> mlirAttributes;
+ for (llvm::Attribute attr : attributes) {
+ StringRef attrName;
+ if (attr.isStringAttribute())
+ attrName = attr.getKindAsString();
+ else
+ attrName = llvm::Attribute::getNameFromAttrKind(attr.getKindAsEnum());
+ if (llvm::is_contained(attributesToSkip, attrName))
+ continue;
+
+ auto keyAttr = StringAttr::get(context, attrName);
+ if (attr.isStringAttribute()) {
+ StringRef val = attr.getValueAsString();
+ if (val.empty()) {
+ // For string attributes without values, add only the attribute name.
+ mlirAttributes.push_back(keyAttr);
+ continue;
+ }
+ // For string attributes with a value, create a [name, value] pair.
+ mlirAttributes.push_back(
+ ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)}));
+ continue;
+ }
+ if (attr.isIntAttribute()) {
+ // For integer attributes, convert the value to a string and create a
+ // [name, value] pair.
+ auto val = std::to_string(attr.getValueAsInt());
+ mlirAttributes.push_back(
+ ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)}));
+ continue;
+ }
+ if (attr.isEnumAttribute()) {
+ // For enum attributes, add only the attribute name.
+ mlirAttributes.push_back(keyAttr);
+ continue;
+ }
+
+ emitWarning(loc)
+ << "'" << attrName
+ << "' attribute is invalid on current operation, skipping it";
+ }
+ return ArrayAttr::get(context, mlirAttributes);
+}
+
+/// Converts LLVM attributes from `globalVar` into MLIR attributes and adds them
+/// to `globalOp` as target-specific attributes.
+static void processTargetSpecificAttrs(llvm::GlobalVariable *globalVar,
+ GlobalOp globalOp) {
+ ArrayAttr targetSpecificAttrs = convertLLVMAttributesToMLIR(
+ globalOp.getLoc(), globalOp.getContext(), globalVar->getAttributes());
+ if (!targetSpecificAttrs.empty())
+ globalOp.setTargetSpecificAttrsAttr(targetSpecificAttrs);
+}
+
LogicalResult ModuleImport::convertGlobal(llvm::GlobalVariable *globalVar) {
// Insert the global after the last one or at the start of the module.
OpBuilder::InsertionGuard guard = setGlobalInsertionPoint();
@@ -1473,6 +1535,8 @@ LogicalResult ModuleImport::convertGlobal(llvm::GlobalVariable *globalVar) {
if (globalVar->hasComdat())
globalOp.setComdatAttr(comdatMapping.lookup(globalVar->getComdat()));
+ processTargetSpecificAttrs(globalVar, globalOp);
+
return success();
}
@@ -2525,7 +2589,7 @@ static void processMemoryEffects(llvm::Function *func, LLVMFuncOp funcOp) {
// List of LLVM IR attributes that map to an explicit attribute on the MLIR
// LLVMFuncOp.
-static constexpr std::array kExplicitAttributes{
+static constexpr std::array kExplicitLLVMFuncOpAttributes{
StringLiteral("aarch64_in_za"),
StringLiteral("aarch64_inout_za"),
StringLiteral("aarch64_new_za"),
@@ -2535,7 +2599,6 @@ static constexpr std::array kExplicitAttributes{
StringLiteral("aarch64_pstate_sm_compatible"),
StringLiteral("aarch64_pstate_sm_enabled"),
StringLiteral("alwaysinline"),
- StringLiteral("approx-func-fp-math"),
StringLiteral("convergent"),
StringLiteral("denormal-fp-math"),
StringLiteral("denormal-fp-math-f32"),
@@ -2543,6 +2606,7 @@ static constexpr std::array kExplicitAttributes{
StringLiteral("frame-pointer"),
StringLiteral("instrument-function-entry"),
StringLiteral("instrument-function-exit"),
+ StringLiteral("memory"),
StringLiteral("no-infs-fp-math"),
StringLiteral("no-nans-fp-math"),
StringLiteral("no-signed-zeros-fp-math"),
@@ -2557,61 +2621,17 @@ static constexpr std::array kExplicitAttributes{
StringLiteral("willreturn"),
};
+/// Converts LLVM attributes from `func` into MLIR attributes and adds them
+/// to `funcOp` as passthrough attributes, skipping those listed in
+/// `kExplicitLLVMFuncAttributes`.
static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) {
- MLIRContext *context = funcOp.getContext();
- SmallVector<Attribute> passthroughs;
llvm::AttributeSet funcAttrs = func->getAttributes().getAttributes(
llvm::AttributeList::AttrIndex::FunctionIndex);
- for (llvm::Attribute attr : funcAttrs) {
- // Skip the memory attribute since the LLVMFuncOp has an explicit memory
- // attribute.
- if (attr.hasAttribute(llvm::Attribute::Memory))
- continue;
-
- // Skip invalid type attributes.
- if (attr.isTypeAttribute()) {
- emitWarning(funcOp.getLoc(),
- "type attributes on a function are invalid, skipping it");
- continue;
- }
-
- StringRef attrName;
- if (attr.isStringAttribute())
- attrName = attr.getKindAsString();
- else
- attrName = llvm::Attribute::getNameFromAttrKind(attr.getKindAsEnum());
- auto keyAttr = StringAttr::get(context, attrName);
-
- // Skip attributes that map to an explicit attribute on the LLVMFuncOp.
- if (llvm::is_contained(kExplicitAttributes, attrName))
- continue;
-
- if (attr.isStringAttribute()) {
- StringRef val = attr.getValueAsString();
- if (val.empty()) {
- passthroughs.push_back(keyAttr);
- continue;
- }
- passthroughs.push_back(
- ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)}));
- continue;
- }
- if (attr.isIntAttribute()) {
- auto val = std::to_string(attr.getValueAsInt());
- passthroughs.push_back(
- ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)}));
- continue;
- }
- if (attr.isEnumAttribute()) {
- passthroughs.push_back(keyAttr);
- continue;
- }
-
- llvm_unreachable("unexpected attribute kind");
- }
-
- if (!passthroughs.empty())
- funcOp.setPassthroughAttr(ArrayAttr::get(context, passthroughs));
+ ArrayAttr passthroughAttr =
+ convertLLVMAttributesToMLIR(funcOp.getLoc(), funcOp.getContext(),
+ funcAttrs, kExplicitLLVMFuncOpAttributes);
+ if (!passthroughAttr.empty())
+ funcOp.setPassthroughAttr(passthroughAttr);
}
void ModuleImport::processFunctionAttributes(llvm::Function *func,
@@ -2703,10 +2723,6 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
attr.isStringAttribute())
funcOp.setNoNansFpMath(attr.getValueAsBool());
- if (llvm::Attribute attr = func->getFnAttribute("approx-func-fp-math");
- attr.isStringAttribute())
- funcOp.setApproxFuncFpMath(attr.getValueAsBool());
-
if (llvm::Attribute attr = func->getFnAttribute("instrument-function-entry");
attr.isStringAttribute())
funcOp.setInstrumentFunctionEntry(
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index b3a06e2..97253591 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1081,6 +1081,83 @@ static void addRuntimePreemptionSpecifier(bool dsoLocalRequested,
gv->setDSOLocal(true);
}
+/// Attempts to translate an MLIR attribute identified by `key`, optionally with
+/// the given `value`, into an LLVM IR attribute. Reports errors at `loc` if
+/// any. If the attribute name corresponds to a known LLVM IR attribute kind,
+/// creates the LLVM attribute of that kind; otherwise, keeps it as a string
+/// attribute. Performs additional checks for attributes known to have or not
+/// have a value in order to avoid assertions inside LLVM upon construction.
+static FailureOr<llvm::Attribute>
+convertMLIRAttributeToLLVM(Location loc, llvm::LLVMContext &ctx, StringRef key,
+ StringRef value = StringRef()) {
+ auto kind = llvm::Attribute::getAttrKindFromName(key);
+ if (kind == llvm::Attribute::None)
+ return llvm::Attribute::get(ctx, key, value);
+
+ if (llvm::Attribute::isIntAttrKind(kind)) {
+ if (value.empty())
+ return emitError(loc) << "LLVM attribute '" << key << "' expects a value";
+
+ int64_t result;
+ if (!value.getAsInteger(/*Radix=*/0, result))
+ return llvm::Attribute::get(ctx, kind, result);
+ return llvm::Attribute::get(ctx, key, value);
+ }
+
+ if (!value.empty())
+ return emitError(loc) << "LLVM attribute '" << key
+ << "' does not expect a value, found '" << value
+ << "'";
+
+ return llvm::Attribute::get(ctx, kind);
+}
+
+/// Converts the MLIR attributes listed in the given array attribute into LLVM
+/// attributes. Returns an `AttrBuilder` containing the converted attributes.
+/// Reports error to `loc` if any and returns immediately. Expects `arrayAttr`
+/// to contain either string attributes, treated as value-less LLVM attributes,
+/// or array attributes containing two string attributes, with the first string
+/// being the name of the corresponding LLVM attribute and the second string
+/// beings its value. Note that even integer attributes are expected to have
+/// their values expressed as strings.
+static FailureOr<llvm::AttrBuilder>
+convertMLIRAttributesToLLVM(Location loc, llvm::LLVMContext &ctx,
+ ArrayAttr arrayAttr, StringRef arrayAttrName) {
+ llvm::AttrBuilder attrBuilder(ctx);
+ if (!arrayAttr)
+ return attrBuilder;
+
+ for (Attribute attr : arrayAttr) {
+ if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
+ FailureOr<llvm::Attribute> llvmAttr =
+ convertMLIRAttributeToLLVM(loc, ctx, stringAttr.getValue());
+ if (failed(llvmAttr))
+ return failure();
+ attrBuilder.addAttribute(*llvmAttr);
+ continue;
+ }
+
+ auto arrayAttr = dyn_cast<ArrayAttr>(attr);
+ if (!arrayAttr || arrayAttr.size() != 2)
+ return emitError(loc) << "expected '" << arrayAttrName
+ << "' to contain string or array attributes";
+
+ auto keyAttr = dyn_cast<StringAttr>(arrayAttr[0]);
+ auto valueAttr = dyn_cast<StringAttr>(arrayAttr[1]);
+ if (!keyAttr || !valueAttr)
+ return emitError(loc) << "expected arrays within '" << arrayAttrName
+ << "' to contain two strings";
+
+ FailureOr<llvm::Attribute> llvmAttr = convertMLIRAttributeToLLVM(
+ loc, ctx, keyAttr.getValue(), valueAttr.getValue());
+ if (failed(llvmAttr))
+ return failure();
+ attrBuilder.addAttribute(*llvmAttr);
+ }
+
+ return attrBuilder;
+}
+
LogicalResult ModuleTranslation::convertGlobalsAndAliases() {
// Mapping from compile unit to its respective set of global variables.
DenseMap<llvm::DICompileUnit *, SmallVector<llvm::Metadata *>> allGVars;
@@ -1191,6 +1268,15 @@ LogicalResult ModuleTranslation::convertGlobalsAndAliases() {
}
}
}
+
+ // Forward the target-specific attributes to LLVM.
+ FailureOr<llvm::AttrBuilder> convertedTargetSpecificAttrs =
+ convertMLIRAttributesToLLVM(op.getLoc(), var->getContext(),
+ op.getTargetSpecificAttrsAttr(),
+ op.getTargetSpecificAttrsAttrName());
+ if (failed(convertedTargetSpecificAttrs))
+ return failure();
+ var->addAttributes(*convertedTargetSpecificAttrs);
}
// Create all llvm::GlobalAlias
@@ -1381,44 +1467,6 @@ LogicalResult ModuleTranslation::convertGlobalsAndAliases() {
return success();
}
-/// Attempts to add an attribute identified by `key`, optionally with the given
-/// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the
-/// attribute has a kind known to LLVM IR, create the attribute of this kind,
-/// otherwise keep it as a string attribute. Performs additional checks for
-/// attributes known to have or not have a value in order to avoid assertions
-/// inside LLVM upon construction.
-static LogicalResult checkedAddLLVMFnAttribute(Location loc,
- llvm::Function *llvmFunc,
- StringRef key,
- StringRef value = StringRef()) {
- auto kind = llvm::Attribute::getAttrKindFromName(key);
- if (kind == llvm::Attribute::None) {
- llvmFunc->addFnAttr(key, value);
- return success();
- }
-
- if (llvm::Attribute::isIntAttrKind(kind)) {
- if (value.empty())
- return emitError(loc) << "LLVM attribute '" << key << "' expects a value";
-
- int64_t result;
- if (!value.getAsInteger(/*Radix=*/0, result))
- llvmFunc->addFnAttr(
- llvm::Attribute::get(llvmFunc->getContext(), kind, result));
- else
- llvmFunc->addFnAttr(key, value);
- return success();
- }
-
- if (!value.empty())
- return emitError(loc) << "LLVM attribute '" << key
- << "' does not expect a value, found '" << value
- << "'";
-
- llvmFunc->addFnAttr(kind);
- return success();
-}
-
/// Return a representation of `value` as metadata.
static llvm::Metadata *convertIntegerToMetadata(llvm::LLVMContext &context,
const llvm::APInt &value) {
@@ -1454,45 +1502,6 @@ static llvm::MDNode *convertIntegerArrayToMDNode(llvm::LLVMContext &context,
return llvm::MDNode::get(context, mdValues);
}
-/// Attaches the attributes listed in the given array attribute to `llvmFunc`.
-/// Reports error to `loc` if any and returns immediately. Expects `attributes`
-/// to be an array attribute containing either string attributes, treated as
-/// value-less LLVM attributes, or array attributes containing two string
-/// attributes, with the first string being the name of the corresponding LLVM
-/// attribute and the second string beings its value. Note that even integer
-/// attributes are expected to have their values expressed as strings.
-static LogicalResult
-forwardPassthroughAttributes(Location loc, std::optional<ArrayAttr> attributes,
- llvm::Function *llvmFunc) {
- if (!attributes)
- return success();
-
- for (Attribute attr : *attributes) {
- if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
- if (failed(
- checkedAddLLVMFnAttribute(loc, llvmFunc, stringAttr.getValue())))
- return failure();
- continue;
- }
-
- auto arrayAttr = dyn_cast<ArrayAttr>(attr);
- if (!arrayAttr || arrayAttr.size() != 2)
- return emitError(loc)
- << "expected 'passthrough' to contain string or array attributes";
-
- auto keyAttr = dyn_cast<StringAttr>(arrayAttr[0]);
- auto valueAttr = dyn_cast<StringAttr>(arrayAttr[1]);
- if (!keyAttr || !valueAttr)
- return emitError(loc)
- << "expected arrays within 'passthrough' to contain two strings";
-
- if (failed(checkedAddLLVMFnAttribute(loc, llvmFunc, keyAttr.getValue(),
- valueAttr.getValue())))
- return failure();
- }
- return success();
-}
-
LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
// Clear the block, branch value mappings, they are only relevant within one
// function.
@@ -1561,10 +1570,6 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
if (auto noNansFpMath = func.getNoNansFpMath())
llvmFunc->addFnAttr("no-nans-fp-math", llvm::toStringRef(*noNansFpMath));
- if (auto approxFuncFpMath = func.getApproxFuncFpMath())
- llvmFunc->addFnAttr("approx-func-fp-math",
- llvm::toStringRef(*approxFuncFpMath));
-
if (auto noSignedZerosFpMath = func.getNoSignedZerosFpMath())
llvmFunc->addFnAttr("no-signed-zeros-fp-math",
llvm::toStringRef(*noSignedZerosFpMath));
@@ -1864,9 +1869,13 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
}
// Forward the pass-through attributes to LLVM.
- if (failed(forwardPassthroughAttributes(
- function.getLoc(), function.getPassthrough(), llvmFunc)))
+ FailureOr<llvm::AttrBuilder> convertedPassthroughAttrs =
+ convertMLIRAttributesToLLVM(function.getLoc(), llvmFunc->getContext(),
+ function.getPassthroughAttr(),
+ function.getPassthroughAttrName());
+ if (failed(convertedPassthroughAttrs))
return failure();
+ llvmFunc->addFnAttrs(*convertedPassthroughAttrs);
// Convert visibility attribute.
llvmFunc->setVisibility(convertVisibilityToLLVM(function.getVisibility_()));
@@ -2407,11 +2416,6 @@ mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
if (failed(translator.convertUnresolvedBlockAddress()))
return nullptr;
- // Once we've finished constructing elements in the module, we should convert
- // it to use the debug info format desired by LLVM.
- // See https://llvm.org/docs/RemoveDIsDebugInfo.html
- translator.llvmModule->convertToNewDbgValues();
-
// Add the necessary debug info module flags, if they were not encoded in MLIR
// beforehand.
translator.debugTranslation->addModuleFlagsIfNotPresent();
diff --git a/mlir/lib/Target/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..044da1c
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Transforms/CMakeLists.txt
@@ -0,0 +1,23 @@
+add_mlir_dialect_library(MLIRTargetLLVMIRTransforms
+ TargetToDataLayout.cpp
+ TargetToTargetFeatures.cpp
+ TargetUtils.cpp
+
+ DEPENDS
+ MLIRTargetLLVMIRTransformsIncGen
+
+ LINK_COMPONENTS
+ MC
+ Target
+ TargetParser
+ AllTargetsAsmParsers
+ AllTargetsCodeGens
+ AllTargetsDescs
+ AllTargetsInfos
+
+ LINK_LIBS PUBLIC
+ MLIRDLTIDialect
+ MLIRLLVMDialect
+ MLIRPass
+ MLIRTargetLLVMIRImport
+ )
diff --git a/mlir/lib/Target/LLVMIR/Transforms/TargetToDataLayout.cpp b/mlir/lib/Target/LLVMIR/Transforms/TargetToDataLayout.cpp
new file mode 100644
index 0000000..c0f9ceb
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Transforms/TargetToDataLayout.cpp
@@ -0,0 +1,62 @@
+//===- TargetToDataLayout.cpp - extract data layout from TargetMachine ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/Transforms/Passes.h"
+#include "mlir/Target/LLVMIR/Transforms/TargetUtils.h"
+
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Target/LLVMIR/Import.h"
+
+namespace mlir {
+namespace LLVM {
+#define GEN_PASS_DEF_LLVMTARGETTODATALAYOUT
+#include "mlir/Target/LLVMIR/Transforms/Passes.h.inc"
+} // namespace LLVM
+} // namespace mlir
+
+using namespace mlir;
+
+struct TargetToDataLayoutPass
+ : public LLVM::impl::LLVMTargetToDataLayoutBase<TargetToDataLayoutPass> {
+ using LLVM::impl::LLVMTargetToDataLayoutBase<
+ TargetToDataLayoutPass>::LLVMTargetToDataLayoutBase;
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+
+ if (initializeLLVMTargets)
+ LLVM::detail::initializeBackendsOnce();
+
+ auto targetAttr = op->getAttrOfType<LLVM::TargetAttrInterface>(
+ LLVM::LLVMDialect::getTargetAttrName());
+ if (!targetAttr) {
+ op->emitError()
+ << "no TargetAttrInterface-implementing attribute at key \""
+ << LLVM::LLVMDialect::getTargetAttrName() << "\"";
+ return signalPassFailure();
+ }
+
+ FailureOr<llvm::DataLayout> dataLayout =
+ LLVM::detail::getDataLayout(targetAttr);
+ if (failed(dataLayout)) {
+ op->emitError() << "failed to obtain llvm::DataLayout for " << targetAttr;
+ return signalPassFailure();
+ }
+
+ DataLayoutSpecInterface dataLayoutSpec =
+ mlir::translateDataLayout(dataLayout.value(), &getContext());
+
+ if (auto existingDlSpec = op->getAttrOfType<DataLayoutSpecInterface>(
+ DLTIDialect::kDataLayoutAttrName)) {
+ dataLayoutSpec = existingDlSpec.combineWith({dataLayoutSpec});
+ }
+
+ op->setAttr(DLTIDialect::kDataLayoutAttrName, dataLayoutSpec);
+ }
+};
diff --git a/mlir/lib/Target/LLVMIR/Transforms/TargetToTargetFeatures.cpp b/mlir/lib/Target/LLVMIR/Transforms/TargetToTargetFeatures.cpp
new file mode 100644
index 0000000..4a1ca46
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Transforms/TargetToTargetFeatures.cpp
@@ -0,0 +1,78 @@
+//===- TargetToTargetFeatures.cpp - extract features from TargetMachine ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/Transforms/Passes.h"
+#include "mlir/Target/LLVMIR/Transforms/TargetUtils.h"
+
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Target/LLVMIR/Import.h"
+
+#include "llvm/MC/MCSubtargetInfo.h"
+
+namespace mlir {
+namespace LLVM {
+#define GEN_PASS_DEF_LLVMTARGETTOTARGETFEATURES
+#include "mlir/Target/LLVMIR/Transforms/Passes.h.inc"
+} // namespace LLVM
+} // namespace mlir
+
+using namespace mlir;
+
+struct TargetToTargetFeaturesPass
+ : public LLVM::impl::LLVMTargetToTargetFeaturesBase<
+ TargetToTargetFeaturesPass> {
+ using LLVM::impl::LLVMTargetToTargetFeaturesBase<
+ TargetToTargetFeaturesPass>::LLVMTargetToTargetFeaturesBase;
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+
+ if (initializeLLVMTargets)
+ LLVM::detail::initializeBackendsOnce();
+
+ auto targetAttr = op->getAttrOfType<LLVM::TargetAttr>(
+ LLVM::LLVMDialect::getTargetAttrName());
+ if (!targetAttr) {
+ op->emitError() << "no LLVM::TargetAttr attribute at key \""
+ << LLVM::LLVMDialect::getTargetAttrName() << "\"";
+ return signalPassFailure();
+ }
+
+ FailureOr<std::unique_ptr<llvm::TargetMachine>> targetMachine =
+ LLVM::detail::getTargetMachine(targetAttr);
+ if (failed(targetMachine)) {
+ op->emitError() << "failed to obtain llvm::TargetMachine for "
+ << targetAttr;
+ return signalPassFailure();
+ }
+
+ llvm::MCSubtargetInfo const *subTargetInfo =
+ (*targetMachine)->getMCSubtargetInfo();
+
+ const std::vector<llvm::SubtargetFeatureKV> enabledFeatures =
+ subTargetInfo->getEnabledProcessorFeatures();
+
+ auto plussedFeatures = llvm::to_vector(
+ llvm::map_range(enabledFeatures, [](llvm::SubtargetFeatureKV feature) {
+ return std::string("+") + feature.Key;
+ }));
+
+ auto plussedFeaturesRefs = llvm::to_vector(llvm::map_range(
+ plussedFeatures, [](auto &it) { return StringRef(it.c_str()); }));
+
+ auto fullTargetFeaturesAttr =
+ LLVM::TargetFeaturesAttr::get(&getContext(), plussedFeaturesRefs);
+
+ auto updatedTargetAttr =
+ LLVM::TargetAttr::get(&getContext(), targetAttr.getTriple(),
+ targetAttr.getChip(), fullTargetFeaturesAttr);
+
+ op->setAttr(LLVM::LLVMDialect::getTargetAttrName(), updatedTargetAttr);
+ }
+};
diff --git a/mlir/lib/Target/LLVMIR/Transforms/TargetUtils.cpp b/mlir/lib/Target/LLVMIR/Transforms/TargetUtils.cpp
new file mode 100644
index 0000000..f1d3622
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Transforms/TargetUtils.cpp
@@ -0,0 +1,71 @@
+//===- TargetUtils.cpp - utils for obtaining generic target backend info --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/Transforms/Passes.h"
+
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Target/LLVMIR/Import.h"
+
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Target/TargetMachine.h"
+
+#define DEBUG_TYPE "mlir-llvm-target-utils"
+
+namespace mlir {
+namespace LLVM {
+namespace detail {
+void initializeBackendsOnce() {
+ static const auto initOnce = [] {
+ // Ensure that the targets, that LLVM has been configured to support,
+ // are loaded into the TargetRegistry.
+ llvm::InitializeAllTargets();
+ llvm::InitializeAllTargetMCs();
+ return true;
+ }();
+ (void)initOnce; // Dummy usage.
+}
+
+FailureOr<std::unique_ptr<llvm::TargetMachine>>
+getTargetMachine(mlir::LLVM::TargetAttrInterface attr) {
+ StringRef triple = attr.getTriple();
+ StringRef chipAKAcpu = attr.getChip();
+ // NB: `TargetAttrInterface::getFeatures()` is coarsely typed to work around
+ // cyclic dependency issue in tablegen files.
+ auto featuresAttr =
+ llvm::cast_if_present<LLVM::TargetFeaturesAttr>(attr.getFeatures());
+ std::string features = featuresAttr ? featuresAttr.getFeaturesString() : "";
+
+ std::string error;
+ const llvm::Target *target =
+ llvm::TargetRegistry::lookupTarget(triple, error);
+ if (!target || !error.empty()) {
+ LDBG() << "Looking up target '" << triple << "' failed: " << error << "\n";
+ return failure();
+ }
+
+ return std::unique_ptr<llvm::TargetMachine>(target->createTargetMachine(
+ llvm::Triple(triple), chipAKAcpu, features, {}, {}));
+}
+
+FailureOr<llvm::DataLayout>
+getDataLayout(mlir::LLVM::TargetAttrInterface attr) {
+ FailureOr<std::unique_ptr<llvm::TargetMachine>> targetMachine =
+ getTargetMachine(attr);
+ if (failed(targetMachine)) {
+ LDBG() << "Failed to retrieve the target machine for data layout.\n";
+ return failure();
+ }
+ return (targetMachine.value())->createDataLayout();
+}
+
+} // namespace detail
+} // namespace LLVM
+} // namespace mlir
diff --git a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
index e4ba478..ddd5946 100644
--- a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -71,7 +72,7 @@ public:
})
.Case<LLVM::LLVMArrayType, IntegerType, LLVM::LLVMFunctionType,
LLVM::LLVMPointerType, LLVM::LLVMStructType, VectorType,
- LLVM::LLVMTargetExtType>(
+ LLVM::LLVMTargetExtType, PtrLikeTypeInterface>(
[this](auto type) { return this->translate(type); })
.Default([](Type t) -> llvm::Type * {
llvm_unreachable("unknown LLVM dialect type");
@@ -149,6 +150,14 @@ private:
type.getIntParams());
}
+ /// Translates the given ptr type.
+ llvm::Type *translate(PtrLikeTypeInterface type) {
+ auto memSpace = dyn_cast<LLVM::AddressSpaceAttr>(type.getMemorySpace());
+ assert(memSpace && "expected pointer with the LLVM address space");
+ assert(!type.hasPtrMetadata() && "expected pointer without metadata");
+ return llvm::PointerType::get(context, memSpace.getAddressSpace());
+ }
+
/// Translates a list of types.
void translateTypes(ArrayRef<Type> types,
SmallVectorImpl<llvm::Type *> &result) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index c967e86..3625dd2 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -229,7 +229,7 @@ spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
}
template <typename AttrTy, typename EnumAttrTy, typename EnumTy>
-LogicalResult deserializeCacheControlDecoration(
+static LogicalResult deserializeCacheControlDecoration(
Location loc, OpBuilder &opBuilder,
DenseMap<uint32_t, NamedAttrList> &decorations, ArrayRef<uint32_t> words,
StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
@@ -1560,7 +1560,19 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
}
auto resultID = operands[1];
- if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
+ if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
+ SmallVector<Attribute> flattenedElems;
+ for (Attribute element : elements) {
+ if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(element)) {
+ for (auto value : denseElemAttr.getValues<Attribute>())
+ flattenedElems.push_back(value);
+ } else {
+ flattenedElems.push_back(element);
+ }
+ }
+ auto attr = DenseElementsAttr::get(tensorType, flattenedElems);
+ constantMap.try_emplace(resultID, attr, tensorType);
+ } else if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
auto attr = DenseElementsAttr::get(shapedType, elements);
// For normal constants, we just record the attribute (and its type) for
// later materialization at use sites.
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index c049574..7fc7795 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
@@ -112,7 +113,9 @@ LogicalResult Serializer::serialize() {
// TODO: handle the other sections
processCapability();
- processExtension();
+ if (failed(processExtension())) {
+ return failure();
+ }
processMemoryModel();
processDebugInfo();
@@ -204,13 +207,24 @@ void Serializer::processDebugInfo() {
// TODO: Encode more debug instructions.
}
-void Serializer::processExtension() {
+LogicalResult Serializer::processExtension() {
llvm::SmallVector<uint32_t, 16> extName;
- for (spirv::Extension ext : module.getVceTriple()->getExtensions()) {
+ llvm::SmallSet<Extension, 4> deducedExts(
+ llvm::from_range, module.getVceTriple()->getExtensions());
+ auto nonSemanticInfoExt = spirv::Extension::SPV_KHR_non_semantic_info;
+ if (options.emitDebugInfo && !deducedExts.contains(nonSemanticInfoExt)) {
+ TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(module);
+ if (!is_contained(targetEnvAttr.getExtensions(), nonSemanticInfoExt))
+ return module.emitError(
+ "SPV_KHR_non_semantic_info extension not available");
+ deducedExts.insert(nonSemanticInfoExt);
+ }
+ for (spirv::Extension ext : deducedExts) {
extName.clear();
spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
}
+ return success();
}
void Serializer::processMemoryModel() {
@@ -956,6 +970,11 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
uint32_t resultID = getNextID();
SmallVector<uint32_t, 4> operands = {typeID, resultID};
auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
+ if (auto tensorArmType = dyn_cast<spirv::TensorArmType>(constType)) {
+ ArrayRef<int64_t> innerShape = tensorArmType.getShape().drop_front();
+ if (!innerShape.empty())
+ elementType = spirv::TensorArmType::get(innerShape, elementType);
+ }
// "If the Result Type is a cooperative matrix type, then there must be only
// one Constituent, with scalar type matching the cooperative matrix Component
@@ -979,30 +998,10 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
} else {
return 0;
}
- } else if (isa<spirv::TensorArmType>(constType)) {
- if (isZeroValue(valueAttr)) {
- encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
- {typeID, resultID});
- return resultID;
- }
- numberOfConstituents = shapedType.getNumElements();
- operands.reserve(numberOfConstituents + 2);
- for (int i = 0; i < numberOfConstituents; ++i) {
- uint32_t elementID = 0;
- if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
- elementID =
- elementType.isInteger(1)
- ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[i])
- : prepareConstantInt(loc, attr.getValues<IntegerAttr>()[i]);
- }
- if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
- elementID = prepareConstantFp(loc, attr.getValues<FloatAttr>()[i]);
- }
- if (!elementID) {
- return 0;
- }
- operands.push_back(elementID);
- }
+ } else if (isa<spirv::TensorArmType>(constType) && isZeroValue(valueAttr)) {
+ encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
+ {typeID, resultID});
+ return resultID;
} else {
operands.reserve(numberOfConstituents + 2);
for (int i = 0; i < numberOfConstituents; ++i) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index 7047869..fb2cecd 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -102,7 +102,7 @@ private:
void processDebugInfo();
- void processExtension();
+ LogicalResult processExtension();
void processMemoryModel();
diff --git a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
index ac338d55..796354e 100644
--- a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
+++ b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
@@ -21,8 +21,11 @@
#include "mlir/Target/SPIRV/Serialization.h"
#include "mlir/Tools/mlir-translate/Translation.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/FileSystem.h"
#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/Path.h"
#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/ToolOutputFile.h"
using namespace mlir;
@@ -76,24 +79,66 @@ void registerFromSPIRVTranslation() {
// Serialization registration
//===----------------------------------------------------------------------===//
-static LogicalResult serializeModule(spirv::ModuleOp module,
- raw_ostream &output) {
+static LogicalResult
+serializeModule(spirv::ModuleOp moduleOp, raw_ostream &output,
+ const spirv::SerializationOptions &options) {
SmallVector<uint32_t, 0> binary;
- if (failed(spirv::serialize(module, binary)))
+ if (failed(spirv::serialize(moduleOp, binary)))
return failure();
- output.write(reinterpret_cast<char *>(binary.data()),
- binary.size() * sizeof(uint32_t));
+ size_t sizeInBytes = binary.size() * sizeof(uint32_t);
+
+ output.write(reinterpret_cast<char *>(binary.data()), sizeInBytes);
+
+ if (options.saveModuleForValidation) {
+ size_t dirSeparator =
+ options.validationFilePrefix.find(llvm::sys::path::get_separator());
+ // If file prefix includes directory check if that directory exists.
+ if (dirSeparator != std::string::npos) {
+ llvm::StringRef parentDir =
+ llvm::sys::path::parent_path(options.validationFilePrefix);
+ if (!llvm::sys::fs::is_directory(parentDir))
+ return moduleOp.emitError(
+ "validation prefix directory does not exist\n");
+ }
+
+ SmallString<128> filename;
+ int fd = 0;
+
+ std::error_code errorCode = llvm::sys::fs::createUniqueFile(
+ options.validationFilePrefix + "%%%%%%.spv", fd, filename);
+ if (errorCode)
+ return moduleOp.emitError("error creating validation output file: ")
+ << errorCode.message() << "\n";
+
+ llvm::raw_fd_ostream validationOutput(fd, /*shouldClose=*/true);
+ validationOutput.write(reinterpret_cast<char *>(binary.data()),
+ sizeInBytes);
+ validationOutput.flush();
+ }
return mlir::success();
}
namespace mlir {
void registerToSPIRVTranslation() {
+ static llvm::cl::opt<std::string> validationFilesPrefix(
+ "spirv-save-validation-files-with-prefix",
+ llvm::cl::desc(
+ "When non-empty string is passed each serialized SPIR-V module is "
+ "saved to an additional file that starts with the given prefix. This "
+ "is used to generate separate binaries for validation, where "
+ "`--split-input-file` normally combines all outputs into one. The "
+ "one combined output (`-o`) is still written. Created files need to "
+ "be removed manually once processed."),
+ llvm::cl::init(""));
+
TranslateFromMLIRRegistration toBinary(
"serialize-spirv", "serialize SPIR-V dialect",
- [](spirv::ModuleOp module, raw_ostream &output) {
- return serializeModule(module, output);
+ [](spirv::ModuleOp moduleOp, raw_ostream &output) {
+ return serializeModule(moduleOp, output,
+ {true, false, !validationFilesPrefix.empty(),
+ validationFilesPrefix});
},
[](DialectRegistry &registry) {
registry.insert<spirv::SPIRVDialect>();
diff --git a/mlir/lib/Target/Wasm/CMakeLists.txt b/mlir/lib/Target/Wasm/CMakeLists.txt
new file mode 100644
index 0000000..890fc0ec
--- /dev/null
+++ b/mlir/lib/Target/Wasm/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_translation_library(MLIRTargetWasmImport
+ TranslateRegistration.cpp
+ TranslateFromWasm.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/Target/Wasm
+
+ LINK_LIBS PUBLIC
+ MLIRWasmSSADialect
+ MLIRIR
+ MLIRSupport
+ MLIRTranslateLib
+)
diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
new file mode 100644
index 0000000..6afbe05
--- /dev/null
+++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
@@ -0,0 +1,1522 @@
+//===- TranslateFromWasm.cpp - Translating to WasmSSA dialect -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the WebAssembly importer.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Target/Wasm/WasmBinaryEncoding.h"
+#include "mlir/Target/Wasm/WasmImporter.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/Endian.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/LEB128.h"
+#include "llvm/Support/LogicalResult.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <variant>
+
+#define DEBUG_TYPE "wasm-translate"
+
+static_assert(CHAR_BIT == 8,
+ "This code expects std::byte to be exactly 8 bits");
+
+using namespace mlir;
+using namespace mlir::wasm;
+using namespace mlir::wasmssa;
+
+namespace {
+using section_id_t = uint8_t;
+enum struct WasmSectionType : section_id_t {
+ CUSTOM = 0,
+ TYPE = 1,
+ IMPORT = 2,
+ FUNCTION = 3,
+ TABLE = 4,
+ MEMORY = 5,
+ GLOBAL = 6,
+ EXPORT = 7,
+ START = 8,
+ ELEMENT = 9,
+ CODE = 10,
+ DATA = 11,
+ DATACOUNT = 12
+};
+
+constexpr section_id_t highestWasmSectionID{
+ static_cast<section_id_t>(WasmSectionType::DATACOUNT)};
+
+#define APPLY_WASM_SEC_TRANSFORM \
+ WASM_SEC_TRANSFORM(CUSTOM) \
+ WASM_SEC_TRANSFORM(TYPE) \
+ WASM_SEC_TRANSFORM(IMPORT) \
+ WASM_SEC_TRANSFORM(FUNCTION) \
+ WASM_SEC_TRANSFORM(TABLE) \
+ WASM_SEC_TRANSFORM(MEMORY) \
+ WASM_SEC_TRANSFORM(GLOBAL) \
+ WASM_SEC_TRANSFORM(EXPORT) \
+ WASM_SEC_TRANSFORM(START) \
+ WASM_SEC_TRANSFORM(ELEMENT) \
+ WASM_SEC_TRANSFORM(CODE) \
+ WASM_SEC_TRANSFORM(DATA) \
+ WASM_SEC_TRANSFORM(DATACOUNT)
+
+template <WasmSectionType>
+constexpr const char *wasmSectionName = "";
+
+#define WASM_SEC_TRANSFORM(section) \
+ template <> \
+ [[maybe_unused]] constexpr const char \
+ *wasmSectionName<WasmSectionType::section> = #section;
+APPLY_WASM_SEC_TRANSFORM
+#undef WASM_SEC_TRANSFORM
+
+constexpr bool sectionShouldBeUnique(WasmSectionType secType) {
+ return secType != WasmSectionType::CUSTOM;
+}
+
+template <std::byte... Bytes>
+struct ByteSequence {};
+
+/// Template class for representing a byte sequence of only one byte
+template <std::byte Byte>
+struct UniqueByte : ByteSequence<Byte> {};
+
+[[maybe_unused]] constexpr ByteSequence<
+ WasmBinaryEncoding::Type::i32, WasmBinaryEncoding::Type::i64,
+ WasmBinaryEncoding::Type::f32, WasmBinaryEncoding::Type::f64,
+ WasmBinaryEncoding::Type::v128> valueTypesEncodings{};
+
+template <std::byte... allowedFlags>
+constexpr bool isValueOneOf(std::byte value,
+ ByteSequence<allowedFlags...> = {}) {
+ return ((value == allowedFlags) | ... | false);
+}
+
+template <std::byte... flags>
+constexpr bool isNotIn(std::byte value, ByteSequence<flags...> = {}) {
+ return !isValueOneOf<flags...>(value);
+}
+
+struct GlobalTypeRecord {
+ Type type;
+ bool isMutable;
+};
+
+struct TypeIdxRecord {
+ size_t id;
+};
+
+struct SymbolRefContainer {
+ FlatSymbolRefAttr symbol;
+};
+
+struct GlobalSymbolRefContainer : SymbolRefContainer {
+ Type globalType;
+};
+
+struct FunctionSymbolRefContainer : SymbolRefContainer {
+ FunctionType functionType;
+};
+
+using ImportDesc =
+ std::variant<TypeIdxRecord, TableType, LimitType, GlobalTypeRecord>;
+
+using parsed_inst_t = FailureOr<SmallVector<Value>>;
+
+struct WasmModuleSymbolTables {
+ SmallVector<FunctionSymbolRefContainer> funcSymbols;
+ SmallVector<GlobalSymbolRefContainer> globalSymbols;
+ SmallVector<SymbolRefContainer> memSymbols;
+ SmallVector<SymbolRefContainer> tableSymbols;
+ SmallVector<FunctionType> moduleFuncTypes;
+
+ std::string getNewSymbolName(StringRef prefix, size_t id) const {
+ return (prefix + Twine{id}).str();
+ }
+
+ std::string getNewFuncSymbolName() const {
+ size_t id = funcSymbols.size();
+ return getNewSymbolName("func_", id);
+ }
+
+ std::string getNewGlobalSymbolName() const {
+ size_t id = globalSymbols.size();
+ return getNewSymbolName("global_", id);
+ }
+
+ std::string getNewMemorySymbolName() const {
+ size_t id = memSymbols.size();
+ return getNewSymbolName("mem_", id);
+ }
+
+ std::string getNewTableSymbolName() const {
+ size_t id = tableSymbols.size();
+ return getNewSymbolName("table_", id);
+ }
+};
+
+class ParserHead;
+
+/// Wrapper around SmallVector to only allow access as push and pop on the
+/// stack. Makes sure that there are no "free accesses" on the stack to preserve
+/// its state.
+class ValueStack {
+private:
+ struct LabelLevel {
+ size_t stackIdx;
+ LabelLevelOpInterface levelOp;
+ };
+
+public:
+ bool empty() const { return values.empty(); }
+
+ size_t size() const { return values.size(); }
+
+ /// Pops values from the stack because they are being used in an operation.
+ /// @param operandTypes The list of expected types of the operation, used
+ /// to know how many values to pop and check if the types match the
+ /// expectation.
+ /// @param opLoc Location of the caller, used to report accurately the
+ /// location
+ /// if an error occurs.
+ /// @return Failure or the vector of popped values.
+ FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes,
+ Location *opLoc);
+
+ /// Push the results of an operation to the stack so they can be used in a
+ /// following operation.
+ /// @param results The list of results of the operation
+ /// @param opLoc Location of the caller, used to report accurately the
+ /// location
+ /// if an error occurs.
+ LogicalResult pushResults(ValueRange results, Location *opLoc);
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+ /// A simple dump function for debugging.
+ /// Writes output to llvm::dbgs().
+ LLVM_DUMP_METHOD void dump() const;
+#endif
+
+private:
+ SmallVector<Value> values;
+};
+
+using local_val_t = TypedValue<wasmssa::LocalRefType>;
+
+class ExpressionParser {
+public:
+ using locals_t = SmallVector<local_val_t>;
+ ExpressionParser(ParserHead &parser, WasmModuleSymbolTables const &symbols,
+ ArrayRef<local_val_t> initLocal)
+ : parser{parser}, symbols{symbols}, locals{initLocal} {}
+
+private:
+ template <std::byte opCode>
+ inline parsed_inst_t parseSpecificInstruction(OpBuilder &builder);
+
+ template <typename valueT>
+ parsed_inst_t
+ parseConstInst(OpBuilder &builder,
+ std::enable_if_t<std::is_arithmetic_v<valueT>> * = nullptr);
+
+ /// Construct an operation with \p numOperands operands and a single result.
+ /// Each operand must have the same type. Suitable for e.g. binops, unary
+ /// ops, etc.
+ ///
+ /// \p opcode - The WASM opcode to build.
+ /// \p valueType - The operand and result type for the built instruction.
+ /// \p numOperands - The number of operands for the built operation.
+ ///
+ /// \returns The parsed instruction result, or failure.
+ template <typename opcode, typename valueType, unsigned int numOperands>
+ inline parsed_inst_t
+ buildNumericOp(OpBuilder &builder,
+ std::enable_if_t<std::is_arithmetic_v<valueType>> * = nullptr);
+
+ /// This function generates a dispatch tree to associate an opcode with a
+ /// parser. Parsers are registered by specialising the
+ /// `parseSpecificInstruction` function for the op code to handle.
+ ///
+ /// The dispatcher is generated by recursively creating all possible patterns
+ /// for an opcode and calling the relevant parser on the leaf.
+ ///
+ /// @tparam patternBitSize is the first bit for which the pattern is not fixed
+ ///
+ /// @tparam highBitPattern is the fixed pattern that this instance handles for
+ /// the 8-patternBitSize bits
+ template <size_t patternBitSize = 0, std::byte highBitPattern = std::byte{0}>
+ inline parsed_inst_t dispatchToInstParser(std::byte opCode,
+ OpBuilder &builder) {
+ static_assert(patternBitSize <= 8,
+ "PatternBitSize is outside of range of opcode space! "
+ "(expected at most 8 bits)");
+ if constexpr (patternBitSize < 8) {
+ constexpr std::byte bitSelect{1 << (7 - patternBitSize)};
+ constexpr std::byte nextHighBitPatternStem = highBitPattern << 1;
+ constexpr size_t nextPatternBitSize = patternBitSize + 1;
+ if ((opCode & bitSelect) != std::byte{0})
+ return dispatchToInstParser<nextPatternBitSize,
+ nextHighBitPatternStem | std::byte{1}>(
+ opCode, builder);
+ return dispatchToInstParser<nextPatternBitSize, nextHighBitPatternStem>(
+ opCode, builder);
+ } else {
+ return parseSpecificInstruction<highBitPattern>(builder);
+ }
+ }
+
+ struct ParseResultWithInfo {
+ SmallVector<Value> opResults;
+ std::byte endingByte;
+ };
+
+public:
+ template <std::byte ParseEndByte = WasmBinaryEncoding::endByte>
+ parsed_inst_t parse(OpBuilder &builder, UniqueByte<ParseEndByte> = {});
+
+ template <std::byte... ExpressionParseEnd>
+ FailureOr<ParseResultWithInfo>
+ parse(OpBuilder &builder,
+ ByteSequence<ExpressionParseEnd...> parsingEndFilters);
+
+ FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes) {
+ return valueStack.popOperands(operandTypes, &currentOpLoc.value());
+ }
+
+ LogicalResult pushResults(ValueRange results) {
+ return valueStack.pushResults(results, &currentOpLoc.value());
+ }
+
+ /// The local.set and local.tee operations behave similarly and only differ
+ /// on their return value. This function factorizes the behavior of the two
+ /// operations in one place.
+ template <typename OpToCreate>
+ parsed_inst_t parseSetOrTee(OpBuilder &);
+
+private:
+ std::optional<Location> currentOpLoc;
+ ParserHead &parser;
+ WasmModuleSymbolTables const &symbols;
+ locals_t locals;
+ ValueStack valueStack;
+};
+
+class ParserHead {
+public:
+ ParserHead(StringRef src, StringAttr name) : head{src}, locName{name} {}
+ ParserHead(ParserHead &&) = default;
+
+private:
+ ParserHead(ParserHead const &other) = default;
+
+public:
+ auto getLocation() const {
+ return FileLineColLoc::get(locName, 0, anchorOffset + offset);
+ }
+
+ FailureOr<StringRef> consumeNBytes(size_t nBytes) {
+ LDBG() << "Consume " << nBytes << " bytes";
+ LDBG() << " Bytes remaining: " << size();
+ LDBG() << " Current offset: " << offset;
+ if (nBytes > size())
+ return emitError(getLocation(), "trying to extract ")
+ << nBytes << "bytes when only " << size() << "are available";
+
+ StringRef res = head.slice(offset, offset + nBytes);
+ offset += nBytes;
+ LDBG() << " Updated offset (+" << nBytes << "): " << offset;
+ return res;
+ }
+
+ FailureOr<std::byte> consumeByte() {
+ FailureOr<StringRef> res = consumeNBytes(1);
+ if (failed(res))
+ return failure();
+ return std::byte{*res->bytes_begin()};
+ }
+
+ template <typename T>
+ FailureOr<T> parseLiteral();
+
+ FailureOr<uint32_t> parseVectorSize();
+
+private:
+ // TODO: This is equivalent to parseLiteral<uint32_t> and could be removed
+ // if parseLiteral specialization were moved here, but default GCC on Ubuntu
+ // 22.04 has bug with template specialization in class declaration
+ inline FailureOr<uint32_t> parseUI32();
+ inline FailureOr<int64_t> parseI64();
+
+public:
+ FailureOr<StringRef> parseName() {
+ FailureOr<uint32_t> size = parseVectorSize();
+ if (failed(size))
+ return failure();
+
+ return consumeNBytes(*size);
+ }
+
+ FailureOr<WasmSectionType> parseWasmSectionType() {
+ FailureOr<std::byte> id = consumeByte();
+ if (failed(id))
+ return failure();
+ if (std::to_integer<unsigned>(*id) > highestWasmSectionID)
+ return emitError(getLocation(), "invalid section ID: ")
+ << static_cast<int>(*id);
+ return static_cast<WasmSectionType>(*id);
+ }
+
+ FailureOr<LimitType> parseLimit(MLIRContext *ctx) {
+ using WasmLimits = WasmBinaryEncoding::LimitHeader;
+ FileLineColLoc limitLocation = getLocation();
+ FailureOr<std::byte> limitHeader = consumeByte();
+ if (failed(limitHeader))
+ return failure();
+
+ if (isNotIn<WasmLimits::bothLimits, WasmLimits::lowLimitOnly>(*limitHeader))
+ return emitError(limitLocation, "invalid limit header: ")
+ << static_cast<int>(*limitHeader);
+ FailureOr<uint32_t> minParse = parseUI32();
+ if (failed(minParse))
+ return failure();
+ std::optional<uint32_t> max{std::nullopt};
+ if (*limitHeader == WasmLimits::bothLimits) {
+ FailureOr<uint32_t> maxParse = parseUI32();
+ if (failed(maxParse))
+ return failure();
+ max = *maxParse;
+ }
+ return LimitType::get(ctx, *minParse, max);
+ }
+
+ FailureOr<Type> parseValueType(MLIRContext *ctx) {
+ FileLineColLoc typeLoc = getLocation();
+ FailureOr<std::byte> typeEncoding = consumeByte();
+ if (failed(typeEncoding))
+ return failure();
+ switch (*typeEncoding) {
+ case WasmBinaryEncoding::Type::i32:
+ return IntegerType::get(ctx, 32);
+ case WasmBinaryEncoding::Type::i64:
+ return IntegerType::get(ctx, 64);
+ case WasmBinaryEncoding::Type::f32:
+ return Float32Type::get(ctx);
+ case WasmBinaryEncoding::Type::f64:
+ return Float64Type::get(ctx);
+ case WasmBinaryEncoding::Type::v128:
+ return IntegerType::get(ctx, 128);
+ case WasmBinaryEncoding::Type::funcRef:
+ return wasmssa::FuncRefType::get(ctx);
+ case WasmBinaryEncoding::Type::externRef:
+ return wasmssa::ExternRefType::get(ctx);
+ default:
+ return emitError(typeLoc, "invalid value type encoding: ")
+ << static_cast<int>(*typeEncoding);
+ }
+ }
+
+ FailureOr<GlobalTypeRecord> parseGlobalType(MLIRContext *ctx) {
+ using WasmGlobalMut = WasmBinaryEncoding::GlobalMutability;
+ FailureOr<Type> typeParsed = parseValueType(ctx);
+ if (failed(typeParsed))
+ return failure();
+ FileLineColLoc mutLoc = getLocation();
+ FailureOr<std::byte> mutSpec = consumeByte();
+ if (failed(mutSpec))
+ return failure();
+ if (isNotIn<WasmGlobalMut::isConst, WasmGlobalMut::isMutable>(*mutSpec))
+ return emitError(mutLoc, "invalid global mutability specifier: ")
+ << static_cast<int>(*mutSpec);
+ return GlobalTypeRecord{*typeParsed, *mutSpec == WasmGlobalMut::isMutable};
+ }
+
+ FailureOr<TupleType> parseResultType(MLIRContext *ctx) {
+ FailureOr<uint32_t> nParamsParsed = parseVectorSize();
+ if (failed(nParamsParsed))
+ return failure();
+ uint32_t nParams = *nParamsParsed;
+ SmallVector<Type> res{};
+ res.reserve(nParams);
+ for (size_t i = 0; i < nParams; ++i) {
+ FailureOr<Type> parsedType = parseValueType(ctx);
+ if (failed(parsedType))
+ return failure();
+ res.push_back(*parsedType);
+ }
+ return TupleType::get(ctx, res);
+ }
+
+ FailureOr<FunctionType> parseFunctionType(MLIRContext *ctx) {
+ FileLineColLoc typeLoc = getLocation();
+ FailureOr<std::byte> funcTypeHeader = consumeByte();
+ if (failed(funcTypeHeader))
+ return failure();
+ if (*funcTypeHeader != WasmBinaryEncoding::Type::funcType)
+ return emitError(typeLoc, "invalid function type header byte. Expecting ")
+ << std::to_integer<unsigned>(WasmBinaryEncoding::Type::funcType)
+ << " got " << std::to_integer<unsigned>(*funcTypeHeader);
+ FailureOr<TupleType> inputTypes = parseResultType(ctx);
+ if (failed(inputTypes))
+ return failure();
+
+ FailureOr<TupleType> resTypes = parseResultType(ctx);
+ if (failed(resTypes))
+ return failure();
+
+ return FunctionType::get(ctx, inputTypes->getTypes(), resTypes->getTypes());
+ }
+
+ FailureOr<TypeIdxRecord> parseTypeIndex() {
+ FailureOr<uint32_t> res = parseUI32();
+ if (failed(res))
+ return failure();
+ return TypeIdxRecord{*res};
+ }
+
+ FailureOr<TableType> parseTableType(MLIRContext *ctx) {
+ FailureOr<Type> elmTypeParse = parseValueType(ctx);
+ if (failed(elmTypeParse))
+ return failure();
+ if (!isWasmRefType(*elmTypeParse))
+ return emitError(getLocation(), "invalid element type for table");
+ FailureOr<LimitType> limitParse = parseLimit(ctx);
+ if (failed(limitParse))
+ return failure();
+ return TableType::get(ctx, *elmTypeParse, *limitParse);
+ }
+
+ FailureOr<ImportDesc> parseImportDesc(MLIRContext *ctx) {
+ FileLineColLoc importLoc = getLocation();
+ FailureOr<std::byte> importType = consumeByte();
+ auto packager = [](auto parseResult) -> FailureOr<ImportDesc> {
+ if (failed(parseResult))
+ return failure();
+ return {*parseResult};
+ };
+ if (failed(importType))
+ return failure();
+ switch (*importType) {
+ case WasmBinaryEncoding::Import::typeID:
+ return packager(parseTypeIndex());
+ case WasmBinaryEncoding::Import::tableType:
+ return packager(parseTableType(ctx));
+ case WasmBinaryEncoding::Import::memType:
+ return packager(parseLimit(ctx));
+ case WasmBinaryEncoding::Import::globalType:
+ return packager(parseGlobalType(ctx));
+ default:
+ return emitError(importLoc, "invalid import type descriptor: ")
+ << static_cast<int>(*importType);
+ }
+ }
+
+ parsed_inst_t parseExpression(OpBuilder &builder,
+ WasmModuleSymbolTables const &symbols,
+ ArrayRef<local_val_t> locals = {}) {
+ auto eParser = ExpressionParser{*this, symbols, locals};
+ return eParser.parse(builder);
+ }
+
+ LogicalResult parseCodeFor(FuncOp func,
+ WasmModuleSymbolTables const &symbols) {
+ SmallVector<local_val_t> locals{};
+ // Populating locals with function argument
+ Block &block = func.getBody().front();
+ // Delete temporary return argument which was only created for IR validity
+ assert(func.getBody().getBlocks().size() == 1 &&
+ "Function should only have its default created block at this point");
+ assert(block.getOperations().size() == 1 &&
+ "Only the placeholder return op should be present at this point");
+ auto returnOp = cast<ReturnOp>(&block.back());
+ assert(returnOp);
+
+ FailureOr<uint32_t> codeSizeInBytes = parseUI32();
+ if (failed(codeSizeInBytes))
+ return failure();
+ FailureOr<StringRef> codeContent = consumeNBytes(*codeSizeInBytes);
+ if (failed(codeContent))
+ return failure();
+ auto name = StringAttr::get(func->getContext(),
+ locName.str() + "::" + func.getSymName());
+ auto cParser = ParserHead{*codeContent, name};
+ FailureOr<uint32_t> localVecSize = cParser.parseVectorSize();
+ if (failed(localVecSize))
+ return failure();
+ OpBuilder builder{&func.getBody().front().back()};
+ for (auto arg : block.getArguments())
+ locals.push_back(cast<TypedValue<LocalRefType>>(arg));
+ // Declare the local ops
+ uint32_t nVarVec = *localVecSize;
+ for (size_t i = 0; i < nVarVec; ++i) {
+ FileLineColLoc varLoc = cParser.getLocation();
+ FailureOr<uint32_t> nSubVar = cParser.parseUI32();
+ if (failed(nSubVar))
+ return failure();
+ FailureOr<Type> varT = cParser.parseValueType(func->getContext());
+ if (failed(varT))
+ return failure();
+ for (size_t j = 0; j < *nSubVar; ++j) {
+ auto local = builder.create<LocalOp>(varLoc, *varT);
+ locals.push_back(local.getResult());
+ }
+ }
+ parsed_inst_t res = cParser.parseExpression(builder, symbols, locals);
+ if (failed(res))
+ return failure();
+ if (!cParser.end())
+ return emitError(cParser.getLocation(),
+ "unparsed garbage remaining at end of code block");
+ builder.create<ReturnOp>(func->getLoc(), *res);
+ returnOp->erase();
+ return success();
+ }
+
+ bool end() const { return curHead().empty(); }
+
+ ParserHead copy() const { return *this; }
+
+private:
+ StringRef curHead() const { return head.drop_front(offset); }
+
+ FailureOr<std::byte> peek() const {
+ if (end())
+ return emitError(
+ getLocation(),
+ "trying to peek at next byte, but input stream is empty");
+ return static_cast<std::byte>(curHead().front());
+ }
+
+ size_t size() const { return head.size() - offset; }
+
+ StringRef head;
+ StringAttr locName;
+ unsigned anchorOffset{0};
+ unsigned offset{0};
+};
+
+template <>
+FailureOr<float> ParserHead::parseLiteral<float>() {
+ FailureOr<StringRef> bytes = consumeNBytes(4);
+ if (failed(bytes))
+ return failure();
+ return llvm::support::endian::read<float>(bytes->bytes_begin(),
+ llvm::endianness::little);
+}
+
+template <>
+FailureOr<double> ParserHead::parseLiteral<double>() {
+ FailureOr<StringRef> bytes = consumeNBytes(8);
+ if (failed(bytes))
+ return failure();
+ return llvm::support::endian::read<double>(bytes->bytes_begin(),
+ llvm::endianness::little);
+}
+
+template <>
+FailureOr<uint32_t> ParserHead::parseLiteral<uint32_t>() {
+ char const *error = nullptr;
+ uint32_t res{0};
+ unsigned encodingSize{0};
+ StringRef src = curHead();
+ uint64_t decoded = llvm::decodeULEB128(src.bytes_begin(), &encodingSize,
+ src.bytes_end(), &error);
+ if (error)
+ return emitError(getLocation(), error);
+
+ if (std::isgreater(decoded, std::numeric_limits<uint32_t>::max()))
+ return emitError(getLocation()) << "literal does not fit on 32 bits";
+
+ res = static_cast<uint32_t>(decoded);
+ offset += encodingSize;
+ return res;
+}
+
+template <>
+FailureOr<int32_t> ParserHead::parseLiteral<int32_t>() {
+ char const *error = nullptr;
+ int32_t res{0};
+ unsigned encodingSize{0};
+ StringRef src = curHead();
+ int64_t decoded = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize,
+ src.bytes_end(), &error);
+ if (error)
+ return emitError(getLocation(), error);
+ if (std::isgreater(decoded, std::numeric_limits<int32_t>::max()) ||
+ std::isgreater(std::numeric_limits<int32_t>::min(), decoded))
+ return emitError(getLocation()) << "literal does not fit on 32 bits";
+
+ res = static_cast<int32_t>(decoded);
+ offset += encodingSize;
+ return res;
+}
+
+template <>
+FailureOr<int64_t> ParserHead::parseLiteral<int64_t>() {
+ char const *error = nullptr;
+ unsigned encodingSize{0};
+ StringRef src = curHead();
+ int64_t res = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize,
+ src.bytes_end(), &error);
+ if (error)
+ return emitError(getLocation(), error);
+
+ offset += encodingSize;
+ return res;
+}
+
+FailureOr<uint32_t> ParserHead::parseVectorSize() {
+ return parseLiteral<uint32_t>();
+}
+
+inline FailureOr<uint32_t> ParserHead::parseUI32() {
+ return parseLiteral<uint32_t>();
+}
+
+inline FailureOr<int64_t> ParserHead::parseI64() {
+ return parseLiteral<int64_t>();
+}
+
+template <std::byte opCode>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction(OpBuilder &) {
+ return emitError(*currentOpLoc, "unknown instruction opcode: ")
+ << static_cast<int>(opCode);
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+void ValueStack::dump() const {
+ llvm::dbgs() << "================= Wasm ValueStack =======================\n";
+ llvm::dbgs() << "size: " << size() << "\n";
+ llvm::dbgs() << "<Top>"
+ << "\n";
+ // Stack is pushed to via push_back. Therefore the top of the stack is the
+ // end of the vector. Iterate in reverse so that the first thing we print
+ // is the top of the stack.
+ size_t stackSize = size();
+ for (size_t idx = 0; idx < stackSize; idx++) {
+ size_t actualIdx = stackSize - 1 - idx;
+ llvm::dbgs() << " ";
+ values[actualIdx].dump();
+ }
+ llvm::dbgs() << "<Bottom>"
+ << "\n";
+ llvm::dbgs() << "=========================================================\n";
+}
+#endif
+
+parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) {
+ LDBG() << "Popping from ValueStack\n"
+ << " Elements(s) to pop: " << operandTypes.size() << "\n"
+ << " Current stack size: " << values.size();
+ if (operandTypes.size() > values.size())
+ return emitError(*opLoc,
+ "stack doesn't contain enough values. trying to get ")
+ << operandTypes.size() << " operands on a stack containing only "
+ << values.size() << " values.";
+ size_t stackIdxOffset = values.size() - operandTypes.size();
+ SmallVector<Value> res{};
+ res.reserve(operandTypes.size());
+ for (size_t i{0}; i < operandTypes.size(); ++i) {
+ Value operand = values[i + stackIdxOffset];
+ Type stackType = operand.getType();
+ if (stackType != operandTypes[i])
+ return emitError(*opLoc, "invalid operand type on stack. expecting ")
+ << operandTypes[i] << ", value on stack is of type " << stackType
+ << ".";
+ LDBG() << " POP: " << operand;
+ res.push_back(operand);
+ }
+ values.resize(values.size() - operandTypes.size());
+ LDBG() << " Updated stack size: " << values.size();
+ return res;
+}
+
+LogicalResult ValueStack::pushResults(ValueRange results, Location *opLoc) {
+ LDBG() << "Pushing to ValueStack\n"
+ << " Elements(s) to push: " << results.size() << "\n"
+ << " Current stack size: " << values.size();
+ for (Value val : results) {
+ if (!isWasmValueType(val.getType()))
+ return emitError(*opLoc, "invalid value type on stack: ")
+ << val.getType();
+ LDBG() << " PUSH: " << val;
+ values.push_back(val);
+ }
+
+ LDBG() << " Updated stack size: " << values.size();
+ return success();
+}
+
+template <std::byte EndParseByte>
+parsed_inst_t ExpressionParser::parse(OpBuilder &builder,
+ UniqueByte<EndParseByte> endByte) {
+ auto res = parse(builder, ByteSequence<EndParseByte>{});
+ if (failed(res))
+ return failure();
+ return res->opResults;
+}
+
+template <std::byte... ExpressionParseEnd>
+FailureOr<ExpressionParser::ParseResultWithInfo>
+ExpressionParser::parse(OpBuilder &builder,
+ ByteSequence<ExpressionParseEnd...> parsingEndFilters) {
+ SmallVector<Value> res;
+ for (;;) {
+ currentOpLoc = parser.getLocation();
+ FailureOr<std::byte> opCode = parser.consumeByte();
+ if (failed(opCode))
+ return failure();
+ if (isValueOneOf(*opCode, parsingEndFilters))
+ return {{res, *opCode}};
+ parsed_inst_t resParsed;
+ resParsed = dispatchToInstParser(*opCode, builder);
+ if (failed(resParsed))
+ return failure();
+ std::swap(res, *resParsed);
+ if (failed(pushResults(res)))
+ return failure();
+ }
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::localGet>(OpBuilder &builder) {
+ FailureOr<uint32_t> id = parser.parseLiteral<uint32_t>();
+ Location instLoc = *currentOpLoc;
+ if (failed(id))
+ return failure();
+ if (*id >= locals.size())
+ return emitError(instLoc, "invalid local index. function has ")
+ << locals.size() << " accessible locals, received index " << *id;
+ return {{builder.create<LocalGetOp>(instLoc, locals[*id]).getResult()}};
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::globalGet>(OpBuilder &builder) {
+ FailureOr<uint32_t> id = parser.parseLiteral<uint32_t>();
+ Location instLoc = *currentOpLoc;
+ if (failed(id))
+ return failure();
+ if (*id >= symbols.globalSymbols.size())
+ return emitError(instLoc, "invalid global index. function has ")
+ << symbols.globalSymbols.size()
+ << " accessible globals, received index " << *id;
+ GlobalSymbolRefContainer globalVar = symbols.globalSymbols[*id];
+ auto globalOp = builder.create<GlobalGetOp>(instLoc, globalVar.globalType,
+ globalVar.symbol);
+
+ return {{globalOp.getResult()}};
+}
+
+template <typename OpToCreate>
+parsed_inst_t ExpressionParser::parseSetOrTee(OpBuilder &builder) {
+ FailureOr<uint32_t> id = parser.parseLiteral<uint32_t>();
+ if (failed(id))
+ return failure();
+ if (*id >= locals.size())
+ return emitError(*currentOpLoc, "invalid local index. function has ")
+ << locals.size() << " accessible locals, received index " << *id;
+ if (valueStack.empty())
+ return emitError(
+ *currentOpLoc,
+ "invalid stack access, trying to access a value on an empty stack.");
+
+ parsed_inst_t poppedOp = popOperands(locals[*id].getType().getElementType());
+ if (failed(poppedOp))
+ return failure();
+ return {
+ builder.create<OpToCreate>(*currentOpLoc, locals[*id], poppedOp->front())
+ ->getResults()};
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::localSet>(OpBuilder &builder) {
+ return parseSetOrTee<LocalSetOp>(builder);
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::localTee>(OpBuilder &builder) {
+ return parseSetOrTee<LocalTeeOp>(builder);
+}
+
+template <typename T>
+inline Type buildLiteralType(OpBuilder &);
+
+template <>
+inline Type buildLiteralType<int32_t>(OpBuilder &builder) {
+ return builder.getI32Type();
+}
+
+template <>
+inline Type buildLiteralType<int64_t>(OpBuilder &builder) {
+ return builder.getI64Type();
+}
+
+template <>
+[[maybe_unused]] inline Type buildLiteralType<uint32_t>(OpBuilder &builder) {
+ return builder.getI32Type();
+}
+
+template <>
+[[maybe_unused]] inline Type buildLiteralType<uint64_t>(OpBuilder &builder) {
+ return builder.getI64Type();
+}
+
+template <>
+inline Type buildLiteralType<float>(OpBuilder &builder) {
+ return builder.getF32Type();
+}
+
+template <>
+inline Type buildLiteralType<double>(OpBuilder &builder) {
+ return builder.getF64Type();
+}
+
+template <typename ValT,
+ typename E = std::enable_if_t<std::is_arithmetic_v<ValT>>>
+struct AttrHolder;
+
+template <typename ValT>
+struct AttrHolder<ValT, std::enable_if_t<std::is_integral_v<ValT>>> {
+ using type = IntegerAttr;
+};
+
+template <typename ValT>
+struct AttrHolder<ValT, std::enable_if_t<std::is_floating_point_v<ValT>>> {
+ using type = FloatAttr;
+};
+
+template <typename ValT>
+using attr_holder_t = typename AttrHolder<ValT>::type;
+
+template <typename ValT,
+ typename EnableT = std::enable_if_t<std::is_arithmetic_v<ValT>>>
+attr_holder_t<ValT> buildLiteralAttr(OpBuilder &builder, ValT val) {
+ return attr_holder_t<ValT>::get(buildLiteralType<ValT>(builder), val);
+}
+
+template <typename valueT>
+parsed_inst_t ExpressionParser::parseConstInst(
+ OpBuilder &builder, std::enable_if_t<std::is_arithmetic_v<valueT>> *) {
+ auto parsedConstant = parser.parseLiteral<valueT>();
+ if (failed(parsedConstant))
+ return failure();
+ auto constOp =
+ ConstOp::create(builder, *currentOpLoc,
+ buildLiteralAttr<valueT>(builder, *parsedConstant));
+ return {{constOp.getResult()}};
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::constI32>(OpBuilder &builder) {
+ return parseConstInst<int32_t>(builder);
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::constI64>(OpBuilder &builder) {
+ return parseConstInst<int64_t>(builder);
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::constFP32>(OpBuilder &builder) {
+ return parseConstInst<float>(builder);
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::constFP64>(OpBuilder &builder) {
+ return parseConstInst<double>(builder);
+}
+
+template <typename opcode, typename valueType, unsigned int numOperands>
+inline parsed_inst_t ExpressionParser::buildNumericOp(
+ OpBuilder &builder, std::enable_if_t<std::is_arithmetic_v<valueType>> *) {
+ auto ty = buildLiteralType<valueType>(builder);
+ LDBG() << "*** buildNumericOp: numOperands = " << numOperands
+ << ", type = " << ty << " ***";
+ auto tysToPop = SmallVector<Type, numOperands>();
+ tysToPop.resize(numOperands);
+ std::fill(tysToPop.begin(), tysToPop.end(), ty);
+ auto operands = popOperands(tysToPop);
+ if (failed(operands))
+ return failure();
+ auto op = builder.create<opcode>(*currentOpLoc, *operands).getResult();
+ LDBG() << "Built operation: " << op;
+ return {{op}};
+}
+
+// Convenience macro for generating numerical operations.
+#define BUILD_NUMERIC_OP(OP_NAME, N_ARGS, PREFIX, SUFFIX, TYPE) \
+ template <> \
+ inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \
+ WasmBinaryEncoding::OpCode::PREFIX##SUFFIX>(OpBuilder & builder) { \
+ return buildNumericOp<OP_NAME, TYPE, N_ARGS>(builder); \
+ }
+
+// Macro to define binops that only support integer types.
+#define BUILD_NUMERIC_BINOP_INT(OP_NAME, PREFIX) \
+ BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, I32, int32_t) \
+ BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, I64, int64_t)
+
+// Macro to define binops that only support floating point types.
+#define BUILD_NUMERIC_BINOP_FP(OP_NAME, PREFIX) \
+ BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, F32, float) \
+ BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, F64, double)
+
+// Macro to define binops that support both floating point and integer types.
+#define BUILD_NUMERIC_BINOP_INTFP(OP_NAME, PREFIX) \
+ BUILD_NUMERIC_BINOP_INT(OP_NAME, PREFIX) \
+ BUILD_NUMERIC_BINOP_FP(OP_NAME, PREFIX)
+
+// Macro to implement unary ops that only support integers.
+#define BUILD_NUMERIC_UNARY_OP_INT(OP_NAME, PREFIX) \
+ BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, I32, int32_t) \
+ BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, I64, int64_t)
+
+// Macro to implement unary ops that support integer and floating point types.
+#define BUILD_NUMERIC_UNARY_OP_FP(OP_NAME, PREFIX) \
+ BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, F32, float) \
+ BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, F64, double)
+
+BUILD_NUMERIC_BINOP_FP(CopySignOp, copysign)
+BUILD_NUMERIC_BINOP_FP(DivOp, div)
+BUILD_NUMERIC_BINOP_FP(MaxOp, max)
+BUILD_NUMERIC_BINOP_FP(MinOp, min)
+BUILD_NUMERIC_BINOP_INT(AndOp, and)
+BUILD_NUMERIC_BINOP_INT(DivSIOp, divS)
+BUILD_NUMERIC_BINOP_INT(DivUIOp, divU)
+BUILD_NUMERIC_BINOP_INT(OrOp, or)
+BUILD_NUMERIC_BINOP_INT(RemSIOp, remS)
+BUILD_NUMERIC_BINOP_INT(RemUIOp, remU)
+BUILD_NUMERIC_BINOP_INT(RotlOp, rotl)
+BUILD_NUMERIC_BINOP_INT(RotrOp, rotr)
+BUILD_NUMERIC_BINOP_INT(ShLOp, shl)
+BUILD_NUMERIC_BINOP_INT(ShRSOp, shrS)
+BUILD_NUMERIC_BINOP_INT(ShRUOp, shrU)
+BUILD_NUMERIC_BINOP_INT(XOrOp, xor)
+BUILD_NUMERIC_BINOP_INTFP(AddOp, add)
+BUILD_NUMERIC_BINOP_INTFP(MulOp, mul)
+BUILD_NUMERIC_BINOP_INTFP(SubOp, sub)
+BUILD_NUMERIC_UNARY_OP_FP(AbsOp, abs)
+BUILD_NUMERIC_UNARY_OP_FP(CeilOp, ceil)
+BUILD_NUMERIC_UNARY_OP_FP(FloorOp, floor)
+BUILD_NUMERIC_UNARY_OP_FP(NegOp, neg)
+BUILD_NUMERIC_UNARY_OP_FP(SqrtOp, sqrt)
+BUILD_NUMERIC_UNARY_OP_FP(TruncOp, trunc)
+BUILD_NUMERIC_UNARY_OP_INT(ClzOp, clz)
+BUILD_NUMERIC_UNARY_OP_INT(CtzOp, ctz)
+BUILD_NUMERIC_UNARY_OP_INT(PopCntOp, popcnt)
+
+// Don't need these anymore so let's undef them.
+#undef BUILD_NUMERIC_BINOP_FP
+#undef BUILD_NUMERIC_BINOP_INT
+#undef BUILD_NUMERIC_BINOP_INTFP
+#undef BUILD_NUMERIC_UNARY_OP_FP
+#undef BUILD_NUMERIC_UNARY_OP_INT
+#undef BUILD_NUMERIC_OP
+#undef BUILD_NUMERIC_CAST_OP
+
+class WasmBinaryParser {
+private:
+ struct SectionRegistry {
+ using section_location_t = StringRef;
+
+ std::array<SmallVector<section_location_t>, highestWasmSectionID + 1>
+ registry;
+
+ template <WasmSectionType SecType>
+ std::conditional_t<sectionShouldBeUnique(SecType),
+ std::optional<section_location_t>,
+ ArrayRef<section_location_t>>
+ getContentForSection() const {
+ constexpr auto idx = static_cast<size_t>(SecType);
+ if constexpr (sectionShouldBeUnique(SecType)) {
+ return registry[idx].empty() ? std::nullopt
+ : std::make_optional(registry[idx][0]);
+ } else {
+ return registry[idx];
+ }
+ }
+
+ bool hasSection(WasmSectionType secType) const {
+ return !registry[static_cast<size_t>(secType)].empty();
+ }
+
+ ///
+ /// @returns success if registration valid, failure in case registration
+ /// can't be done (if another section of same type already exist and this
+ /// section type should only be present once)
+ ///
+ LogicalResult registerSection(WasmSectionType secType,
+ section_location_t location, Location loc) {
+ if (sectionShouldBeUnique(secType) && hasSection(secType))
+ return emitError(loc,
+ "trying to add a second instance of unique section");
+
+ registry[static_cast<size_t>(secType)].push_back(location);
+ emitRemark(loc, "Adding section with section ID ")
+ << static_cast<uint8_t>(secType);
+ return success();
+ }
+
+ LogicalResult populateFromBody(ParserHead ph) {
+ while (!ph.end()) {
+ FileLineColLoc sectionLoc = ph.getLocation();
+ FailureOr<WasmSectionType> secType = ph.parseWasmSectionType();
+ if (failed(secType))
+ return failure();
+
+ FailureOr<uint32_t> secSizeParsed = ph.parseLiteral<uint32_t>();
+ if (failed(secSizeParsed))
+ return failure();
+
+ uint32_t secSize = *secSizeParsed;
+ FailureOr<StringRef> sectionContent = ph.consumeNBytes(secSize);
+ if (failed(sectionContent))
+ return failure();
+
+ LogicalResult registration =
+ registerSection(*secType, *sectionContent, sectionLoc);
+
+ if (failed(registration))
+ return failure();
+ }
+ return success();
+ }
+ };
+
+ auto getLocation(int offset = 0) const {
+ return FileLineColLoc::get(srcName, 0, offset);
+ }
+
+ template <WasmSectionType>
+ LogicalResult parseSectionItem(ParserHead &, size_t);
+
+ template <WasmSectionType section>
+ LogicalResult parseSection() {
+ auto secName = std::string{wasmSectionName<section>};
+ auto sectionNameAttr =
+ StringAttr::get(ctx, srcName.strref() + ":" + secName + "-SECTION");
+ unsigned offset = 0;
+ auto getLocation = [sectionNameAttr, &offset]() {
+ return FileLineColLoc::get(sectionNameAttr, 0, offset);
+ };
+ auto secContent = registry.getContentForSection<section>();
+ if (!secContent) {
+ LDBG() << secName << " section is not present in file.";
+ return success();
+ }
+
+ auto secSrc = secContent.value();
+ ParserHead ph{secSrc, sectionNameAttr};
+ FailureOr<uint32_t> nElemsParsed = ph.parseVectorSize();
+ if (failed(nElemsParsed))
+ return failure();
+ uint32_t nElems = *nElemsParsed;
+ LDBG() << "starting to parse " << nElems << " items for section "
+ << secName;
+ for (size_t i = 0; i < nElems; ++i) {
+ if (failed(parseSectionItem<section>(ph, i)))
+ return failure();
+ }
+
+ if (!ph.end())
+ return emitError(getLocation(), "unparsed garbage at end of section ")
+ << secName;
+ return success();
+ }
+
+ /// Handles the registration of a function import
+ LogicalResult visitImport(Location loc, StringRef moduleName,
+ StringRef importName, TypeIdxRecord tid) {
+ using llvm::Twine;
+ if (tid.id >= symbols.moduleFuncTypes.size())
+ return emitError(loc, "invalid type id: ")
+ << tid.id << ". Only " << symbols.moduleFuncTypes.size()
+ << " type registration.";
+ FunctionType type = symbols.moduleFuncTypes[tid.id];
+ std::string symbol = symbols.getNewFuncSymbolName();
+ auto funcOp = FuncImportOp::create(builder, loc, symbol, moduleName,
+ importName, type);
+ symbols.funcSymbols.push_back({{FlatSymbolRefAttr::get(funcOp)}, type});
+ return funcOp.verify();
+ }
+
+ /// Handles the registration of a memory import
+ LogicalResult visitImport(Location loc, StringRef moduleName,
+ StringRef importName, LimitType limitType) {
+ std::string symbol = symbols.getNewMemorySymbolName();
+ auto memOp = MemImportOp::create(builder, loc, symbol, moduleName,
+ importName, limitType);
+ symbols.memSymbols.push_back({FlatSymbolRefAttr::get(memOp)});
+ return memOp.verify();
+ }
+
+ /// Handles the registration of a table import
+ LogicalResult visitImport(Location loc, StringRef moduleName,
+ StringRef importName, TableType tableType) {
+ std::string symbol = symbols.getNewTableSymbolName();
+ auto tableOp = TableImportOp::create(builder, loc, symbol, moduleName,
+ importName, tableType);
+ symbols.tableSymbols.push_back({FlatSymbolRefAttr::get(tableOp)});
+ return tableOp.verify();
+ }
+
+ /// Handles the registration of a global variable import
+ LogicalResult visitImport(Location loc, StringRef moduleName,
+ StringRef importName, GlobalTypeRecord globalType) {
+ std::string symbol = symbols.getNewGlobalSymbolName();
+ auto giOp =
+ GlobalImportOp::create(builder, loc, symbol, moduleName, importName,
+ globalType.type, globalType.isMutable);
+ symbols.globalSymbols.push_back(
+ {{FlatSymbolRefAttr::get(giOp)}, giOp.getType()});
+ return giOp.verify();
+ }
+
+ // Detect occurence of errors
+ LogicalResult peekDiag(Diagnostic &diag) {
+ if (diag.getSeverity() == DiagnosticSeverity::Error)
+ isValid = false;
+ return failure();
+ }
+
+public:
+ WasmBinaryParser(llvm::SourceMgr &sourceMgr, MLIRContext *ctx)
+ : builder{ctx}, ctx{ctx} {
+ ctx->getDiagEngine().registerHandler(
+ [this](Diagnostic &diag) { return peekDiag(diag); });
+ ctx->loadAllAvailableDialects();
+ if (sourceMgr.getNumBuffers() != 1) {
+ emitError(UnknownLoc::get(ctx), "one source file should be provided");
+ return;
+ }
+ uint32_t sourceBufId = sourceMgr.getMainFileID();
+ StringRef source = sourceMgr.getMemoryBuffer(sourceBufId)->getBuffer();
+ srcName = StringAttr::get(
+ ctx, sourceMgr.getMemoryBuffer(sourceBufId)->getBufferIdentifier());
+
+ auto parser = ParserHead{source, srcName};
+ auto const wasmHeader = StringRef{"\0asm", 4};
+ FileLineColLoc magicLoc = parser.getLocation();
+ FailureOr<StringRef> magic = parser.consumeNBytes(wasmHeader.size());
+ if (failed(magic) || magic->compare(wasmHeader)) {
+ emitError(magicLoc, "source file does not contain valid Wasm header.");
+ return;
+ }
+ auto const expectedVersionString = StringRef{"\1\0\0\0", 4};
+ FileLineColLoc versionLoc = parser.getLocation();
+ FailureOr<StringRef> version =
+ parser.consumeNBytes(expectedVersionString.size());
+ if (failed(version))
+ return;
+ if (version->compare(expectedVersionString)) {
+ emitError(versionLoc,
+ "unsupported Wasm version. only version 1 is supported");
+ return;
+ }
+ LogicalResult fillRegistry = registry.populateFromBody(parser.copy());
+ if (failed(fillRegistry))
+ return;
+
+ mOp = ModuleOp::create(builder, getLocation());
+ builder.setInsertionPointToStart(&mOp.getBodyRegion().front());
+ LogicalResult parsingTypes = parseSection<WasmSectionType::TYPE>();
+ if (failed(parsingTypes))
+ return;
+
+ LogicalResult parsingImports = parseSection<WasmSectionType::IMPORT>();
+ if (failed(parsingImports))
+ return;
+
+ firstInternalFuncID = symbols.funcSymbols.size();
+
+ LogicalResult parsingFunctions = parseSection<WasmSectionType::FUNCTION>();
+ if (failed(parsingFunctions))
+ return;
+
+ LogicalResult parsingTables = parseSection<WasmSectionType::TABLE>();
+ if (failed(parsingTables))
+ return;
+
+ LogicalResult parsingMems = parseSection<WasmSectionType::MEMORY>();
+ if (failed(parsingMems))
+ return;
+
+ LogicalResult parsingGlobals = parseSection<WasmSectionType::GLOBAL>();
+ if (failed(parsingGlobals))
+ return;
+
+ LogicalResult parsingCode = parseSection<WasmSectionType::CODE>();
+ if (failed(parsingCode))
+ return;
+
+ LogicalResult parsingExports = parseSection<WasmSectionType::EXPORT>();
+ if (failed(parsingExports))
+ return;
+
+ // Copy over sizes of containers into statistics.
+ LDBG() << "WASM Imports:"
+ << "\n"
+ << " - Num functions: " << symbols.funcSymbols.size() << "\n"
+ << " - Num globals: " << symbols.globalSymbols.size() << "\n"
+ << " - Num memories: " << symbols.memSymbols.size() << "\n"
+ << " - Num tables: " << symbols.tableSymbols.size();
+ }
+
+ ModuleOp getModule() {
+ if (isValid)
+ return mOp;
+ if (mOp)
+ mOp.erase();
+ return ModuleOp{};
+ }
+
+private:
+ mlir::StringAttr srcName;
+ OpBuilder builder;
+ WasmModuleSymbolTables symbols;
+ MLIRContext *ctx;
+ ModuleOp mOp;
+ SectionRegistry registry;
+ size_t firstInternalFuncID{0};
+ bool isValid{true};
+};
+
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::IMPORT>(ParserHead &ph,
+ size_t) {
+ FileLineColLoc importLoc = ph.getLocation();
+ auto moduleName = ph.parseName();
+ if (failed(moduleName))
+ return failure();
+
+ auto importName = ph.parseName();
+ if (failed(importName))
+ return failure();
+
+ FailureOr<ImportDesc> import = ph.parseImportDesc(ctx);
+ if (failed(import))
+ return failure();
+
+ return std::visit(
+ [this, importLoc, &moduleName, &importName](auto import) {
+ return visitImport(importLoc, *moduleName, *importName, import);
+ },
+ *import);
+}
+
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::EXPORT>(ParserHead &ph,
+ size_t) {
+ FileLineColLoc exportLoc = ph.getLocation();
+
+ auto exportName = ph.parseName();
+ if (failed(exportName))
+ return failure();
+
+ FailureOr<std::byte> opcode = ph.consumeByte();
+ if (failed(opcode))
+ return failure();
+
+ FailureOr<uint32_t> idx = ph.parseLiteral<uint32_t>();
+ if (failed(idx))
+ return failure();
+
+ using SymbolRefDesc = std::variant<SmallVector<SymbolRefContainer>,
+ SmallVector<GlobalSymbolRefContainer>,
+ SmallVector<FunctionSymbolRefContainer>>;
+
+ SymbolRefDesc currentSymbolList;
+ std::string symbolType = "";
+ switch (*opcode) {
+ case WasmBinaryEncoding::Export::function:
+ symbolType = "function";
+ currentSymbolList = symbols.funcSymbols;
+ break;
+ case WasmBinaryEncoding::Export::table:
+ symbolType = "table";
+ currentSymbolList = symbols.tableSymbols;
+ break;
+ case WasmBinaryEncoding::Export::memory:
+ symbolType = "memory";
+ currentSymbolList = symbols.memSymbols;
+ break;
+ case WasmBinaryEncoding::Export::global:
+ symbolType = "global";
+ currentSymbolList = symbols.globalSymbols;
+ break;
+ default:
+ return emitError(exportLoc, "invalid value for export type: ")
+ << std::to_integer<unsigned>(*opcode);
+ }
+
+ auto currentSymbol = std::visit(
+ [&](const auto &list) -> FailureOr<FlatSymbolRefAttr> {
+ if (*idx > list.size()) {
+ emitError(
+ exportLoc,
+ llvm::formatv(
+ "trying to export {0} {1} which is undefined in this scope",
+ symbolType, *idx));
+ return failure();
+ }
+ return list[*idx].symbol;
+ },
+ currentSymbolList);
+
+ if (failed(currentSymbol))
+ return failure();
+
+ Operation *op = SymbolTable::lookupSymbolIn(mOp, *currentSymbol);
+ SymbolTable::setSymbolVisibility(op, SymbolTable::Visibility::Public);
+ StringAttr symName = SymbolTable::getSymbolName(op);
+ return SymbolTable{mOp}.rename(symName, *exportName);
+}
+
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::TABLE>(ParserHead &ph,
+ size_t) {
+ FileLineColLoc opLocation = ph.getLocation();
+ FailureOr<TableType> tableType = ph.parseTableType(ctx);
+ if (failed(tableType))
+ return failure();
+ LDBG() << " Parsed table description: " << *tableType;
+ StringAttr symbol = builder.getStringAttr(symbols.getNewTableSymbolName());
+ auto tableOp =
+ TableOp::create(builder, opLocation, symbol.strref(), *tableType);
+ symbols.tableSymbols.push_back({SymbolRefAttr::get(tableOp)});
+ return success();
+}
+
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::FUNCTION>(ParserHead &ph,
+ size_t) {
+ FileLineColLoc opLoc = ph.getLocation();
+ auto typeIdxParsed = ph.parseLiteral<uint32_t>();
+ if (failed(typeIdxParsed))
+ return failure();
+ uint32_t typeIdx = *typeIdxParsed;
+ if (typeIdx >= symbols.moduleFuncTypes.size())
+ return emitError(getLocation(), "invalid type index: ") << typeIdx;
+ std::string symbol = symbols.getNewFuncSymbolName();
+ auto funcOp =
+ FuncOp::create(builder, opLoc, symbol, symbols.moduleFuncTypes[typeIdx]);
+ Block *block = funcOp.addEntryBlock();
+ OpBuilder::InsertionGuard guard{builder};
+ builder.setInsertionPointToEnd(block);
+ ReturnOp::create(builder, opLoc);
+ symbols.funcSymbols.push_back(
+ {{FlatSymbolRefAttr::get(funcOp.getSymNameAttr())},
+ symbols.moduleFuncTypes[typeIdx]});
+ return funcOp.verify();
+}
+
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::TYPE>(ParserHead &ph,
+ size_t) {
+ FailureOr<FunctionType> funcType = ph.parseFunctionType(ctx);
+ if (failed(funcType))
+ return failure();
+ LDBG() << "Parsed function type " << *funcType;
+ symbols.moduleFuncTypes.push_back(*funcType);
+ return success();
+}
+
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::MEMORY>(ParserHead &ph,
+ size_t) {
+ FileLineColLoc opLocation = ph.getLocation();
+ FailureOr<LimitType> memory = ph.parseLimit(ctx);
+ if (failed(memory))
+ return failure();
+
+ LDBG() << " Registering memory " << *memory;
+ std::string symbol = symbols.getNewMemorySymbolName();
+ auto memOp = MemOp::create(builder, opLocation, symbol, *memory);
+ symbols.memSymbols.push_back({SymbolRefAttr::get(memOp)});
+ return success();
+}
+
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::GLOBAL>(ParserHead &ph,
+ size_t) {
+ FileLineColLoc globalLocation = ph.getLocation();
+ auto globalTypeParsed = ph.parseGlobalType(ctx);
+ if (failed(globalTypeParsed))
+ return failure();
+
+ GlobalTypeRecord globalType = *globalTypeParsed;
+ auto symbol = builder.getStringAttr(symbols.getNewGlobalSymbolName());
+ auto globalOp = builder.create<wasmssa::GlobalOp>(
+ globalLocation, symbol, globalType.type, globalType.isMutable);
+ symbols.globalSymbols.push_back(
+ {{FlatSymbolRefAttr::get(globalOp)}, globalOp.getType()});
+ OpBuilder::InsertionGuard guard{builder};
+ Block *block = builder.createBlock(&globalOp.getInitializer());
+ builder.setInsertionPointToStart(block);
+ parsed_inst_t expr = ph.parseExpression(builder, symbols);
+ if (failed(expr))
+ return failure();
+ if (block->empty())
+ return emitError(globalLocation, "global with empty initializer");
+ if (expr->size() != 1 && (*expr)[0].getType() != globalType.type)
+ return emitError(
+ globalLocation,
+ "initializer result type does not match global declaration type");
+ builder.create<ReturnOp>(globalLocation, *expr);
+ return success();
+}
+
+template <>
+LogicalResult WasmBinaryParser::parseSectionItem<WasmSectionType::CODE>(
+ ParserHead &ph, size_t innerFunctionId) {
+ unsigned long funcId = innerFunctionId + firstInternalFuncID;
+ FunctionSymbolRefContainer symRef = symbols.funcSymbols[funcId];
+ auto funcOp =
+ dyn_cast<FuncOp>(SymbolTable::lookupSymbolIn(mOp, symRef.symbol));
+ assert(funcOp);
+ if (failed(ph.parseCodeFor(funcOp, symbols)))
+ return failure();
+ return success();
+}
+} // namespace
+
+namespace mlir::wasm {
+OwningOpRef<ModuleOp> importWebAssemblyToModule(llvm::SourceMgr &source,
+ MLIRContext *context) {
+ WasmBinaryParser wBN{source, context};
+ ModuleOp mOp = wBN.getModule();
+ if (mOp)
+ return {mOp};
+
+ return {nullptr};
+}
+} // namespace mlir::wasm
diff --git a/mlir/lib/Target/Wasm/TranslateRegistration.cpp b/mlir/lib/Target/Wasm/TranslateRegistration.cpp
new file mode 100644
index 0000000..03b9784
--- /dev/null
+++ b/mlir/lib/Target/Wasm/TranslateRegistration.cpp
@@ -0,0 +1,28 @@
+//===- TranslateRegistration.cpp - Register translation -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/OwningOpRef.h"
+#include "mlir/Target/Wasm/WasmImporter.h"
+#include "mlir/Tools/mlir-translate/Translation.h"
+
+using namespace mlir;
+
+namespace mlir {
+void registerFromWasmTranslation() {
+ TranslateToMLIRRegistration registration{
+ "import-wasm", "Translate WASM to MLIR",
+ [](llvm::SourceMgr &sourceMgr,
+ MLIRContext *context) -> OwningOpRef<Operation *> {
+ return wasm::importWebAssemblyToModule(sourceMgr, context);
+ },
+ [](DialectRegistry &registry) {
+ registry.insert<wasmssa::WasmSSADialect>();
+ }};
+}
+} // namespace mlir
diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index 51e702a..c883baa 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -147,8 +147,9 @@ private:
std::string docStr;
{
llvm::raw_string_ostream docOS(docStr);
+ std::string tmpDocStr = doc.str();
raw_indented_ostream(docOS).printReindented(
- StringRef(docStr).rtrim(" \t"));
+ StringRef(tmpDocStr).rtrim(" \t"));
}
return docStr;
}
diff --git a/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp b/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp
index 9950050..6945c09 100644
--- a/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp
+++ b/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp
@@ -21,6 +21,7 @@
#include "llvm/LineEditor/LineEditor.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/Process.h"
#include "llvm/Support/SourceMgr.h"
//===----------------------------------------------------------------------===//
@@ -43,7 +44,7 @@ mlir::mlirQueryMain(int argc, char **argv, MLIRContext &context,
llvm::cl::value_desc("command"), llvm::cl::cat(mlirQueryCategory));
static llvm::cl::opt<std::string> inputFilename(
- llvm::cl::Positional, llvm::cl::desc("<input file>"),
+ llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"),
llvm::cl::cat(mlirQueryCategory));
static llvm::cl::opt<bool> noImplicitModule{
@@ -68,6 +69,14 @@ mlir::mlirQueryMain(int argc, char **argv, MLIRContext &context,
return mlir::success();
}
+ // When reading from stdin and the input is a tty, it is often a user mistake
+ // and the process "appears to be stuck". Print a message to let the user
+ // know!
+ if (inputFilename == "-" &&
+ llvm::sys::Process::FileDescriptorIsDisplayed(fileno(stdin)))
+ llvm::errs() << "(processing input from stdin now, hit ctrl-c/ctrl-d to "
+ "interrupt)\n";
+
// Set up the input file.
std::string errorMessage;
auto file = openInputFile(inputFilename, &errorMessage);
diff --git a/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp b/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp
index e89d392..34459b8 100644
--- a/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp
+++ b/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp
@@ -26,9 +26,9 @@
using namespace mlir;
// Parse and verify the input MLIR file. Returns null on error.
-OwningOpRef<Operation *> loadModule(MLIRContext &context,
- StringRef inputFilename,
- bool insertImplictModule) {
+static OwningOpRef<Operation *> loadModule(MLIRContext &context,
+ StringRef inputFilename,
+ bool insertImplictModule) {
// Set up the input file.
std::string errorMessage;
auto file = openInputFile(inputFilename, &errorMessage);
@@ -65,6 +65,11 @@ LogicalResult mlir::mlirReduceMain(int argc, char **argv,
"Disable implicit addition of a top-level module op during parsing"),
llvm::cl::init(false)};
+ static llvm::cl::opt<bool> allowUnregisteredDialects(
+ "allow-unregistered-dialect",
+ llvm::cl::desc("Allow operation with no registered dialects"),
+ llvm::cl::init(false));
+
llvm::cl::HideUnrelatedOptions(mlirReduceCategory);
llvm::InitLLVM y(argc, argv);
@@ -79,6 +84,8 @@ LogicalResult mlir::mlirReduceMain(int argc, char **argv,
llvm::cl::PrintHelpMessage();
return success();
}
+ if (allowUnregisteredDialects)
+ context.allowUnregisteredDialects();
std::string errorMessage;
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 09e5a02..8eaac30 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -177,11 +177,10 @@ void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
Operation *toOp) {
assert(fromOp->getBlock() == toOp->getBlock());
- assert(
- isa<MemoryEffectOpInterface>(fromOp) &&
- cast<MemoryEffectOpInterface>(fromOp).hasEffect<MemoryEffects::Read>() &&
- isa<MemoryEffectOpInterface>(toOp) &&
- cast<MemoryEffectOpInterface>(toOp).hasEffect<MemoryEffects::Read>());
+ assert(hasEffect<MemoryEffects::Read>(fromOp) &&
+ "expected read effect on fromOp");
+ assert(hasEffect<MemoryEffects::Read>(toOp) &&
+ "expected read effect on toOp");
Operation *nextOp = fromOp->getNextNode();
auto result =
memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp, nullptr));
@@ -245,11 +244,10 @@ LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
// Some simple use case of operation with memory side-effect are dealt with
// here. Operations with no side-effect are done after.
if (!isMemoryEffectFree(op)) {
- auto memEffects = dyn_cast<MemoryEffectOpInterface>(op);
// TODO: Only basic use case for operations with MemoryEffects::Read can be
// eleminated now. More work needs to be done for more complicated patterns
// and other side-effects.
- if (!memEffects || !memEffects.onlyHasEffect<MemoryEffects::Read>())
+ if (!hasSingleEffect<MemoryEffects::Read>(op))
return failure();
// Look for an existing definition for the operation.
diff --git a/mlir/lib/Transforms/InlinerPass.cpp b/mlir/lib/Transforms/InlinerPass.cpp
index 703e517..77a9e6c 100644
--- a/mlir/lib/Transforms/InlinerPass.cpp
+++ b/mlir/lib/Transforms/InlinerPass.cpp
@@ -18,6 +18,7 @@
#include "mlir/Analysis/CallGraph.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Inliner.h"
+#include "llvm/Support/DebugLog.h"
namespace mlir {
#define GEN_PASS_DEF_INLINER
@@ -120,8 +121,8 @@ static bool isProfitableToInline(const Inliner::ResolvedCall &resolvedCall,
return true;
unsigned ratio = countOps(calleeRegion) * 100 / callerOps;
- LLVM_DEBUG(llvm::dbgs() << "Callee / caller operation ratio (max: "
- << inliningThreshold << "%): " << ratio << "%\n");
+ LDBG() << "Callee / caller operation ratio (max: " << inliningThreshold
+ << "%): " << ratio << "%";
return ratio <= inliningThreshold;
}
@@ -138,7 +139,7 @@ void InlinerPass::runOnOperation() {
}
// By default, assume that any inlining is profitable.
- auto profitabilityCb = [=](const Inliner::ResolvedCall &call) {
+ auto profitabilityCb = [this](const Inliner::ResolvedCall &call) {
return isProfitableToInline(call, inliningThreshold);
};
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index cf039c3..d36a3c1 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -19,6 +19,7 @@
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/GenericIteratedDominanceFrontier.h"
namespace mlir {
@@ -632,8 +633,7 @@ MemorySlotPromoter::promoteSlot() {
}
}
- LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: " << slot.ptr
- << "\n");
+ LDBG() << "Promoted memory slot: " << slot.ptr;
if (statistics.promotedAmount)
(*statistics.promotedAmount)++;
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 4ccb83f..0e84b6d 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -258,18 +258,17 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
- LDBG() << "Processing simple op: " << *op;
if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
- LDBG()
- << "Simple op is not memory effect free or has live results, skipping: "
- << *op;
+ LDBG() << "Simple op is not memory effect free or has live results, "
+ "preserving it: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
return;
}
LDBG()
<< "Simple op has all dead results and is memory effect free, scheduling "
"for removal: "
- << *op;
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
cl.operations.push_back(op);
collectNonLiveValues(nonLiveSet, op->getResults(),
BitVector(op->getNumResults(), true));
@@ -345,8 +344,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
// being returned, in order to optimize our IR. So, this demonstrates how we
// can make our optimization strong by even removing a live return value (%0),
// since it forwards only to non-live value(s) (%1#1).
- Operation *lastReturnOp = funcOp.back().getTerminator();
- size_t numReturns = lastReturnOp->getNumOperands();
+ size_t numReturns = funcOp.getNumResults();
BitVector nonLiveRets(numReturns, true);
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
@@ -728,19 +726,31 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
/// Removes dead values collected in RDVFinalCleanupList.
/// To be run once when all dead values have been collected.
static void cleanUpDeadVals(RDVFinalCleanupList &list) {
+ LDBG() << "Starting cleanup of dead values...";
+
// 1. Operations
+ LDBG() << "Cleaning up " << list.operations.size() << " operations";
for (auto &op : list.operations) {
+ LDBG() << "Erasing operation: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
op->dropAllUses();
op->erase();
}
// 2. Values
+ LDBG() << "Cleaning up " << list.values.size() << " values";
for (auto &v : list.values) {
+ LDBG() << "Dropping all uses of value: " << v;
v.dropAllUses();
}
// 3. Functions
+ LDBG() << "Cleaning up " << list.functions.size() << " functions";
for (auto &f : list.functions) {
+ LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName();
+ LDBG() << " Erasing " << f.nonLiveArgs.count() << " non-live arguments";
+ LDBG() << " Erasing " << f.nonLiveRets.count()
+ << " non-live return values";
// Some functions may not allow erasing arguments or results. These calls
// return failure in such cases without modifying the function, so it's okay
// to proceed.
@@ -749,44 +759,67 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
}
// 4. Operands
+ LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
for (OperationToCleanup &o : list.operands) {
- if (o.op->getNumOperands() > 0)
+ if (o.op->getNumOperands() > 0) {
+ LDBG() << "Erasing " << o.nonLive.count()
+ << " non-live operands from operation: "
+ << OpWithFlags(o.op, OpPrintingFlags().skipRegions());
o.op->eraseOperands(o.nonLive);
+ }
}
// 5. Results
+ LDBG() << "Cleaning up " << list.results.size() << " result lists";
for (auto &r : list.results) {
+ LDBG() << "Erasing " << r.nonLive.count()
+ << " non-live results from operation: "
+ << OpWithFlags(r.op, OpPrintingFlags().skipRegions());
dropUsesAndEraseResults(r.op, r.nonLive);
}
// 6. Blocks
+ LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists";
for (auto &b : list.blocks) {
// blocks that are accessed via multiple codepaths processed once
if (b.b->getNumArguments() != b.nonLiveArgs.size())
continue;
+ LDBG() << "Erasing " << b.nonLiveArgs.count()
+ << " non-live arguments from block: " << b.b;
// it iterates backwards because erase invalidates all successor indexes
for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
if (!b.nonLiveArgs[i])
continue;
+ LDBG() << " Erasing block argument " << i << ": " << b.b->getArgument(i);
b.b->getArgument(i).dropAllUses();
b.b->eraseArgument(i);
}
}
// 7. Successor Operands
+ LDBG() << "Cleaning up " << list.successorOperands.size()
+ << " successor operand lists";
for (auto &op : list.successorOperands) {
SuccessorOperands successorOperands =
op.branch.getSuccessorOperands(op.successorIndex);
// blocks that are accessed via multiple codepaths processed once
if (successorOperands.size() != op.nonLiveOperands.size())
continue;
+ LDBG() << "Erasing " << op.nonLiveOperands.count()
+ << " non-live successor operands from successor "
+ << op.successorIndex << " of branch: "
+ << OpWithFlags(op.branch, OpPrintingFlags().skipRegions());
// it iterates backwards because erase invalidates all successor indexes
for (int i = successorOperands.size() - 1; i >= 0; --i) {
if (!op.nonLiveOperands[i])
continue;
+ LDBG() << " Erasing successor operand " << i << ": "
+ << successorOperands[i];
successorOperands.erase(i);
}
}
+
+ LDBG() << "Finished cleanup of dead values";
}
struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp
index 67f536a..859c030 100644
--- a/mlir/lib/Transforms/SROA.cpp
+++ b/mlir/lib/Transforms/SROA.cpp
@@ -12,6 +12,7 @@
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Transforms/Passes.h"
+#include "llvm/Support/DebugLog.h"
namespace mlir {
#define GEN_PASS_DEF_SROA
@@ -180,8 +181,7 @@ static void destructureSlot(
assert(slot.ptr.use_empty() && "after destructuring, the original slot "
"pointer should no longer be used");
- LLVM_DEBUG(llvm::dbgs() << "[sroa] Destructured memory slot: " << slot.ptr
- << "\n");
+ LDBG() << "Destructured memory slot: " << slot.ptr;
if (statistics.destructuredAmount)
(*statistics.destructuredAmount)++;
diff --git a/mlir/lib/Transforms/SymbolDCE.cpp b/mlir/lib/Transforms/SymbolDCE.cpp
index 0a925c4..87885be 100644
--- a/mlir/lib/Transforms/SymbolDCE.cpp
+++ b/mlir/lib/Transforms/SymbolDCE.cpp
@@ -13,8 +13,11 @@
#include "mlir/Transforms/Passes.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/InterleavedRange.h"
namespace mlir {
#define GEN_PASS_DEF_SYMBOLDCE
@@ -87,8 +90,8 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
SymbolTableCollection &symbolTable,
bool symbolTableIsHidden,
DenseSet<Operation *> &liveSymbols) {
- LLVM_DEBUG(llvm::dbgs() << "computeLiveness: " << symbolTableOp->getName()
- << "\n");
+ LDBG() << "computeLiveness: "
+ << OpWithFlags(symbolTableOp, OpPrintingFlags().skipRegions());
// A worklist of live operations to propagate uses from.
SmallVector<Operation *, 16> worklist;
@@ -116,7 +119,8 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
// consideration.
while (!worklist.empty()) {
Operation *op = worklist.pop_back_val();
- LLVM_DEBUG(llvm::dbgs() << "processing: " << op->getName() << "\n");
+ LDBG() << "processing: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
// If this is a symbol table, recursively compute its liveness.
if (op->hasTrait<OpTrait::SymbolTable>()) {
@@ -124,13 +128,14 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
// symbol, or if it is a private symbol.
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate();
- LLVM_DEBUG(llvm::dbgs() << "\tsymbol table: " << op->getName()
- << " is hidden: " << symIsHidden << "\n");
+ LDBG() << "\tsymbol table: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " is hidden: " << symIsHidden;
if (failed(computeLiveness(op, symbolTable, symIsHidden, liveSymbols)))
return failure();
} else {
- LLVM_DEBUG(llvm::dbgs()
- << "\tnon-symbol table: " << op->getName() << "\n");
+ LDBG() << "\tnon-symbol table: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
// If the op is not a symbol table, then, unless op itself is dead which
// would be handled by DCE, we need to check all the regions and blocks
// within the op to find the uses (e.g., consider visibility within op as
@@ -160,20 +165,17 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
}
SmallVector<Operation *, 4> resolvedSymbols;
- LLVM_DEBUG(llvm::dbgs() << "uses of " << op->getName() << "\n");
+ LDBG() << "uses of " << OpWithFlags(op, OpPrintingFlags().skipRegions());
for (const SymbolTable::SymbolUse &use : *uses) {
- LLVM_DEBUG(llvm::dbgs() << "\tuse: " << use.getUser() << "\n");
+ LDBG() << "\tuse: " << use.getUser();
// Lookup the symbols referenced by this use.
resolvedSymbols.clear();
if (failed(symbolTable.lookupSymbolIn(parentOp, use.getSymbolRef(),
resolvedSymbols)))
// Ignore references to unknown symbols.
continue;
- LLVM_DEBUG({
- llvm::dbgs() << "\t\tresolved symbols: ";
- llvm::interleaveComma(resolvedSymbols, llvm::dbgs());
- llvm::dbgs() << "\n";
- });
+ LDBG() << "\t\tresolved symbols: "
+ << llvm::interleaved(resolvedSymbols, ", ");
// Mark each of the resolved symbols as live.
for (Operation *resolvedSymbol : resolvedSymbols)
diff --git a/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp b/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp
index cfd568f..19cf464 100644
--- a/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp
+++ b/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp
@@ -21,7 +21,10 @@
#include "mlir/Transforms/ControlFlowSinkUtils.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "llvm/Support/DebugLog.h"
#include <vector>
#define DEBUG_TYPE "cf-sink"
@@ -84,13 +87,15 @@ bool Sinker::allUsersDominatedBy(Operation *op, Region *region) {
void Sinker::tryToSinkPredecessors(Operation *user, Region *region,
std::vector<Operation *> &stack) {
- LLVM_DEBUG(user->print(llvm::dbgs() << "\nContained op:\n"));
+ LDBG() << "Contained op: "
+ << OpWithFlags(user, OpPrintingFlags().skipRegions());
for (Value value : user->getOperands()) {
Operation *op = value.getDefiningOp();
// Ignore block arguments and ops that are already inside the region.
if (!op || op->getParentRegion() == region)
continue;
- LLVM_DEBUG(op->print(llvm::dbgs() << "\nTry to sink:\n"));
+ LDBG() << "Try to sink:\n"
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
// If the op's users are all in the region and it can be moved, then do so.
if (allUsersDominatedBy(op, region) && shouldMoveIntoRegion(op, region)) {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 0c26b4e..5ba109d 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -182,15 +182,24 @@ private:
/// conversions.)
static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
+/// Return the operation that defines all values in the vector. Return nullptr
+/// if the values are not defined by the same operation.
+static Operation *getCommonDefiningOp(const ValueVector &values) {
+ assert(!values.empty() && "expected non-empty value vector");
+ Operation *op = values.front().getDefiningOp();
+ for (Value v : llvm::drop_begin(values)) {
+ if (v.getDefiningOp() != op)
+ return nullptr;
+ }
+ return op;
+}
+
/// A vector of values is a pure type conversion if all values are defined by
/// the same operation and the operation has the `kPureTypeConversionMarker`
/// attribute.
static bool isPureTypeConversion(const ValueVector &values) {
assert(!values.empty() && "expected non-empty value vector");
- Operation *op = values.front().getDefiningOp();
- for (Value v : llvm::drop_begin(values))
- if (v.getDefiningOp() != op)
- return false;
+ Operation *op = getCommonDefiningOp(values);
return op && op->hasAttr(kPureTypeConversionMarker);
}
@@ -839,9 +848,10 @@ static bool hasRewrite(R &&rewrites, Block *block) {
namespace mlir {
namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
- explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
+ explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
const ConversionConfig &config)
- : context(ctx), config(config) {}
+ : rewriter(rewriter), config(config),
+ notifyingRewriter(rewriter.getContext(), config.listener) {}
//===--------------------------------------------------------------------===//
// State Management
@@ -863,6 +873,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// failure.
template <typename RewriteTy, typename... Args>
void appendRewrite(Args &&...args) {
+ assert(config.allowPatternRollback && "appending rewrites is not allowed");
rewrites.push_back(
std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
}
@@ -877,8 +888,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// is the tag used when describing a value within a diagnostic, e.g.
/// "operand".
LogicalResult remapValues(StringRef valueDiagTag,
- std::optional<Location> inputLoc,
- PatternRewriter &rewriter, ValueRange values,
+ std::optional<Location> inputLoc, ValueRange values,
SmallVector<ValueVector> &remapped);
/// Return "true" if the given operation is ignored, and does not need to be
@@ -889,15 +899,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
bool wasOpReplaced(Operation *op) const;
/// Lookup the most recently mapped values with the desired types in the
- /// mapping.
- ///
- /// Special cases:
- /// - If the desired type range is empty, simply return the most recently
- /// mapped values.
- /// - If there is no mapping to the desired types, also return the most
- /// recently mapped values.
- /// - If there is no mapping for the given values at all, return the given
- /// value.
+ /// mapping, taking into account only replacements. Perform a best-effort
+ /// search for existing materializations with the desired types.
///
/// If `skipPureTypeConversions` is "true", materializations that are pure
/// type conversions are not considered.
@@ -915,8 +918,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Convert the types of block arguments within the given region.
FailureOr<Block *>
- convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
- const TypeConverter &converter,
+ convertRegionTypes(Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion);
/// Apply the given signature conversion on the given block. The new block
@@ -926,8 +928,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// translate between the origin argument types and those specified in the
/// signature conversion.
Block *applySignatureConversion(
- ConversionPatternRewriter &rewriter, Block *block,
- const TypeConverter *converter,
+ Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion);
/// Replace the results of the given operation with the given values and
@@ -976,7 +977,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
Type originalType, const TypeConverter *converter,
- UnrealizedConversionCastOp *castOp = nullptr,
bool isPureTypeConversion = true);
/// Find a replacement value for the given SSA value in the conversion value
@@ -1058,14 +1058,17 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// State
//===--------------------------------------------------------------------===//
- /// MLIR context.
- MLIRContext *context;
+ /// The rewriter that is used to perform the conversion.
+ ConversionPatternRewriter &rewriter;
// Mapping between replaced values that differ in type. This happens when
// replacing a value with one of a different type.
ConversionValueMapping mapping;
/// Ordered list of block operations (creations, splits, motions).
+ /// This vector is maintained only if `allowPatternRollback` is set to
+ /// "true". Otherwise, all IR rewrites are materialized immediately and no
+ /// bookkeeping is needed.
SmallVector<std::unique_ptr<IRRewrite>> rewrites;
/// A set of operations that should no longer be considered for legalization.
@@ -1089,6 +1092,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// by the current pattern.
SetVector<Block *> patternInsertedBlocks;
+ /// A list of unresolved materializations that were created by the current
+ /// pattern.
+ DenseSet<UnrealizedConversionCastOp> patternMaterializations;
+
/// A mapping for looking up metadata of unresolved materializations.
DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
unresolvedMaterializations;
@@ -1104,15 +1111,37 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Dialect conversion configuration.
const ConversionConfig &config;
+ /// A set of erased operations. This set is utilized only if
+ /// `allowPatternRollback` is set to "false". Conceptually, this set is
+ /// similar to `replacedOps` (which is maintained when the flag is set to
+ /// "true"). However, erasing from a DenseSet is more efficient than erasing
+ /// from a SetVector.
+ DenseSet<Operation *> erasedOps;
+
+ /// A set of erased blocks. This set is utilized only if
+ /// `allowPatternRollback` is set to "false".
+ DenseSet<Block *> erasedBlocks;
+
+ /// A rewriter that notifies the listener (if any) about all IR
+ /// modifications. This rewriter is utilized only if `allowPatternRollback`
+ /// is set to "false". If the flag is set to "true", the listener is notified
+ /// with a separate mechanism (e.g., in `IRRewrite::commit`).
+ IRRewriter notifyingRewriter;
+
#ifndef NDEBUG
+ /// A set of replaced block arguments. This set is for debugging purposes
+ /// only and it is maintained only if `allowPatternRollback` is set to
+ /// "true".
+ DenseSet<BlockArgument> replacedArgs;
+
/// A set of operations that have pending updates. This tracking isn't
/// strictly necessary, and is thus only active during debug builds for extra
/// verification.
SmallPtrSet<Operation *, 1> pendingRootUpdates;
/// A raw output stream used to prefix the debug log.
- llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + "] ").str(),
- llvm::dbgs(), /*HasPendingNewline=*/false};
+ llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(),
+ llvm::dbgs()};
/// A logger used to emit diagnostics during the conversion process.
llvm::ScopedPrinter logger{os};
@@ -1140,11 +1169,8 @@ void BlockTypeConversionRewrite::rollback() {
getNewBlock()->replaceAllUsesWith(getOrigBlock());
}
-void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
- Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
- if (!repl)
- return;
-
+static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg,
+ Value repl) {
if (isa<BlockArgument>(repl)) {
rewriter.replaceAllUsesWith(arg, repl);
return;
@@ -1161,6 +1187,13 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
});
}
+void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
+ Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
+ if (!repl)
+ return;
+ performReplaceBlockArg(rewriter, arg, repl);
+}
+
void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
@@ -1223,16 +1256,17 @@ void UnresolvedMaterializationRewrite::rollback() {
}
void ConversionPatternRewriterImpl::applyRewrites() {
- // Commit all rewrites.
- IRRewriter rewriter(context, config.listener);
+ // Commit all rewrites. Use a new rewriter, so the modifications are not
+ // tracked for rollback purposes etc.
+ IRRewriter irRewriter(rewriter.getContext(), config.listener);
// Note: New rewrites may be added during the "commit" phase and the
// `rewrites` vector may reallocate.
for (size_t i = 0; i < rewrites.size(); ++i)
- rewrites[i]->commit(rewriter);
+ rewrites[i]->commit(irRewriter);
// Clean up all rewrites.
SingleEraseRewriter eraseRewriter(
- context, /*opErasedCallback=*/[&](Operation *op) {
+ rewriter.getContext(), /*opErasedCallback=*/[&](Operation *op) {
if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
unresolvedMaterializations.erase(castOp);
});
@@ -1246,6 +1280,30 @@ void ConversionPatternRewriterImpl::applyRewrites() {
ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
+ // Helper function that looks up a single value.
+ auto lookup = [&](const ValueVector &values) -> ValueVector {
+ assert(!values.empty() && "expected non-empty value vector");
+
+ // If the pattern rollback is enabled, use the mapping to look up the
+ // values.
+ if (config.allowPatternRollback)
+ return mapping.lookup(values);
+
+ // Otherwise, look up values by examining the IR. All replacements have
+ // already been materialized in IR.
+ Operation *op = getCommonDefiningOp(values);
+ if (!op)
+ return {};
+ auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
+ if (!castOp)
+ return {};
+ if (!this->unresolvedMaterializations.contains(castOp))
+ return {};
+ if (castOp.getOutputs() != values)
+ return {};
+ return castOp.getInputs();
+ };
+
// Helper function that looks up each value in `values` individually and then
// composes the results. If that fails, it tries to look up the entire vector
// at once.
@@ -1253,7 +1311,7 @@ ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
// If possible, replace each value with (one or multiple) mapped values.
ValueVector next;
for (Value v : values) {
- ValueVector r = mapping.lookup({v});
+ ValueVector r = lookup({v});
if (!r.empty()) {
llvm::append_range(next, r);
} else {
@@ -1273,7 +1331,7 @@ ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
// be stored (and looked up) in the mapping. But for performance reasons,
// we choose to reuse existing IR (when possible) instead of creating it
// multiple times.
- ValueVector r = mapping.lookup(values);
+ ValueVector r = lookup(values);
if (r.empty()) {
// No mapping found: The lookup stops here.
return {};
@@ -1347,21 +1405,13 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state,
void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
StringRef patternName) {
for (auto &rewrite :
- llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) {
- if (!config.allowPatternRollback &&
- !isa<UnresolvedMaterializationRewrite>(rewrite)) {
- // Unresolved materializations can always be rolled back (erased).
- llvm::report_fatal_error("pattern '" + patternName +
- "' rollback of IR modifications requested");
- }
+ llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
rewrite->rollback();
- }
rewrites.resize(numRewritesToKeep);
}
LogicalResult ConversionPatternRewriterImpl::remapValues(
- StringRef valueDiagTag, std::optional<Location> inputLoc,
- PatternRewriter &rewriter, ValueRange values,
+ StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
SmallVector<ValueVector> &remapped) {
remapped.reserve(llvm::size(values));
@@ -1383,7 +1433,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
// If there is no legal conversion, fail to match this pattern.
SmallVector<Type, 1> legalTypes;
- if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
+ if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
diag << "unable to convert type for " << valueDiagTag << " #"
<< it.index() << ", type was " << origType;
@@ -1419,12 +1469,12 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
// Check to see if this operation is ignored or was replaced.
- return replacedOps.count(op) || ignoredOps.count(op);
+ return wasOpReplaced(op) || ignoredOps.count(op);
}
bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
// Check to see if this operation was replaced.
- return replacedOps.count(op);
+ return replacedOps.count(op) || erasedOps.count(op);
}
//===----------------------------------------------------------------------===//
@@ -1432,8 +1482,7 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
//===----------------------------------------------------------------------===//
FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
- ConversionPatternRewriter &rewriter, Region *region,
- const TypeConverter &converter,
+ Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion) {
regionToConverter[region] = &converter;
if (region->empty())
@@ -1448,25 +1497,23 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
if (!conversion)
return failure();
// Convert the block with the computed signature.
- applySignatureConversion(rewriter, &block, &converter, *conversion);
+ applySignatureConversion(&block, &converter, *conversion);
}
// Convert the entry block. If an entry signature conversion was provided,
// use that one. Otherwise, compute the signature with the type converter.
if (entryConversion)
- return applySignatureConversion(rewriter, &region->front(), &converter,
+ return applySignatureConversion(&region->front(), &converter,
*entryConversion);
std::optional<TypeConverter::SignatureConversion> conversion =
converter.convertBlockSignature(&region->front());
if (!conversion)
return failure();
- return applySignatureConversion(rewriter, &region->front(), &converter,
- *conversion);
+ return applySignatureConversion(&region->front(), &converter, *conversion);
}
Block *ConversionPatternRewriterImpl::applySignatureConversion(
- ConversionPatternRewriter &rewriter, Block *block,
- const TypeConverter *converter,
+ Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion) {
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// A block cannot be converted multiple times.
@@ -1508,7 +1555,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// a bit more efficient, so we try to do that when possible.
bool fastPath = !config.listener;
if (fastPath) {
- appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
+ if (config.allowPatternRollback)
+ appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
newBlock->getOperations().splice(newBlock->end(), block->getOperations());
} else {
while (!block->empty())
@@ -1534,7 +1582,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
origArg.getLoc(),
/*valuesToMap=*/{}, /*inputs=*/ValueRange(),
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
- /*castOp=*/nullptr, /*isPureTypeConversion=*/false)
+ /*isPureTypeConversion=*/false)
.front();
replaceUsesOfBlockArgument(origArg, mat, converter);
continue;
@@ -1556,7 +1604,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
replaceUsesOfBlockArgument(origArg, replArgs, converter);
}
- appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
+ if (config.allowPatternRollback)
+ appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
// Erase the old block. (It is just unlinked for now and will be erased during
// cleanup.)
@@ -1575,7 +1624,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
Type originalType, const TypeConverter *converter,
- UnrealizedConversionCastOp *castOp, bool isPureTypeConversion) {
+ bool isPureTypeConversion) {
assert((!originalType || kind == MaterializationKind::Target) &&
"original type is valid only for target materializations");
assert(TypeRange(inputs) != outputTypes &&
@@ -1585,23 +1634,35 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
// tracking the materialization like we do for other operations.
OpBuilder builder(outputTypes.front().getContext());
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
- auto convertOp =
+ UnrealizedConversionCastOp convertOp =
UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
+ if (config.attachDebugMaterializationKind) {
+ StringRef kindStr =
+ kind == MaterializationKind::Source ? "source" : "target";
+ convertOp->setAttr("__kind__", builder.getStringAttr(kindStr));
+ }
if (isPureTypeConversion)
convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
- if (!valuesToMap.empty())
- mapping.map(valuesToMap, convertOp.getResults());
- if (castOp)
- *castOp = convertOp;
+
+ // Register the materialization.
unresolvedMaterializations[convertOp] =
UnresolvedMaterializationInfo(converter, kind, originalType);
- appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
- std::move(valuesToMap));
+ if (config.allowPatternRollback) {
+ if (!valuesToMap.empty())
+ mapping.map(valuesToMap, convertOp.getResults());
+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
+ std::move(valuesToMap));
+ } else {
+ patternMaterializations.insert(convertOp);
+ }
return convertOp.getResults();
}
Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
Value value, const TypeConverter *converter) {
+ assert(config.allowPatternRollback &&
+ "this code path is valid only in rollback mode");
+
// Try to find a replacement value with the same type in the conversion value
// mapping. This includes cached materializations. We try to reuse those
// instead of generating duplicate IR.
@@ -1663,26 +1724,119 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
logger.getOStream() << " (was detached)";
logger.getOStream() << "\n";
});
- assert(!wasOpReplaced(op->getParentOp()) &&
+
+ // In rollback mode, it is easier to misuse the API, so perform extra error
+ // checking.
+ assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) &&
"attempting to insert into a block within a replaced/erased op");
+ // In "no rollback" mode, the listener is always notified immediately.
+ if (!config.allowPatternRollback && config.listener)
+ config.listener->notifyOperationInserted(op, previous);
+
if (wasDetached) {
- // If the op was detached, it is most likely a newly created op.
- // TODO: If the same op is inserted multiple times from a detached state,
- // the rollback mechanism may erase the same op multiple times. This is a
- // bug in the rollback-based dialect conversion driver.
- appendRewrite<CreateOperationRewrite>(op);
+ // If the op was detached, it is most likely a newly created op. Add it the
+ // set of newly created ops, so that it will be legalized. If this op is
+ // not a newly created op, it will be legalized a second time, which is
+ // inefficient but harmless.
patternNewOps.insert(op);
+
+ if (config.allowPatternRollback) {
+ // TODO: If the same op is inserted multiple times from a detached
+ // state, the rollback mechanism may erase the same op multiple times.
+ // This is a bug in the rollback-based dialect conversion driver.
+ appendRewrite<CreateOperationRewrite>(op);
+ } else {
+ // In "no rollback" mode, there is an extra data structure for tracking
+ // erased operations that must be kept up to date.
+ erasedOps.erase(op);
+ }
return;
}
// The op was moved from one place to another.
- appendRewrite<MoveOperationRewrite>(op, previous);
+ if (config.allowPatternRollback)
+ appendRewrite<MoveOperationRewrite>(op, previous);
+}
+
+/// Given that `fromRange` is about to be replaced with `toRange`, compute
+/// replacement values with the types of `fromRange`.
+static SmallVector<Value>
+getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange,
+ const SmallVector<SmallVector<Value>> &toRange,
+ const TypeConverter *converter) {
+ assert(!impl.config.allowPatternRollback &&
+ "this code path is valid only in 'no rollback' mode");
+ SmallVector<Value> repls;
+ for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
+ if (from.use_empty()) {
+ // The replaced value is dead. No replacement value is needed.
+ repls.push_back(Value());
+ continue;
+ }
+
+ if (to.empty()) {
+ // The replaced value is dropped. Materialize a replacement value "out of
+ // thin air".
+ Value srcMat = impl.buildUnresolvedMaterialization(
+ MaterializationKind::Source, computeInsertPoint(from), from.getLoc(),
+ /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
+ /*outputTypes=*/from.getType(), /*originalType=*/Type(),
+ converter)[0];
+ repls.push_back(srcMat);
+ continue;
+ }
+
+ if (TypeRange(ValueRange(to)) == TypeRange(from.getType())) {
+ // The replacement value already has the correct type. Use it directly.
+ repls.push_back(to[0]);
+ continue;
+ }
+
+ // The replacement value has the wrong type. Build a source materialization
+ // to the original type.
+ // TODO: This is a bit inefficient. We should try to reuse existing
+ // materializations if possible. This would require an extension of the
+ // `lookupOrDefault` API.
+ Value srcMat = impl.buildUnresolvedMaterialization(
+ MaterializationKind::Source, computeInsertPoint(to), from.getLoc(),
+ /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(),
+ /*originalType=*/Type(), converter)[0];
+ repls.push_back(srcMat);
+ }
+
+ return repls;
}
void ConversionPatternRewriterImpl::replaceOp(
Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
- assert(newValues.size() == op->getNumResults());
+ assert(newValues.size() == op->getNumResults() &&
+ "incorrect number of replacement values");
+
+ if (!config.allowPatternRollback) {
+ // Pattern rollback is not allowed: materialize all IR changes immediately.
+ SmallVector<Value> repls = getReplacementValues(
+ *this, op->getResults(), newValues, currentTypeConverter);
+ // Update internal data structures, so that there are no dangling pointers
+ // to erased IR.
+ op->walk([&](Operation *op) {
+ erasedOps.insert(op);
+ ignoredOps.remove(op);
+ if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
+ unresolvedMaterializations.erase(castOp);
+ patternMaterializations.erase(castOp);
+ }
+ // The original op will be erased, so remove it from the set of
+ // unlegalized ops.
+ if (config.unlegalizedOps)
+ config.unlegalizedOps->erase(op);
+ });
+ op->walk([&](Block *block) { erasedBlocks.insert(block); });
+ // Replace the op with the replacement values and notify the listener.
+ notifyingRewriter.replaceOp(op, repls);
+ return;
+ }
+
assert(!ignoredOps.contains(op) && "operation was already replaced");
// Check if replaced op is an unresolved materialization, i.e., an
@@ -1704,8 +1858,7 @@ void ConversionPatternRewriterImpl::replaceOp(
MaterializationKind::Source, computeInsertPoint(result),
result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
/*outputTypes=*/result.getType(), /*originalType=*/Type(),
- currentTypeConverter, /*castOp=*/nullptr,
- /*isPureTypeConversion=*/false);
+ currentTypeConverter, /*isPureTypeConversion=*/false);
continue;
}
@@ -1722,11 +1875,59 @@ void ConversionPatternRewriterImpl::replaceOp(
void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
BlockArgument from, ValueRange to, const TypeConverter *converter) {
+ if (!config.allowPatternRollback) {
+ SmallVector<Value> toConv = llvm::to_vector(to);
+ SmallVector<Value> repls =
+ getReplacementValues(*this, from, {toConv}, converter);
+ IRRewriter r(from.getContext());
+ Value repl = repls.front();
+ if (!repl)
+ return;
+
+ performReplaceBlockArg(r, from, repl);
+ return;
+ }
+
+#ifndef NDEBUG
+ // Make sure that a block argument is not replaced multiple times. In
+ // rollback mode, `replaceUsesOfBlockArgument` replaces not only all current
+ // uses of the given block argument, but also all future uses that may be
+ // introduced by future pattern applications. Therefore, it does not make
+ // sense to call `replaceUsesOfBlockArgument` multiple times with the same
+ // block argument. Doing so would overwrite the mapping and mess with the
+ // internal state of the dialect conversion driver.
+ assert(!replacedArgs.contains(from) &&
+ "attempting to replace a block argument that was already replaced");
+ replacedArgs.insert(from);
+#endif // NDEBUG
+
appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
mapping.map(from, to);
}
void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
+ if (!config.allowPatternRollback) {
+ // Pattern rollback is not allowed: materialize all IR changes immediately.
+ // Update internal data structures, so that there are no dangling pointers
+ // to erased IR.
+ block->walk([&](Operation *op) {
+ erasedOps.insert(op);
+ ignoredOps.remove(op);
+ if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
+ unresolvedMaterializations.erase(castOp);
+ patternMaterializations.erase(castOp);
+ }
+ // The original op will be erased, so remove it from the set of
+ // unlegalized ops.
+ if (config.unlegalizedOps)
+ config.unlegalizedOps->erase(op);
+ });
+ block->walk([&](Block *block) { erasedBlocks.insert(block); });
+ // Erase the block and notify the listener.
+ notifyingRewriter.eraseBlock(block);
+ return;
+ }
+
assert(!wasOpReplaced(block->getParentOp()) &&
"attempting to erase a block within a replaced/erased op");
appendRewrite<EraseBlockRewrite>(block);
@@ -1760,23 +1961,37 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
logger.getOStream() << " (was detached)";
logger.getOStream() << "\n";
});
- assert(!wasOpReplaced(newParentOp) &&
+
+ // In rollback mode, it is easier to misuse the API, so perform extra error
+ // checking.
+ assert(!(config.allowPatternRollback && wasOpReplaced(newParentOp)) &&
"attempting to insert into a region within a replaced/erased op");
(void)newParentOp;
+ // In "no rollback" mode, the listener is always notified immediately.
+ if (!config.allowPatternRollback && config.listener)
+ config.listener->notifyBlockInserted(block, previous, previousIt);
+
patternInsertedBlocks.insert(block);
if (wasDetached) {
// If the block was detached, it is most likely a newly created block.
- // TODO: If the same block is inserted multiple times from a detached state,
- // the rollback mechanism may erase the same block multiple times. This is a
- // bug in the rollback-based dialect conversion driver.
- appendRewrite<CreateBlockRewrite>(block);
+ if (config.allowPatternRollback) {
+ // TODO: If the same block is inserted multiple times from a detached
+ // state, the rollback mechanism may erase the same block multiple times.
+ // This is a bug in the rollback-based dialect conversion driver.
+ appendRewrite<CreateBlockRewrite>(block);
+ } else {
+ // In "no rollback" mode, there is an extra data structure for tracking
+ // erased blocks that must be kept up to date.
+ erasedBlocks.erase(block);
+ }
return;
}
// The block was moved from one place to another.
- appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
+ if (config.allowPatternRollback)
+ appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
}
void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source,
@@ -1803,7 +2018,7 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
ConversionPatternRewriter::ConversionPatternRewriter(
MLIRContext *ctx, const ConversionConfig &config)
: PatternRewriter(ctx),
- impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
+ impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
setListener(impl.get());
}
@@ -1880,7 +2095,7 @@ Block *ConversionPatternRewriter::applySignatureConversion(
assert(!impl->wasOpReplaced(block->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
- return impl->applySignatureConversion(*this, block, converter, conversion);
+ return impl->applySignatureConversion(block, converter, conversion);
}
FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
@@ -1889,7 +2104,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
assert(!impl->wasOpReplaced(region->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
- return impl->convertRegionTypes(*this, region, converter, entryConversion);
+ return impl->convertRegionTypes(region, converter, entryConversion);
}
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
@@ -1908,7 +2123,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
Value ConversionPatternRewriter::getRemappedValue(Value key) {
SmallVector<ValueVector> remappedValues;
- if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
+ if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, key,
remappedValues)))
return nullptr;
assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
@@ -1921,7 +2136,7 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
if (keys.empty())
return success();
SmallVector<ValueVector> remapped;
- if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
+ if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, keys,
remapped)))
return failure();
for (const auto &values : remapped) {
@@ -1956,7 +2171,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
// a bit more efficient, so we try to do that when possible.
bool fastPath = !getConfig().listener;
- if (fastPath)
+ if (fastPath && impl->config.allowPatternRollback)
impl->inlineBlockBefore(source, dest, before);
// Replace all uses of block arguments.
@@ -1982,6 +2197,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
}
void ConversionPatternRewriter::startOpModification(Operation *op) {
+ if (!impl->config.allowPatternRollback) {
+ // Pattern rollback is not allowed: no extra bookkeeping is needed.
+ PatternRewriter::startOpModification(op);
+ return;
+ }
assert(!impl->wasOpReplaced(op) &&
"attempting to modify a replaced/erased op");
#ifndef NDEBUG
@@ -1991,20 +2211,29 @@ void ConversionPatternRewriter::startOpModification(Operation *op) {
}
void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
- assert(!impl->wasOpReplaced(op) &&
- "attempting to modify a replaced/erased op");
- PatternRewriter::finalizeOpModification(op);
impl->patternModifiedOps.insert(op);
+ if (!impl->config.allowPatternRollback) {
+ PatternRewriter::finalizeOpModification(op);
+ if (getConfig().listener)
+ getConfig().listener->notifyOperationModified(op);
+ return;
+ }
// There is nothing to do here, we only need to track the operation at the
// start of the update.
#ifndef NDEBUG
+ assert(!impl->wasOpReplaced(op) &&
+ "attempting to modify a replaced/erased op");
assert(impl->pendingRootUpdates.erase(op) &&
"operation did not have a pending in-place update");
#endif
}
void ConversionPatternRewriter::cancelOpModification(Operation *op) {
+ if (!impl->config.allowPatternRollback) {
+ PatternRewriter::cancelOpModification(op);
+ return;
+ }
#ifndef NDEBUG
assert(impl->pendingRootUpdates.erase(op) &&
"operation did not have a pending in-place update");
@@ -2029,17 +2258,17 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
// ConversionPattern
//===----------------------------------------------------------------------===//
-SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands(
+FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
ArrayRef<ValueRange> operands) const {
SmallVector<Value> oneToOneOperands;
oneToOneOperands.reserve(operands.size());
for (ValueRange operand : operands) {
if (operand.size() != 1)
- llvm::report_fatal_error("pattern '" + getDebugName() +
- "' does not support 1:N conversion");
+ return failure();
+
oneToOneOperands.push_back(operand.front());
}
- return oneToOneOperands;
+ return std::move(oneToOneOperands);
}
LogicalResult
@@ -2054,7 +2283,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
// Remap the operands of the operation.
SmallVector<ValueVector> remapped;
- if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
+ if (failed(rewriterImpl.remapValues("operand", op->getLoc(),
op->getOperands(), remapped))) {
return failure();
}
@@ -2076,7 +2305,8 @@ class OperationLegalizer {
public:
using LegalizationAction = ConversionTarget::LegalizationAction;
- OperationLegalizer(const ConversionTarget &targetInfo,
+ OperationLegalizer(ConversionPatternRewriter &rewriter,
+ const ConversionTarget &targetInfo,
const FrozenRewritePatternSet &patterns);
/// Returns true if the given operation is known to be illegal on the target.
@@ -2084,29 +2314,25 @@ public:
/// Attempt to legalize the given operation. Returns success if the operation
/// was legalized, failure otherwise.
- LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
+ LogicalResult legalize(Operation *op);
/// Returns the conversion target in use by the legalizer.
const ConversionTarget &getTarget() { return target; }
private:
/// Attempt to legalize the given operation by folding it.
- LogicalResult legalizeWithFold(Operation *op,
- ConversionPatternRewriter &rewriter);
+ LogicalResult legalizeWithFold(Operation *op);
/// Attempt to legalize the given operation by applying a pattern. Returns
/// success if the operation was legalized, failure otherwise.
- LogicalResult legalizeWithPattern(Operation *op,
- ConversionPatternRewriter &rewriter);
+ LogicalResult legalizeWithPattern(Operation *op);
/// Return true if the given pattern may be applied to the given operation,
/// false otherwise.
- bool canApplyPattern(Operation *op, const Pattern &pattern,
- ConversionPatternRewriter &rewriter);
+ bool canApplyPattern(Operation *op, const Pattern &pattern);
/// Legalize the resultant IR after successfully applying the given pattern.
LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
- ConversionPatternRewriter &rewriter,
const RewriterState &curState,
const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
@@ -2115,18 +2341,12 @@ private:
/// Legalizes the actions registered during the execution of a pattern.
LogicalResult
legalizePatternBlockRewrites(Operation *op,
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
const SetVector<Block *> &insertedBlocks,
const SetVector<Operation *> &newOps);
LogicalResult
- legalizePatternCreatedOperations(ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
- const SetVector<Operation *> &newOps);
+ legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
LogicalResult
- legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
- const SetVector<Operation *> &modifiedOps);
+ legalizePatternRootUpdates(const SetVector<Operation *> &modifiedOps);
//===--------------------------------------------------------------------===//
// Cost Model
@@ -2169,6 +2389,9 @@ private:
/// The current set of patterns that have been applied.
SmallPtrSet<const Pattern *, 8> appliedPatterns;
+ /// The rewriter to use when converting operations.
+ ConversionPatternRewriter &rewriter;
+
/// The legalization information provided by the target.
const ConversionTarget &target;
@@ -2177,9 +2400,10 @@ private:
};
} // namespace
-OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
+OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
+ const ConversionTarget &targetInfo,
const FrozenRewritePatternSet &patterns)
- : target(targetInfo), applicator(patterns) {
+ : rewriter(rewriter), target(targetInfo), applicator(patterns) {
// The set of patterns that can be applied to illegal operations to transform
// them into legal ones.
DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
@@ -2193,9 +2417,7 @@ bool OperationLegalizer::isIllegal(Operation *op) const {
return target.isIllegal(op);
}
-LogicalResult
-OperationLegalizer::legalize(Operation *op,
- ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalize(Operation *op) {
#ifndef NDEBUG
const char *logLineComment =
"//===-------------------------------------------===//\n";
@@ -2257,19 +2479,21 @@ OperationLegalizer::legalize(Operation *op,
return success();
}
- // If the operation isn't legal, try to fold it in-place.
- // TODO: Should we always try to do this, even if the op is
- // already legal?
- if (succeeded(legalizeWithFold(op, rewriter))) {
- LLVM_DEBUG({
- logSuccess(logger, "operation was folded");
- logger.startLine() << logLineComment;
- });
- return success();
+ // If the operation is not legal, try to fold it in-place if the folding mode
+ // is 'BeforePatterns'. 'Never' will skip this.
+ const ConversionConfig &config = rewriter.getConfig();
+ if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
+ if (succeeded(legalizeWithFold(op))) {
+ LLVM_DEBUG({
+ logSuccess(logger, "operation was folded");
+ logger.startLine() << logLineComment;
+ });
+ return success();
+ }
}
// Otherwise, we need to apply a legalization pattern to this operation.
- if (succeeded(legalizeWithPattern(op, rewriter))) {
+ if (succeeded(legalizeWithPattern(op))) {
LLVM_DEBUG({
logSuccess(logger, "");
logger.startLine() << logLineComment;
@@ -2277,6 +2501,18 @@ OperationLegalizer::legalize(Operation *op,
return success();
}
+ // If the operation can't be legalized via patterns, try to fold it in-place
+ // if the folding mode is 'AfterPatterns'.
+ if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
+ if (succeeded(legalizeWithFold(op))) {
+ LLVM_DEBUG({
+ logSuccess(logger, "operation was folded");
+ logger.startLine() << logLineComment;
+ });
+ return success();
+ }
+ }
+
LLVM_DEBUG({
logFailure(logger, "no matched legalization pattern");
logger.startLine() << logLineComment;
@@ -2293,9 +2529,7 @@ static T moveAndReset(T &obj) {
return result;
}
-LogicalResult
-OperationLegalizer::legalizeWithFold(Operation *op,
- ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
auto &rewriterImpl = rewriter.getImpl();
LLVM_DEBUG({
rewriterImpl.logger.startLine() << "* Fold {\n";
@@ -2329,14 +2563,14 @@ OperationLegalizer::legalizeWithFold(Operation *op,
// An empty list of replacement values indicates that the fold was in-place.
// As the operation changed, a new legalization needs to be attempted.
if (replacementValues.empty())
- return legalize(op, rewriter);
+ return legalize(op);
// Insert a replacement for 'op' with the folded replacement values.
rewriter.replaceOp(op, replacementValues);
// Recursively legalize any new constant operations.
for (Operation *newOp : newOps) {
- if (failed(legalize(newOp, rewriter))) {
+ if (failed(legalize(newOp))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger,
"failed to legalize generated constant '{0}'",
newOp->getName()));
@@ -2381,9 +2615,7 @@ reportNewIrLegalizationFatalError(const Pattern &pattern,
llvm::join(insertedBlockNames, ", ") + "}");
}
-LogicalResult
-OperationLegalizer::legalizeWithPattern(Operation *op,
- ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
auto &rewriterImpl = rewriter.getImpl();
const ConversionConfig &config = rewriter.getConfig();
@@ -2415,7 +2647,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
// Functor that returns if the given pattern may be applied.
auto canApply = [&](const Pattern &pattern) {
- bool canApply = canApplyPattern(op, pattern, rewriter);
+ bool canApply = canApplyPattern(op, pattern);
if (canApply && config.listener)
config.listener->notifyPatternBegin(pattern, op);
return canApply;
@@ -2425,17 +2657,23 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
RewriterState curState = rewriterImpl.getCurrentState();
auto onFailure = [&](const Pattern &pattern) {
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
-#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (!rewriterImpl.config.allowPatternRollback) {
- // Returning "failure" after modifying IR is not allowed.
+ // Erase all unresolved materializations.
+ for (auto op : rewriterImpl.patternMaterializations) {
+ rewriterImpl.unresolvedMaterializations.erase(op);
+ op.erase();
+ }
+ rewriterImpl.patternMaterializations.clear();
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ // Expensive pattern check that can detect API violations.
if (checkOp) {
OperationFingerPrint fingerPrintAfterPattern(checkOp);
if (fingerPrintAfterPattern != *topLevelFingerPrint)
llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
"' returned failure but IR did change");
}
- }
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ }
rewriterImpl.patternNewOps.clear();
rewriterImpl.patternModifiedOps.clear();
rewriterImpl.patternInsertedBlocks.clear();
@@ -2459,12 +2697,22 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
// successfully applied.
auto onSuccess = [&](const Pattern &pattern) {
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
+ if (!rewriterImpl.config.allowPatternRollback) {
+ // Eagerly erase unused materializations.
+ for (auto op : rewriterImpl.patternMaterializations) {
+ if (op->use_empty()) {
+ rewriterImpl.unresolvedMaterializations.erase(op);
+ op.erase();
+ }
+ }
+ rewriterImpl.patternMaterializations.clear();
+ }
SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
SetVector<Operation *> modifiedOps =
moveAndReset(rewriterImpl.patternModifiedOps);
SetVector<Block *> insertedBlocks =
moveAndReset(rewriterImpl.patternInsertedBlocks);
- auto result = legalizePatternResult(op, pattern, rewriter, curState, newOps,
+ auto result = legalizePatternResult(op, pattern, curState, newOps,
modifiedOps, insertedBlocks);
appliedPatterns.erase(&pattern);
if (failed(result)) {
@@ -2483,8 +2731,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
onSuccess);
}
-bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
- ConversionPatternRewriter &rewriter) {
+bool OperationLegalizer::canApplyPattern(Operation *op,
+ const Pattern &pattern) {
LLVM_DEBUG({
auto &os = rewriter.getImpl().logger;
os.getOStream() << "\n";
@@ -2506,11 +2754,11 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
}
LogicalResult OperationLegalizer::legalizePatternResult(
- Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
- const RewriterState &curState, const SetVector<Operation *> &newOps,
+ Operation *op, const Pattern &pattern, const RewriterState &curState,
+ const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
const SetVector<Block *> &insertedBlocks) {
- auto &impl = rewriter.getImpl();
+ [[maybe_unused]] auto &impl = rewriter.getImpl();
assert(impl.pendingRootUpdates.empty() && "dangling root updates");
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
@@ -2528,10 +2776,9 @@ LogicalResult OperationLegalizer::legalizePatternResult(
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Legalize each of the actions registered during application.
- if (failed(legalizePatternBlockRewrites(op, rewriter, impl, insertedBlocks,
- newOps)) ||
- failed(legalizePatternRootUpdates(rewriter, impl, modifiedOps)) ||
- failed(legalizePatternCreatedOperations(rewriter, impl, newOps))) {
+ if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
+ failed(legalizePatternRootUpdates(modifiedOps)) ||
+ failed(legalizePatternCreatedOperations(newOps))) {
return failure();
}
@@ -2540,15 +2787,17 @@ LogicalResult OperationLegalizer::legalizePatternResult(
}
LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
- Operation *op, ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
- const SetVector<Block *> &insertedBlocks,
+ Operation *op, const SetVector<Block *> &insertedBlocks,
const SetVector<Operation *> &newOps) {
+ ConversionPatternRewriterImpl &impl = rewriter.getImpl();
SmallPtrSet<Operation *, 16> alreadyLegalized;
// If the pattern moved or created any blocks, make sure the types of block
// arguments get legalized.
for (Block *block : insertedBlocks) {
+ if (impl.erasedBlocks.contains(block))
+ continue;
+
// Only check blocks outside of the current operation.
Operation *parentOp = block->getParentOp();
if (!parentOp || parentOp == op || block->getNumArguments() == 0)
@@ -2564,7 +2813,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
"block"));
return failure();
}
- impl.applySignatureConversion(rewriter, block, converter, *conversion);
+ impl.applySignatureConversion(block, converter, *conversion);
continue;
}
@@ -2573,7 +2822,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
// operation, and blocks in regions created by this pattern will already be
// legalized later on.
if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
- if (failed(legalize(parentOp, rewriter))) {
+ if (failed(legalize(parentOp))) {
LLVM_DEBUG(logFailure(
impl.logger, "operation '{0}'({1}) became illegal after rewrite",
parentOp->getName(), parentOp));
@@ -2585,11 +2834,10 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
}
LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
- ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
const SetVector<Operation *> &newOps) {
for (Operation *op : newOps) {
- if (failed(legalize(op, rewriter))) {
- LLVM_DEBUG(logFailure(impl.logger,
+ if (failed(legalize(op))) {
+ LLVM_DEBUG(logFailure(rewriter.getImpl().logger,
"failed to legalize generated operation '{0}'({1})",
op->getName(), op));
return failure();
@@ -2599,13 +2847,13 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
}
LogicalResult OperationLegalizer::legalizePatternRootUpdates(
- ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
const SetVector<Operation *> &modifiedOps) {
for (Operation *op : modifiedOps) {
- if (failed(legalize(op, rewriter))) {
- LLVM_DEBUG(logFailure(
- impl.logger, "failed to legalize operation updated in-place '{0}'",
- op->getName()));
+ if (failed(legalize(op))) {
+ LLVM_DEBUG(
+ logFailure(rewriter.getImpl().logger,
+ "failed to legalize operation updated in-place '{0}'",
+ op->getName()));
return failure();
}
}
@@ -2825,21 +3073,22 @@ namespace mlir {
// rewrite patterns. The conversion behaves differently depending on the
// conversion mode.
struct OperationConverter {
- explicit OperationConverter(const ConversionTarget &target,
+ explicit OperationConverter(MLIRContext *ctx, const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
const ConversionConfig &config,
OpConversionMode mode)
- : config(config), opLegalizer(target, patterns), mode(mode) {}
+ : rewriter(ctx, config), opLegalizer(rewriter, target, patterns),
+ mode(mode) {}
/// Converts the given operations to the conversion target.
LogicalResult convertOperations(ArrayRef<Operation *> ops);
private:
/// Converts an operation with the given rewriter.
- LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
+ LogicalResult convert(Operation *op);
- /// Dialect conversion configuration.
- ConversionConfig config;
+ /// The rewriter to use when converting operations.
+ ConversionPatternRewriter rewriter;
/// The legalizer to use when converting operations.
OperationLegalizer opLegalizer;
@@ -2849,10 +3098,11 @@ private:
};
} // namespace mlir
-LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
- Operation *op) {
+LogicalResult OperationConverter::convert(Operation *op) {
+ const ConversionConfig &config = rewriter.getConfig();
+
// Legalize the given operation.
- if (failed(opLegalizer.legalize(op, rewriter))) {
+ if (failed(opLegalizer.legalize(op))) {
// Handle the case of a failed conversion for each of the different modes.
// Full conversions expect all operations to be converted.
if (mode == OpConversionMode::Full)
@@ -2928,7 +3178,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
}
LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
- assert(!ops.empty() && "expected at least one operation");
const ConversionTarget &target = opLegalizer.getTarget();
// Compute the set of operations and blocks to convert.
@@ -2947,11 +3196,10 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
}
// Convert each operation and discard rewrites on failure.
- ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
for (auto *op : toConvert) {
- if (failed(convert(rewriter, op))) {
+ if (failed(convert(op))) {
// Dialect conversion failed.
if (rewriterImpl.config.allowPatternRollback) {
// Rollback is allowed: restore the original IR.
@@ -2986,13 +3234,16 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
castOp->removeAttr(kPureTypeConversionMarker);
// Try to legalize all unresolved materializations.
- if (config.buildMaterializations) {
- IRRewriter rewriter(rewriterImpl.context, config.listener);
+ if (rewriter.getConfig().buildMaterializations) {
+ // Use a new rewriter, so the modifications are not tracked for rollback
+ // purposes etc.
+ IRRewriter irRewriter(rewriterImpl.rewriter.getContext(),
+ rewriter.getConfig().listener);
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
auto it = materializations.find(castOp);
assert(it != materializations.end() && "inconsistent state");
- if (failed(
- legalizeUnresolvedMaterialization(rewriter, castOp, it->second)))
+ if (failed(legalizeUnresolvedMaterialization(irRewriter, castOp,
+ it->second)))
return failure();
}
}
@@ -3159,6 +3410,27 @@ LogicalResult TypeConverter::convertType(Type t,
return failure();
}
+LogicalResult TypeConverter::convertType(Value v,
+ SmallVectorImpl<Type> &results) const {
+ assert(v && "expected non-null value");
+
+ // If this type converter does not have context-aware type conversions, call
+ // the type-based overload, which has caching.
+ if (!hasContextAwareTypeConversions)
+ return convertType(v.getType(), results);
+
+ // Walk the added converters in reverse order to apply the most recently
+ // registered first.
+ for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
+ if (std::optional<LogicalResult> result = converter(v, results)) {
+ if (!succeeded(*result))
+ return failure();
+ return success();
+ }
+ }
+ return failure();
+}
+
Type TypeConverter::convertType(Type t) const {
// Use the multi-type result version to convert the type.
SmallVector<Type, 1> results;
@@ -3169,6 +3441,16 @@ Type TypeConverter::convertType(Type t) const {
return results.size() == 1 ? results.front() : nullptr;
}
+Type TypeConverter::convertType(Value v) const {
+ // Use the multi-type result version to convert the type.
+ SmallVector<Type, 1> results;
+ if (failed(convertType(v, results)))
+ return nullptr;
+
+ // Check to ensure that only one type was produced.
+ return results.size() == 1 ? results.front() : nullptr;
+}
+
LogicalResult
TypeConverter::convertTypes(TypeRange types,
SmallVectorImpl<Type> &results) const {
@@ -3178,21 +3460,38 @@ TypeConverter::convertTypes(TypeRange types,
return success();
}
+LogicalResult
+TypeConverter::convertTypes(ValueRange values,
+ SmallVectorImpl<Type> &results) const {
+ for (Value value : values)
+ if (failed(convertType(value, results)))
+ return failure();
+ return success();
+}
+
bool TypeConverter::isLegal(Type type) const {
return convertType(type) == type;
}
+
+bool TypeConverter::isLegal(Value value) const {
+ return convertType(value) == value.getType();
+}
+
bool TypeConverter::isLegal(Operation *op) const {
- return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
+ return isLegal(op->getOperands()) && isLegal(op->getResults());
}
bool TypeConverter::isLegal(Region *region) const {
- return llvm::all_of(*region, [this](Block &block) {
- return isLegal(block.getArgumentTypes());
- });
+ return llvm::all_of(
+ *region, [this](Block &block) { return isLegal(block.getArguments()); });
}
bool TypeConverter::isSignatureLegal(FunctionType ty) const {
- return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
+ if (!isLegal(ty.getInputs()))
+ return false;
+ if (!isLegal(ty.getResults()))
+ return false;
+ return true;
}
LogicalResult
@@ -3220,6 +3519,31 @@ TypeConverter::convertSignatureArgs(TypeRange types,
return failure();
return success();
}
+LogicalResult
+TypeConverter::convertSignatureArg(unsigned inputNo, Value value,
+ SignatureConversion &result) const {
+ // Try to convert the given input type.
+ SmallVector<Type, 1> convertedTypes;
+ if (failed(convertType(value, convertedTypes)))
+ return failure();
+
+ // If this argument is being dropped, there is nothing left to do.
+ if (convertedTypes.empty())
+ return success();
+
+ // Otherwise, add the new inputs.
+ result.addInputs(inputNo, convertedTypes);
+ return success();
+}
+LogicalResult
+TypeConverter::convertSignatureArgs(ValueRange values,
+ SignatureConversion &result,
+ unsigned origInputOffset) const {
+ for (unsigned i = 0, e = values.size(); i != e; ++i)
+ if (failed(convertSignatureArg(origInputOffset + i, values[i], result)))
+ return failure();
+ return success();
+}
Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
Location loc, Type resultType,
@@ -3263,7 +3587,7 @@ SmallVector<Value> TypeConverter::materializeTargetConversion(
std::optional<TypeConverter::SignatureConversion>
TypeConverter::convertBlockSignature(Block *block) const {
SignatureConversion conversion(block->getNumArguments());
- if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
+ if (failed(convertSignatureArgs(block->getArguments(), conversion)))
return std::nullopt;
return conversion;
}
@@ -3388,7 +3712,7 @@ mlir::convertOpResultTypes(Operation *op, ValueRange operands,
newOp.addOperands(operands);
SmallVector<Type> newResultTypes;
- if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
+ if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
newOp.addTypes(newResultTypes);
@@ -3661,7 +3985,8 @@ static LogicalResult applyConversion(ArrayRef<Operation *> ops,
SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
ctx->executeAction<ApplyConversionAction>(
[&] {
- OperationConverter opConverter(target, patterns, config, mode);
+ OperationConverter opConverter(ops.front()->getContext(), target,
+ patterns, config, mode);
status = opConverter.convertOperations(ops);
},
irUnits);
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 607b86c..0324588 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -15,6 +15,8 @@
#include "mlir/Config/mlir-config.h"
#include "mlir/IR/Action.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Rewrite/PatternApplicator.h"
@@ -23,7 +25,7 @@
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/ScopeExit.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ScopedPrinter.h"
#include "llvm/Support/raw_ostream.h"
@@ -178,9 +180,8 @@ static Operation *getDumpRootOp(Operation *op) {
return op;
}
static void logSuccessfulFolding(Operation *op) {
- llvm::dbgs() << "// *** IR Dump After Successful Folding ***\n";
- op->dump();
- llvm::dbgs() << "\n\n";
+ LDBG() << "// *** IR Dump After Successful Folding ***\n"
+ << OpWithFlags(op, OpPrintingFlags().elideLargeElementsAttrs());
}
#endif // NDEBUG
@@ -394,8 +395,12 @@ private:
function_ref<void(Diagnostic &)> reasonCallback) override;
#ifndef NDEBUG
+ /// A raw output stream used to prefix the debug log.
+
+ llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(),
+ llvm::dbgs()};
/// A logger used to emit information during the application process.
- llvm::ScopedPrinter logger{llvm::dbgs()};
+ llvm::ScopedPrinter logger{os};
#endif
/// The low-level pattern applicator.
@@ -871,7 +876,18 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
ctx->executeAction<GreedyPatternRewriteIteration>(
[&] {
- continueRewrites = processWorklist();
+ continueRewrites = false;
+
+ // Erase unreachable blocks
+ // Operations like:
+ // %add = arith.addi %add, %add : i64
+ // are legal in unreachable code. Unfortunately many patterns would be
+ // unsafe to apply on such IR and can lead to crashes or infinite
+ // loops.
+ continueRewrites |=
+ succeeded(eraseUnreachableBlocks(rewriter, region));
+
+ continueRewrites |= processWorklist();
// After applying patterns, make sure that the CFG of each of the
// regions is kept up to date.
@@ -917,10 +933,9 @@ mlir::applyPatternsGreedily(Region &region,
RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
region);
LogicalResult converged = std::move(driver).simplify(changed);
- LLVM_DEBUG(if (failed(converged)) {
- llvm::dbgs() << "The pattern rewrite did not converge after scanning "
- << config.getMaxIterations() << " times\n";
- });
+ if (failed(converged))
+ LDBG() << "The pattern rewrite did not converge after scanning "
+ << config.getMaxIterations() << " times";
return converged;
}
@@ -1052,9 +1067,8 @@ LogicalResult mlir::applyOpPatternsGreedily(
LogicalResult converged = std::move(driver).simplify(ops, changed);
if (allErased)
*allErased = surviving.empty();
- LLVM_DEBUG(if (failed(converged)) {
- llvm::dbgs() << "The pattern rewrite did not converge after "
- << config.getMaxNumRewrites() << " rewrites";
- });
+ if (failed(converged))
+ LDBG() << "The pattern rewrite did not converge after "
+ << config.getMaxNumRewrites() << " rewrites";
return converged;
}
diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp
index eeb4052..73107cf 100644
--- a/mlir/lib/Transforms/Utils/InliningUtils.cpp
+++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp
@@ -13,10 +13,12 @@
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
@@ -182,13 +184,16 @@ static bool isLegalToInline(InlinerInterface &interface, Region *src,
IRMapping &valueMapping) {
for (auto &block : *src) {
for (auto &op : block) {
+ // UnrealizedConversionCastOp is inlineable but cannot implement the
+ // inliner interface due to layering constraints.
+ if (isa<UnrealizedConversionCastOp>(op))
+ continue;
+
// Check this operation.
if (!interface.isLegalToInline(&op, insertRegion,
shouldCloneInlinedRegion, valueMapping)) {
- LLVM_DEBUG({
- llvm::dbgs() << "* Illegal to inline because of op: ";
- op.dump();
- });
+ LDBG() << "* Illegal to inline because of op: "
+ << OpWithFlags(&op, OpPrintingFlags().skipRegions());
return false;
}
// Check any nested regions.
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index cb3f2c5..111f58e 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -13,11 +13,13 @@
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include <queue>
#define DEBUG_TYPE "licm"
@@ -64,8 +66,7 @@ size_t mlir::moveLoopInvariantCode(
size_t numMoved = 0;
for (Region *region : regions) {
- LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
- << *region->getParentOp() << "\n");
+ LDBG() << "Original loop:\n" << *region->getParentOp();
std::queue<Operation *> worklist;
// Add top-level operations in the loop body to the worklist.
@@ -83,12 +84,13 @@ size_t mlir::moveLoopInvariantCode(
if (op->getParentRegion() != region)
continue;
- LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
+ LDBG() << "Checking op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
if (!shouldMoveOutOfRegion(op, region) ||
!canBeHoisted(op, definedOutside))
continue;
- LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
+ LDBG() << "Moving loop-invariant op: " << *op;
moveOutOfRegion(op, region);
++numMoved;
@@ -322,7 +324,7 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
LoopLikeOpInterface loopLike,
BlockArgument iterArg) {
assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
- auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
+ BlockArgument *it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
MatchingSubsets subsets;
if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index a1d975d..31ae1d1 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -23,12 +23,15 @@
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
#include <deque>
#include <iterator>
using namespace mlir;
+#define DEBUG_TYPE "region-utils"
+
void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement,
Region &region) {
for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
@@ -182,19 +185,34 @@ SmallVector<Value> mlir::makeRegionIsolatedFromAbove(
// TODO: We could likely merge this with the DCE algorithm below.
LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter,
MutableArrayRef<Region> regions) {
+ LDBG() << "Starting eraseUnreachableBlocks with " << regions.size()
+ << " regions";
+
// Set of blocks found to be reachable within a given region.
llvm::df_iterator_default_set<Block *, 16> reachable;
// If any blocks were found to be dead.
- bool erasedDeadBlocks = false;
+ int erasedDeadBlocks = 0;
SmallVector<Region *, 1> worklist;
worklist.reserve(regions.size());
for (Region &region : regions)
worklist.push_back(&region);
+
+ LDBG(2) << "Initial worklist size: " << worklist.size();
+
while (!worklist.empty()) {
Region *region = worklist.pop_back_val();
- if (region->empty())
+ if (region->empty()) {
+ LDBG(2) << "Skipping empty region";
continue;
+ }
+
+ LDBG(2) << "Processing region with " << region->getBlocks().size()
+ << " blocks";
+ if (region->getParentOp())
+ LDBG(2) << " -> for operation: "
+ << OpWithFlags(region->getParentOp(),
+ OpPrintingFlags().skipRegions());
// If this is a single block region, just collect the nested regions.
if (region->hasOneBlock()) {
@@ -209,13 +227,17 @@ LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter,
for (Block *block : depth_first_ext(&region->front(), reachable))
(void)block /* Mark all reachable blocks */;
+ LDBG(2) << "Found " << reachable.size() << " reachable blocks out of "
+ << region->getBlocks().size() << " total blocks";
+
// Collect all of the dead blocks and push the live regions onto the
// worklist.
for (Block &block : llvm::make_early_inc_range(*region)) {
if (!reachable.count(&block)) {
+ LDBG() << "Erasing unreachable block: " << &block;
block.dropAllDefinedValueUses();
rewriter.eraseBlock(&block);
- erasedDeadBlocks = true;
+ ++erasedDeadBlocks;
continue;
}
@@ -226,7 +248,10 @@ LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter,
}
}
- return success(erasedDeadBlocks);
+ LDBG() << "Finished eraseUnreachableBlocks, erased " << erasedDeadBlocks
+ << " dead blocks";
+
+ return success(erasedDeadBlocks > 0);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index ee5c642..1382550 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -13,18 +13,40 @@
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Rewrite/PatternApplicator.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"
#define DEBUG_TYPE "walk-rewriter"
namespace mlir {
+// Find all reachable blocks in the region and add them to the visitedBlocks
+// set.
+static void findReachableBlocks(Region &region,
+ DenseSet<Block *> &reachableBlocks) {
+ Block *entryBlock = &region.front();
+ reachableBlocks.insert(entryBlock);
+ // Traverse the CFG and add all reachable blocks to the blockList.
+ SmallVector<Block *> worklist({entryBlock});
+ while (!worklist.empty()) {
+ Block *block = worklist.pop_back_val();
+ Operation *terminator = &block->back();
+ for (Block *successor : terminator->getSuccessors()) {
+ if (reachableBlocks.contains(successor))
+ continue;
+ worklist.push_back(successor);
+ reachableBlocks.insert(successor);
+ }
+ }
+}
+
namespace {
struct WalkAndApplyPatternsAction final
: tracing::ActionImpl<WalkAndApplyPatternsAction> {
@@ -88,20 +110,104 @@ void walkAndApplyPatterns(Operation *op,
PatternApplicator applicator(patterns);
applicator.applyDefaultCostModel();
+ // Iterator on all reachable operations in the region.
+ // Also keep track if we visited the nested regions of the current op
+ // already to drive the post-order traversal.
+ struct RegionReachableOpIterator {
+ RegionReachableOpIterator(Region *region) : region(region) {
+ regionIt = region->begin();
+ if (regionIt != region->end())
+ blockIt = regionIt->begin();
+ if (!llvm::hasSingleElement(*region))
+ findReachableBlocks(*region, reachableBlocks);
+ }
+ // Advance the iterator to the next reachable operation.
+ void advance() {
+ assert(regionIt != region->end());
+ hasVisitedRegions = false;
+ if (blockIt == regionIt->end()) {
+ ++regionIt;
+ while (regionIt != region->end() &&
+ !reachableBlocks.contains(&*regionIt))
+ ++regionIt;
+ if (regionIt != region->end())
+ blockIt = regionIt->begin();
+ return;
+ }
+ ++blockIt;
+ if (blockIt != regionIt->end()) {
+ LDBG() << "Incrementing block iterator, next op: "
+ << OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions());
+ }
+ }
+ // The region we're iterating over.
+ Region *region;
+ // The Block currently being iterated over.
+ Region::iterator regionIt;
+ // The Operation currently being iterated over.
+ Block::iterator blockIt;
+ // The set of blocks that are reachable in the current region.
+ DenseSet<Block *> reachableBlocks;
+ // Whether we've visited the nested regions of the current op already.
+ bool hasVisitedRegions = false;
+ };
+
+ // Worklist of regions to visit to drive the post-order traversal.
+ SmallVector<RegionReachableOpIterator> worklist;
+
+ LDBG() << "Starting walk-based pattern rewrite driver";
ctx->executeAction<WalkAndApplyPatternsAction>(
[&] {
+ // Perform a post-order traversal of the regions, visiting each
+ // reachable operation.
for (Region &region : op->getRegions()) {
- region.walk([&](Operation *visitedOp) {
- LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
- llvm::dbgs(), OpPrintingFlags().skipRegions());
- llvm::dbgs() << "\n";);
+ assert(worklist.empty());
+ if (region.empty())
+ continue;
+
+ // Prime the worklist with the entry block of this region.
+ worklist.push_back({&region});
+ while (!worklist.empty()) {
+ RegionReachableOpIterator &it = worklist.back();
+ if (it.regionIt == it.region->end()) {
+ // We're done with this region.
+ worklist.pop_back();
+ continue;
+ }
+ if (it.blockIt == it.regionIt->end()) {
+ // We're done with this block.
+ it.advance();
+ continue;
+ }
+ Operation *op = &*it.blockIt;
+ // If we haven't visited the nested regions of this op yet,
+ // enqueue them.
+ if (!it.hasVisitedRegions) {
+ it.hasVisitedRegions = true;
+ for (Region &nestedRegion : llvm::reverse(op->getRegions())) {
+ if (nestedRegion.empty())
+ continue;
+ worklist.push_back({&nestedRegion});
+ }
+ }
+ // If we're not at the back of the worklist, we've enqueued some
+ // nested region for processing. We'll come back to this op later
+ // (post-order)
+ if (&it != &worklist.back())
+ continue;
+
+ // Preemptively increment the iterator, in case the current op
+ // would be erased.
+ it.advance();
+
+ LDBG() << "Visiting op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- erasedListener.visitedOp = visitedOp;
+ erasedListener.visitedOp = op;
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
- LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
- }
- });
+ if (succeeded(applicator.matchAndRewrite(op, rewriter)))
+ LDBG() << "\tOp matched and rewritten";
+ }
}
},
{op});
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index be71737..dcae3dd 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -194,6 +194,13 @@ class _OperationBase:
"""
Detaches the operation from its parent block.
"""
+
+ @property
+ def attached(self) -> bool:
+ """
+ Reports if the operation is attached to its parent block.
+ """
+
def erase(self) -> None: ...
@overload
diff --git a/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi b/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi
index 58d453d..4b82c78 100644
--- a/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi
@@ -19,5 +19,6 @@ class ExecutionEngine:
def dump_to_object_file(self, file_name: str) -> None: ...
def raw_lookup(self, func_name: str) -> int: ...
def raw_register_runtime(self, name: str, callback: object) -> None: ...
+ def init() -> None: ...
@property
def _CAPIPtr(self) -> object: ...
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index a5efa05..10abd06 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -78,12 +78,12 @@ def equally_sized_accessor(
def get_default_loc_context(location=None):
"""
Returns a context in which the defaulted location is created. If the location
- is None, takes the current location from the stack, raises ValueError if there
- is no location on the stack.
+ is None, takes the current location from the stack.
"""
if location is None:
- # Location.current raises ValueError if there is no current location.
- return _cext.ir.Location.current.context
+ if _cext.ir.Location.current:
+ return _cext.ir.Location.current.context
+ return None
return location.context
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 1b359da..fd4a5a8 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -374,42 +374,6 @@ def quantized_matmul(
@linalg_structured_op
-def matmul_transpose_a(
- A=TensorDef(T1, S.K, S.N),
- B=TensorDef(T2, S.K, S.M),
- C=TensorDef(U, S.M, S.N, output=True),
- cast=TypeFnAttrDef(default=TypeFn.cast_signed),
-):
- """Performs a matrix multiplication of two 2D inputs with lhs operand
- transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n])
-
-
-@linalg_structured_op
-def matmul_transpose_b(
- A=TensorDef(T1, S.M, S.K),
- B=TensorDef(T2, S.N, S.K),
- C=TensorDef(U, S.M, S.N, output=True),
- cast=TypeFnAttrDef(default=TypeFn.cast_signed),
-):
- """Performs a matrix multiplication of two 2D inputs with rhs operand
- transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k])
-
-
-@linalg_structured_op
def mmt4d(
lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0),
@@ -454,44 +418,6 @@ def batch_mmt4d(
@linalg_structured_op
-def batch_matmul_transpose_a(
- A=TensorDef(T1, Batch, S.K, S.M),
- B=TensorDef(T2, Batch, S.K, S.N),
- C=TensorDef(U, Batch, S.M, S.N, output=True),
-):
- """Performs a batched matrix multiplication of two 3D inputs where lhs operand
- has its non-batch dimensions transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.b, D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) * TypeFn.cast_signed(
- U, B[D.b, D.k, D.n]
- )
-
-
-@linalg_structured_op
-def batch_matmul_transpose_b(
- A=TensorDef(T1, Batch, S.M, S.K),
- B=TensorDef(T2, Batch, S.N, S.K),
- C=TensorDef(U, Batch, S.M, S.N, output=True),
-):
- """Performs a batched matrix multiplication of two 3D inputs where rhs operand
- has its non-batch dimensions transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.b, D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
- U, B[D.b, D.n, D.k]
- )
-
-
-@linalg_structured_op
def quantized_batch_matmul(
A=TensorDef(T1, Batch, S.M, S.K),
B=TensorDef(T2, Batch, S.K, S.N),
@@ -513,25 +439,6 @@ def quantized_batch_matmul(
@linalg_structured_op
-def batch_reduce_matmul(
- A=TensorDef(T1, Batch, S.M, S.K),
- B=TensorDef(T2, Batch, S.K, S.N),
- C=TensorDef(U, S.M, S.N, output=True),
-):
- """Performs a batch-reduce matrix multiplication of two 3D inputs.
- The partial multiplication results are reduced into a 2D output.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.b, D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
- U, B[D.b, D.k, D.n]
- )
-
-
-@linalg_structured_op
def matvec(
A=TensorDef(T1, S.M, S.N), y=TensorDef(T2, S.N), x=TensorDef(U, S.M, output=True)
):
diff --git a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
index a89a0f4..3748be7 100644
--- a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
+++ b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
@@ -283,3 +283,23 @@ func.func @test_10_negative() -> (i32) {
%0:2 = func.call @private_1() : () -> (i32, i32)
return %0#0 : i32
}
+
+// -----
+
+// Test that we correctly set a liveness value for operations in dead block.
+// These won't be visited by the dataflow framework so the analysis need to
+// explicitly manage them.
+// CHECK-LABEL: test_tag: dead_block_cmpi:
+// CHECK-NEXT: operand #0: not live
+// CHECK-NEXT: operand #1: not live
+// CHECK-NEXT: result #0: not live
+func.func @dead_block() {
+ %false = arith.constant false
+ %zero = arith.constant 0 : i64
+ cf.cond_br %false, ^bb1, ^bb4
+ ^bb1:
+ %3 = arith.cmpi eq, %zero, %zero {tag = "dead_block_cmpi"} : i64
+ cf.br ^bb1
+ ^bb4:
+ return
+}
diff --git a/mlir/test/CAPI/CMakeLists.txt b/mlir/test/CAPI/CMakeLists.txt
index a7f9eb9..d451425 100644
--- a/mlir/test/CAPI/CMakeLists.txt
+++ b/mlir/test/CAPI/CMakeLists.txt
@@ -31,6 +31,13 @@ if(MLIR_ENABLE_EXECUTION_ENGINE)
MLIRCAPIExecutionEngine
MLIRCAPIRegisterEverything
)
+ _add_capi_test_executable(mlir-capi-global-constructors-test
+ global_constructors.c
+ LINK_LIBS PRIVATE
+ MLIRCAPIConversion
+ MLIRCAPIExecutionEngine
+ MLIRCAPIRegisterEverything
+)
endif()
_add_capi_test_executable(mlir-capi-ir-test
diff --git a/mlir/test/CAPI/global_constructors.c b/mlir/test/CAPI/global_constructors.c
new file mode 100644
index 0000000..bd2fe14
--- /dev/null
+++ b/mlir/test/CAPI/global_constructors.c
@@ -0,0 +1,113 @@
+//===- global_constructors.c - Test JIT with the global constructors ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// UNSUPPORTED: target=aarch64{{.*}}, target=arm64{{.*}}
+/* RUN: mlir-capi-global-constructors-test 2>&1 | FileCheck %s
+ */
+/* REQUIRES: host-supports-jit
+ */
+
+#include "mlir-c/Conversion.h"
+#include "mlir-c/ExecutionEngine.h"
+#include "mlir-c/IR.h"
+#include "mlir-c/RegisterEverything.h"
+
+#include <assert.h>
+#include <math.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+static void registerAllUpstreamDialects(MlirContext ctx) {
+ MlirDialectRegistry registry = mlirDialectRegistryCreate();
+ mlirRegisterAllDialects(registry);
+ mlirContextAppendDialectRegistry(ctx, registry);
+ mlirDialectRegistryDestroy(registry);
+}
+
+void lowerModuleToLLVM(MlirContext ctx, MlirModule module) {
+ MlirPassManager pm = mlirPassManagerCreate(ctx);
+ MlirOpPassManager opm = mlirPassManagerGetNestedUnder(
+ pm, mlirStringRefCreateFromCString("func.func"));
+ mlirPassManagerAddOwnedPass(pm, mlirCreateConversionConvertFuncToLLVMPass());
+ mlirOpPassManagerAddOwnedPass(
+ opm, mlirCreateConversionArithToLLVMConversionPass());
+ MlirLogicalResult status =
+ mlirPassManagerRunOnOp(pm, mlirModuleGetOperation(module));
+ if (mlirLogicalResultIsFailure(status)) {
+ fprintf(stderr, "Unexpected failure running pass pipeline\n");
+ exit(2);
+ }
+ mlirPassManagerDestroy(pm);
+}
+
+// Helper variable to track callback invocations
+static int initCnt = 0;
+
+// Callback function that will be called during JIT initialization
+static void initCallback(void) { initCnt += 1; }
+
+// CHECK-LABEL: Running test 'testGlobalCtorJitCallback'
+void testGlobalCtorJitCallback(void) {
+ MlirContext ctx = mlirContextCreate();
+ registerAllUpstreamDialects(ctx);
+
+ // Create module with global constructor that calls our callback
+ MlirModule module = mlirModuleCreateParse(
+ ctx, mlirStringRefCreateFromCString(
+ // clang-format off
+"module { \n"
+" llvm.mlir.global_ctors ctors = [@ctor], priorities = [0 : i32], data = [#llvm.zero] \n"
+" llvm.func @ctor() { \n"
+" func.call @init_callback() : () -> () \n"
+" llvm.return \n"
+" } \n"
+" func.func private @init_callback() attributes { llvm.emit_c_interface } \n"
+"} \n"
+ // clang-format on
+ ));
+
+ lowerModuleToLLVM(ctx, module);
+ mlirRegisterAllLLVMTranslations(ctx);
+
+ // Create execution engine with initialization disabled
+ MlirExecutionEngine jit = mlirExecutionEngineCreate(
+ module, /*optLevel=*/2, /*numPaths=*/0, /*sharedLibPaths=*/NULL,
+ /*enableObjectDump=*/false);
+
+ if (mlirExecutionEngineIsNull(jit)) {
+ fprintf(stderr, "Execution engine creation failed");
+ exit(2);
+ }
+
+ // Register callback symbol before initialization
+ mlirExecutionEngineRegisterSymbol(
+ jit, mlirStringRefCreateFromCString("_mlir_ciface_init_callback"),
+ (void *)(uintptr_t)initCallback);
+
+ mlirExecutionEngineInitialize(jit);
+
+ // CHECK: Init count: 1
+ printf("Init count: %d\n", initCnt);
+
+ mlirExecutionEngineDestroy(jit);
+ mlirModuleDestroy(module);
+ mlirContextDestroy(ctx);
+}
+
+int main(void) {
+
+#define _STRINGIFY(x) #x
+#define STRINGIFY(x) _STRINGIFY(x)
+#define TEST(test) \
+ printf("Running test '" STRINGIFY(test) "'\n"); \
+ test();
+ TEST(testGlobalCtorJitCallback);
+ return 0;
+}
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index a4a942d..8ddc620 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -123,7 +123,7 @@ set(MLIR_TEST_DEPENDS
tblgen-to-irdl
)
if(NOT MLIR_STANDALONE_BUILD)
- list(APPEND MLIR_TEST_DEPENDS FileCheck count not split-file)
+ list(APPEND MLIR_TEST_DEPENDS FileCheck count not split-file yaml2obj)
endif()
set(MLIR_TEST_DEPENDS ${MLIR_TEST_DEPENDS}
@@ -141,6 +141,7 @@ if(LLVM_ENABLE_PIC AND TARGET ${LLVM_NATIVE_ARCH})
llc
mlir_async_runtime
mlir-capi-execution-engine-test
+ mlir-capi-global-constructors-test
mlir_c_runner_utils
mlir_runner_utils
mlir_float16_utils
@@ -156,7 +157,10 @@ if(MLIR_ENABLE_CUDA_RUNNER)
endif()
if(MLIR_ENABLE_EXECUTION_ENGINE)
- list(APPEND MLIR_TEST_DEPENDS mlir-capi-execution-engine-test)
+ list(APPEND MLIR_TEST_DEPENDS
+ mlir-capi-execution-engine-test
+ mlir-capi-global-constructors-test
+ )
endif()
if(MLIR_ENABLE_ROCM_RUNNER)
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir b/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir
new file mode 100644
index 0000000..aae2b1d
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir
@@ -0,0 +1,163 @@
+// RUN: mlir-opt --convert-amdgpu-to-rocdl=chipset=gfx950 --canonicalize %s | FileCheck %s
+
+// CHECK-LABEL: func @test_permlane16_i32
+// CHECK-SAME: (%[[ARG0:.*]]: i32)
+func.func @test_permlane16_i32(%arg0 : i32) -> i32 {
+// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ARG0]], %[[ARG0]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: return %[[RES]] : i32
+ %0 = amdgpu.permlane_swap %arg0 16 : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: func @test_permlane16_i32_optional_attr
+// CHECK-SAME: (%[[ARG0:.*]]: i32)
+func.func @test_permlane16_i32_optional_attr(%arg0 : i32) -> i32 {
+// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ARG0]], %[[ARG0]], true, true : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: return %[[RES]] : i32
+ %0 = amdgpu.permlane_swap %arg0 16 { fetch_inactive = true, bound_ctrl = true } : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: func @test_permlane32_i32
+// CHECK-SAME: (%[[ARG0:.*]]: i32)
+func.func @test_permlane32_i32(%arg0 : i32) -> i32 {
+// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[ARG0]], %[[ARG0]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: return %[[RES]] : i32
+ %0 = amdgpu.permlane_swap %arg0 32 : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: func @test_permlane16_f32
+// CHECK-SAME: (%[[ARG0:.*]]: f32)
+func.func @test_permlane16_f32(%arg0 : f32) -> f32 {
+// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
+// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[CAST]], %[[CAST]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
+// CHECK: return %[[RES_CAST]] : f32
+ %0 = amdgpu.permlane_swap %arg0 16 : f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func @test_permlane32_f32
+// CHECK-SAME: (%[[ARG0:.*]]: f32)
+func.func @test_permlane32_f32(%arg0 : f32) -> f32 {
+// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
+// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[CAST]], %[[CAST]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
+// CHECK: return %[[RES_CAST]] : f32
+ %0 = amdgpu.permlane_swap %arg0 32 : f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func @test_permlane16_f16
+// CHECK-SAME: (%[[ARG0:.*]]: f16)
+func.func @test_permlane16_f16(%arg0 : f16) -> f16 {
+// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
+// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
+// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ZEXT]], %[[ZEXT]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
+// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
+// CHECK: return %[[RES_CAST]] : f16
+ %0 = amdgpu.permlane_swap %arg0 16 : f16
+ return %0 : f16
+}
+
+// CHECK-LABEL: func @test_permlane32_f16
+// CHECK-SAME: (%[[ARG0:.*]]: f16)
+func.func @test_permlane32_f16(%arg0 : f16) -> f16 {
+// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
+// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
+// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[ZEXT]], %[[ZEXT]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
+// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
+// CHECK: return %[[RES_CAST]] : f16
+ %0 = amdgpu.permlane_swap %arg0 32 : f16
+ return %0 : f16
+}
+
+// CHECK-LABEL: func @test_permlane16_2xi32
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>)
+func.func @test_permlane16_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
+// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: return %[[VEC_INSERT1]] : vector<2xi32>
+ %0 = amdgpu.permlane_swap %arg0 16 : vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func @test_permlane32_2xi32
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>)
+func.func @test_permlane32_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
+// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: return %[[VEC_INSERT1]] : vector<2xi32>
+ %0 = amdgpu.permlane_swap %arg0 32 : vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func @test_permlane16_4xf16
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>)
+func.func @test_permlane16_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
+// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32>
+// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[VEC_INSERT1]] : vector<2xi32> to vector<4xf16>
+// CHECK: return %[[CAST2]] : vector<4xf16>
+ %0 = amdgpu.permlane_swap %arg0 16 : vector<4xf16>
+ return %0 : vector<4xf16>
+}
+
+// CHECK-LABEL: func @test_permlane32_4xf16
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>)
+func.func @test_permlane32_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
+// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32>
+// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[VEC_INSERT1]] : vector<2xi32> to vector<4xf16>
+// CHECK: return %[[CAST2]] : vector<4xf16>
+ %0 = amdgpu.permlane_swap %arg0 32 : vector<4xf16>
+ return %0 : vector<4xf16>
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 1382f3c..319dfc31 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -153,7 +153,7 @@ func.func @arith_shift_left(%arg0: i32, %arg1: i32) {
// CHECK-DAG: %[[SizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32
// CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[C2]], %[[SizeConstant]] : (ui32, ui32) -> i1
// CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0
- // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32 {
+ // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression %[[C1]], %[[C2]], %[[CmpNoExcess]], %[[Zero]] : (ui32, ui32, i1, ui32) -> ui32 {
// CHECK-NEXT: %[[SHL:[^ ]*]] = bitwise_left_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
// CHECK-NEXT: %[[Ternary:[^ ]*]] = conditional %[[CmpNoExcess]], %[[SHL]], %[[Zero]] : ui32
// CHECK-NEXT: yield %[[Ternary]] : ui32
@@ -173,7 +173,7 @@ func.func @arith_shift_right(%arg0: i32, %arg1: i32) {
// CHECK-DAG: %[[SizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}ui32
// CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[C2]], %[[SizeConstant]] : (ui32, ui32) -> i1
// CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}ui32
- // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32 {
+ // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression %[[C1]], %[[C2]], %[[CmpNoExcess]], %[[Zero]] : (ui32, ui32, i1, ui32) -> ui32 {
// CHECK-NEXT: %[[SHR:[^ ]*]] = bitwise_right_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
// CHECK-NEXT: %[[Ternary:[^ ]*]] = conditional %[[CmpNoExcess]], %[[SHR]], %[[Zero]] : ui32
// CHECK-NEXT: yield %[[Ternary]] : ui32
@@ -185,7 +185,7 @@ func.func @arith_shift_right(%arg0: i32, %arg1: i32) {
// CHECK-DAG: %[[SSizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}ui32
// CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[SC2]], %[[SSizeConstant]] : (ui32, ui32) -> i1
// CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}i32
- // CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : i32 {
+ // CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression %[[ARG0]], %[[SC2]], %[[SCmpNoExcess]], %[[SZero]] : (i32, ui32, i1, i32) -> i32 {
// CHECK-NEXT: %[[SHRSI:[^ ]*]] = bitwise_right_shift %[[ARG0]], %[[SC2]] : (i32, ui32) -> i32
// CHECK-NEXT: %[[STernary:[^ ]*]] = conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : i32
// CHECK-NEXT: yield %[[STernary]] : i32
@@ -210,7 +210,7 @@ func.func @arith_shift_left_index(%amount: i32) {
// CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
// CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
// CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0
- // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t {
+ // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression %[[C1]], %[[AmountIdx]], %[[CmpNoExcess]], %[[Zero]] : (!emitc.size_t, !emitc.size_t, i1, !emitc.size_t) -> !emitc.size_t {
// CHECK-NEXT: %[[SHL:[^ ]*]] = bitwise_left_shift %[[C1]], %[[AmountIdx]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
// CHECK-NEXT: %[[Ternary:[^ ]*]] = conditional %[[CmpNoExcess]], %[[SHL]], %[[Zero]] : !emitc.size_t
// CHECK-NEXT: yield %[[Ternary]] : !emitc.size_t
@@ -235,7 +235,7 @@ func.func @arith_shift_right_index(%amount: i32) {
// CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
// CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
// CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.size_t
- // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t {
+ // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression %[[C1]], %[[AmountIdx]], %[[CmpNoExcess]], %[[Zero]] : (!emitc.size_t, !emitc.size_t, i1, !emitc.size_t) -> !emitc.size_t {
// CHECK-NEXT: %[[SHR:[^ ]*]] = bitwise_right_shift %[[C1]], %[[AmountIdx]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
// CHECK-NEXT: %[[Ternary:[^ ]*]] = conditional %[[CmpNoExcess]], %[[SHR]], %[[Zero]] : !emitc.size_t
// CHECK-NEXT: yield %[[Ternary]] : !emitc.size_t
@@ -248,7 +248,7 @@ func.func @arith_shift_right_index(%amount: i32) {
// CHECK-DAG: %[[SSizeConstant:[^ ]*]] = emitc.mul %[[SByte]], %[[SSizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
// CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SSizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
// CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.ptrdiff_t
- // CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : !emitc.ptrdiff_t {
+ // CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression %[[SC1]], %[[AmountIdx]], %[[SCmpNoExcess]], %[[SZero]] : (!emitc.ptrdiff_t, !emitc.size_t, i1, !emitc.ptrdiff_t) -> !emitc.ptrdiff_t {
// CHECK-NEXT: %[[SHRSI:[^ ]*]] = bitwise_right_shift %[[SC1]], %[[AmountIdx]] : (!emitc.ptrdiff_t, !emitc.size_t) -> !emitc.ptrdiff_t
// CHECK-NEXT: %[[STernary:[^ ]*]] = conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : !emitc.ptrdiff_t
// CHECK-NEXT: yield %[[STernary]] : !emitc.ptrdiff_t
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 83bdbe1..ba12ff2 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -3,6 +3,7 @@
// Same below, but using the `ConvertToLLVMPatternInterface` entry point
// and the generic `convert-to-llvm` pass.
// RUN: mlir-opt --convert-to-llvm="filter-dialects=arith" --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --convert-to-llvm="filter-dialects=arith allow-pattern-rollback=0" --split-input-file %s | FileCheck %s
// CHECK-LABEL: @vector_ops
func.func @vector_ops(%arg0: vector<4xf32>, %arg1: vector<4xi1>, %arg2: vector<4xi64>, %arg3: vector<4xi64>) -> vector<4xf32> {
@@ -373,12 +374,11 @@ func.func @integer_extension_and_truncation(%arg0 : i3) {
// CHECK-LABEL: @integer_cast_0d_vector
func.func @integer_cast_0d_vector(%arg0 : vector<i3>) {
-// CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
-// CHECK-NEXT: = llvm.sext %[[ARG0]] : vector<1xi3> to vector<1xi6>
+// CHECK: = llvm.sext %{{.*}}: vector<1xi3> to vector<1xi6>
%0 = arith.extsi %arg0 : vector<i3> to vector<i6>
-// CHECK-NEXT: = llvm.zext %[[ARG0]] : vector<1xi3> to vector<1xi6>
+// CHECK-NEXT: = llvm.zext %{{.*}} : vector<1xi3> to vector<1xi6>
%1 = arith.extui %arg0 : vector<i3> to vector<i6>
-// CHECK-NEXT: = llvm.trunc %[[ARG0]] : vector<1xi3> to vector<1xi2>
+// CHECK-NEXT: = llvm.trunc %{{.*}} : vector<1xi3> to vector<1xi2>
%2 = arith.trunci %arg0 : vector<i3> to vector<i2>
return
}
diff --git a/mlir/test/Conversion/ArithToLLVM/type-conversion.mlir b/mlir/test/Conversion/ArithToLLVM/type-conversion.mlir
new file mode 100644
index 0000000..e3a0c82
--- /dev/null
+++ b/mlir/test/Conversion/ArithToLLVM/type-conversion.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s -test-llvm-legalize-patterns -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-llvm-legalize-patterns="allow-pattern-rollback=0" -split-input-file | FileCheck %s
+
+// CHECK-LABEL: llvm.func @arith_select(
+// CHECK-SAME: %[[arg0:.*]]: i1, %[[arg1:.*]]: i18, %[[arg2:.*]]: i18, %[[arg3:.*]]: i18, %[[arg4:.*]]: i18) -> !llvm.struct<(i18, i18)>
+// CHECK: %[[select0:.*]] = llvm.select %[[arg0]], %[[arg1]], %[[arg3]] : i1, i18
+// CHECK: %[[select1:.*]] = llvm.select %[[arg0]], %[[arg2]], %[[arg4]] : i1, i18
+// CHECK: %[[i0:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18)>
+// CHECK: %[[i1:.*]] = llvm.insertvalue %[[select0]], %[[i0]][0] : !llvm.struct<(i18, i18)>
+// CHECK: %[[i2:.*]] = llvm.insertvalue %[[select1]], %[[i1]][1] : !llvm.struct<(i18, i18)>
+// CHECK: llvm.return %[[i2]]
+func.func @arith_select(%arg0: i1, %arg1: i17, %arg2: i17) -> (i17) {
+ %0 = arith.select %arg0, %arg1, %arg2 : i17
+ return %0 : i17
+}
diff --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
index ad1b665..4d2c12a 100644
--- a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
@@ -3,6 +3,7 @@
// Same below, but using the `ConvertToLLVMPatternInterface` entry point
// and the generic `convert-to-llvm` pass.
// RUN: mlir-opt --convert-to-llvm="filter-dialects=complex" --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --convert-to-llvm="filter-dialects=complex allow-pattern-rollback=0" --split-input-file %s | FileCheck %s
// CHECK-LABEL: func @complex_create
// CHECK-SAME: (%[[REAL0:.*]]: f32, %[[IMAG0:.*]]: f32)
@@ -23,9 +24,9 @@ func.func @complex_constant() -> complex<f64> {
// CHECK-LABEL: func @complex_extract
// CHECK-SAME: (%[[CPLX:.*]]: complex<f32>)
-// CHECK-NEXT: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[CPLX]] : complex<f32> to !llvm.struct<(f32, f32)>
-// CHECK-NEXT: %[[REAL:.*]] = llvm.extractvalue %[[CAST0]][0] : !llvm.struct<(f32, f32)>
-// CHECK-NEXT: %[[IMAG:.*]] = llvm.extractvalue %[[CAST0]][1] : !llvm.struct<(f32, f32)>
+// CHECK: builtin.unrealized_conversion_cast %[[CPLX]] : complex<f32> to !llvm.struct<(f32, f32)>
+// CHECK: %[[REAL:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(f32, f32)>
+// CHECK: %[[IMAG:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(f32, f32)>
func.func @complex_extract(%cplx: complex<f32>) {
%real1 = complex.re %cplx : complex<f32>
%imag1 = complex.im %cplx : complex<f32>
diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
index ae59f28..080ba4f 100644
--- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
+++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
@@ -1,19 +1,13 @@
-// RUN: mlir-opt %s -convert-complex-to-rocdl-library-calls | FileCheck %s
+// RUN: mlir-opt %s --allow-unregistered-dialect -convert-complex-to-rocdl-library-calls | FileCheck %s
// CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
// CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
-// CHECK-DAG: @__ocml_carg_f32(complex<f32>) -> f32
-// CHECK-DAG: @__ocml_carg_f64(complex<f64>) -> f64
// CHECK-DAG: @__ocml_ccos_f32(complex<f32>) -> complex<f32>
// CHECK-DAG: @__ocml_ccos_f64(complex<f64>) -> complex<f64>
// CHECK-DAG: @__ocml_cexp_f32(complex<f32>) -> complex<f32>
// CHECK-DAG: @__ocml_cexp_f64(complex<f64>) -> complex<f64>
// CHECK-DAG: @__ocml_clog_f32(complex<f32>) -> complex<f32>
// CHECK-DAG: @__ocml_clog_f64(complex<f64>) -> complex<f64>
-// CHECK-DAG: @__ocml_conj_f32(complex<f32>) -> complex<f32>
-// CHECK-DAG: @__ocml_conj_f64(complex<f64>) -> complex<f64>
-// CHECK-DAG: @__ocml_cpow_f32(complex<f32>, complex<f32>) -> complex<f32>
-// CHECK-DAG: @__ocml_cpow_f64(complex<f64>, complex<f64>) -> complex<f64>
// CHECK-DAG: @__ocml_csin_f32(complex<f32>) -> complex<f32>
// CHECK-DAG: @__ocml_csin_f64(complex<f64>) -> complex<f64>
// CHECK-DAG: @__ocml_csqrt_f32(complex<f32>) -> complex<f32>
@@ -33,16 +27,6 @@ func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
return %rf, %rd : f32, f64
}
-//CHECK-LABEL: @angle_caller
-func.func @angle_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
- // CHECK: %[[AF:.*]] = call @__ocml_carg_f32(%{{.*}})
- %af = complex.angle %f : complex<f32>
- // CHECK: %[[AD:.*]] = call @__ocml_carg_f64(%{{.*}})
- %ad = complex.angle %d : complex<f64>
- // CHECK: return %[[AF]], %[[AD]]
- return %af, %ad : f32, f64
-}
-
//CHECK-LABEL: @cos_caller
func.func @cos_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
// CHECK: %[[CF:.*]] = call @__ocml_ccos_f32(%{{.*}})
@@ -73,24 +57,15 @@ func.func @log_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, comp
return %lf, %ld : complex<f32>, complex<f64>
}
-//CHECK-LABEL: @conj_caller
-func.func @conj_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
- // CHECK: %[[CF:.*]] = call @__ocml_conj_f32(%{{.*}})
- %cf2 = complex.conj %f : complex<f32>
- // CHECK: %[[CD:.*]] = call @__ocml_conj_f64(%{{.*}})
- %cd2 = complex.conj %d : complex<f64>
- // CHECK: return %[[CF]], %[[CD]]
- return %cf2, %cd2 : complex<f32>, complex<f64>
-}
-
//CHECK-LABEL: @pow_caller
-func.func @pow_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
- // CHECK: %[[PF:.*]] = call @__ocml_cpow_f32(%{{.*}}, %{{.*}})
- %pf = complex.pow %f, %f : complex<f32>
- // CHECK: %[[PD:.*]] = call @__ocml_cpow_f64(%{{.*}}, %{{.*}})
- %pd = complex.pow %d, %d : complex<f64>
- // CHECK: return %[[PF]], %[[PD]]
- return %pf, %pd : complex<f32>, complex<f64>
+//CHECK: (%[[Z:.*]]: complex<f32>, %[[W:.*]]: complex<f32>)
+func.func @pow_caller(%z: complex<f32>, %w: complex<f32>) -> complex<f32> {
+ // CHECK: %[[LOG:.*]] = call @__ocml_clog_f32(%[[Z]])
+ // CHECK: %[[MUL:.*]] = complex.mul %[[W]], %[[LOG]]
+ // CHECK: %[[EXP:.*]] = call @__ocml_cexp_f32(%[[MUL]])
+ // CHECK: return %[[EXP]]
+ %r = complex.pow %z, %w : complex<f32>
+ return %r : complex<f32>
}
//CHECK-LABEL: @sin_caller
diff --git a/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir b/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir
index 3ec8f1f..18d0526 100644
--- a/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir
+++ b/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir
@@ -3,6 +3,7 @@
// Same below, but using the `ConvertToLLVMPatternInterface` entry point
// and the generic `convert-to-llvm` pass.
// RUN: mlir-opt --convert-to-llvm="filter-dialects=cf" --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --convert-to-llvm="filter-dialects=cf allow-pattern-rollback=0" --split-input-file %s | FileCheck %s
func.func @main() {
%a = arith.constant 0 : i1
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
index d68ba44..c85f433 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
@@ -83,20 +83,16 @@ func.func @vaddi_reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>) -> (i32
// CHECK-LABEL: @transpose
// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>)
func.func @transpose(%arg0 : vector<2x3xi32>) -> (vector<3x2xi32>) {
- // CHECK: %[[UB:.*]] = ub.poison : vector<2xi32>
// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : i32 from vector<3xi32>
- // CHECK: %[[INSERT0:.*]]= vector.insert %[[EXTRACT0]], %[[UB]] [0] : i32 into vector<2xi32>
// CHECK: %[[EXTRACT1:.*]] = vector.extract %[[ARG1]][0] : i32 from vector<3xi32>
- // CHECK: %[[INSERT1:.*]] = vector.insert %[[EXTRACT1]], %[[INSERT0]][1] : i32 into vector<2xi32>
+ // CHECK: %[[FROM_ELEMENTS0:.*]] = vector.from_elements %[[EXTRACT0]], %[[EXTRACT1]] : vector<2xi32>
// CHECK: %[[EXTRACT2:.*]] = vector.extract %[[ARG0]][1] : i32 from vector<3xi32>
- // CHECK: %[[INSERT2:.*]] = vector.insert %[[EXTRACT2]], %[[UB]] [0] : i32 into vector<2xi32>
// CHECK: %[[EXTRACT3:.*]] = vector.extract %[[ARG1]][1] : i32 from vector<3xi32>
- // CHECK: %[[INSERT3:.*]] = vector.insert %[[EXTRACT3]], %[[INSERT2]] [1] : i32 into vector<2xi32>
+ // CHECK: %[[FROM_ELEMENTS1:.*]] = vector.from_elements %[[EXTRACT2]], %[[EXTRACT3]] : vector<2xi32>
// CHECK: %[[EXTRACT4:.*]] = vector.extract %[[ARG0]][2] : i32 from vector<3xi32>
- // CHECK: %[[INSERT4:.*]] = vector.insert %[[EXTRACT4]], %[[UB]] [0] : i32 into vector<2xi32>
// CHECK: %[[EXTRACT5:.*]] = vector.extract %[[ARG1]][2] : i32 from vector<3xi32>
- // CHECK: %[[INSERT5:.*]] = vector.insert %[[EXTRACT5]], %[[INSERT4]] [1] : i32 into vector<2xi32>
- // CHECK: return %[[INSERT1]], %[[INSERT3]], %[[INSERT5]] : vector<2xi32>, vector<2xi32>, vector<2xi32>
+ // CHECK: %[[FROM_ELEMENTS2:.*]] = vector.from_elements %[[EXTRACT4]], %[[EXTRACT5]] : vector<2xi32>
+ // CHECK: return %[[FROM_ELEMENTS0]], %[[FROM_ELEMENTS1]], %[[FROM_ELEMENTS2]] : vector<2xi32>, vector<2xi32>, vector<2xi32>
%0 = vector.transpose %arg0, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
return %0 : vector<3x2xi32>
}
diff --git a/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir b/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir
index 2113557..94dfcea 100644
--- a/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir
@@ -9,6 +9,7 @@
// Same below, but using the `ConvertToLLVMPatternInterface` entry point
// and the generic `convert-to-llvm` pass.
// RUN: mlir-opt --convert-to-llvm="filter-dialects=arith,cf,func,math" %s | FileCheck %s
+// RUN: mlir-opt --convert-to-llvm="filter-dialects=arith,cf,func,math allow-pattern-rollback=0" %s | FileCheck %s
// CHECK-LABEL: func @empty() {
// CHECK-NEXT: llvm.return
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
index ed7fa65..0016db5 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm{dynamic=true}))" | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm{dynamic=true allow-pattern-rollback=0}))" | FileCheck %s
// CHECK-LABEL: gpu.module @nvvm_module
gpu.module @nvvm_module [#nvvm.target] {
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index 2b6adff..6dd03b1 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt %s -convert-gpu-to-rocdl -split-input-file | FileCheck %s
-// RUN: mlir-opt %s -convert-gpu-to-rocdl='allowed-dialects=func,arith,math' -split-input-file | FileCheck %s
-// RUN: mlir-opt %s -convert-gpu-to-rocdl='index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s
+// RUN: mlir-opt %s -convert-gpu-to-rocdl='chipset=gfx950' -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-gpu-to-rocdl='chipset=gfx950 allowed-dialects=func,arith,math' -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-gpu-to-rocdl='chipset=gfx950 index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s
// CHECK-LABEL: @test_module
// CHECK-SAME: llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"
@@ -54,8 +54,8 @@ gpu.module @test_module {
// CHECK: = llvm.sext %{{.*}} : i32 to i64
%gDimZ = gpu.grid_dim z
- // CHECK: = rocdl.mbcnt.lo %{{.*}}, %{{.*}} : (i32, i32) -> i32
- // CHECK: = rocdl.mbcnt.hi %{{.*}}, %{{.*}} : (i32, i32) -> i32
+ // CHECK: = rocdl.mbcnt.lo %{{.*}}, %{{.*}} {res_attrs = [{llvm.noundef, llvm.range = #llvm.constant_range<i32, 0, 32>}]} : (i32, i32) -> i32
+ // CHECK: = rocdl.mbcnt.hi %{{.*}}, %{{.*}} {res_attrs = [{llvm.noundef, llvm.range = #llvm.constant_range<i32, 0, 64>}]} : (i32, i32) -> i32
// CHECK: = llvm.sext %{{.*}} : i32 to i64
%laneId = gpu.lane_id
@@ -701,7 +701,7 @@ gpu.module @test_module {
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
// CHECK: %[[#PERMUTE:]] = rocdl.ds_bpermute %[[#ALIGNED_DST_LANE]], %[[#CAST_VALUE]] : (i32, i32) -> i32
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32
- %shfli, %predi = gpu.shuffle idx %arg0, %arg1, %arg2 : f32
+ %shfli, %predi = gpu.shuffle idx %arg0, %arg1, %arg2 : f32
// *** UP mode shuffle ***
// CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
// CHECK: %[[#ZERO:]] = llvm.mlir.constant(0 : i32) : i32
@@ -734,14 +734,40 @@ gpu.module @test_module {
func.return %shfl, %shfli, %shflu, %shfld : f32, f32, f32, f32
}
+ // CHECK-LABEL: func @gpu_shuffle_promote()
+ func.func @gpu_shuffle_promote() -> (f32, f32, f32) {
+ // CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
+ %arg0 = arith.constant 1.0 : f32
+ %arg1 = arith.constant 4 : i32
+ %arg2 = arith.constant 16 : i32
+ %arg3 = arith.constant 32 : i32
+ %arg4 = arith.constant 64 : i32
+ // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
+ // CHECK: %[[#MASK:]] = llvm.mlir.constant(4127 : i32) : i32
+ // CHECK: %[[#PERMUTE:]] = rocdl.ds_swizzle %[[#CAST_VALUE]], %[[#MASK]] : (i32, i32) -> i32
+ // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32
+ %shfl1, %pred1 = gpu.shuffle xor %arg0, %arg1, %arg4 : f32
+ // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
+ // CHECK: %[[#PERMUTE:]] = rocdl.permlane16.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)>
+ // CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
+ // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32
+ %shfl2, %pred2 = gpu.shuffle xor %arg0, %arg2, %arg4 : f32
+ // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
+ // CHECK: %[[#PERMUTE:]] = rocdl.permlane32.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)>
+ // CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
+ // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32
+ %shfl3, %pred3 = gpu.shuffle xor %arg0, %arg3, %arg4 : f32
+ func.return %shfl1, %shfl2, %shfl3 : f32, f32, f32
+ }
+
// CHECK-LABEL: func @gpu_shuffle_vec
// CHECK-SAME: (%[[ARG:.*]]: vector<4xf16>, %{{.*}}: i32, %{{.*}}: i32)
func.func @gpu_shuffle_vec(%arg0: vector<4xf16>, %arg1: i32, %arg2: i32) -> vector<4xf16> {
// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG]] : vector<4xf16> to vector<2xi32>
// CHECK: %[[IDX0:.*]] = llvm.mlir.constant(0 : i32) : i32
- // CHECK: %[[ELEM0:.*]] = llvm.extractelement %13[%[[IDX0]] : i32] : vector<2xi32>
+ // CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[IDX0]] : i32] : vector<2xi32>
// CHECK: %[[IDX1:.*]] = llvm.mlir.constant(1 : i32) : i32
- // CHECK: %[[ELEM1:.*]] = llvm.extractelement %13[%[[IDX1]] : i32] : vector<2xi32>
+ // CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[IDX1]] : i32] : vector<2xi32>
// CHECK: %[[PERM0:.*]] = rocdl.ds_bpermute %{{.*}}, %[[ELEM0]] : (i32, i32) -> i32
// CHECK: %[[PERM1:.*]] = rocdl.ds_bpermute %{{.*}}, %[[ELEM1]] : (i32, i32) -> i32
// CHECK: %[[V0:.*]] = llvm.mlir.poison : vector<2xi32>
@@ -776,3 +802,19 @@ gpu.module @test_module {
func.return %bDimX : index
}
}
+
+// -----
+
+gpu.module @test_module {
+// CHECK-LABEL: func @broadcast
+// CHECK-SAME: (%[[ARG:.*]]: i64, %[[IDX:.*]]: i32)
+func.func @broadcast(%arg0 : index, %arg1 : i32) -> (index, index, index) {
+// CHECK: %{{.*}} = rocdl.readfirstlane %[[ARG]] : i64
+// CHECK: %{{.*}} = rocdl.readfirstlane %[[ARG]] : i64
+// CHECK: %{{.*}} = rocdl.readlane %[[ARG]], %[[IDX]] : (i64, i32) -> i64
+ %0 = gpu.subgroup_broadcast %arg0, first_active_lane : index
+ %1 = gpu.subgroup_broadcast %arg0, any_lane : index
+ %2 = gpu.subgroup_broadcast %arg0, specific_lane %arg1 : index
+ func.return %0, %1, %2 : index, index, index
+}
+}
diff --git a/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir b/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir
index 26abb3b..007929e 100644
--- a/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir
+++ b/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir
@@ -5,6 +5,7 @@
// Same below, but using the `ConvertToLLVMPatternInterface` entry point
// and the generic `convert-to-llvm` pass.
// RUN: mlir-opt --convert-to-llvm="filter-dialects=index" --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --convert-to-llvm="filter-dialects=index allow-pattern-rollback=0" --split-input-file %s | FileCheck %s
// CHECK-LABEL: @trivial_ops
func.func @trivial_ops(%a: index, %b: index) {
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index 9290408..f454122 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -3,6 +3,7 @@
// Same below, but using the `ConvertToLLVMPatternInterface` entry point
// and the generic `convert-to-llvm` pass.
// RUN: mlir-opt --convert-to-llvm="filter-dialects=math" --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --convert-to-llvm="filter-dialects=math allow-pattern-rollback=0" --split-input-file %s | FileCheck %s
// CHECK-LABEL: @ops
func.func @ops(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64) {
diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
index 08354db..26b5456 100644
--- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
@@ -79,21 +79,17 @@ func.func @absf_caller(%float: f32, %double: f64) -> (f32, f64) {
// CHECK-LABEL: func @absf_vec_caller(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @fabsf(%[[IN0_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @fabsf(%[[IN1_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK: %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @fabs(%[[IN0_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @fabs(%[[IN1_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK: %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK: return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
// CHECK: }
func.func @absf_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
%float_result = math.absf %float : vector<2xf32>
@@ -116,21 +112,17 @@ func.func @acos_caller(%float: f32, %double: f64) -> (f32, f64) {
// CHECK-LABEL: func @acos_vec_caller(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @acosf(%[[IN0_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @acosf(%[[IN1_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK: %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @acos(%[[IN0_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @acos(%[[IN1_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK: %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK: return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
// CHECK: }
func.func @acos_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
%float_result = math.acos %float : vector<2xf32>
@@ -153,21 +145,17 @@ func.func @acosh_caller(%float: f32, %double: f64) -> (f32, f64) {
// CHECK-LABEL: func @acosh_vec_caller(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @acoshf(%[[IN0_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @acoshf(%[[IN1_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK: %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @acosh(%[[IN0_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @acosh(%[[IN1_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK: %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK: return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
// CHECK: }
func.func @acosh_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
%float_result = math.acosh %float : vector<2xf32>
@@ -190,21 +178,17 @@ func.func @asin_caller(%float: f32, %double: f64) -> (f32, f64) {
// CHECK-LABEL: func @asin_vec_caller(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @asinf(%[[IN0_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @asinf(%[[IN1_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK: %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @asin(%[[IN0_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @asin(%[[IN1_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK: %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK: return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
// CHECK: }
func.func @asin_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
%float_result = math.asin %float : vector<2xf32>
@@ -227,21 +211,17 @@ func.func @asinh_caller(%float: f32, %double: f64) -> (f32, f64) {
// CHECK-LABEL: func @asinh_vec_caller(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @asinhf(%[[IN0_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @asinhf(%[[IN1_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK: %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @asinh(%[[IN0_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @asinh(%[[IN1_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK: %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK: return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
// CHECK: }
func.func @asinh_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
%float_result = math.asinh %float : vector<2xf32>
@@ -274,21 +254,17 @@ func.func @atan_caller(%float: f32, %double: f64, %half: f16, %bfloat: bf16) ->
// CHECK-LABEL: func @atan_vec_caller(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @atanf(%[[IN0_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @atanf(%[[IN1_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK: %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @atan(%[[IN0_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @atan(%[[IN1_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK: %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK: return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
// CHECK: }
func.func @atan_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
%float_result = math.atan %float : vector<2xf32>
@@ -321,21 +297,17 @@ func.func @atanh_caller(%float: f32, %double: f64, %half: f16, %bfloat: bf16) ->
// CHECK-LABEL: func @atanh_vec_caller(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @atanhf(%[[IN0_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @atanhf(%[[IN1_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK: %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @atanh(%[[IN0_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @atanh(%[[IN1_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK: %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK: return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
// CHECK: }
func.func @atanh_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
%float_result = math.atanh %float : vector<2xf32>
@@ -419,23 +391,19 @@ func.func @erf_caller(%float: f32, %double: f64) -> (f32, f64) {
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
func.func @erf_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
- // CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
- // CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @erff(%[[IN0_F32]]) : (f32) -> f32
- // CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @erff(%[[IN1_F32]]) : (f32) -> f32
- // CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+ // CHECK: %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
%float_result = math.erf %float : vector<2xf32>
// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @erf(%[[IN0_F64]]) : (f64) -> f64
- // CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @erf(%[[IN1_F64]]) : (f64) -> f64
- // CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
+ // CHECK: %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
%double_result = math.erf %double : vector<2xf64>
- // CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+ // CHECK: return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
return %float_result, %double_result : vector<2xf32>, vector<2xf64>
}
@@ -459,21 +427,17 @@ func.func @exp_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vec
// CHECK-LABEL: func @exp_vec_caller(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @expf(%[[IN0_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @expf(%[[IN1_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK: %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @exp(%[[IN0_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @exp(%[[IN1_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK: %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK: return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
// CHECK: }
// CHECK-LABEL: func @exp2_caller
@@ -496,21 +460,17 @@ func.func @exp2_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (ve
// CHECK-LABEL: func @exp2_vec_caller(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @exp2f(%[[IN0_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @exp2f(%[[IN1_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK: %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @exp2(%[[IN0_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @exp2(%[[IN1_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK: %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK: return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
// CHECK: }
// CHECK-LABEL: func @log_caller
@@ -533,21 +493,17 @@ func.func @log_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vec
// CHECK-LABEL: func @log_vec_caller(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @logf(%[[IN0_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @logf(%[[IN1_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK: %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @log(%[[IN0_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @log(%[[IN1_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK: %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK: return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
// CHECK: }
// CHECK-LABEL: func @log2_caller
@@ -570,21 +526,17 @@ func.func @log2_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (ve
// CHECK-LABEL: func @log2_vec_caller(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @log2f(%[[IN0_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @log2f(%[[IN1_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK: %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @log2(%[[IN0_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @log2(%[[IN1_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK: %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK: return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
// CHECK: }
// CHECK-LABEL: func @log10_caller
@@ -607,21 +559,17 @@ func.func @log10_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (v
// CHECK-LABEL: func @log10_vec_caller(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @log10f(%[[IN0_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @log10f(%[[IN1_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK: %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @log10(%[[IN0_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @log10(%[[IN1_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK: %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK: return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
// CHECK: }
// CHECK-LABEL: func @expm1_caller
@@ -644,21 +592,17 @@ func.func @expm1_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (v
// CHECK-LABEL: func @expm1_vec_caller(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @expm1f(%[[IN0_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @expm1f(%[[IN1_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK: %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @expm1(%[[IN0_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @expm1(%[[IN1_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK: %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK: return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
// CHECK: }
func.func @expm1_multidim_vec_caller(%float: vector<2x2xf32>) -> (vector<2x2xf32>) {
@@ -667,20 +611,16 @@ func.func @expm1_multidim_vec_caller(%float: vector<2x2xf32>) -> (vector<2x2xf32
}
// CHECK-LABEL: func @expm1_multidim_vec_caller(
// CHECK-SAME: %[[VAL:.*]]: vector<2x2xf32>
-// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
// CHECK: %[[IN0_0_F32:.*]] = vector.extract %[[VAL]][0, 0] : f32 from vector<2x2xf32>
// CHECK: %[[OUT0_0_F32:.*]] = call @expm1f(%[[IN0_0_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_1:.*]] = vector.insert %[[OUT0_0_F32]], %[[CVF]] [0, 0] : f32 into vector<2x2xf32>
// CHECK: %[[IN0_1_F32:.*]] = vector.extract %[[VAL]][0, 1] : f32 from vector<2x2xf32>
// CHECK: %[[OUT0_1_F32:.*]] = call @expm1f(%[[IN0_1_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_2:.*]] = vector.insert %[[OUT0_1_F32]], %[[VAL_1]] [0, 1] : f32 into vector<2x2xf32>
// CHECK: %[[IN1_0_F32:.*]] = vector.extract %[[VAL]][1, 0] : f32 from vector<2x2xf32>
// CHECK: %[[OUT1_0_F32:.*]] = call @expm1f(%[[IN1_0_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_3:.*]] = vector.insert %[[OUT1_0_F32]], %[[VAL_2]] [1, 0] : f32 into vector<2x2xf32>
// CHECK: %[[IN1_1_F32:.*]] = vector.extract %[[VAL]][1, 1] : f32 from vector<2x2xf32>
// CHECK: %[[OUT1_1_F32:.*]] = call @expm1f(%[[IN1_1_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_4:.*]] = vector.insert %[[OUT1_1_F32]], %[[VAL_3]] [1, 1] : f32 into vector<2x2xf32>
-// CHECK: return %[[VAL_4]] : vector<2x2xf32>
+// CHECK: %[[RES_F32:.*]] = vector.from_elements %[[OUT0_0_F32]], %[[OUT0_1_F32]], %[[OUT1_0_F32]], %[[OUT1_1_F32]] : vector<2x2xf32>
+// CHECK: return %[[RES_F32]] : vector<2x2xf32>
// CHECK: }
// CHECK-LABEL: func @fma_caller(
@@ -704,29 +644,25 @@ func.func @fma_vec_caller(%float_a: vector<2xf32>, %float_b: vector<2xf32>, %flo
// CHECK-SAME: %[[VAL_0A:.*]]: vector<2xf32>, %[[VAL_0B:.*]]: vector<2xf32>, %[[VAL_0C:.*]]: vector<2xf32>,
// CHECK-SAME: %[[VAL_1A:.*]]: vector<2xf64>, %[[VAL_1B:.*]]: vector<2xf64>, %[[VAL_1C:.*]]: vector<2xf64>
// CHECK-SAME: ) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK: %[[IN0_F32A:.*]] = vector.extract %[[VAL_0A]][0] : f32 from vector<2xf32>
// CHECK: %[[IN0_F32B:.*]] = vector.extract %[[VAL_0B]][0] : f32 from vector<2xf32>
// CHECK: %[[IN0_F32C:.*]] = vector.extract %[[VAL_0C]][0] : f32 from vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @fmaf(%[[IN0_F32A]], %[[IN0_F32B]], %[[IN0_F32C]]) : (f32, f32, f32) -> f32
-// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
// CHECK: %[[IN1_F32A:.*]] = vector.extract %[[VAL_0A]][1] : f32 from vector<2xf32>
// CHECK: %[[IN1_F32B:.*]] = vector.extract %[[VAL_0B]][1] : f32 from vector<2xf32>
// CHECK: %[[IN1_F32C:.*]] = vector.extract %[[VAL_0C]][1] : f32 from vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @fmaf(%[[IN1_F32A]], %[[IN1_F32B]], %[[IN1_F32C]]) : (f32, f32, f32) -> f32
-// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK: %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
// CHECK: %[[IN0_F64A:.*]] = vector.extract %[[VAL_1A]][0] : f64 from vector<2xf64>
// CHECK: %[[IN0_F64B:.*]] = vector.extract %[[VAL_1B]][0] : f64 from vector<2xf64>
// CHECK: %[[IN0_F64C:.*]] = vector.extract %[[VAL_1C]][0] : f64 from vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @fma(%[[IN0_F64A]], %[[IN0_F64B]], %[[IN0_F64C]]) : (f64, f64, f64) -> f64
-// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
// CHECK: %[[IN1_F64A:.*]] = vector.extract %[[VAL_1A]][1] : f64 from vector<2xf64>
// CHECK: %[[IN1_F64B:.*]] = vector.extract %[[VAL_1B]][1] : f64 from vector<2xf64>
// CHECK: %[[IN1_F64C:.*]] = vector.extract %[[VAL_1C]][1] : f64 from vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @fma(%[[IN1_F64A]], %[[IN1_F64B]], %[[IN1_F64C]]) : (f64, f64, f64) -> f64
-// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK: %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK: return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
// CHECK: }
// CHECK-LABEL: func @round_caller
@@ -814,23 +750,19 @@ func.func @sin_caller(%float: f32, %double: f64) -> (f32, f64) {
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
func.func @round_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
- // CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
- // CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @roundf(%[[IN0_F32]]) : (f32) -> f32
- // CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @roundf(%[[IN1_F32]]) : (f32) -> f32
- // CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+ // CHECK: %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
%float_result = math.round %float : vector<2xf32>
// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @round(%[[IN0_F64]]) : (f64) -> f64
- // CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @round(%[[IN1_F64]]) : (f64) -> f64
- // CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
+ // CHECK: %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
%double_result = math.round %double : vector<2xf64>
- // CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+ // CHECK: return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
return %float_result, %double_result : vector<2xf32>, vector<2xf64>
}
@@ -838,23 +770,19 @@ func.func @round_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (v
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
func.func @roundeven_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
- // CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
- // CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @roundevenf(%[[IN0_F32]]) : (f32) -> f32
- // CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @roundevenf(%[[IN1_F32]]) : (f32) -> f32
- // CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+ // CHECK: %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
%float_result = math.roundeven %float : vector<2xf32>
// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @roundeven(%[[IN0_F64]]) : (f64) -> f64
- // CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @roundeven(%[[IN1_F64]]) : (f64) -> f64
- // CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
+ // CHECK: %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
%double_result = math.roundeven %double : vector<2xf64>
- // CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+ // CHECK: return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
return %float_result, %double_result : vector<2xf32>, vector<2xf64>
}
@@ -862,23 +790,19 @@ func.func @roundeven_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
func.func @trunc_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
- // CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
- // CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @truncf(%[[IN0_F32]]) : (f32) -> f32
- // CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @truncf(%[[IN1_F32]]) : (f32) -> f32
- // CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+ // CHECK: %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
%float_result = math.trunc %float : vector<2xf32>
// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @trunc(%[[IN0_F64]]) : (f64) -> f64
- // CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @trunc(%[[IN1_F64]]) : (f64) -> f64
- // CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
+ // CHECK: %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
%double_result = math.trunc %double : vector<2xf64>
- // CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+ // CHECK: return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
return %float_result, %double_result : vector<2xf32>, vector<2xf64>
}
@@ -907,21 +831,17 @@ func.func @tan_caller(%float: f32, %double: f64, %half: f16, %bfloat: bf16) -> (
// CHECK-LABEL: func @tan_vec_caller(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @tanf(%[[IN0_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @tanf(%[[IN1_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK: %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @tan(%[[IN0_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @tan(%[[IN1_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK: %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK: return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
// CHECK: }
func.func @tan_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
%float_result = math.tan %float : vector<2xf32>
@@ -985,21 +905,17 @@ func.func @sqrt_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (ve
// CHECK-LABEL: func @sqrt_vec_caller(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @sqrtf(%[[IN0_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @sqrtf(%[[IN1_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK: %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @sqrt(%[[IN0_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @sqrt(%[[IN1_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK: %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK: return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
// CHECK: }
// CHECK-LABEL: func @rsqrt_caller
@@ -1022,21 +938,17 @@ func.func @rsqrt_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (v
// CHECK-LABEL: func @rsqrt_vec_caller(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @rsqrtf(%[[IN0_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @rsqrtf(%[[IN1_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK: %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @rsqrt(%[[IN0_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @rsqrt(%[[IN1_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK: %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK: return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
// CHECK: }
// CHECK-LABEL: func @powf_caller(
@@ -1060,23 +972,19 @@ func.func @powf_vec_caller(%float_a: vector<2xf32>, %float_b: vector<2xf32>, %do
// CHECK-SAME: %[[VAL_0A:.*]]: vector<2xf32>, %[[VAL_0B:.*]]: vector<2xf32>,
// CHECK-SAME: %[[VAL_1A:.*]]: vector<2xf64>, %[[VAL_1B:.*]]: vector<2xf64>
// CHECK-SAME: ) -> (vector<2xf32>, vector<2xf64>) {
-// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
// CHECK: %[[IN0_F32A:.*]] = vector.extract %[[VAL_0A]][0] : f32 from vector<2xf32>
// CHECK: %[[IN0_F32B:.*]] = vector.extract %[[VAL_0B]][0] : f32 from vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @powf(%[[IN0_F32A]], %[[IN0_F32B]]) : (f32, f32) -> f32
-// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
// CHECK: %[[IN1_F32A:.*]] = vector.extract %[[VAL_0A]][1] : f32 from vector<2xf32>
// CHECK: %[[IN1_F32B:.*]] = vector.extract %[[VAL_0B]][1] : f32 from vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @powf(%[[IN1_F32A]], %[[IN1_F32B]]) : (f32, f32) -> f32
-// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK: %[[RES_F32:.*]] = vector.from_elements %[[OUT0_F32]], %[[OUT1_F32]] : vector<2xf32>
// CHECK: %[[IN0_F64A:.*]] = vector.extract %[[VAL_1A]][0] : f64 from vector<2xf64>
// CHECK: %[[IN0_F64B:.*]] = vector.extract %[[VAL_1B]][0] : f64 from vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @pow(%[[IN0_F64A]], %[[IN0_F64B]]) : (f64, f64) -> f64
-// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
// CHECK: %[[IN1_F64A:.*]] = vector.extract %[[VAL_1A]][1] : f64 from vector<2xf64>
// CHECK: %[[IN1_F64B:.*]] = vector.extract %[[VAL_1B]][1] : f64 from vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @pow(%[[IN1_F64A]], %[[IN1_F64B]]) : (f64, f64) -> f64
-// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
-// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK: %[[RES_F64:.*]] = vector.from_elements %[[OUT0_F64]], %[[OUT1_F64]] : vector<2xf64>
+// CHECK: return %[[RES_F32]], %[[RES_F64]] : vector<2xf32>, vector<2xf64>
// CHECK: }
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir
new file mode 100644
index 0000000..c1627a0
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir
@@ -0,0 +1,50 @@
+// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP
+// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP
+
+func.func @alloc_copy(%arg0: memref<999xi32>) {
+ %alloc = memref.alloc() : memref<999xi32>
+ memref.copy %arg0, %alloc : memref<999xi32> to memref<999xi32>
+ %alloc_1 = memref.alloc() : memref<999xi32>
+ memref.copy %arg0, %alloc_1 : memref<999xi32> to memref<999xi32>
+ return
+}
+
+// CHECK: module {
+// NOCPP: emitc.include <"stdlib.h">
+// NOCPP-NEXT: emitc.include <"string.h">
+
+// CPP: emitc.include <"cstdlib">
+// CPP-NEXT: emitc.include <"cstring">
+
+// CHECK-LABEL: alloc_copy
+// CHECK-SAME: %[[arg0:.*]]: memref<999xi32>
+// CHECK-NEXT: builtin.unrealized_conversion_cast %arg0 : memref<999xi32> to !emitc.array<999xi32>
+// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
+// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index
+// CHECK-NEXT: emitc.mul %1, %2 : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK-NEXT: emitc.call_opaque "malloc"(%3) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
+// CHECK-NEXT: emitc.cast %4 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
+// CHECK-NEXT: builtin.unrealized_conversion_cast %5 : !emitc.ptr<i32> to !emitc.array<999xi32>
+// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> index
+// CHECK-NEXT: emitc.subscript %0[%7] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
+// CHECK-NEXT: emitc.apply "&"(%8) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
+// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
+// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index
+// CHECK-NEXT: emitc.mul %12, %13 : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK-NEXT: emitc.call_opaque "memcpy"(%11, %9, %14) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> ()
+// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
+// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index
+// CHECK-NEXT: emitc.mul %15, %16 : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK-NEXT: emitc.call_opaque "malloc"(%17) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
+// CHECK-NEXT: emitc.cast %18 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
+// CHECK-NEXT: builtin.unrealized_conversion_cast %19 : !emitc.ptr<i32> to !emitc.array<999xi32>
+// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> index
+// CHECK-NEXT: emitc.subscript %0[%21] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
+// CHECK-NEXT: emitc.apply "&"(%22) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
+// CHECK-NEXT: emitc.subscript %20[%21] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
+// CHECK-NEXT: emitc.apply "&"(%24) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
+// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
+// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index
+// CHECK-NEXT: emitc.mul %26, %27 : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK-NEXT: emitc.call_opaque "memcpy"(%25, %23, %28) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> ()
+// CHECK-NEXT: return
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
new file mode 100644
index 0000000..d151d1b
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP
+// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP
+
+func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) {
+ memref.copy %arg0, %arg1 : memref<9x4x5x7xf32> to memref<9x4x5x7xf32>
+ return
+}
+
+// CHECK: module {
+// NOCPP: emitc.include <"string.h">
+// CPP: emitc.include <"cstring">
+
+// CHECK-LABEL: copying
+// CHECK-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32>
+// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
+// CHECK-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
+// CHECK-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index
+// CHECK-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
+// CHECK-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// CHECK-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
+// CHECK-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// CHECK-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
+// CHECK-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index
+// CHECK-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+// CHECK-NEXT:}
+
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
index fda0197..b6eccfc 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
@@ -1,13 +1,5 @@
// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics
-func.func @memref_op(%arg0 : memref<2x4xf32>) {
- // expected-error@+1 {{failed to legalize operation 'memref.copy'}}
- memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
- return
-}
-
-// -----
-
func.func @alloca_with_dynamic_shape() {
%0 = index.constant 1
// expected-error@+1 {{failed to legalize operation 'memref.alloca'}}
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index ad9d649..45b1a1f 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -464,7 +464,9 @@ func.func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fv
// CHECK: llvm.atomicrmw _or %{{.*}}, %{{.*}} acq_rel
memref.atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel
- // CHECK-INTERFACE-COUNT-13: llvm.atomicrmw
+ memref.atomic_rmw xori %ival, %I[%i] : (i32, memref<10xi32>) -> i32
+ // CHECK: llvm.atomicrmw _xor %{{.*}}, %{{.*}} acq_rel
+ // CHECK-INTERFACE-COUNT-14: llvm.atomicrmw
return
}
diff --git a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
index 0288aa1..6c6756f 100644
--- a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
@@ -1,12 +1,13 @@
-// RUN: mlir-opt %s -test-llvm-legalize-patterns -split-input-file
+// RUN: mlir-opt %s -test-llvm-legalize-patterns -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-llvm-legalize-patterns="allow-pattern-rollback=0" -split-input-file | FileCheck %s
// Test the argument materializer for ranked MemRef types.
// CHECK-LABEL: func @construct_ranked_memref_descriptor(
-// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-COUNT-7: llvm.insertvalue
// CHECK: builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<5x4xf32>
-func.func @construct_ranked_memref_descriptor(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) {
+func.func @construct_ranked_memref_descriptor(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) attributes {is_legal} {
%0 = "test.direct_replacement"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64) -> (memref<5x4xf32>)
"test.legal_op"(%0) : (memref<5x4xf32>) -> ()
return
@@ -21,7 +22,7 @@ func.func @construct_ranked_memref_descriptor(%arg0: !llvm.ptr, %arg1: !llvm.ptr
// CHECK-LABEL: func @invalid_ranked_memref_descriptor(
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i1 to memref<5x4xf32>
// CHECK: "test.legal_op"(%[[cast]])
-func.func @invalid_ranked_memref_descriptor(%arg0: i1) {
+func.func @invalid_ranked_memref_descriptor(%arg0: i1) attributes {is_legal} {
%0 = "test.direct_replacement"(%arg0) : (i1) -> (memref<5x4xf32>)
"test.legal_op"(%0) : (memref<5x4xf32>) -> ()
return
@@ -32,10 +33,10 @@ func.func @invalid_ranked_memref_descriptor(%arg0: i1) {
// Test the argument materializer for unranked MemRef types.
// CHECK-LABEL: func @construct_unranked_memref_descriptor(
-// CHECK: llvm.mlir.undef : !llvm.struct<(i64, ptr)>
+// CHECK: llvm.mlir.poison : !llvm.struct<(i64, ptr)>
// CHECK-COUNT-2: llvm.insertvalue
// CHECK: builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(i64, ptr)> to memref<*xf32>
-func.func @construct_unranked_memref_descriptor(%arg0: i64, %arg1: !llvm.ptr) {
+func.func @construct_unranked_memref_descriptor(%arg0: i64, %arg1: !llvm.ptr) attributes {is_legal} {
%0 = "test.direct_replacement"(%arg0, %arg1) : (i64, !llvm.ptr) -> (memref<*xf32>)
"test.legal_op"(%0) : (memref<*xf32>) -> ()
return
@@ -50,8 +51,107 @@ func.func @construct_unranked_memref_descriptor(%arg0: i64, %arg1: !llvm.ptr) {
// CHECK-LABEL: func @invalid_unranked_memref_descriptor(
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i1 to memref<*xf32>
// CHECK: "test.legal_op"(%[[cast]])
-func.func @invalid_unranked_memref_descriptor(%arg0: i1) {
+func.func @invalid_unranked_memref_descriptor(%arg0: i1) attributes {is_legal} {
%0 = "test.direct_replacement"(%arg0) : (i1) -> (memref<*xf32>)
"test.legal_op"(%0) : (memref<*xf32>) -> ()
return
}
+
+// -----
+
+// CHECK-LABEL: llvm.func @simple_func_conversion(
+// CHECK-SAME: %[[arg0:.*]]: i64) -> i64
+// CHECK: llvm.return %[[arg0]] : i64
+func.func @simple_func_conversion(%arg0: i64) -> i64 {
+ return %arg0 : i64
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @one_to_n_argument_conversion(
+// CHECK-SAME: %[[arg0:.*]]: i18, %[[arg1:.*]]: i18)
+// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[arg0]], %[[arg1]] : i18, i18 to i17
+// CHECK: "test.legal_op"(%[[cast]]) : (i17) -> ()
+func.func @one_to_n_argument_conversion(%arg0: i17) {
+ "test.legal_op"(%arg0) : (i17) -> ()
+ return
+}
+
+// CHECK: llvm.func @caller(%[[arg0:.*]]: i18, %[[arg1:.*]]: i18)
+// CHECK: llvm.call @one_to_n_argument_conversion(%[[arg0]], %[[arg1]]) : (i18, i18) -> ()
+func.func @caller(%arg0: i17) {
+ func.call @one_to_n_argument_conversion(%arg0) : (i17) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @one_to_n_return_conversion(
+// CHECK-SAME: %[[arg0:.*]]: i18, %[[arg1:.*]]: i18) -> !llvm.struct<(i18, i18)>
+// CHECK: %[[p1:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18)>
+// CHECK: %[[p2:.*]] = llvm.insertvalue %[[arg0]], %[[p1]][0] : !llvm.struct<(i18, i18)>
+// CHECK: %[[p3:.*]] = llvm.insertvalue %[[arg1]], %[[p2]][1] : !llvm.struct<(i18, i18)>
+// CHECK: llvm.return %[[p3]]
+func.func @one_to_n_return_conversion(%arg0: i17) -> i17 {
+ return %arg0 : i17
+}
+
+// CHECK: llvm.func @caller(%[[arg0:.*]]: i18, %[[arg1:.*]]: i18)
+// CHECK: %[[res:.*]] = llvm.call @one_to_n_return_conversion(%[[arg0]], %[[arg1]]) : (i18, i18) -> !llvm.struct<(i18, i18)>
+// CHECK: %[[e0:.*]] = llvm.extractvalue %[[res]][0] : !llvm.struct<(i18, i18)>
+// CHECK: %[[e1:.*]] = llvm.extractvalue %[[res]][1] : !llvm.struct<(i18, i18)>
+// CHECK: %[[i0:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18)>
+// CHECK: %[[i1:.*]] = llvm.insertvalue %[[e0]], %[[i0]][0] : !llvm.struct<(i18, i18)>
+// CHECK: %[[i2:.*]] = llvm.insertvalue %[[e1]], %[[i1]][1] : !llvm.struct<(i18, i18)>
+// CHECK: llvm.return %[[i2]]
+func.func @caller(%arg0: i17) -> (i17) {
+ %res = func.call @one_to_n_return_conversion(%arg0) : (i17) -> (i17)
+ return %res : i17
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @multi_return(
+// CHECK-SAME: %[[arg0:.*]]: i18, %[[arg1:.*]]: i18, %[[arg2:.*]]: i1) -> !llvm.struct<(i18, i18, i1)>
+// CHECK: %[[p1:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18, i1)>
+// CHECK: %[[p2:.*]] = llvm.insertvalue %[[arg0]], %[[p1]][0] : !llvm.struct<(i18, i18, i1)>
+// CHECK: %[[p3:.*]] = llvm.insertvalue %[[arg1]], %[[p2]][1] : !llvm.struct<(i18, i18, i1)>
+// CHECK: %[[p4:.*]] = llvm.insertvalue %[[arg2]], %[[p3]][2] : !llvm.struct<(i18, i18, i1)>
+// CHECK: llvm.return %[[p4]]
+func.func @multi_return(%arg0: i17, %arg1: i1) -> (i17, i1) {
+ return %arg0, %arg1 : i17, i1
+}
+
+// CHECK: llvm.func @caller(%[[arg0:.*]]: i1, %[[arg1:.*]]: i18, %[[arg2:.*]]: i18)
+// CHECK: %[[res:.*]] = llvm.call @multi_return(%[[arg1]], %[[arg2]], %[[arg0]]) : (i18, i18, i1) -> !llvm.struct<(i18, i18, i1)>
+// CHECK: %[[e0:.*]] = llvm.extractvalue %[[res]][0] : !llvm.struct<(i18, i18, i1)>
+// CHECK: %[[e1:.*]] = llvm.extractvalue %[[res]][1] : !llvm.struct<(i18, i18, i1)>
+// CHECK: %[[e2:.*]] = llvm.extractvalue %[[res]][2] : !llvm.struct<(i18, i18, i1)>
+// CHECK: %[[i0:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18, i1, i18, i18)>
+// CHECK: %[[i1:.*]] = llvm.insertvalue %[[e0]], %[[i0]][0]
+// CHECK: %[[i2:.*]] = llvm.insertvalue %[[e1]], %[[i1]][1]
+// CHECK: %[[i3:.*]] = llvm.insertvalue %[[e2]], %[[i2]][2]
+// CHECK: %[[i4:.*]] = llvm.insertvalue %[[e0]], %[[i3]][3]
+// CHECK: %[[i5:.*]] = llvm.insertvalue %[[e1]], %[[i4]][4]
+// CHECK: llvm.return %[[i5]]
+func.func @caller(%arg0: i1, %arg1: i17) -> (i17, i1, i17) {
+ %res:2 = func.call @multi_return(%arg1, %arg0) : (i17, i1) -> (i17, i1)
+ return %res#0, %res#1, %res#0 : i17, i1, i17
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @branch(
+// CHECK-SAME: %[[arg0:.*]]: i1, %[[arg1:.*]]: i18, %[[arg2:.*]]: i18)
+// CHECK: llvm.br ^[[bb1:.*]](%[[arg1]], %[[arg2]], %[[arg0]] : i18, i18, i1)
+// CHECK: ^[[bb1]](%[[arg3:.*]]: i18, %[[arg4:.*]]: i18, %[[arg5:.*]]: i1):
+// CHECK: llvm.cond_br %[[arg5]], ^[[bb1]](%[[arg1]], %[[arg2]], %[[arg5]] : i18, i18, i1), ^[[bb2:.*]](%[[arg3]], %[[arg4]] : i18, i18)
+// CHECK: ^bb2(%{{.*}}: i18, %{{.*}}: i18):
+// CHECK: llvm.return
+func.func @branch(%arg0: i1, %arg1: i17) {
+ cf.br ^bb1(%arg1, %arg0: i17, i1)
+^bb1(%arg2: i17, %arg3: i1):
+ cf.cond_br %arg3, ^bb1(%arg1, %arg3 : i17, i1), ^bb2(%arg2 : i17)
+^bb2(%arg4: i17):
+ return
+}
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 2a7be0b..e6321e9 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -85,6 +85,28 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i :
return %0: i1
}
+// CHECK-LABEL: func @load_aligned
+// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %[[IDX:.+]]: index)
+func.func @load_aligned(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i : index) -> i1 {
+ // CHECK: spirv.Load "StorageBuffer" {{.*}} ["Aligned", 32] : i8
+ %0 = memref.load %src[%i] { alignment = 32 } : memref<4xi1, #spirv.storage_class<StorageBuffer>>
+ return %0: i1
+}
+
+// CHECK-LABEL: func @load_aligned_nontemporal
+func.func @load_aligned_nontemporal(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i : index) -> i1 {
+ // CHECK: spirv.Load "StorageBuffer" {{.*}} ["Aligned|Nontemporal", 32] : i8
+ %0 = memref.load %src[%i] { alignment = 32, nontemporal = true } : memref<4xi1, #spirv.storage_class<StorageBuffer>>
+ return %0: i1
+}
+
+// CHECK-LABEL: func @load_aligned_psb
+func.func @load_aligned_psb(%src: memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>>, %i : index) -> i1 {
+ // CHECK: %[[VAL:.+]] = spirv.Load "PhysicalStorageBuffer" {{.*}} ["Aligned", 32] : i8
+ %0 = memref.load %src[%i] { alignment = 32 } : memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>>
+ return %0: i1
+}
+
// CHECK-LABEL: func @store_i1
// CHECK-SAME: %[[DST:.+]]: memref<4xi1, #spirv.storage_class<StorageBuffer>>,
// CHECK-SAME: %[[IDX:.+]]: index
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 8d4f947..0c500e1 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -159,7 +159,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec
// CHECK-LABEL: @ldmatrix_x4
func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> {
%c0 = arith.constant 0 : index
- // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} {{.*}} -> !llvm.struct<(i32, i32, i32, i32)
+ // CHECK: nvvm.ldmatrix {{%.+}} {eltType = #nvvm.ld_st_matrix_elt_type<b16>, layout = #nvvm.mma_layout<row>, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)>
%a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16>
// CHECK: llvm.extractvalue
// CHECK: llvm.bitcast
@@ -179,7 +179,7 @@ func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> {
// CHECK-LABEL: @ldmatrix_x1
func.func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) -> vector<1x2xf16> {
%c0 = arith.constant 0 : index
- // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} {{.*}} -> i32
+ // CHECK: nvvm.ldmatrix {{%.+}} {eltType = #nvvm.ld_st_matrix_elt_type<b16>, layout = #nvvm.mma_layout<row>, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : {{.*}} -> i32
%a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16>
// CHECK: llvm.bitcast
// CHECK: llvm.insertvalue
@@ -817,9 +817,9 @@ func.func @create_tensor_map(%devicePtr2d : memref<64x128xf32>, %devicePtr1d : m
// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.tensormap.descriptor<tensor = memref<128xf32, 3>, swizzle = none, l2promo = none, oob = nan, interleave = none>, %[[arg1:[a-zA-Z0-9_]+]]: i1
func.func @tma_prefetch(%tensorMap1d: !tensorMap1d, %p : i1) {
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.tensormap.descriptor<tensor = memref<128xf32, 3>, swizzle = none, l2promo = none, oob = nan, interleave = none> to !llvm.ptr
- // CHECK: nvvm.prefetch.tensormap %[[S0]] : !llvm.ptr
+ // CHECK: nvvm.prefetch tensormap, %[[S0]] : !llvm.ptr
nvgpu.tma.prefetch.descriptor %tensorMap1d: !tensorMap1d
- // CHECK: nvvm.prefetch.tensormap %[[S0]], predicate = %[[arg1]] : !llvm.ptr, i1
+ // CHECK: nvvm.prefetch tensormap, %[[S0]], predicate = %[[arg1]] : !llvm.ptr, i1
nvgpu.tma.prefetch.descriptor %tensorMap1d, predicate = %p: !tensorMap1d
func.return
}
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 580b09d..bf80d9a 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -3,6 +3,7 @@
// Same below, but using the `ConvertToLLVMPatternInterface` entry point
// and the generic `convert-to-llvm` pass.
// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --convert-to-llvm="allow-pattern-rollback=0" --split-input-file %s | FileCheck %s
// CHECK-LABEL: @init_mbarrier
llvm.func @init_mbarrier(%barrier_gen : !llvm.ptr, %barrier : !llvm.ptr<3>, %count : i32, %pred : i1) {
@@ -213,47 +214,36 @@ func.func @tma_load_multicast5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>,
// CHECK-LABEL: @tma_store_1d
func.func @tma_store_1d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [$0, {$2} ], [$1];", "l,r,r"
- nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0] : !llvm.ptr, !llvm.ptr<3>, i32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$3 cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [$0, {$2} ], [$1];", "l,r,r,b"
- nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i1
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0], predicate=%p : !llvm.ptr, !llvm.ptr<3>
return
}
// CHECK-LABEL: @tma_store_2d
func.func @tma_store_2d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$0, {$2, $3} ], [$1];", "l,r,r,r"
- nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1] : !llvm.ptr, !llvm.ptr<3>, i32, i32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$4 cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$0, {$2, $3} ], [$1];", "l,r,r,r,b"
- nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i1
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1], predicate=%p : !llvm.ptr, !llvm.ptr<3>
return
}
// CHECK-LABEL: @tma_store_3d
func.func @tma_store_3d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [$0, {$2, $3, $4} ], [$1];", "l,r,r,r,r"
- nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2] : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [$0, {$2, $3, $4} ], [$1];", "l,r,r,r,r,b"
- nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i1
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2], predicate=%p : !llvm.ptr, !llvm.ptr<3>
return
}
// CHECK-LABEL: @tma_store_4d
func.func @tma_store_4d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5} ], [$1];", "l,r,r,r,r,r"
- nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3] : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$6 cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5} ], [$1];", "l,r,r,r,r,r,b"
- nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i1
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3], predicate=%p : !llvm.ptr, !llvm.ptr<3>
return
}
// CHECK-LABEL: @tma_store_5d
func.func @tma_store_5d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %p : i1) {
- // CHECK-NEXT: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5, $6} ], [$1];", "l,r,r,r,r,r,r"
- nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i32
-
// CHECK-NEXT: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5, $6} ], [$1];", "l,r,r,r,r,r,r,b"
- nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3,%crd4], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i32, i1
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3,%crd4], predicate=%p : !llvm.ptr, !llvm.ptr<3>
return
}
@@ -582,10 +572,10 @@ func.func @elect_one_leader_sync() {
// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l"
- nvvm.prefetch.tensormap %desc : !llvm.ptr
+ //CHECK: nvvm.prefetch tensormap, %{{.*}}
+ nvvm.prefetch tensormap, %desc : !llvm.ptr
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$1 prefetch.tensormap [$0];", "l,b"
- nvvm.prefetch.tensormap %desc, predicate = %pred : !llvm.ptr, i1
+ nvvm.prefetch tensormap, %desc, predicate = %pred : !llvm.ptr, i1
llvm.return
}
@@ -666,18 +656,127 @@ llvm.func @init_mbarrier(
%count : i32,
%pred : i1) {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.init.b64 [$0], $1;", "l,r"
- nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count) : !llvm.ptr, i32
+ nvvm.inline_ptx "mbarrier.init.b64 [{$r0}], {$r1};" ro (%barrier_gen, %count : !llvm.ptr, i32)
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b"
- nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count), predicate = %pred : !llvm.ptr, i32, i1
+ nvvm.inline_ptx "mbarrier.init.b64 [{$r0}], {$r1};" ro (%barrier_gen, %count : !llvm.ptr, i32), predicate = %pred
llvm.return
}
// -----
llvm.func @ex2(%input : f32, %pred : i1) {
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "ex2.approx.ftz.f32 $0, $1;", "=f,f" %{{.*}} : (f32) -> f32
- %0 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input) : f32 -> f32
+ %0 = nvvm.inline_ptx "ex2.approx.ftz.f32 {$w0}, {$r0};" ro (%input : f32) -> f32
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "@$1 ex2.approx.ftz.f32 $0, $1;", "=f,f,b" %{{.*}}, %{{.*}} : (f32, i1) -> f32
- %1 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input), predicate = %pred : f32, i1 -> f32
+ %1 = nvvm.inline_ptx "ex2.approx.ftz.f32 {$w0}, {$r0};" ro (%input : f32), predicate = %pred -> f32
+ llvm.return
+}
+
+// CHECK-LABEL: @multi_return(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32)
+llvm.func @multi_return(%a : i32, %b : i32) -> i32 {
+ // CHECK: %[[S1:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{.reg .pred p; setp.ge.s32 p, $2, $3; selp.s32 $0, $2,$3, p; selp.s32 $1, $2,$3, p;}", "=r,=r,r,r" %[[arg0]], %[[arg1]] : (i32, i32) -> !llvm.struct<(i32, i32)>
+ // CHECK: %[[S2:.+]] = llvm.extractvalue %[[S1]][0] : !llvm.struct<(i32, i32)>
+ // CHECK: %[[S3:.+]] = llvm.extractvalue %[[S1]][1] : !llvm.struct<(i32, i32)>
+ // CHECK: %[[S4:.+]] = llvm.add %[[S2]], %[[S3]] : i32
+ // CHECK: llvm.return %[[S4]] : i32
+ %r1, %r2 = nvvm.inline_ptx "{.reg .pred p; setp.ge.s32 p, {$r0}, {$r1}; selp.s32 {$w0}, {$r0},{$r1}, p; selp.s32 {$w1}, {$r0},{$r1}, p;}"
+ ro (%a, %b : i32,i32) -> i32,i32
+ %r3 = llvm.add %r1, %r2 : i32
+ llvm.return %r3 : i32
+}
+
+// CHECK-LABEL: @inline_ptx_multi_rw(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32, %[[arg2:[a-zA-Z0-9_]+]]: f32, %[[arg3:[a-zA-Z0-9_]+]]: f32)
+llvm.func @inline_ptx_multi_rw(%a : i32, %b : i32, %rw_c : f32, %rw_d : f32) -> f32 {
+// CHECK: %[[S0:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{.reg .pred p; setp.ge.s32 p, $2, $3; selp.s32 $0, $2,$3, p; selp.s32 $1, $2,$3, p;}",
+// CHECK-SAME: "=f,=f,r,r,0,1"
+// CHECK-SAME: %[[arg2]], %[[arg3]], %[[arg0]], %[[arg1]]
+// CHECK-SAME: : (f32, f32, i32, i32) -> !llvm.struct<(f32, f32)>
+// CHECK: %[[S1:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct<(f32, f32)>
+// CHECK: %[[S2:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct<(f32, f32)>
+// CHECK: %[[S3:.+]] = llvm.fadd %[[S1]], %[[S2]] : f32
+// CHECK: llvm.return %[[S3]] : f32
+ nvvm.inline_ptx "{.reg .pred p; setp.ge.s32 p, {$r0}, {$r1}; selp.s32 {$rw0}, {$r0},{$r1}, p; selp.s32 {$rw1}, {$r0},{$r1}, p;}"
+ ro (%a, %b : i32,i32)
+ rw (%rw_c, %rw_d: f32,f32)
+ %r4 = llvm.fadd %rw_c, %rw_d : f32
+ llvm.return %r4 : f32
+}
+
+// CHECK-LABEL: @inline_ptx_multi_rw_r(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32, %[[arg2:[a-zA-Z0-9_]+]]: f32, %[[arg3:[a-zA-Z0-9_]+]]: f32)
+llvm.func @inline_ptx_multi_rw_r(%a : i32, %b : i32, %rw_c : f32, %rw_d : f32) -> f32 {
+// CHECK: %[[S0:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{.reg .pred p; setp.ge.s32 p, $4, $5; selp.s32 $0, $4,$5, p; selp.s32 $1, $4,$5, p; selp.s32 $2, $4,$5, p; selp.s32 $3, $4,$5, p;}",
+// CHECK-SAME: "=f,=f,=r,=r,r,r,0,1"
+// CHECK-SAME: %[[arg2]], %[[arg3]], %[[arg0]], %[[arg1]] :
+// CHECK-SAME: (f32, f32, i32, i32) -> !llvm.struct<(f32, f32, i32, i32)>
+// CHECK: %[[S1:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct<(f32, f32, i32, i32)>
+// CHECK: %[[S2:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct<(f32, f32, i32, i32)>
+// CHECK: %[[S3:.+]] = llvm.extractvalue %[[S0]][2] : !llvm.struct<(f32, f32, i32, i32)>
+// CHECK: %[[S4:.+]] = llvm.extractvalue %[[S0]][3] : !llvm.struct<(f32, f32, i32, i32)>
+// CHECK: %[[S5:.+]] = llvm.add %[[S3]], %[[S4]] : i32
+// CHECK: %[[S6:.+]] = llvm.sitofp %[[S5]] : i32 to f32
+// CHECK: %[[S7:.+]] = llvm.fadd %[[S1]], %[[S2]] : f32
+// CHECK: %[[S8:.+]] = llvm.fadd %[[S6]], %[[S2]] : f32
+// CHECK: llvm.return %[[S8]] : f32
+
+ %wo0, %wo1 = nvvm.inline_ptx "{.reg .pred p; setp.ge.s32 p, {$r0}, {$r1}; selp.s32 {$rw0}, {$r0},{$r1}, p; selp.s32 {$rw1}, {$r0},{$r1}, p; selp.s32 {$w0}, {$r0},{$r1}, p; selp.s32 {$w1}, {$r0},{$r1}, p;}"
+ ro (%a, %b : i32,i32)
+ rw (%rw_c, %rw_d: f32,f32) -> i32,i32
+ %r3 = llvm.add %wo0, %wo1 : i32
+ %r3f = llvm.sitofp %r3 : i32 to f32
+ %r4 = llvm.fadd %rw_c, %rw_d : f32
+ %r5 = llvm.fadd %r3f, %rw_d : f32
+ llvm.return %r5 : f32
+}
+
+
+// -----
+
+// CHECK-LABEL: @nvvm_pmevent
+llvm.func @nvvm_pmevent() {
+ // CHECK: %[[S0:.+]] = llvm.mlir.constant(10 : i32) : i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "pmevent $0;", "n" %[[S0]] : (i32) -> ()
+
+ nvvm.pmevent id = 10
+ // CHECK: %[[S1:.+]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "pmevent $0;", "n" %[[S1]] : (i32) -> ()
+ nvvm.pmevent id = 4
llvm.return
}
+
+// -----
+
+llvm.func @inline_ptx_pack_4i8(%src : vector<4xi8>, %mask : i32, %zero: i32) {
+// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "dp4a.s32.s32 $0, $1, $2, $3;", "=r,r,r,r" %{{.*}}, %{{.*}}, %{{.*}} : (vector<4xi8>, i32, i32) -> i32
+ %wo0 = nvvm.inline_ptx "dp4a.s32.s32 {$w0}, {$r0}, {$r1}, {$r2};"
+ ro(%src, %mask, %zero : vector<4xi8>, i32, i32)
+ -> i32
+ llvm.return
+}
+
+llvm.func @inline_ptx_pack_2bf16(%a : f32, %b : f32) {
+ // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "cvt.rn.satfinite.bf16x2.f32 $0, $1, $2;", "=f,f,f" %{{.*}}, %{{.*}} : (f32, f32) -> vector<2xbf16>
+ %wo0 = nvvm.inline_ptx "cvt.rn.satfinite.bf16x2.f32 {$w0}, {$r0}, {$r1};"
+ ro(%a, %b : f32, f32)
+ -> vector<2xbf16>
+ llvm.return
+}
+
+llvm.func @inline_ptx_cvt_rn_e4m3x2_f16x2(%a : i16) {
+// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "cvt.rz.satfinite.ue8m0x2.bf16x2 $0, $1", "=f,h" %{{.*}} : (i16) -> vector<2xbf16>
+ %wo0 = nvvm.inline_ptx "cvt.rz.satfinite.ue8m0x2.bf16x2 {$w0}, {$r0}"
+ ro(%a : i16)
+ -> vector<2xbf16>
+ llvm.return
+}
+
+llvm.func @cvt_i8_bf16(%a : i8) {
+ // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .b16 r;\0A\09.reg .b8 s;\0A\09mov.b16 {s,_}, $0;\0A\09cvt.rn.bf16.s8 r, s;\0A\09mov.b16 $1, r;\0A\09", "=h,h" %{{.*}} : (i16) -> i16
+ %za = llvm.zext %a : i8 to i16
+ %wo0 = nvvm.inline_ptx "{\n\t.reg .b16 r;\n\t.reg .b8 s;\n\tmov.b16 {s,_}, {$w0};\n\tcvt.rn.bf16.s8 r, s;\n\tmov.b16 {$r0}, r;\n\t"
+ ro(%za : i16)
+ -> i16
+ llvm.return
+}
diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
index d69de99..7d8ccd9 100644
--- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt -convert-openmp-to-llvm -split-input-file %s | FileCheck %s
// RUN: mlir-opt -convert-to-llvm -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -convert-to-llvm="allow-pattern-rollback=0" -split-input-file %s | FileCheck %s
// CHECK-LABEL: llvm.func @foo(i64, i64)
func.func private @foo(index, index)
diff --git a/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir b/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir
new file mode 100644
index 0000000..dc645fe
--- /dev/null
+++ b/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir
@@ -0,0 +1,318 @@
+// RUN: mlir-opt %s -convert-to-llvm | FileCheck %s
+
+// Tests different variants of ptr_add operation with various attributes
+// (regular, nusw, nuw, inbounds)
+// CHECK-LABEL: llvm.func @test_ptr_add(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: i64) -> !llvm.struct<(ptr, ptr, ptr, ptr)> {
+// CHECK: %[[VAL_0:.*]] = llvm.getelementptr %[[ARG0]]{{\[}}%[[ARG1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK: %[[VAL_1:.*]] = llvm.getelementptr nusw %[[ARG0]]{{\[}}%[[ARG1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK: %[[VAL_2:.*]] = llvm.getelementptr nuw %[[ARG0]]{{\[}}%[[ARG1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK: %[[VAL_3:.*]] = llvm.getelementptr inbounds %[[ARG0]]{{\[}}%[[ARG1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK: %[[VAL_4:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, ptr, ptr)>
+// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_4]][0] : !llvm.struct<(ptr, ptr, ptr, ptr)>
+// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_5]][1] : !llvm.struct<(ptr, ptr, ptr, ptr)>
+// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[VAL_2]], %[[VAL_6]][2] : !llvm.struct<(ptr, ptr, ptr, ptr)>
+// CHECK: %[[VAL_8:.*]] = llvm.insertvalue %[[VAL_3]], %[[VAL_7]][3] : !llvm.struct<(ptr, ptr, ptr, ptr)>
+// CHECK: llvm.return %[[VAL_8]] : !llvm.struct<(ptr, ptr, ptr, ptr)>
+// CHECK: }
+func.func @test_ptr_add(%arg0: !ptr.ptr<#ptr.generic_space>, %arg1: index) -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>) {
+ %0 = ptr.ptr_add %arg0, %arg1 : <#ptr.generic_space>, index
+ %1 = ptr.ptr_add nusw %arg0, %arg1 : <#ptr.generic_space>, index
+ %2 = ptr.ptr_add nuw %arg0, %arg1 : <#ptr.generic_space>, index
+ %3 = ptr.ptr_add inbounds %arg0, %arg1 : <#ptr.generic_space>, index
+ return %0, %1, %2, %3 : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>
+}
+
+// Tests type_offset operation which returns the size of different types
+// CHECK-LABEL: llvm.func @test_type_offset() -> !llvm.struct<(i64, i64, i64)> {
+// CHECK: %[[VAL_0:.*]] = llvm.mlir.zero : !llvm.ptr
+// CHECK: %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][1] : (!llvm.ptr) -> !llvm.ptr, f32
+// CHECK: %[[VAL_2:.*]] = llvm.ptrtoint %[[VAL_1]] : !llvm.ptr to i64
+// CHECK: %[[VAL_3:.*]] = llvm.mlir.zero : !llvm.ptr
+// CHECK: %[[VAL_4:.*]] = llvm.getelementptr %[[VAL_3]][1] : (!llvm.ptr) -> !llvm.ptr, i64
+// CHECK: %[[VAL_5:.*]] = llvm.ptrtoint %[[VAL_4]] : !llvm.ptr to i64
+// CHECK: %[[VAL_6:.*]] = llvm.mlir.zero : !llvm.ptr
+// CHECK: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_6]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i32, f64)>
+// CHECK: %[[VAL_8:.*]] = llvm.ptrtoint %[[VAL_7]] : !llvm.ptr to i64
+// CHECK: %[[VAL_9:.*]] = llvm.mlir.poison : !llvm.struct<(i64, i64, i64)>
+// CHECK: %[[VAL_10:.*]] = llvm.insertvalue %[[VAL_2]], %[[VAL_9]][0] : !llvm.struct<(i64, i64, i64)>
+// CHECK: %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_10]][1] : !llvm.struct<(i64, i64, i64)>
+// CHECK: %[[VAL_12:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_11]][2] : !llvm.struct<(i64, i64, i64)>
+// CHECK: llvm.return %[[VAL_12]] : !llvm.struct<(i64, i64, i64)>
+// CHECK: }
+func.func @test_type_offset() -> (index, index, index) {
+ %0 = ptr.type_offset f32 : index
+ %1 = ptr.type_offset i64 : index
+ %2 = ptr.type_offset !llvm.struct<(i32, f64)> : index
+ return %0, %1, %2 : index, index, index
+}
+
+// Tests converting a memref to a pointer using to_ptr
+// CHECK-LABEL: llvm.func @test_to_ptr(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: i64, %[[ARG3:.*]]: i64, %[[ARG4:.*]]: i64) -> !llvm.ptr {
+// CHECK: %[[VAL_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[VAL_1:.*]] = llvm.insertvalue %[[ARG0]], %[[VAL_0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[ARG1]], %[[VAL_1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[ARG2]], %[[VAL_2]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[ARG3]], %[[VAL_3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[ARG4]], %[[VAL_4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[VAL_6:.*]] = llvm.extractvalue %[[VAL_5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: llvm.return %[[VAL_6]] : !llvm.ptr
+// CHECK: }
+func.func @test_to_ptr(%arg0: memref<10xf32, #ptr.generic_space>) -> !ptr.ptr<#ptr.generic_space> {
+ %0 = ptr.to_ptr %arg0 : memref<10xf32, #ptr.generic_space> -> <#ptr.generic_space>
+ return %0 : !ptr.ptr<#ptr.generic_space>
+}
+
+// Tests extracting metadata from a static-sized memref
+// CHECK-LABEL: llvm.func @test_get_metadata_static(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: i64, %[[ARG3:.*]]: i64, %[[ARG4:.*]]: i64, %[[ARG5:.*]]: i64, %[[ARG6:.*]]: i64) -> !llvm.struct<(ptr)> {
+// CHECK: %[[VAL_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_1:.*]] = llvm.insertvalue %[[ARG0]], %[[VAL_0]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[ARG1]], %[[VAL_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[ARG2]], %[[VAL_2]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[ARG3]], %[[VAL_3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[ARG5]], %[[VAL_4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[ARG4]], %[[VAL_5]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[ARG6]], %[[VAL_6]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_8:.*]] = llvm.mlir.undef : !llvm.struct<(ptr)>
+// CHECK: %[[VAL_9:.*]] = llvm.extractvalue %[[VAL_7]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_10:.*]] = llvm.insertvalue %[[VAL_9]], %[[VAL_8]][0] : !llvm.struct<(ptr)>
+// CHECK: llvm.return %[[VAL_10]] : !llvm.struct<(ptr)>
+// CHECK: }
+func.func @test_get_metadata_static(%arg0: memref<10x20xf32, #ptr.generic_space>) -> !ptr.ptr_metadata<memref<10x20xf32, #ptr.generic_space>> {
+ %0 = ptr.get_metadata %arg0 : memref<10x20xf32, #ptr.generic_space>
+ return %0 : !ptr.ptr_metadata<memref<10x20xf32, #ptr.generic_space>>
+}
+
+// Tests extracting metadata from a dynamically-sized memref
+// CHECK-LABEL: llvm.func @test_get_metadata_dynamic(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: i64, %[[ARG3:.*]]: i64, %[[ARG4:.*]]: i64, %[[ARG5:.*]]: i64, %[[ARG6:.*]]: i64) -> !llvm.struct<(ptr, i64, i64, i64)> {
+// CHECK: %[[VAL_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_1:.*]] = llvm.insertvalue %[[ARG0]], %[[VAL_0]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[ARG1]], %[[VAL_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[ARG2]], %[[VAL_2]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[ARG3]], %[[VAL_3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[ARG5]], %[[VAL_4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[ARG4]], %[[VAL_5]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[ARG6]], %[[VAL_6]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_8:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, i64, i64, i64)>
+// CHECK: %[[VAL_9:.*]] = llvm.extractvalue %[[VAL_7]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_10:.*]] = llvm.insertvalue %[[VAL_9]], %[[VAL_8]][0] : !llvm.struct<(ptr, i64, i64, i64)>
+// CHECK: %[[VAL_11:.*]] = llvm.extractvalue %[[VAL_7]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_12:.*]] = llvm.insertvalue %[[VAL_11]], %[[VAL_10]][1] : !llvm.struct<(ptr, i64, i64, i64)>
+// CHECK: %[[VAL_13:.*]] = llvm.extractvalue %[[VAL_7]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_14:.*]] = llvm.insertvalue %[[VAL_13]], %[[VAL_12]][2] : !llvm.struct<(ptr, i64, i64, i64)>
+// CHECK: %[[VAL_15:.*]] = llvm.extractvalue %[[VAL_7]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_16:.*]] = llvm.insertvalue %[[VAL_15]], %[[VAL_14]][3] : !llvm.struct<(ptr, i64, i64, i64)>
+// CHECK: llvm.return %[[VAL_16]] : !llvm.struct<(ptr, i64, i64, i64)>
+// CHECK: }
+func.func @test_get_metadata_dynamic(%arg0: memref<?x?xf32, #ptr.generic_space>) -> !ptr.ptr_metadata<memref<?x?xf32, #ptr.generic_space>> {
+ %0 = ptr.get_metadata %arg0 : memref<?x?xf32, #ptr.generic_space>
+ return %0 : !ptr.ptr_metadata<memref<?x?xf32, #ptr.generic_space>>
+}
+
+// Tests reconstructing a static-sized memref from a pointer and metadata
+// CHECK-LABEL: llvm.func @test_from_ptr_static(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.struct<(ptr)>) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> {
+// CHECK: %[[VAL_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(ptr)>
+// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[ARG0]], %[[VAL_2]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_3]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_6:.*]] = llvm.mlir.constant(10 : index) : i64
+// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_8:.*]] = llvm.mlir.constant(20 : index) : i64
+// CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_7]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_10:.*]] = llvm.mlir.constant(20 : index) : i64
+// CHECK: %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_9]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_12:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[VAL_13:.*]] = llvm.insertvalue %[[VAL_12]], %[[VAL_11]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.return %[[VAL_13]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: }
+func.func @test_from_ptr_static(%arg0: !ptr.ptr<#ptr.generic_space>, %arg1: !ptr.ptr_metadata<memref<10x20xf32, #ptr.generic_space>>) -> memref<10x20xf32, #ptr.generic_space> {
+ %0 = ptr.from_ptr %arg0 metadata %arg1 : <#ptr.generic_space> -> memref<10x20xf32, #ptr.generic_space>
+ return %0 : memref<10x20xf32, #ptr.generic_space>
+}
+
+// Tests reconstructing a dynamically-sized memref from a pointer and metadata
+// CHECK-LABEL: llvm.func @test_from_ptr_dynamic(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.struct<(ptr, i64, i64, i64)>) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> {
+// CHECK: %[[VAL_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(ptr, i64, i64, i64)>
+// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[ARG0]], %[[VAL_2]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_3]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_6:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(ptr, i64, i64, i64)>
+// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_8:.*]] = llvm.extractvalue %[[ARG1]][2] : !llvm.struct<(ptr, i64, i64, i64)>
+// CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_7]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_10:.*]] = llvm.extractvalue %[[ARG1]][3] : !llvm.struct<(ptr, i64, i64, i64)>
+// CHECK: %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_9]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_12:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[VAL_13:.*]] = llvm.insertvalue %[[VAL_12]], %[[VAL_11]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.return %[[VAL_13]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: }
+func.func @test_from_ptr_dynamic(%arg0: !ptr.ptr<#ptr.generic_space>, %arg1: !ptr.ptr_metadata<memref<?x?xf32, #ptr.generic_space>>) -> memref<?x?xf32, #ptr.generic_space> {
+ %0 = ptr.from_ptr %arg0 metadata %arg1 : <#ptr.generic_space> -> memref<?x?xf32, #ptr.generic_space>
+ return %0 : memref<?x?xf32, #ptr.generic_space>
+}
+
+// Tests a round-trip conversion of a memref with mixed static/dynamic dimensions
+// CHECK-LABEL: llvm.func @test_memref_mixed(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: i64, %[[ARG3:.*]]: i64, %[[ARG4:.*]]: i64, %[[ARG5:.*]]: i64, %[[ARG6:.*]]: i64, %[[ARG7:.*]]: i64, %[[ARG8:.*]]: i64) -> !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> {
+// CHECK: %[[VAL_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_1:.*]] = llvm.insertvalue %[[ARG0]], %[[VAL_0]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[ARG1]], %[[VAL_1]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[ARG2]], %[[VAL_2]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[ARG3]], %[[VAL_3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[ARG6]], %[[VAL_4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[ARG4]], %[[VAL_5]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[ARG7]], %[[VAL_6]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_8:.*]] = llvm.insertvalue %[[ARG5]], %[[VAL_7]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[ARG8]], %[[VAL_8]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_10:.*]] = llvm.extractvalue %[[VAL_9]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_11:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, i64, i64)>
+// CHECK: %[[VAL_12:.*]] = llvm.extractvalue %[[VAL_9]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_13:.*]] = llvm.insertvalue %[[VAL_12]], %[[VAL_11]][0] : !llvm.struct<(ptr, i64, i64)>
+// CHECK: %[[VAL_14:.*]] = llvm.extractvalue %[[VAL_9]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_15:.*]] = llvm.insertvalue %[[VAL_14]], %[[VAL_13]][1] : !llvm.struct<(ptr, i64, i64)>
+// CHECK: %[[VAL_16:.*]] = llvm.extractvalue %[[VAL_9]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_17:.*]] = llvm.insertvalue %[[VAL_16]], %[[VAL_15]][2] : !llvm.struct<(ptr, i64, i64)>
+// CHECK: llvm.return %[[VAL_9]] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: }
+func.func @test_memref_mixed(%arg0: memref<10x?x30xf32, #ptr.generic_space>) -> memref<10x?x30xf32, #ptr.generic_space> {
+ %0 = ptr.to_ptr %arg0 : memref<10x?x30xf32, #ptr.generic_space> -> <#ptr.generic_space>
+ %1 = ptr.get_metadata %arg0 : memref<10x?x30xf32, #ptr.generic_space>
+ %2 = ptr.from_ptr %0 metadata %1 : <#ptr.generic_space> -> memref<10x?x30xf32, #ptr.generic_space>
+ return %2 : memref<10x?x30xf32, #ptr.generic_space>
+}
+
+// Tests a round-trip conversion of a strided memref with explicit offset
+// CHECK-LABEL: llvm.func @test_memref_strided(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: i64, %[[ARG3:.*]]: i64, %[[ARG4:.*]]: i64, %[[ARG5:.*]]: i64, %[[ARG6:.*]]: i64) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> {
+// CHECK: %[[VAL_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_1:.*]] = llvm.insertvalue %[[ARG0]], %[[VAL_0]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[ARG1]], %[[VAL_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[ARG2]], %[[VAL_2]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[ARG3]], %[[VAL_3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[ARG5]], %[[VAL_4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[ARG4]], %[[VAL_5]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[ARG6]], %[[VAL_6]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_8:.*]] = llvm.extractvalue %[[VAL_7]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_9:.*]] = llvm.mlir.undef : !llvm.struct<(ptr)>
+// CHECK: %[[VAL_10:.*]] = llvm.extractvalue %[[VAL_7]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_9]][0] : !llvm.struct<(ptr)>
+// CHECK: llvm.return %[[VAL_7]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: }
+func.func @test_memref_strided(%arg0: memref<10x20xf32, strided<[40, 2], offset: 5>, #ptr.generic_space>) -> memref<10x20xf32, strided<[40, 2], offset: 5>, #ptr.generic_space> {
+ %0 = ptr.to_ptr %arg0 : memref<10x20xf32, strided<[40, 2], offset: 5>, #ptr.generic_space> -> <#ptr.generic_space>
+ %1 = ptr.get_metadata %arg0 : memref<10x20xf32, strided<[40, 2], offset: 5>, #ptr.generic_space>
+ %2 = ptr.from_ptr %0 metadata %1 : <#ptr.generic_space> -> memref<10x20xf32, strided<[40, 2], offset: 5>, #ptr.generic_space>
+ return %2 : memref<10x20xf32, strided<[40, 2], offset: 5>, #ptr.generic_space>
+}
+
+// Tests a comprehensive scenario with fully dynamic memref, including pointer arithmetic
+// CHECK-LABEL: llvm.func @test_comprehensive_dynamic(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: i64, %[[ARG3:.*]]: i64, %[[ARG4:.*]]: i64, %[[ARG5:.*]]: i64, %[[ARG6:.*]]: i64) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> {
+// CHECK: %[[VAL_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_1:.*]] = llvm.insertvalue %[[ARG0]], %[[VAL_0]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[ARG1]], %[[VAL_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[ARG2]], %[[VAL_2]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[ARG3]], %[[VAL_3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[ARG5]], %[[VAL_4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[ARG4]], %[[VAL_5]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[ARG6]], %[[VAL_6]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_8:.*]] = llvm.extractvalue %[[VAL_7]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_9:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, i64, i64, i64, i64, i64)>
+// CHECK: %[[VAL_10:.*]] = llvm.extractvalue %[[VAL_7]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_9]][0] : !llvm.struct<(ptr, i64, i64, i64, i64, i64)>
+// CHECK: %[[VAL_12:.*]] = llvm.extractvalue %[[VAL_7]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_13:.*]] = llvm.insertvalue %[[VAL_12]], %[[VAL_11]][1] : !llvm.struct<(ptr, i64, i64, i64, i64, i64)>
+// CHECK: %[[VAL_14:.*]] = llvm.extractvalue %[[VAL_7]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_15:.*]] = llvm.insertvalue %[[VAL_14]], %[[VAL_13]][2] : !llvm.struct<(ptr, i64, i64, i64, i64, i64)>
+// CHECK: %[[VAL_16:.*]] = llvm.extractvalue %[[VAL_7]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_17:.*]] = llvm.insertvalue %[[VAL_16]], %[[VAL_15]][3] : !llvm.struct<(ptr, i64, i64, i64, i64, i64)>
+// CHECK: %[[VAL_18:.*]] = llvm.extractvalue %[[VAL_7]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_19:.*]] = llvm.insertvalue %[[VAL_18]], %[[VAL_17]][4] : !llvm.struct<(ptr, i64, i64, i64, i64, i64)>
+// CHECK: %[[VAL_20:.*]] = llvm.extractvalue %[[VAL_7]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_21:.*]] = llvm.insertvalue %[[VAL_20]], %[[VAL_19]][5] : !llvm.struct<(ptr, i64, i64, i64, i64, i64)>
+// CHECK: %[[VAL_22:.*]] = llvm.mlir.zero : !llvm.ptr
+// CHECK: %[[VAL_23:.*]] = llvm.getelementptr %[[VAL_22]][1] : (!llvm.ptr) -> !llvm.ptr, f32
+// CHECK: %[[VAL_24:.*]] = llvm.ptrtoint %[[VAL_23]] : !llvm.ptr to i64
+// CHECK: %[[VAL_25:.*]] = llvm.getelementptr inbounds %[[VAL_8]]{{\[}}%[[VAL_24]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK: %[[VAL_26:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_27:.*]] = llvm.extractvalue %[[VAL_21]][0] : !llvm.struct<(ptr, i64, i64, i64, i64, i64)>
+// CHECK: %[[VAL_28:.*]] = llvm.insertvalue %[[VAL_27]], %[[VAL_26]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_29:.*]] = llvm.insertvalue %[[VAL_25]], %[[VAL_28]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_30:.*]] = llvm.extractvalue %[[VAL_21]][1] : !llvm.struct<(ptr, i64, i64, i64, i64, i64)>
+// CHECK: %[[VAL_31:.*]] = llvm.insertvalue %[[VAL_30]], %[[VAL_29]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_32:.*]] = llvm.extractvalue %[[VAL_21]][2] : !llvm.struct<(ptr, i64, i64, i64, i64, i64)>
+// CHECK: %[[VAL_33:.*]] = llvm.insertvalue %[[VAL_32]], %[[VAL_31]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_34:.*]] = llvm.extractvalue %[[VAL_21]][3] : !llvm.struct<(ptr, i64, i64, i64, i64, i64)>
+// CHECK: %[[VAL_35:.*]] = llvm.insertvalue %[[VAL_34]], %[[VAL_33]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_36:.*]] = llvm.extractvalue %[[VAL_21]][4] : !llvm.struct<(ptr, i64, i64, i64, i64, i64)>
+// CHECK: %[[VAL_37:.*]] = llvm.insertvalue %[[VAL_36]], %[[VAL_35]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[VAL_38:.*]] = llvm.extractvalue %[[VAL_21]][5] : !llvm.struct<(ptr, i64, i64, i64, i64, i64)>
+// CHECK: %[[VAL_39:.*]] = llvm.insertvalue %[[VAL_38]], %[[VAL_37]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.return %[[VAL_39]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: }
+func.func @test_comprehensive_dynamic(%arg0: memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>) -> memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space> {
+ %0 = ptr.to_ptr %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space> -> <#ptr.generic_space>
+ %1 = ptr.get_metadata %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
+ %2 = ptr.type_offset f32 : index
+ %3 = ptr.ptr_add inbounds %0, %2 : <#ptr.generic_space>, index
+ %4 = ptr.from_ptr %3 metadata %1 : <#ptr.generic_space> -> memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
+ return %4 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
+}
+
+// Tests a round-trip conversion of a 0D (scalar) memref
+// CHECK-LABEL: llvm.func @test_memref_0d(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: i64) -> !llvm.struct<(ptr, ptr, i64)> {
+// CHECK: %[[VAL_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[VAL_1:.*]] = llvm.insertvalue %[[ARG0]], %[[VAL_0]][0] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[ARG1]], %[[VAL_1]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[ARG2]], %[[VAL_2]][2] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[VAL_4:.*]] = llvm.extractvalue %[[VAL_3]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[VAL_5:.*]] = llvm.mlir.undef : !llvm.struct<(ptr)>
+// CHECK: %[[VAL_6:.*]] = llvm.extractvalue %[[VAL_3]][0] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_5]][0] : !llvm.struct<(ptr)>
+// CHECK: llvm.return %[[VAL_3]] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: }
+func.func @test_memref_0d(%arg0: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
+ %0 = ptr.to_ptr %arg0 : memref<f32, #ptr.generic_space> -> <#ptr.generic_space>
+ %1 = ptr.get_metadata %arg0 : memref<f32, #ptr.generic_space>
+ %2 = ptr.from_ptr %0 metadata %1 : <#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ return %2 : memref<f32, #ptr.generic_space>
+}
+
+// Tests ptr indexing with a pointer coming from a memref.
+// CHECK-LABEL: llvm.func @test_memref_ptradd_indexing(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: i64, %[[ARG3:.*]]: i64, %[[ARG4:.*]]: i64, %[[ARG5:.*]]: i64, %[[ARG6:.*]]: i64, %[[ARG7:.*]]: i64, %[[ARG8:.*]]: i64, %[[ARG9:.*]]: i64) -> !llvm.ptr {
+// CHECK: %[[VAL_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_1:.*]] = llvm.insertvalue %[[ARG0]], %[[VAL_0]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[ARG1]], %[[VAL_1]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[ARG2]], %[[VAL_2]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[ARG3]], %[[VAL_3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[ARG6]], %[[VAL_4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[ARG4]], %[[VAL_5]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[ARG7]], %[[VAL_6]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_8:.*]] = llvm.insertvalue %[[ARG5]], %[[VAL_7]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[ARG8]], %[[VAL_8]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_10:.*]] = llvm.extractvalue %[[VAL_9]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_11:.*]] = llvm.mlir.zero : !llvm.ptr
+// CHECK: %[[VAL_12:.*]] = llvm.getelementptr %[[VAL_11]][1] : (!llvm.ptr) -> !llvm.ptr, f32
+// CHECK: %[[VAL_13:.*]] = llvm.ptrtoint %[[VAL_12]] : !llvm.ptr to i64
+// CHECK: %[[VAL_14:.*]] = llvm.mul %[[VAL_13]], %[[ARG9]] : i64
+// CHECK: %[[VAL_15:.*]] = llvm.getelementptr %[[VAL_10]]{{\[}}%[[VAL_14]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK: llvm.return %[[VAL_15]] : !llvm.ptr
+// CHECK: }
+func.func @test_memref_ptradd_indexing(%arg0: memref<10x?x30xf32, #ptr.generic_space>, %arg1: index) -> !ptr.ptr<#ptr.generic_space> {
+ %0 = ptr.to_ptr %arg0 : memref<10x?x30xf32, #ptr.generic_space> -> <#ptr.generic_space>
+ %1 = ptr.type_offset f32 : index
+ %2 = arith.muli %1, %arg1 : index
+ %3 = ptr.ptr_add %0, %2 : <#ptr.generic_space>, index
+ return %3 : !ptr.ptr<#ptr.generic_space>
+}
diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
index ef0fa08..483c7b3 100644
--- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
@@ -18,6 +18,24 @@ func.func @simple_std_for_loop(%arg0 : index, %arg1 : index, %arg2 : index) {
return
}
+// CHECK-LABEL: func @unsigned_loop(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
+// CHECK-NEXT: cf.br ^bb1(%{{.*}} : index)
+// CHECK-NEXT: ^bb1(%{{.*}}: index): // 2 preds: ^bb0, ^bb2
+// CHECK-NEXT: %{{.*}} = arith.cmpi ult, %{{.*}}, %{{.*}} : index
+// CHECK-NEXT: cf.cond_br %{{.*}}, ^bb2, ^bb3
+// CHECK-NEXT: ^bb2: // pred: ^bb1
+// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
+// CHECK-NEXT: %[[iv:.*]] = arith.addi %{{.*}}, %{{.*}} : index
+// CHECK-NEXT: cf.br ^bb1(%[[iv]] : index)
+// CHECK-NEXT: ^bb3: // pred: ^bb1
+// CHECK-NEXT: return
+func.func @unsigned_loop(%arg0 : index, %arg1 : index, %arg2 : index) {
+ scf.for unsigned %i0 = %arg0 to %arg1 step %arg2 {
+ %c1 = arith.constant 1 : index
+ }
+ return
+}
+
// CHECK-LABEL: func @simple_std_2_for_loops(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
// CHECK-NEXT: cf.br ^bb1(%{{.*}} : index)
// CHECK-NEXT: ^bb1(%[[ub0:.*]]: index): // 2 preds: ^bb0, ^bb5
diff --git a/mlir/test/Conversion/SCFToSPIRV/for.mlir b/mlir/test/Conversion/SCFToSPIRV/for.mlir
index 81661ec..9c55216 100644
--- a/mlir/test/Conversion/SCFToSPIRV/for.mlir
+++ b/mlir/test/Conversion/SCFToSPIRV/for.mlir
@@ -5,6 +5,7 @@ module attributes {
#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
} {
+// CHECK-LABEL: @loop_kernel
func.func @loop_kernel(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg3 : memref<10xf32, #spirv.storage_class<StorageBuffer>>) {
// CHECK: %[[LB:.*]] = spirv.Constant 4 : i32
%lb = arith.constant 4 : index
@@ -34,6 +35,19 @@ func.func @loop_kernel(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer
return
}
+// CHECK-LABEL: @unsigned_loop
+// CHECK: spirv.ULessThan
+func.func @unsigned_loop(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg3 : memref<10xf32, #spirv.storage_class<StorageBuffer>>) {
+ %lb = arith.constant 4 : index
+ %ub = arith.constant 42 : index
+ %step = arith.constant 2 : index
+ scf.for unsigned %arg4 = %lb to %ub step %step {
+ %1 = memref.load %arg2[%arg4] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
+ memref.store %1, %arg3[%arg4] : memref<10xf32, #spirv.storage_class<StorageBuffer>>
+ }
+ return
+}
+
// CHECK-LABEL: @loop_yield
func.func @loop_yield(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer>>, %arg3 : memref<10xf32, #spirv.storage_class<StorageBuffer>>) {
// CHECK: %[[LB:.*]] = spirv.Constant 4 : i32
diff --git a/mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir b/mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir
index 4b32495..749c833 100644
--- a/mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir
+++ b/mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir
@@ -3,6 +3,6 @@
// CHECK-LABEL: @apply_scale_unsupported_inexact_round
func.func @apply_scale_unsupported_inexact_round(%arg0 : i64, %arg1 : i32, %arg2 : i8) -> (i32) {
// expected-error@+1 {{failed to legalize operation 'tosa.apply_scale'}}
- %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "INEXACT_ROUND"} : (i64, i32, i8) -> i32
+ %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = INEXACT_ROUND} : (i64, i32, i8) -> i32
return %res : i32
}
diff --git a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
index db68ca4..f293138 100644
--- a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
+++ b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
@@ -67,7 +67,7 @@ func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
// CHECK-DAG: %[[LOWALIGN:.+]] = arith.select %[[OVER31]], %[[C0]], %[[LOR]]
// CHECK-DAG: %[[RESULT:.+]] = arith.addi %[[LOWALIGN]], %[[HIALIGN]]
// CHECK: return %[[RESULT]]
- %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (i32, i32, i8) -> i32
+ %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = DOUBLE_ROUND} : (i32, i32, i8) -> i32
return %res : i32
}
@@ -77,7 +77,7 @@ func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
// SCALE: tosa.apply_scale
func.func @apply_scale_test_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>, %arg2 : vector<4xi8>) -> (vector<4xi32>) {
// CHECK-NOT: "tosa.apply_scale"
- %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32>
+ %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = DOUBLE_ROUND} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32>
return %res : vector<4xi32>
}
@@ -115,7 +115,7 @@ func.func @apply_scale_test_i48(%arg0 : i48, %arg1 : i32, %arg2 : i8) -> (i32) {
// CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]]
// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32
// CHECK: return %[[TRUNC]]
- %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (i48, i32, i8) -> i32
+ %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = DOUBLE_ROUND} : (i48, i32, i8) -> i32
return %res : i32
}
@@ -152,6 +152,6 @@ func.func @apply_scale_test_i64(%arg0 : i64, %arg1 : i32, %arg2 : i8) -> (i32) {
// CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]]
// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32
// CHECK: return %[[TRUNC]]
- %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (i64, i32, i8) -> i32
+ %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = DOUBLE_ROUND} : (i64, i32, i8) -> i32
return %res : i32
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 69d8471..ecfd953 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -38,7 +38,7 @@ func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32,
%input_zp = "tosa.const"() {values = dense<127> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = SINGLE_ROUND, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
}
@@ -73,11 +73,3 @@ func.func @unranked_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>)
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
-
-// -----
-
-func.func @mul_no_const_shift(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<1xi8>) -> tensor<2x3xi32> {
- // expected-error@+1 {{failed to legalize operation 'tosa.mul'}}
- %0 = tosa.mul %arg0, %arg1, %arg2 : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
- return %0 : tensor<2x3xi32>
-}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index a737a8a..9ea224a 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -423,7 +423,7 @@ func.func @avg_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) {
// CHECK: %[[TRUNC_SHIFT:.+]] = arith.trunci %[[SUB]]
// CHECK: %[[C30:.+]] = arith.constant 30
// CHECK: %[[SHIFT:.+]] = arith.addi %[[TRUNC_SHIFT]], %[[C30]] : i8
- // CHECK: %[[SCALED:.+]] = tosa.apply_scale %[[IN]], %[[TRUNC_MUL]], %[[SHIFT]] {rounding_mode = "SINGLE_ROUND"}
+ // CHECK: %[[SCALED:.+]] = tosa.apply_scale %[[IN]], %[[TRUNC_MUL]], %[[SHIFT]] {rounding_mode = SINGLE_ROUND}
// Perform the normalization.
// CHECK: %[[CMIN:.+]] = arith.constant -128
@@ -1018,7 +1018,7 @@ func.func @test_transpose_dyn_multiple_3d(%arg0: tensor<?x?x?xf32>) {
func.func @max_pool2d_nan_propagate(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x4x32x62xf32>) {
// CHECK: linalg.pooling_nhwc_max
// CHECK-NOT: linalg.generic
- %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "PROPAGATE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
+ %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = PROPAGATE} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
return %0 : tensor<1x4x32x62xf32>
}
@@ -1028,7 +1028,7 @@ func.func @max_pool2d_nan_propagate(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x4
func.func @max_pool2d_nan_ignore_int(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x4x32x62xi8>) {
// CHECK: linalg.pooling_nhwc_max
// CHECK-NOT: linalg.generic
- %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "IGNORE"} : (tensor<1x6x34x62xi8>) -> tensor<1x4x32x62xi8>
+ %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = IGNORE} : (tensor<1x6x34x62xi8>) -> tensor<1x4x32x62xi8>
return %0: tensor<1x4x32x62xi8>
}
@@ -1042,6 +1042,6 @@ func.func @max_pool2d_nan_ignore(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x4x32
// CHECK: arith.cmpf uno
// CHECK: arith.select
// CHECK: linalg.yield
- %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "IGNORE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
+ %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = IGNORE} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
return %0: tensor<1x4x32x62xf32>
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
index 59ccdaa..ff2cbbc 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
@@ -5,7 +5,7 @@ func.func @unary_resize_nearest_fp32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x
%scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
- %resize = tosa.resize %arg0, %scale, %offset, %border {mode = "NEAREST_NEIGHBOR"} : (tensor<3x1x1x7xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x1x7xf32>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<3x1x1x7xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x1x7xf32>
// CHECK: return %arg0
return %resize : tensor<3x1x1x7xf32>
}
@@ -17,7 +17,7 @@ func.func @unary_resize_nearest_fp16(%arg0 : tensor<3x1x1x7xf16>) -> tensor<3x1x
%scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
- %resize = tosa.resize %arg0, %scale, %offset, %border {mode = "NEAREST_NEIGHBOR"} : (tensor<3x1x1x7xf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x1x7xf16>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<3x1x1x7xf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x1x7xf16>
// CHECK: return %arg0
return %resize : tensor<3x1x1x7xf16>
}
@@ -29,7 +29,7 @@ func.func @unary_resize_bilinear_fp32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1
%scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
- %resize = tosa.resize %arg0, %scale, %offset, %border {mode = "BILINEAR"} : (tensor<3x1x1x7xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x1x7xf32>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = BILINEAR} : (tensor<3x1x1x7xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x1x7xf32>
// CHECK: return %arg0
return %resize : tensor<3x1x1x7xf32>
}
@@ -41,7 +41,7 @@ func.func @unary_resize_bilinear_fp16(%arg0 : tensor<3x1x1x7xf16>) -> tensor<3x1
%scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
- %resize = tosa.resize %arg0, %scale, %offset, %border {mode = "BILINEAR"} : (tensor<3x1x1x7xf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x1x7xf16>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = BILINEAR} : (tensor<3x1x1x7xf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x1x7xf16>
// CHECK: return %arg0
return %resize : tensor<3x1x1x7xf16>
}
@@ -53,7 +53,7 @@ func.func @unary_resize_nearest_i8(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x1x1x7
%scale = tosa.const_shape { values = dense<[2, 1, 3, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
- %resize = tosa.resize %arg0, %scale, %offset, %border {mode = "NEAREST_NEIGHBOR"} : (tensor<3x1x1x7xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x1x7xi8>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<3x1x1x7xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x1x7xi8>
// CHECK: return %arg0
return %resize : tensor<3x1x1x7xi8>
}
@@ -73,7 +73,7 @@ func.func @broadcast_resize_nearest_f32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3
%scale = tosa.const_shape { values = dense<[2, 1, 3, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
- %resize = tosa.resize %arg0, %scale, %offset, %border {mode = "NEAREST_NEIGHBOR"} : (tensor<3x1x1x7xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x5x7xf32>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<3x1x1x7xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x5x7xf32>
return %resize : tensor<3x1x5x7xf32>
}
@@ -110,7 +110,7 @@ func.func @broadcast_resize_bilinear_i8(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x
%scale = tosa.const_shape { values = dense<[2, 1, 3, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
- %resize = tosa.resize %arg0, %scale, %offset, %border {mode = "BILINEAR"} : (tensor<3x1x1x7xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x4x5x7xi32>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = BILINEAR} : (tensor<3x1x1x7xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x4x5x7xi32>
return %resize : tensor<3x4x5x7xi32>
}
@@ -139,7 +139,7 @@ func.func @unary_resize_bilinear_i32(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x1x1
%scale = tosa.const_shape { values = dense<[2, 1, 2, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
- %resize = tosa.resize %arg0, %scale, %offset, %border {mode = "BILINEAR"} : (tensor<3x1x1x7xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x1x7xi32>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = BILINEAR} : (tensor<3x1x1x7xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x1x7xi32>
// CHECK: return %[[EXPAND]]
return %resize : tensor<3x1x1x7xi32>
@@ -210,7 +210,7 @@ func.func @resize_nearest_int(%arg0: tensor<1x15x13x1xi8>) -> () {
%scale = tosa.const_shape { values = dense<[11, 7, 89, 6]> : tensor<4xindex> } : () -> !tosa.shape<4>
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
- %0 = tosa.resize %arg0, %scale, %offset, %border {mode = "NEAREST_NEIGHBOR"} : (tensor<1x15x13x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x23x179x1xi8>
+ %0 = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<1x15x13x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x23x179x1xi8>
return
}
@@ -314,7 +314,7 @@ func.func @resize_bilinear_int(%arg0: tensor<1x19x20x1xi8>) {
%scale = tosa.const_shape { values = dense<[16, 1, 16, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
- %0 = tosa.resize %arg0, %scale, %offset, %border {mode = "BILINEAR"} : (tensor<1x19x20x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x289x305x1xi48>
+ %0 = tosa.resize %arg0, %scale, %offset, %border {mode = BILINEAR} : (tensor<1x19x20x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x289x305x1xi48>
return
}
@@ -381,7 +381,7 @@ func.func @resize_nearest_fp32(%input: tensor<1x50x48x1xf32>) -> () {
%scale = tosa.const_shape { values = dense<[64, 2, 64, 2]> : tensor<4xindex> } : () -> !tosa.shape<4>
%offset = tosa.const_shape { values = dense<[-31, -31]> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<[31, 31]> : tensor<2xindex> } : () -> !tosa.shape<2>
- %output = tosa.resize %input, %scale, %offset, %border {mode = "NEAREST_NEIGHBOR"} : (tensor<1x50x48x1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x1600x1536x1xf32>
+ %output = tosa.resize %input, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<1x50x48x1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x1600x1536x1xf32>
return
}
@@ -476,7 +476,7 @@ func.func @resize_bilinear_fp(%input: tensor<1x23x24x1xf32>) -> () {
%scale = tosa.const_shape { values = dense<[4, 1, 4, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
- %output = tosa.resize %input, %scale, %offset, %border {mode = "BILINEAR"} : (tensor<1x23x24x1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x89x93x1xf32>
+ %output = tosa.resize %input, %scale, %offset, %border {mode = BILINEAR} : (tensor<1x23x24x1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x89x93x1xf32>
return
}
@@ -493,7 +493,7 @@ func.func @resize_dyn(%input: tensor<?x2x2x1xi8>) -> () {
%scale = tosa.const_shape { values = dense<[4, 2, 4, 2]> : tensor<4xindex> } : () -> !tosa.shape<4>
%offset = tosa.const_shape { values = dense<[-1, -1]> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<[1, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
- %output = tosa.resize %input, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<?x2x2x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> (tensor<?x4x4x1xi32>)
+ %output = tosa.resize %input, %scale, %offset, %border { mode = BILINEAR } : (tensor<?x2x2x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> (tensor<?x4x4x1xi32>)
return
}
@@ -504,7 +504,7 @@ func.func @resize_bilinear_int48(%arg0: tensor<1x19x19x1xi16>) {
%scale = tosa.const_shape { values = dense<[16, 1, 16, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
- %0 = tosa.resize %arg0, %scale, %offset, %border {mode = "BILINEAR"} : (tensor<1x19x19x1xi16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x289x289x1xi48>
+ %0 = tosa.resize %arg0, %scale, %offset, %border {mode = BILINEAR} : (tensor<1x19x19x1xi16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x289x289x1xi48>
return
}
@@ -530,7 +530,7 @@ func.func @skip_interpolate_bilinear_i8(%arg0 : tensor<3x1x2x7xi8>) -> tensor<3x
%scale = tosa.const_shape { values = dense<[2, 1, 3, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
- %resize = tosa.resize %arg0, %scale, %offset, %border {mode = "BILINEAR"} : (tensor<3x1x2x7xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x4x7xi32>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = BILINEAR} : (tensor<3x1x2x7xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x4x7xi32>
// CHECK: return %[[GENERIC]]
return %resize : tensor<3x1x4x7xi32>
@@ -552,7 +552,7 @@ func.func @skip_interpolate_bilinear_f32(%arg0 : tensor<3x1x2x7xf32>) -> tensor<
%scale = tosa.const_shape { values = dense<[2, 1, 3, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
- %resize = tosa.resize %arg0, %scale, %offset, %border {mode = "BILINEAR"} : (tensor<3x1x2x7xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x4x7xf32>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = BILINEAR} : (tensor<3x1x2x7xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x4x7xf32>
// CHECK: return %[[GENERIC]]
return %resize : tensor<3x1x4x7xf32>
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index fb912e4..3fc513f 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1131,7 +1131,7 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
// CHECK: [[C22:%.+]] = arith.constant 22
// CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
- // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
+ // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = SINGLE_ROUND}
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
// CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
@@ -1143,7 +1143,7 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8>} : () -> tensor<1xi8>
%input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = SINGLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
// CHECK: return
return
@@ -1164,7 +1164,7 @@ func.func @rescale_i8_unsigned_output_explicit(%arg0 : tensor<2xi8>) -> () {
// CHECK-DAG: [[C234:%.+]] = arith.constant 234
// CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
- // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
+ // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = SINGLE_ROUND}
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
// CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
@@ -1177,7 +1177,7 @@ func.func @rescale_i8_unsigned_output_explicit(%arg0 : tensor<2xi8>) -> () {
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
%input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
- %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xui8>
+ %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = SINGLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xui8>
// CHECK: return
return
@@ -1198,7 +1198,7 @@ func.func @rescale_i8_unsigned_output_implicit(%arg0 : tensor<2xi8>) -> () {
// CHECK-DAG: [[C234:%.+]] = arith.constant 234
// CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
- // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
+ // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = SINGLE_ROUND}
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
// CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
@@ -1211,7 +1211,7 @@ func.func @rescale_i8_unsigned_output_implicit(%arg0 : tensor<2xi8>) -> () {
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
%input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
- %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
+ %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = SINGLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
// CHECK: return
return
@@ -1232,7 +1232,7 @@ func.func @rescale_i48_unsigned_output_implicit(%arg0 : tensor<2xi48>) -> () {
// CHECK-DAG: [[C0:%.+]] = arith.constant 0
// CHECK-DAG: [[C234:%.+]] = arith.constant 234
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN]], [[C0]]
- // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C19689]], [[C15]] {rounding_mode = "SINGLE_ROUND"}
+ // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C19689]], [[C15]] {rounding_mode = SINGLE_ROUND}
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
// CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
@@ -1244,7 +1244,7 @@ func.func @rescale_i48_unsigned_output_implicit(%arg0 : tensor<2xi48>) -> () {
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
%output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
- %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<2xi8>
+ %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = SINGLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<2xi8>
// CHECK: return
return
@@ -1265,13 +1265,13 @@ func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
// CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x2xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = SINGLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x2xi8>
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
- %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x2xi8>
+ %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = SINGLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x2xi8>
return
}
@@ -1293,7 +1293,7 @@ func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<1x?x?x32xi32>) outs(%[[INIT]] : tensor<1x?x?x32xi8>)
%multiplier = "tosa.const"() {values = dense<1376784203> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<38> : tensor<1xi8> } : () -> tensor<1xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "DOUBLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x?x?x32xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x?x?x32xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = DOUBLE_ROUND, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x?x?x32xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x?x?x32xi8>
return
}
@@ -1313,7 +1313,7 @@ func.func @rescale_i8_unsigned_input_explicit(%arg0 : tensor<2xui8>) -> () {
// CHECK-DAG: [[IN_UTOI:%.+]] = builtin.unrealized_conversion_cast [[IN]] : ui8 to i8
// CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN_UTOI]]
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
- // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
+ // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = SINGLE_ROUND}
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
// CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
@@ -1325,7 +1325,7 @@ func.func @rescale_i8_unsigned_input_explicit(%arg0 : tensor<2xui8>) -> () {
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
%input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xui8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = SINGLE_ROUND, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xui8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
return
}
@@ -1346,7 +1346,7 @@ func.func @rescale_i8_unsigned_input_implicit(%arg0 : tensor<2xi8>) -> () {
// CHECK-DAG: [[C22:%.+]] = arith.constant 22
// CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C128]]
- // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
+ // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = SINGLE_ROUND}
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
// CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
@@ -1358,7 +1358,7 @@ func.func @rescale_i8_unsigned_input_implicit(%arg0 : tensor<2xi8>) -> () {
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
%input_zp = "tosa.const"() {values = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = SINGLE_ROUND, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
return
}
@@ -1379,7 +1379,7 @@ func.func @rescale_i8_unsigned_input_output_explicit(%arg0 : tensor<2xui8>) -> (
// CHECK-DAG: [[IN_UTOI:%.+]] = builtin.unrealized_conversion_cast [[IN]] : ui8 to i8
// CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN_UTOI]]
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
- // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
+ // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = SINGLE_ROUND}
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
// CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
@@ -1392,7 +1392,7 @@ func.func @rescale_i8_unsigned_input_output_explicit(%arg0 : tensor<2xui8>) -> (
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
%input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xui8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xui8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = SINGLE_ROUND, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xui8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xui8>
return
}
@@ -1414,7 +1414,7 @@ func.func @rescale_per_channel(%arg0 : tensor<3xi8>) -> (tensor<3xi8>) {
// CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C243]]
- // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[MULTIPLIER]], [[SHIFT]] {rounding_mode = "SINGLE_ROUND"}
+ // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[MULTIPLIER]], [[SHIFT]] {rounding_mode = SINGLE_ROUND}
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C252]]
// CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
@@ -1426,7 +1426,7 @@ func.func @rescale_per_channel(%arg0 : tensor<3xi8>) -> (tensor<3xi8>) {
%shift = "tosa.const"() {values = dense<[14, 15, 64]> : tensor<3xi8>} : () -> tensor<3xi8>
%input_zp = "tosa.const"() {values = dense<43> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<52> : tensor<1xi8>} : () -> tensor<1xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<3xi8>, tensor<3xi16>, tensor<3xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<3xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = SINGLE_ROUND, per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<3xi8>, tensor<3xi16>, tensor<3xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<3xi8>
// CHECK: return [[GENERIC]]
return %0 : tensor<3xi8>
@@ -1443,8 +1443,8 @@ func.func @rescaleDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
// CHECK: linalg.generic
// CHECK: tosa.apply_scale
- // CHECK-SAME: {rounding_mode = "DOUBLE_ROUND"}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "DOUBLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
+ // CHECK-SAME: {rounding_mode = DOUBLE_ROUND}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = DOUBLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
}
@@ -1459,8 +1459,8 @@ func.func @rescaleUnnecessaryDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>)
// CHECK: linalg.generic
// CHECK: tosa.apply_scale
- // CHECK-SAME: {rounding_mode = "SINGLE_ROUND"}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "DOUBLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
+ // CHECK-SAME: {rounding_mode = SINGLE_ROUND}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = DOUBLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
}
@@ -1472,7 +1472,7 @@ func.func @unsupportedRescaleInexactRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = true, rounding_mode = "INEXACT_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = INEXACT_ROUND, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
}
@@ -2183,7 +2183,7 @@ func.func @reduce_min_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf3
// CHECK-NOT: tensor.empty()
// CHECK-NOT: select
// CHECK: return
- %3 = tosa.reduce_min %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+ %3 = tosa.reduce_min %arg0 {axis = 0 : i32, nan_mode = PROPAGATE} : (tensor<5x4xf32>) -> tensor<1x4xf32>
return
}
@@ -2202,7 +2202,7 @@ func.func @reduce_max_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf3
// CHECK-NOT: tensor.empty()
// CHECK-NOT: select
// CHECK: return
- %4 = tosa.reduce_max %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+ %4 = tosa.reduce_max %arg0 {axis = 0 : i32, nan_mode = PROPAGATE} : (tensor<5x4xf32>) -> tensor<1x4xf32>
return
}
@@ -2221,7 +2221,7 @@ func.func @reduce_min_nan_ignore_int(%arg0: tensor<5x4xi8>, %arg1: tensor<5x4xi8
// CHECK-NOT: tensor.empty()
// CHECK-NOT: select
// CHECK: return
- %5 = tosa.reduce_min %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xi8>) -> tensor<1x4xi8>
+ %5 = tosa.reduce_min %arg0 {axis = 0 : i32, nan_mode = IGNORE} : (tensor<5x4xi8>) -> tensor<1x4xi8>
return
}
@@ -2240,7 +2240,7 @@ func.func @reduce_max_nan_ignore_int(%arg0: tensor<5x4xi8>, %arg1: tensor<5x4xi8
// CHECK-NOT: tensor.empty()
// CHECK-NOT: select
// CHECK: return
- %6 = tosa.reduce_max %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xi8>) -> tensor<1x4xi8>
+ %6 = tosa.reduce_max %arg0 {axis = 0 : i32, nan_mode = IGNORE} : (tensor<5x4xi8>) -> tensor<1x4xi8>
return
}
@@ -2258,7 +2258,7 @@ func.func @reduce_min_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>)
// CHECK: linalg.fill
// CHECK: tensor.empty()
// CHECK: select
- %5 = tosa.reduce_min %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+ %5 = tosa.reduce_min %arg0 {axis = 0 : i32, nan_mode = IGNORE} : (tensor<5x4xf32>) -> tensor<1x4xf32>
return
}
@@ -2276,7 +2276,7 @@ func.func @reduce_max_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>)
// CHECK: linalg.fill
// CHECK: tensor.empty()
// CHECK: select
- %6 = tosa.reduce_max %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+ %6 = tosa.reduce_max %arg0 {axis = 0 : i32, nan_mode = IGNORE} : (tensor<5x4xf32>) -> tensor<1x4xf32>
return
}
@@ -2289,7 +2289,7 @@ func.func @minimum_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>)
// CHECK-NOT: arith.cmpf uno
// CHECK-NOT: arith.select
// CHECK: linalg.yield
- %7 = tosa.minimum %arg0, %arg1 {nan_mode = "PROPAGATE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
+ %7 = tosa.minimum %arg0, %arg1 {nan_mode = PROPAGATE} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
return
}
@@ -2302,7 +2302,7 @@ func.func @maximum_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>)
// CHECK-NOT: arith.cmpf uno
// CHECK-NOT: arith.select
// CHECK: linalg.yield
- %8 = tosa.maximum %arg0, %arg1 {nan_mode = "PROPAGATE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
+ %8 = tosa.maximum %arg0, %arg1 {nan_mode = PROPAGATE} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
return
}
@@ -2315,7 +2315,7 @@ func.func @minimum_nan_ignore_int(%arg0: tensor<5x4xi8>, %arg1: tensor<5x4xi8>)
// CHECK-NOT: arith.cmpf uno
// CHECK-NOT: arith.select
// CHECK: linalg.yield
- %9 = tosa.minimum %arg0, %arg1 {nan_mode = "IGNORE"} : (tensor<5x4xi8>, tensor<5x4xi8>) -> tensor<5x4xi8>
+ %9 = tosa.minimum %arg0, %arg1 {nan_mode = IGNORE} : (tensor<5x4xi8>, tensor<5x4xi8>) -> tensor<5x4xi8>
return
}
@@ -2328,7 +2328,7 @@ func.func @maximum_nan_ignore_int(%arg0: tensor<5x4xi8>, %arg1: tensor<5x4xi8>)
// CHECK-NOT: arith.cmpf uno
// CHECK-NOT: arith.select
// CHECK: linalg.yield
- %10 = tosa.maximum %arg0, %arg1 {nan_mode = "IGNORE"} : (tensor<5x4xi8>, tensor<5x4xi8>) -> tensor<5x4xi8>
+ %10 = tosa.maximum %arg0, %arg1 {nan_mode = IGNORE} : (tensor<5x4xi8>, tensor<5x4xi8>) -> tensor<5x4xi8>
return
}
@@ -2343,7 +2343,7 @@ func.func @minimum_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) ->
// CHECK: arith.select
// CHECK: arith.select
// CHECK: linalg.yield
- %9 = tosa.minimum %arg0, %arg1 {nan_mode = "IGNORE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
+ %9 = tosa.minimum %arg0, %arg1 {nan_mode = IGNORE} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
return
}
@@ -2358,7 +2358,7 @@ func.func @maximum_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) ->
// CHECK: arith.select
// CHECK: arith.select
// CHECK: linalg.yield
- %10 = tosa.maximum %arg0, %arg1 {nan_mode = "IGNORE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
+ %10 = tosa.maximum %arg0, %arg1 {nan_mode = IGNORE} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
return
}
@@ -2375,7 +2375,7 @@ func.func @argmax_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>)
// CHECK-NOT: arith.cmpf uno
// CHECK-NOT: arith.select
// CHECK: linalg.yield
- %11 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<4xi32>
+ %11 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = PROPAGATE} : (tensor<5x4xf32>) -> tensor<4xi32>
return
}
@@ -2392,7 +2392,7 @@ func.func @argmax_nan_ignore_int(%arg0: tensor<5x4xi8>, %arg1: tensor<5x4xi8>) -
// CHECK-NOT: arith.select
// CHECK-NOT: arith.select
// CHECK: linalg.yield
- %12 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xi8>) -> tensor<4xi32>
+ %12 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = IGNORE} : (tensor<5x4xi8>) -> tensor<4xi32>
return
}
@@ -2405,7 +2405,7 @@ func.func @argmax_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) ->
// CHECK: arith.select
// CHECK: arith.select
// CHECK: linalg.yield
- %12 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<4xi32>
+ %12 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = IGNORE} : (tensor<5x4xf32>) -> tensor<4xi32>
return
}
@@ -2419,7 +2419,7 @@ func.func @clamp_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -
// CHECK-NOT: arith.cmpf uno
// CHECK-NOT: arith.select
// CHECK: linalg.yield
- %13 = tosa.clamp %arg0 {min_val = 1.0 : f32, max_val = 5.0 : f32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<5x4xf32>
+ %13 = tosa.clamp %arg0 {min_val = 1.0 : f32, max_val = 5.0 : f32, nan_mode = PROPAGATE} : (tensor<5x4xf32>) -> tensor<5x4xf32>
return
}
@@ -2433,7 +2433,7 @@ func.func @clamp_nan_ignore_int(%arg0: tensor<5x4xi8>, %arg1: tensor<5x4xi8>) ->
// CHECK-NOT: arith.cmpf uno
// CHECK-NOT: arith.select
// CHECK: linalg.yield
- %14 = tosa.clamp %arg0 {min_val = 1 : i8, max_val = 5 : i8, nan_mode = "IGNORE"} : (tensor<5x4xi8>) -> tensor<5x4xi8>
+ %14 = tosa.clamp %arg0 {min_val = 1 : i8, max_val = 5 : i8, nan_mode = IGNORE} : (tensor<5x4xi8>) -> tensor<5x4xi8>
return
}
@@ -2447,7 +2447,7 @@ func.func @clamp_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> (
// CHECK: arith.cmpf uno
// CHECK: arith.select
// CHECK: linalg.yield
- %14 = tosa.clamp %arg0 {min_val = 1.0 : f32, max_val = 5.0 : f32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<5x4xf32>
+ %14 = tosa.clamp %arg0 {min_val = 1.0 : f32, max_val = 5.0 : f32, nan_mode = IGNORE} : (tensor<5x4xf32>) -> tensor<5x4xf32>
return
}
@@ -2471,3 +2471,14 @@ func.func @test_0d_input(%arg0: tensor<i32>) -> () {
return
}
+
+// -----
+
+// CHECK-LABEL: @mul_no_const_shift
+func.func @mul_no_const_shift(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<1xi8>) -> tensor<2x3xi32> {
+ // CHECK: linalg.generic
+ // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i8, %[[OUT:.*]]: i32):
+ // CHECK: tosa.apply_scale %[[ARG0]], %[[ARG1]], %[[ARG2]]
+ %0 = tosa.mul %arg0, %arg1, %arg2 : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
+ return %0 : tensor<2x3xi32>
+}
diff --git a/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir b/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir
index 5307e47..6c0b111 100644
--- a/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir
+++ b/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir
@@ -3,6 +3,7 @@
// Same below, but using the `ConvertToLLVMPatternInterface` entry point
// and the generic `convert-to-llvm` pass.
// RUN: mlir-opt --convert-to-llvm="filter-dialects=ub" --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --convert-to-llvm="filter-dialects=ub allow-pattern-rollback=0" --split-input-file %s | FileCheck %s
// CHECK-LABEL: @check_poison
func.func @check_poison() {
diff --git a/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir b/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir
new file mode 100644
index 0000000..4fb88dd
--- /dev/null
+++ b/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir
@@ -0,0 +1,310 @@
+// RUN: mlir-opt %s -convert-vector-to-amx -split-input-file | FileCheck %s
+
+/// VNNI format is Intel's packed data layout.
+/// For matrix multiplication, elements from the reduction dimension `k`
+/// are packed into 32-bit tuples. Then the appropriate AMX operations can
+/// perform tile multiplication directly on the packed data.
+///
+/// These packed elements are represented in the indexing maps by a separate
+/// reduction dimension `vnni`.
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @contract_vnni_f16(%A: vector<4x8x2xf16>, %B: vector<8x16x2xf16>,
+ %C: vector<4x16xf32>) -> vector<4x16xf32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %A, %B, %C : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+ return %0 : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @contract_vnni_f16(
+// CHECK-SAME: %[[A:.+]]: vector<4x8x2xf16>,
+// CHECK-SAME: %[[B:.+]]: vector<8x16x2xf16>,
+// CHECK-SAME: %[[C:.+]]: vector<4x16xf32>
+
+/// AMX hardware has no direct access to the registers. Thus, data must
+/// be transfered through intermediate buffers.
+///
+/// Load A vector into an AMX tile
+// CHECK: %[[A_BUF:.+]] = memref.alloca() : memref<4x8x2xf16>
+// CHECK: vector.transfer_write %[[A]], %[[A_BUF]]
+// CHECK: %[[A_BUF_2D:.+]] = memref.collapse_shape %[[A_BUF]]
+// CHECK-SAME: {{\[}}[0], [1, 2]] : memref<4x8x2xf16> into memref<4x16xf16>
+// CHECK: %[[A_TILE:.+]] = amx.tile_load %[[A_BUF_2D]]
+
+/// Load B vector into an AMX tile
+// CHECK: %[[B_BUF:.+]] = memref.alloca() : memref<8x16x2xf16>
+// CHECK: vector.transfer_write %[[B]], %[[B_BUF]]
+// CHECK: %[[B_BUF_2D:.+]] = memref.collapse_shape %[[B_BUF]]
+// CHECK-SAME: {{\[}}[0], [1, 2]] : memref<8x16x2xf16> into memref<8x32xf16>
+// CHECK: %[[B_TILE:.+]] = amx.tile_load %[[B_BUF_2D]]
+
+/// Load C vector into an AMX tile
+// CHECK: %[[C_BUF:.+]] = memref.alloca() : memref<4x16xf32>
+// CHECK: vector.transfer_write %[[C]], %[[C_BUF]]
+// CHECK: %[[C_TILE:.+]] = amx.tile_load %[[C_BUF]]
+
+/// Perform tile multiplication
+// CHECK: %[[RES:.+]] = amx.tile_mulf
+// CHECK-SAME: %[[A_TILE]], %[[B_TILE]], %[[C_TILE]]
+
+/// Load the result back into a vector
+// CHECK: %[[RES_BUF:.+]] = memref.alloca() : memref<4x16xf32>
+// CHECK: amx.tile_store %[[RES_BUF]]{{.*}}, %[[RES]]
+// CHECK: %[[RES_VEC:.+]] = vector.transfer_read %[[RES_BUF]]
+
+// CHECK: return %[[RES_VEC]]
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @contract_vnni_bf16(%A: vector<4x8x2xbf16>, %B: vector<8x16x2xbf16>,
+ %C: vector<4x16xf32>) -> vector<4x16xf32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %A, %B, %C : vector<4x8x2xbf16>, vector<8x16x2xbf16> into vector<4x16xf32>
+ return %0 : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @contract_vnni_bf16(
+// CHECK-COUNT-3: amx.tile_load
+// CHECK: amx.tile_mulf
+// CHECK: amx.tile_store
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @contract_vnni_i8(%A: vector<4x16x4xi8>, %B: vector<16x8x4xi8>,
+ %C: vector<4x8xi32>) -> vector<4x8xi32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %A, %B, %C : vector<4x16x4xi8>, vector<16x8x4xi8> into vector<4x8xi32>
+ return %0 : vector<4x8xi32>
+}
+
+// CHECK-LABEL: @contract_vnni_i8(
+// CHECK-COUNT-3: amx.tile_load
+// CHECK: amx.tile_muli
+// CHECK: amx.tile_store
+
+// -----
+
+#map = affine_map<(vnni, m, k, n) -> (m, k, vnni)>
+#map1 = affine_map<(vnni, m, k, n) -> (k, n, vnni)>
+#map2 = affine_map<(vnni, m, k, n) -> (m, n)>
+func.func @contract_shuffled_iterators(%A: vector<4x16x4xi8>, %B: vector<16x8x4xi8>,
+ %C: vector<4x8xi32>) -> vector<4x8xi32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "reduction", "parallel"]}
+ %A, %B, %C : vector<4x16x4xi8>, vector<16x8x4xi8> into vector<4x8xi32>
+ return %0 : vector<4x8xi32>
+}
+
+// CHECK-LABEL: @contract_shuffled_iterators(
+// CHECK-COUNT-3: amx.tile_load
+// CHECK: amx.tile_muli
+// CHECK: amx.tile_store
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_invalid_kind(%A: vector<4x8x2xf16>, %B: vector<8x16x2xf16>,
+ %C: vector<4x16xf32>) -> vector<4x16xf32> {
+ %0 = vector.contract
+ {kind = #vector.kind<mul>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %A, %B, %C : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+ return %0 : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @negative_invalid_kind(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(m, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, k, vnni) -> (k, m, vnni)>
+#map2 = affine_map<(m, k, vnni) -> ()>
+func.func @negative_non_vector_acc(%A: vector<4x8x2xf16>, %B: vector<8x4x2xf16>,
+ %C: f32) -> f32 {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "reduction"]}
+ %A, %B, %C : vector<4x8x2xf16>, vector<8x4x2xf16> into f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: @negative_non_vector_acc(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_invalid_operand_types(%A: vector<4x8x2xf32>,
+ %B: vector<8x16x2xf32>, %C: vector<4x16xf32>) -> vector<4x16xf32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %A, %B, %C : vector<4x8x2xf32>, vector<8x16x2xf32> into vector<4x16xf32>
+ return %0 : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @negative_invalid_operand_types(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(m, n, k) -> (m, k)>
+#map1 = affine_map<(m, n, k) -> (k, n)>
+#map2 = affine_map<(m, n, k) -> (m, n)>
+func.func @negative_non_packed_layout(%A: vector<4x16xf16>, %B: vector<16x16xf16>,
+ %C: vector<4x16xf32>) -> vector<4x16xf32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ %A, %B, %C : vector<4x16xf16>, vector<16x16xf16> into vector<4x16xf32>
+ return %0 : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @negative_non_packed_layout(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_invalid_vnni_factor(%A: vector<4x2x4xf16>, %B: vector<2x2x4xf16>,
+ %C: vector<4x2xf32>) -> vector<4x2xf32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %A, %B, %C : vector<4x2x4xf16>, vector<2x2x4xf16> into vector<4x2xf32>
+ return %0 : vector<4x2xf32>
+}
+
+// CHECK-LABEL: @negative_invalid_vnni_factor(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(batch, m, n, k, vnni) -> (batch, m, k, vnni)>
+#map1 = affine_map<(batch, m, n, k, vnni) -> (batch, k, n, vnni)>
+#map2 = affine_map<(batch, m, n, k, vnni) -> (batch, m, n)>
+func.func @negative_invalid_operands_shapes(%A: vector<1x4x8x2xf16>,
+ %B: vector<1x8x16x2xf16>, %C: vector<1x4x16xf32>) -> vector<1x4x16xf32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
+ %A, %B, %C : vector<1x4x8x2xf16>, vector<1x8x16x2xf16> into vector<1x4x16xf32>
+ return %0 : vector<1x4x16xf32>
+}
+
+// CHECK-LABEL: @negative_invalid_operands_shapes(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_too_many_rows(%A: vector<32x8x2xf16>, %B: vector<8x16x2xf16>,
+ %C: vector<32x16xf32>) -> vector<32x16xf32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %A, %B, %C : vector<32x8x2xf16>, vector<8x16x2xf16> into vector<32x16xf32>
+ return %0 : vector<32x16xf32>
+}
+
+// CHECK-LABEL: @negative_too_many_rows(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_too_wide_rows(%A: vector<4x32x2xf16>, %B: vector<32x16x2xf16>,
+ %C: vector<4x16xf32>) -> vector<4x16xf32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %A, %B, %C : vector<4x32x2xf16>, vector<32x16x2xf16> into vector<4x16xf32>
+ return %0 : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @negative_too_wide_rows(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (k, vnni, m)>
+#map1 = affine_map<(m, n, k, vnni) -> (n, k, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_input_dim_permutation(%A: vector<2x2x2xf16>,
+ %B: vector<2x2x2xf16>, %C: vector<2x2xf32>) -> vector<2x2xf32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %A, %B, %C : vector<2x2x2xf16>, vector<2x2x2xf16> into vector<2x2xf32>
+ return %0 : vector<2x2xf32>
+}
+
+// CHECK-LABEL: @negative_input_dim_permutation(
+// CHECK-NOT: amx
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (n, m)>
+func.func @negative_output_dim_permutation(%A: vector<4x8x2xf16>,
+ %B: vector<8x16x2xf16>, %C: vector<16x4xf32>) -> vector<16x4xf32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %A, %B, %C : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<16x4xf32>
+ return %0 : vector<16x4xf32>
+}
+
+// CHECK-LABEL: @negative_output_dim_permutation(
+// CHECK-NOT: amx
+// CHECK: vector.contract
diff --git a/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir b/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir
new file mode 100644
index 0000000..8fab4cf
--- /dev/null
+++ b/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir
@@ -0,0 +1,355 @@
+// RUN: mlir-opt %s -convert-vector-to-amx -split-input-file | FileCheck %s
+
+/// These test cases validate replacement of vector transfer ops with equivalent
+/// AMX tile data transfers.
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @transfers_static_dims(%A: memref<64x32x16x2xf16>,
+ %B: memref<64x16x32x2xf16>, %C: memref<64x64xf32>, %idx: index) {
+ %c0_f16 = arith.constant 0.0 : f16
+ %c0_f32 = arith.constant 0.0 : f32
+ %vecA = vector.transfer_read %A[%idx, %idx, %idx, %idx], %c0_f16
+ {in_bounds = [true, true, true]} : memref<64x32x16x2xf16>, vector<4x8x2xf16>
+ %vecB = vector.transfer_read %B[%idx, %idx, %idx, %idx], %c0_f16
+ {in_bounds = [true, true, true]} : memref<64x16x32x2xf16>, vector<8x16x2xf16>
+ %vecC = vector.transfer_read %C[%idx, %idx], %c0_f32
+ {in_bounds = [true, true]} : memref<64x64xf32>, vector<4x16xf32>
+ %vecD = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+ vector.transfer_write %vecD, %C[%idx, %idx]
+ {in_bounds = [true, true]} : vector<4x16xf32>, memref<64x64xf32>
+ return
+}
+
+// CHECK-LABEL: @transfers_static_dims(
+// CHECK-SAME: %[[A:.+]]: memref<64x32x16x2xf16>,
+// CHECK-SAME: %[[B:.+]]: memref<64x16x32x2xf16>,
+// CHECK-SAME: %[[C:.+]]: memref<64x64xf32>,
+// CHECK-SAME: %[[IDX:.+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+
+/// Load A into an AMX tile
+// CHECK: %[[A_SUBVIEW:.+]] = memref.subview %[[A]]
+// CHECK-SAME: {{\[}}%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]{{\]}}
+// CHECK: %[[A_PACKED_DIM_COLLAPSE:.+]] = memref.collapse_shape %[[A_SUBVIEW]]
+// CHECK-SAME: {{\[}}[0], [1], [2, 3]] : memref<1x4x8x2xf16{{.*}}into memref<1x4x16xf16
+// CHECK: %[[A_TILE:.+]] = amx.tile_load %[[A_PACKED_DIM_COLLAPSE]]
+// CHECK-SAME: {{\[}}%[[C0]], %[[C0]], %[[C0]]{{\]}}
+// CHECK-NOT: vector.transfer_read %[[A]]
+
+/// Load B into an AMX tile
+// CHECK: %[[B_SUBVIEW:.+]] = memref.subview %[[B]]
+// CHECK-SAME: {{\[}}%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]{{\]}}
+// CHECK: %[[B_PACKED_DIM_COLLAPSE:.+]] = memref.collapse_shape %[[B_SUBVIEW]]
+// CHECK-SAME: {{\[}}[0], [1], [2, 3]] : memref<1x8x16x2xf16{{.*}}into memref<1x8x32xf16
+// CHECK: %[[B_TILE:.+]] = amx.tile_load %[[B_PACKED_DIM_COLLAPSE]]
+// CHECK-SAME: {{\[}}%[[C0]], %[[C0]], %[[C0]]{{\]}}
+// CHECK-NOT: vector.transfer_read %[[B]]
+
+/// Load C into an AMX tile
+// CHECK: %[[C_SUBVIEW:.+]] = memref.subview %[[C]]
+// CHECK-SAME: {{\[}}%[[IDX]], %[[IDX]]{{\]}}
+// CHECK: %[[C_TILE:.+]] = amx.tile_load %[[C_SUBVIEW]]
+// CHECK-SAME: {{\[}}%[[C0]], %[[C0]]{{\]}}
+// CHECK-NOT: vector.transfer_read %[[C]]
+
+/// Perform tile multiplication
+// CHECK: %[[RES:.+]] = amx.tile_mulf
+// CHECK-SAME: %[[A_TILE]], %[[B_TILE]], %[[C_TILE]]
+
+/// Store the result back
+// CHECK: %[[RES_SUBVIEW:.+]] = memref.subview %[[C]]
+// CHECK-SAME: {{\[}}%[[IDX]], %[[IDX]]{{\]}}
+// CHECK: amx.tile_store %[[RES_SUBVIEW]]{{\[}}%[[C0]], %[[C0]]{{\]}}, %[[RES]]
+// CHECK-NOT: vector.transfer_write{{.*}}%[[C]]
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @transfers_dynamic_outer_dims(%A: memref<?x?x16x2xf16>,
+ %B: memref<?x?x32x2xf16>, %C: memref<?x64xf32>, %idx: index) {
+ %c0_f16 = arith.constant 0.0 : f16
+ %c0_f32 = arith.constant 0.0 : f32
+ %vecA = vector.transfer_read %A[%idx, %idx, %idx, %idx], %c0_f16
+ {in_bounds = [true, true, true]} : memref<?x?x16x2xf16>, vector<4x8x2xf16>
+ %vecB = vector.transfer_read %B[%idx, %idx, %idx, %idx], %c0_f16
+ {in_bounds = [true, true, true]} : memref<?x?x32x2xf16>, vector<8x16x2xf16>
+ %vecC = vector.transfer_read %C[%idx, %idx], %c0_f32
+ {in_bounds = [true, true]} : memref<?x64xf32>, vector<4x16xf32>
+ %vecD = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+ vector.transfer_write %vecD, %C[%idx, %idx]
+ {in_bounds = [true, true]} : vector<4x16xf32>, memref<?x64xf32>
+ return
+}
+
+// CHECK-LABEL: @transfers_dynamic_outer_dims(
+// CHECK-SAME: %[[A:.+]]: memref<?x?x16x2xf16>,
+// CHECK-SAME: %[[B:.+]]: memref<?x?x32x2xf16>,
+// CHECK-SAME: %[[C:.+]]: memref<?x64xf32>
+// CHECK-NOT: vector.transfer_read %[[A]]
+// CHECK-NOT: vector.transfer_read %[[B]]
+// CHECK-NOT: vector.transfer_read %[[C]]
+// CHECK-NOT: vector.transfer_write{{.*}}%[[C]]
+
+// -----
+
+/// AMX tile can be loaded directly from the buffer. However, vector transfer
+/// has to remain due to other users that require data in registers.
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @transfer_read_multiple_users(%C: memref<64x64xf32>,
+ %vecA: vector<4x8x2xf16>, %vecB: vector<8x16x2xf16>,
+ %idx: index) -> vector<4x16xf32> {
+ %c0_f32 = arith.constant 0.0 : f32
+ %vecC = vector.transfer_read %C[%idx, %idx], %c0_f32
+ {in_bounds = [true, true]} : memref<64x64xf32>, vector<4x16xf32>
+ %vecD = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+ %mul = arith.mulf %vecC, %vecD : vector<4x16xf32>
+ return %mul : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @transfer_read_multiple_users(
+// CHECK-SAME: %[[C:.+]]: memref<64x64xf32>,
+
+/// Load to AMX tile directly from buffer.
+// CHECK: %[[C_SUBVIEW:.+]] = memref.subview %[[C]]
+// CHECK: %[[C_TILE:.+]] = amx.tile_load %[[C_SUBVIEW]]
+
+/// Vector read remains to load data for the other non-AMX consumer.
+// CHECK: %[[C_VEC:.+]] = vector.transfer_read %[[C]]
+
+/// Contraction uses the directly loaded tile.
+// CHECK: %[[TILE_MUL:.+]] = amx.tile_mulf{{.*}}%[[C_TILE]]
+
+/// Consumer uses original C value and the updated one after contraction.
+// CHECK: %[[RES_BUF:.+]] = memref.alloca
+// CHECK: amx.tile_store %[[RES_BUF]]
+// CHECK: %[[RES_VEC:.+]] = vector.transfer_read %[[RES_BUF]]
+// CHECK: %[[VEC_MUL:.+]] = arith.mulf %[[C_VEC]], %[[RES_VEC]]
+
+// -----
+
+/// As contraction has multiple users, the results have to loaded back
+/// from AMX tile into registers.
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_contract_multiple_users(%C: memref<64x64xf32>,
+ %vecA: vector<4x8x2xf16>, %vecB: vector<8x16x2xf16>,
+ %vecC: vector<4x16xf32>, %idx: index) -> vector<4x16xf32> {
+ %vecD = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+ vector.transfer_write %vecD, %C[%idx, %idx]
+ {in_bounds = [true, true]} : vector<4x16xf32>, memref<64x64xf32>
+ %mul = arith.mulf %vecC, %vecD : vector<4x16xf32>
+ return %mul : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @negative_contract_multiple_users(
+// CHECK-SAME: %[[C:.+]]: memref<64x64xf32>
+// CHECK: %[[TILE_MUL:.+]] = amx.tile_mulf
+// CHECK: vector.transfer_write{{.*}}%[[C]]
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_out_of_bounds(%C: memref<64x64xf32>,
+ %vecA: vector<4x8x2xf16>, %vecB: vector<8x16x2xf16>,
+ %vecC: vector<4x16xf32>, %idx: index) {
+ %vecD = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+ vector.transfer_write %vecD, %C[%idx, %idx]
+ {in_bounds = [true, false]} : vector<4x16xf32>, memref<64x64xf32>
+ return
+}
+
+// CHECK-LABEL: @negative_out_of_bounds(
+// CHECK-SAME: %[[C:.+]]: memref<64x64xf32>
+// CHECK: vector.transfer_write{{.*}}%[[C]]
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_non_identity_map(%C: memref<64x64xf32>,
+ %vecA: vector<4x8x2xf16>, %vecB: vector<8x16x2xf16>,
+ %vecC: vector<4x16xf32>, %idx: index) {
+ %vecD = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+ vector.transfer_write %vecD, %C[%idx, %idx]
+ {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
+ in_bounds = [true, true]} : vector<4x16xf32>, memref<64x64xf32>
+ return
+}
+
+// CHECK-LABEL: @negative_non_identity_map(
+// CHECK-SAME: %[[C:.+]]: memref<64x64xf32>
+// CHECK: vector.transfer_write{{.*}}%[[C]]
+
+// -----
+
+/// AMX tile transfers require row elements to be contiguous
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_non_contiguous_row(
+ %A: memref<8x128x2xf16, strided<[256, 4, 1]>>,
+ %vecB: vector<8x16x2xf16>, %vecC: vector<4x16xf32>,
+ %idx: index) -> vector<4x16xf32> {
+ %c0_f16 = arith.constant 0.0 : f16
+ %vecA = vector.transfer_read %A[%idx, %idx, %idx], %c0_f16
+ {in_bounds = [true, true, true]}
+ : memref<8x128x2xf16, strided<[256, 4, 1]>>, vector<4x8x2xf16>
+ %vecD = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+ return %vecD : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @negative_non_contiguous_row(
+// CHECK-SAME: %[[A:.+]]: memref<8x128x2xf16, strided<[256, 4, 1]>>
+// CHECK: vector.transfer_read %[[A]]
+
+// -----
+
+/// Buffer shape checks are conservative to avoid problems with deriving
+/// stride for AMX tile rows.
+/// When in doubt, vector operations are left to perform initial transfers.
+/// Afterwards, data can be placed in a contiguous temporary buffer which
+/// ensures correct layout for AMX transfers.
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_1D_buffer(%C: memref<512xf32>,
+ %vecA: vector<4x8x2xf16>, %vecB: vector<8x16x2xf16>,
+ %idx: index) -> vector<4x16xf32> {
+ %c0_f32 = arith.constant 0.0 : f32
+ %vecC = vector.transfer_read %C[%idx], %c0_f32
+ {permutation_map = affine_map<(d0) -> (0, d0)>,
+ in_bounds = [true, true]} : memref<512xf32>, vector<4x16xf32>
+ %vecD = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+ return %vecD : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @negative_1D_buffer(
+// CHECK-SAME: %[[C:.+]]: memref<512xf32>
+// CHECK: vector.transfer_read %[[C]]
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_dynamic_shapes(%A: memref<?x?x?x2xf16>,
+ %B: memref<?x?x2xf16>, %C: memref<?x?xf32>, %idx: index) {
+ %c0_f16 = arith.constant 0.0 : f16
+ %c0_f32 = arith.constant 0.0 : f32
+ %vecA = vector.transfer_read %A[%idx, %idx, %idx, %idx], %c0_f16
+ {in_bounds = [true, true, true]} : memref<?x?x?x2xf16>, vector<4x8x2xf16>
+ %vecB = vector.transfer_read %B[%idx, %idx, %idx], %c0_f16
+ {in_bounds = [true, true, true]} : memref<?x?x2xf16>, vector<8x16x2xf16>
+ %vecC = vector.transfer_read %C[%idx, %idx], %c0_f32
+ {in_bounds = [true, true]} : memref<?x?xf32>, vector<4x16xf32>
+ %vecD = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+ vector.transfer_write %vecD, %C[%idx, %idx]
+ {in_bounds = [true, true]} : vector<4x16xf32>, memref<?x?xf32>
+ return
+}
+
+// CHECK-LABEL: @negative_dynamic_shapes(
+// CHECK-SAME: %[[A:.+]]: memref<?x?x?x2xf16>,
+// CHECK-SAME: %[[B:.+]]: memref<?x?x2xf16>,
+// CHECK-SAME: %[[C:.+]]: memref<?x?xf32>
+// CHECK: vector.transfer_read %[[A]]
+// CHECK: vector.transfer_read %[[B]]
+// CHECK: vector.transfer_read %[[C]]
+// CHECK: vector.transfer_write{{.*}}%[[C]]
+
+// -----
+
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_invalid_buffer_row_shape(%C: memref<5x2x4x4xf32>,
+ %vecA: vector<4x8x2xf16>, %vecB: vector<8x16x2xf16>,
+ %idx: index) -> vector<4x16xf32> {
+ %c0_f32 = arith.constant 0.0 : f32
+ %vecC = vector.transfer_read %C[%idx, %idx, %idx, %idx], %c0_f32
+ {in_bounds = [true, true]} : memref<5x2x4x4xf32>, vector<4x16xf32>
+ %vecD = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+ return %vecD : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @negative_invalid_buffer_row_shape(
+// CHECK-SAME: %[[C:.+]]: memref<5x2x4x4xf32>
+// CHECK: vector.transfer_read %[[C]]
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_buffer_non_packed_source_shape(%A: memref<8x64x64xf16>,
+ %vecB: vector<8x16x2xf16>, %vecC: vector<4x16xf32>,
+ %idx: index) -> vector<4x16xf32> {
+ %c0_f16 = arith.constant 0.0 : f16
+ %vecA = vector.transfer_read %A[%idx, %idx, %idx], %c0_f16
+ {in_bounds = [true, true, true]} : memref<8x64x64xf16>, vector<4x8x2xf16>
+ %vecD = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+ return %vecD : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @negative_buffer_non_packed_source_shape(
+// CHECK-SAME: %[[A:.+]]: memref<8x64x64xf16>
+// CHECK: vector.transfer_read %[[A]]
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 31e17fb..9b57b1b 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt --convert-to-llvm="filter-dialects=vector" --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --convert-to-llvm="filter-dialects=vector allow-pattern-rollback=0" --split-input-file %s | FileCheck %s
// RUN: mlir-opt %s -convert-vector-to-llvm -split-input-file | FileCheck %s
//===========================================================================//
@@ -182,8 +183,7 @@ func.func @shuffle_0D_direct(%arg0: vector<f32>) -> vector<3xf32> {
}
// CHECK-LABEL: @shuffle_0D_direct(
// CHECK-SAME: %[[A:.*]]: vector<f32>
-// CHECK: %[[c:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
-// CHECK: %[[s:.*]] = llvm.shufflevector %[[c]], %[[c]] [0, 1, 0] : vector<1xf32>
+// CHECK: %[[s:.*]] = llvm.shufflevector %{{.*}}, %{{.*}} [0, 1, 0] : vector<1xf32>
// CHECK: return %[[s]] : vector<3xf32>
// -----
@@ -1679,6 +1679,16 @@ func.func @load_0d(%memref : memref<200x100xf32>, %i : index, %j : index) -> vec
// -----
+func.func @load_with_alignment(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
+ %0 = vector.load %memref[%i, %j] { alignment = 8 } : memref<200x100xf32>, vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @load_with_alignment
+// CHECK: llvm.load {{.*}} {alignment = 8 : i64} : !llvm.ptr -> vector<8xf32>
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.store
//===----------------------------------------------------------------------===//
@@ -1785,6 +1795,16 @@ func.func @store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
// -----
+func.func @store_with_alignment(%memref : memref<200x100xf32>, %i : index, %j : index, %val : vector<4xf32>) {
+ vector.store %val, %memref[%i, %j] {alignment = 8} : memref<200x100xf32>, vector<4xf32>
+ return
+}
+
+// CHECK-LABEL: func @store_with_alignment
+// CHECK: llvm.store %{{.*}} {alignment = 8 : i64} : vector<4xf32>, !llvm.ptr
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.maskedload
//===----------------------------------------------------------------------===//
@@ -1839,6 +1859,16 @@ func.func @masked_load_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[16]
// -----
+func.func @masked_load_with_alignment(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>, %arg3: index) -> vector<16xf32> {
+ %0 = vector.maskedload %arg0[%arg3], %arg1, %arg2 { alignment = 2 } : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ return %0 : vector<16xf32>
+}
+
+// CHECK-LABEL: func @masked_load_with_alignment
+// CHECK: llvm.intr.masked.load %{{.*}} {alignment = 2 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.maskedstore
//===----------------------------------------------------------------------===//
@@ -1891,6 +1921,16 @@ func.func @masked_store_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[16
// -----
+func.func @masked_store_with_alignment(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>, %arg3: index) {
+ vector.maskedstore %arg0[%arg3], %arg1, %arg2 { alignment = 2 } : memref<?xf32>, vector<16xi1>, vector<16xf32>
+ return
+}
+
+// CHECK-LABEL: func @masked_store_with_alignment
+// CHECK: llvm.intr.masked.store %{{.*}} {alignment = 2 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.gather
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 72810b5..07d3351 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1737,3 +1737,40 @@ func.func @step() -> vector<4xindex> {
%0 = vector.step : vector<4xindex>
return %0 : vector<4xindex>
}
+
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.from_elements
+//===----------------------------------------------------------------------===//
+
+// NOTE: We unroll multi-dimensional from_elements ops with pattern `UnrollFromElements`
+// and then convert the 1-D from_elements ops to llvm.
+
+// CHECK-LABEL: func @from_elements_3d
+// CHECK-SAME: %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: f32, %[[ARG_2:.*]]: f32, %[[ARG_3:.*]]: f32)
+// CHECK: %[[UNDEF_RES:.*]] = ub.poison : vector<2x1x2xf32>
+// CHECK: %[[UNDEF_RES_LLVM:.*]] = builtin.unrealized_conversion_cast %[[UNDEF_RES]] : vector<2x1x2xf32> to !llvm.array<2 x array<1 x vector<2xf32>>>
+// CHECK: %[[UNDEF_VEC_RANK_2:.*]] = ub.poison : vector<1x2xf32>
+// CHECK: %[[UNDEF_VEC_RANK_2_LLVM:.*]] = builtin.unrealized_conversion_cast %[[UNDEF_VEC_RANK_2]] : vector<1x2xf32> to !llvm.array<1 x vector<2xf32>>
+// CHECK: %[[UNDEF_VEC0:.*]] = llvm.mlir.poison : vector<2xf32>
+// CHECK: %[[C0_0:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VEC0_0:.*]] = llvm.insertelement %[[ARG_0]], %[[UNDEF_VEC0]][%[[C0_0]] : i64] : vector<2xf32>
+// CHECK: %[[C1_0:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[VEC0_1:.*]] = llvm.insertelement %[[ARG_1]], %[[VEC0_0]][%[[C1_0]] : i64] : vector<2xf32>
+// CHECK: %[[RES_RANK_2_0:.*]] = llvm.insertvalue %[[VEC0_1]], %[[UNDEF_VEC_RANK_2_LLVM]][0] : !llvm.array<1 x vector<2xf32>>
+// CHECK: %[[RES_0:.*]] = llvm.insertvalue %[[RES_RANK_2_0]], %[[UNDEF_RES_LLVM]][0] : !llvm.array<2 x array<1 x vector<2xf32>>>
+// CHECK: %[[UNDEF_VEC1:.*]] = llvm.mlir.poison : vector<2xf32>
+// CHECK: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VEC1_0:.*]] = llvm.insertelement %[[ARG_2]], %[[UNDEF_VEC1]][%[[C0_1]] : i64] : vector<2xf32>
+// CHECK: %[[C1_1:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[VEC1_1:.*]] = llvm.insertelement %[[ARG_3]], %[[VEC1_0]][%[[C1_1]] : i64] : vector<2xf32>
+// CHECK: %[[RES_RANK_2_1:.*]] = llvm.insertvalue %[[VEC1_1]], %[[UNDEF_VEC_RANK_2_LLVM]][0] : !llvm.array<1 x vector<2xf32>>
+// CHECK: %[[RES_1:.*]] = llvm.insertvalue %[[RES_RANK_2_1]], %[[RES_0]][1] : !llvm.array<2 x array<1 x vector<2xf32>>>
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x array<1 x vector<2xf32>>> to vector<2x1x2xf32>
+// CHECK: return %[[CAST]]
+func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> {
+ %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
+ return %0 : vector<2x1x2xf32>
+}
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 8918f91..4b56897 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -953,6 +953,14 @@ func.func @vector_load_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<St
return %0: vector<1xf32>
}
+// CHECK-LABEL: @vector_load_aligned
+func.func @vector_load_aligned(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+ %idx = arith.constant 0 : index
+ // CHECK: spirv.Load
+ // CHECK-SAME: ["Aligned", 8]
+ %0 = vector.load %arg0[%idx] { alignment = 8 } : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+ return %0: vector<4xf32>
+}
// CHECK-LABEL: @vector_load_2d
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
@@ -996,6 +1004,15 @@ func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer
return
}
+// CHECK-LABEL: @vector_store_aligned
+func.func @vector_store_aligned(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
+ %idx = arith.constant 0 : index
+ // CHECK: spirv.Store
+ // CHECK-SAME: ["Aligned", 8]
+ vector.store %arg1, %arg0[%idx] { alignment = 8 } : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+ return
+}
+
// CHECK-LABEL: @vector_store_single_elem
// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
// CHECK-SAME: %[[ARG1:.*]]: vector<1xf32>
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index d1e5a62..b373bda 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -1,212 +1,398 @@
-// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
+// RUN: mlir-opt %s --xevm-attach-target='module=xevm_* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=LOAD-ND
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=LOAD-GATHER
-func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector<8xf32> {
+gpu.module @xevm_module {
+gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector<8xf32> {
%c0 = arith.constant 0.0 : f32
%0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
{in_bounds = [true]} : memref<8x16x32xf32>, vector<8xf32>
- return %0 : vector<8xf32>
+ gpu.return %0 : vector<8xf32>
}
-// CHECK-LABEL: @load_1D_vector(
-// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
-// CHECK-SAME: %[[OFFSET:.+]]: index
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
-// CHECK-SAME: boundary_check = false
-// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
-// CHECK: return %[[VEC]]
+// LOAD-ND-LABEL: @load_1D_vector(
+// LOAD-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
+// LOAD-ND-SAME: %[[OFFSET:.+]]: index
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
+// LOAD-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// LOAD-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// LOAD-ND-SAME: boundary_check = false
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
+// LOAD-ND: return %[[VEC]]
+
+// LOAD-GATHER-LABEL: @load_1D_vector(
+// LOAD-GATHER-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
+// LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
+// LOAD-GATHER: %[[STEP:.+]] = vector.step : vector<8xindex>
+// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index
+// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
+// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
+// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
+// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
+// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
-// -----
+}
-func.func @load_2D_vector(%source: memref<8x16x32xf32>,
+// -----
+gpu.module @xevm_module {
+gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
%offset: index) -> vector<8x16xf32> {
%c0 = arith.constant 0.0 : f32
%0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
{in_bounds = [true, true]} : memref<8x16x32xf32>, vector<8x16xf32>
- return %0 : vector<8x16xf32>
+ gpu.return %0 : vector<8x16xf32>
}
-// CHECK-LABEL: @load_2D_vector(
-// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
-// CHECK-SAME: %[[OFFSET:.+]]: index
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
-// CHECK-SAME: boundary_check = false
-// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
-// CHECK: return %[[VEC]]
+// LOAD-ND-LABEL: @load_2D_vector(
+// LOAD-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
+// LOAD-ND-SAME: %[[OFFSET:.+]]: index
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
+// LOAD-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// LOAD-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// LOAD-ND-SAME: boundary_check = false
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND: return %[[VEC]]
+
+// LOAD-GATHER-LABEL: @load_2D_vector(
+// LOAD-GATHER-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
+// LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// LOAD-GATHER-COUNT2: vector.step
+// LOAD-GATHER-COUNT2: vector.shape_cast
+// LOAD-GATHER-COUNT2: vector.broadcast
+// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index
+// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
+// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
+// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}}: vector<8x16xindex>
+// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
+// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
-// -----
+}
-func.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
%offset: index) -> vector<8x16xf32> {
%c0 = arith.constant 0.0 : f32
%0 = vector.transfer_read %source[%offset, %offset], %c0
{in_bounds = [false, true]} : memref<32x64xf32>, vector<8x16xf32>
- return %0 : vector<8x16xf32>
+ gpu.return %0 : vector<8x16xf32>
}
-// CHECK-LABEL: @load_zero_pad_out_of_bounds(
-// CHECK-SAME: %[[SRC:.+]]: memref<32x64xf32>,
-// CHECK-SAME: %[[OFFSET:.+]]: index
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
-// CHECK: return %[[VEC]]
+// LOAD-ND-LABEL: @load_zero_pad_out_of_bounds(
+// LOAD-ND-SAME: %[[SRC:.+]]: memref<32x64xf32>,
+// LOAD-ND-SAME: %[[OFFSET:.+]]: index
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// LOAD-ND-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32>
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND: return %[[VEC]]
-// -----
+// LOAD-GATHER-LABEL: @load_zero_pad_out_of_bounds(
+// LOAD-GATHER: vector.transfer_read
-func.func @load_transposed(%source: memref<32x64xf32>,
- %offset: index) -> vector<8x16xf32> {
+}
+
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_transposed(%source: memref<32x64xf32>,
+ %i: index, %j: index) -> vector<8x16xf32> {
%c0 = arith.constant 0.0 : f32
- %0 = vector.transfer_read %source[%offset, %offset], %c0
+ %0 = vector.transfer_read %source[%i, %j], %c0
{permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
in_bounds = [true, true]} : memref<32x64xf32>, vector<8x16xf32>
- return %0 : vector<8x16xf32>
+ gpu.return %0 : vector<8x16xf32>
}
-// CHECK-LABEL: @load_transposed(
-// CHECK-SAME: %[[SRC:.+]]: memref<32x64xf32>,
-// CHECK-SAME: %[[OFFSET:.+]]: index
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<16x8xf32
-// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]] <{transpose = array<i64: 1, 0>}>
-// CHECK-SAME: -> vector<8x16xf32>
-// CHECK: return %[[VEC]]
+// LOAD-ND-LABEL: @load_transposed(
+// LOAD-ND-SAME: %[[SRC:.+]]: memref<32x64xf32>,
+// LOAD-ND-SAME: %[[OFFSET1:.+]]: index,
+// LOAD-ND-SAME: %[[OFFSET2:.+]]: index
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET1]], %[[OFFSET2]]]
+// LOAD-ND-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<16x8xf32
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]] <{transpose = array<i64: 1, 0>}>
+// LOAD-ND-SAME: -> vector<8x16xf32>
+// LOAD-ND: return %[[VEC]]
+
+
+// LOAD-GATHER-LABEL: @load_transposed(
+// LOAD-GATHER-SAME: %[[SRC:.+]]: memref<32x64xf32>,
+// LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// LOAD-GATHER-COUNT2: vector.step
+// LOAD-GATHER-COUNT2: vector.shape_cast
+// LOAD-GATHER-COUNT2: vector.broadcast
+// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index
+// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
+// LOAD-GATHER: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
+// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
+// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf32> into memref<2048xf32>
+// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
-// -----
+}
-func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
- %offset: index) -> vector<8x16xf32> {
+// -----
+gpu.module @xevm_module {
+gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
+ %i: index, %j: index, %k: index) -> vector<8x16xf32> {
%c0 = arith.constant 0.0 : f32
- %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
+ %0 = vector.transfer_read %source[%i, %j, %k], %c0
{in_bounds = [true, true]} : memref<?x?x?xf32>, vector<8x16xf32>
- return %0 : vector<8x16xf32>
-}
-
-// CHECK-LABEL: @load_dynamic_source(
-// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
-// CHECK-SAME: %[[OFFSET:.+]]: index
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
-// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
-// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
-// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
-// CHECK: return %[[VEC]]
+ gpu.return %0 : vector<8x16xf32>
+}
+// LOAD-ND-LABEL: @load_dynamic_source(
+// LOAD-ND-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
+// LOAD-ND-SAME: %[[OFFSET:.+]]: index
+// LOAD-ND: %[[C2:.+]] = arith.constant 2 : index
+// LOAD-ND: %[[C1:.+]] = arith.constant 1 : index
+// LOAD-ND: %[[C0:.+]] = arith.constant 0 : index
+// LOAD-ND-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
+// LOAD-ND-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
+// LOAD-ND-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
+// LOAD-ND: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]]
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND: return %[[VEC]]
+
+
+// LOAD-GATHER-LABEL: @load_dynamic_source(
+// LOAD-GATHER-SAME: %[[ARG0:.+]]: memref<?x?x?xf32>,
+// LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// LOAD-GATHER: memref.extract_strided_metadata %[[ARG0]]
+// LOAD-GATHER-COUNT2: vector.step
+// LOAD-GATHER-COUNT2: vector.shape_cast
+// LOAD-GATHER-COUNT2: vector.broadcast
+// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index
+// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
+// LOAD-GATHER: %[[BROADIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
+// LOAD-GATHER: %[[FINALIDX:.+]] = arith.addi %[[BROADIDX]], {{.*}} : vector<8x16xindex>
+// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
+// LOAD-GATHER: %[[RES:.+]] = xegpu.load %[[COLLAPSE]][%[[FINALIDX]]], %[[CST]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-GATHER: gpu.return %[[RES]] : vector<8x16xf32>
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
+ %i: index, %j: index, %k: index) -> vector<8x16xf32> {
+ %c0 = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %source[%i, %j, %k], %c0
+ {in_bounds = [true, true]} : memref<?x8x16xf32>, vector<8x16xf32>
+ gpu.return %0 : vector<8x16xf32>
+}
+
+// LOAD-ND-LABEL: @load_dynamic_source2(
+// LOAD-ND-DAG: %[[C0:.+]] = arith.constant 0 : index
+// LOAD-ND-DAG: %[[DIM:.+]] = memref.dim %{{.*}}, %[[C0]] : memref<?x8x16xf32>
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}], shape : [%[[DIM]], 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
+// LOAD-ND: return %[[VEC]] : vector<8x16xf32>
+
+// LOAD-GATHER-LABEL: @load_dynamic_source2(
+// LOAD-GATHER-DAG: %[[CST_0:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// LOAD-GATHER-COUNT2: vector.step
+// LOAD-GATHER-COUNT2: vector.shape_cast
+// LOAD-GATHER-COUNT2: vector.broadcast
+// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index
+// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
+// LOAD-GATHER-DAG: %[[BCASTIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
+// LOAD-GATHER-DAG: %[[OFFSETS:.+]] = arith.addi %[[BCASTIDX]], {{.*}} : vector<8x16xindex>
+// LOAD-GATHER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
+// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_dynamic_source3(%source: memref<?x?x?x?x?xf32>,
+ %i: index, %j: index, %k: index, %l: index, %m: index) -> vector<2x4x8x16xf32> {
+ %c0 = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %source[%i, %j, %k, %l, %m], %c0
+ {in_bounds = [true, true, true, true]} : memref<?x?x?x?x?xf32>, vector<2x4x8x16xf32>
+ gpu.return %0 : vector<2x4x8x16xf32>
+}
+
+// LOAD-ND-LABEL: @load_dynamic_source3(
+// LOAD-ND: vector.transfer_read
+
+// LOAD-GATHER-LABEL: @load_dynamic_source3(
+// LOAD-GATHER-SAME: %[[SRC:.+]]: memref<?x?x?x?x?xf32>
+// LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<2x4x8x16xi1>
+// LOAD-GATHER: memref.extract_strided_metadata %[[SRC]] : memref<?x?x?x?x?xf32> -> memref<f32>, index, index, index, index, index, index, index, index, index, index, index
+// LOAD-GATHER-COUNT4: vector.step
+// LOAD-GATHER-COUNT3: vector.broadcast
+// LOAD-GATHER-COUNT4: vector.shape_cast
+// LOAD-GATHER-COUNT4: vector.broadcast {{.*}} : vector<2x4x8x16xindex>
+// LOAD-GATHER-COUNT3: arith.addi {{.*}} : vector<2x4x8x16xindex>
+// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<2x4x8x16xindex>
+// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<2x4x8x16xindex>
+// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2, 3, 4]{{\]}} : memref<?x?x?x?x?xf32> into memref<?xf32>
+// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<?xf32>, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
+// LOAD-GATHER: return %[[VEC]]
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_high_dim_vector(%source: memref<16x32x64xf32>,
+ %offset: index, %arg2: index) -> vector<8x16x32xf32> {
+ %c0 = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %source[%offset, %arg2, %offset], %c0
+ {in_bounds = [true, true, true]} : memref<16x32x64xf32>, vector<8x16x32xf32>
+ gpu.return %0 : vector<8x16x32xf32>
+}
+
+// LOAD-ND-LABEL: @load_high_dim_vector(
+// LOAD-ND: vector.transfer_read
+
+// LOAD-GATHER-LABEL: @load_high_dim_vector(
+// LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8x16x32xi1>
+// LOAD-GATHER: %[[CST_0:.+]] = arith.constant dense<64> : vector<16xindex>
+// LOAD-GATHER: %[[CST_1:.+]] = arith.constant dense<2048> : vector<8xindex>
+// LOAD-GATHER: %[[C2048:.+]] = arith.constant 2048 : index
+// LOAD-GATHER: %[[C64:.+]] = arith.constant 64 : index
+// LOAD-GATHER-COUNT3: vector.step
+// LOAD-GATHER-COUNT3: vector.shape_cast
+// LOAD-GATHER-COUNT3: vector.broadcast {{.*}} : vector<8x16x32xindex>
+// LOAD-GATHER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
+// LOAD-GATHER: %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
+// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
+// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32>
+// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
+
+}
// -----
+gpu.module @xevm_module {
+gpu.func @load_transpose_f16(%source: memref<32x64xf16>,
+ %offset: index) -> vector<8x16xf16> {
+ %c0 = arith.constant 0.0 : f16
+ %0 = vector.transfer_read %source[%offset, %offset], %c0
+ {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
+ in_bounds = [true, true]} : memref<32x64xf16>, vector<8x16xf16>
+ gpu.return %0 : vector<8x16xf16>
+}
-func.func @no_load_out_of_bounds_non_zero_pad(%source: memref<32x64xf32>,
+// LOAD-ND-LABEL: @load_transpose_f16(
+// LOAD-ND: vector.transfer_read
+
+// LOAD-GATHER-LABEL: @load_transpose_f16(
+// LOAD-GATHER-SAME: %[[SRC:.+]]: memref<32x64xf16>,
+// LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// LOAD-GATHER-COUNT2: vector.step
+// LOAD-GATHER-COUNT2: vector.shape_cast
+// LOAD-GATHER-COUNT2: vector.broadcast
+// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index
+// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
+// LOAD-GATHER: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
+// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
+// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf16> into memref<2048xf16>
+// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf16>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @no_load_out_of_bounds_non_zero_pad(%source: memref<32x64xf32>,
%offset: index, %arg2: index, %pad: f32) -> (vector<8x16xf32>, vector<8x16xf32>) {
%c1 = arith.constant 1.0 : f32
%0 = vector.transfer_read %source[%offset, %arg2], %c1
{in_bounds = [true, false]} : memref<32x64xf32>, vector<8x16xf32>
%1 = vector.transfer_read %source[%arg2, %offset], %pad
{in_bounds = [false, true]} : memref<32x64xf32>, vector<8x16xf32>
- return %0, %1 : vector<8x16xf32>, vector<8x16xf32>
+ gpu.return %0, %1 : vector<8x16xf32>, vector<8x16xf32>
}
-// CHECK-LABEL: @no_load_out_of_bounds_non_zero_pad(
-// CHECK-COUNT-2: vector.transfer_read
+// LOAD-ND-LABEL: @no_load_out_of_bounds_non_zero_pad(
+// LOAD-ND-COUNT-2: vector.transfer_read
-// -----
+// LOAD-GATHER-LABEL: @no_load_out_of_bounds_non_zero_pad(
+// LOAD-GATHER-COUNT-2: vector.transfer_read
+}
-func.func @no_load_out_of_bounds_1D_vector(%source: memref<8x16x32xf32>,
+// -----
+gpu.module @xevm_module {
+gpu.func @no_load_out_of_bounds_1D_vector(%source: memref<8x16x32xf32>,
%offset: index) -> vector<8xf32> {
%c0 = arith.constant 0.0 : f32
%0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
{in_bounds = [false]} : memref<8x16x32xf32>, vector<8xf32>
- return %0 : vector<8xf32>
+ gpu.return %0 : vector<8xf32>
}
-// CHECK-LABEL: @no_load_out_of_bounds_1D_vector(
-// CHECK: vector.transfer_read
+// LOAD-ND-LABEL: @no_load_out_of_bounds_1D_vector(
+// LOAD-ND: vector.transfer_read
-// -----
+// LOAD-GATHER-LABEL: @no_load_out_of_bounds_1D_vector(
+// LOAD-GATHER: vector.transfer_read
+}
-func.func @no_load_masked(%source : memref<4xf32>,
+// -----
+gpu.module @xevm_module {
+gpu.func @no_load_masked(%source : memref<4xf32>,
%offset : index) -> vector<4xf32> {
%c0 = arith.constant 0.0 : f32
%mask = arith.constant dense<[0, 1, 0, 1]> : vector<4xi1>
%0 = vector.transfer_read %source[%offset], %c0, %mask
{in_bounds = [true]} : memref<4xf32>, vector<4xf32>
- return %0 : vector<4xf32>
+ gpu.return %0 : vector<4xf32>
}
-// CHECK-LABEL: @no_load_masked(
-// CHECK: vector.transfer_read
+// LOAD-ND-LABEL: @no_load_masked(
+// LOAD-ND: vector.transfer_read
-// -----
+// LOAD-GATHER-LABEL: @no_load_masked(
+// LOAD-GATHER: vector.transfer_read
+}
-func.func @no_load_tensor(%source: tensor<32x64xf32>,
+// -----
+gpu.module @xevm_module {
+gpu.func @no_load_tensor(%source: tensor<32x64xf32>,
%offset: index, %arg2: index) -> vector<8x16xf32> {
%c0 = arith.constant 0.0 : f32
%0 = vector.transfer_read %source[%offset, %arg2], %c0
{in_bounds = [true, true]} : tensor<32x64xf32>, vector<8x16xf32>
- return %0 : vector<8x16xf32>
+ gpu.return %0 : vector<8x16xf32>
}
-// CHECK-LABEL: @no_load_tensor(
-// CHECK: vector.transfer_read
+// LOAD-ND-LABEL: @no_load_tensor(
+// LOAD-ND: vector.transfer_read
-// -----
-
-func.func @no_load_high_dim_vector(%source: memref<16x32x64xf32>,
- %offset: index, %arg2: index) -> vector<8x16x32xf32> {
- %c0 = arith.constant 0.0 : f32
- %0 = vector.transfer_read %source[%offset, %arg2, %offset], %c0
- {in_bounds = [true, true, true]} : memref<16x32x64xf32>, vector<8x16x32xf32>
- return %0 : vector<8x16x32xf32>
+// LOAD-GATHER-LABEL: @no_load_tensor(
+// LOAD-GATHER: vector.transfer_read
}
-// CHECK-LABEL: @no_load_high_dim_vector(
-// CHECK: vector.transfer_read
// -----
-
-func.func @no_load_non_unit_inner_stride(
+gpu.module @xevm_module {
+gpu.func @no_load_non_unit_inner_stride(
%source: memref<32xf32, strided<[?], offset: ?>>,
%offset: index) -> vector<8xf32> {
%c0 = arith.constant 0.0 : f32
%0 = vector.transfer_read %source[%offset], %c0 {in_bounds = [true]}
: memref<32xf32, strided<[?], offset: ?>>, vector<8xf32>
- return %0 : vector<8xf32>
+ gpu.return %0 : vector<8xf32>
}
-// CHECK-LABEL: @no_load_non_unit_inner_stride(
-// CHECK: vector.transfer_read
+// LOAD-ND-LABEL: @no_load_non_unit_inner_stride(
+// LOAD-ND: vector.transfer_read
+
+// LOAD-GATHER-LABEL: @no_load_non_unit_inner_stride(
+// LOAD-GATHER: vector.transfer_read
+}
-// -----
-func.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
+// -----
+gpu.module @xevm_module {
+gpu.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
%offset: index) -> vector<8x16xf32> {
%c0 = arith.constant 0.0 : f32
%0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
{permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)>,
in_bounds = [true, true]} : memref<16x32x64xf32>, vector<8x16xf32>
- return %0 : vector<8x16xf32>
+ gpu.return %0 : vector<8x16xf32>
}
-// CHECK-LABEL: @no_load_unsupported_map(
-// CHECK: vector.transfer_read
-
-// -----
+// LOAD-ND-LABEL: @no_load_unsupported_map(
+// LOAD-ND: vector.transfer_read
-func.func @no_load_transpose_unsupported_data_type(%source: memref<32x64xf16>,
- %offset: index) -> vector<8x16xf16> {
- %c0 = arith.constant 0.0 : f16
- %0 = vector.transfer_read %source[%offset, %offset], %c0
- {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
- in_bounds = [true, true]} : memref<32x64xf16>, vector<8x16xf16>
- return %0 : vector<8x16xf16>
+// LOAD-GATHER-LABEL: @no_load_unsupported_map(
+// LOAD-GATHER: vector.transfer_read
}
-// CHECK-LABEL: @no_load_transpose_unsupported_data_type(
-// CHECK: vector.transfer_read
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index d5f1221..b3f761a 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -1,178 +1,278 @@
-// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
+// RUN: mlir-opt %s --xevm-attach-target='module=xevm.* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-ND
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-SCATTER
-func.func @store_1D_vector(%vec: vector<8xf32>,
+
+gpu.module @xevm_module {
+gpu.func @store_1D_vector(%vec: vector<8xf32>,
%source: memref<8x16x32xf32>, %offset: index) {
vector.transfer_write %vec, %source[%offset, %offset, %offset]
{in_bounds = [true]}
: vector<8xf32>, memref<8x16x32xf32>
- return
+ gpu.return
}
-// CHECK-LABEL: @store_1D_vector(
-// CHECK-SAME: %[[VEC:.+]]: vector<8xf32>,
-// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
-// CHECK-SAME: %[[OFFSET:.+]]: index
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
-// CHECK-SAME: boundary_check = false
-// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32>
+// STORE-ND-LABEL: @store_1D_vector(
+// STORE-ND-SAME: %[[VEC:.+]]: vector<8xf32>,
+// STORE-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
+// STORE-ND-SAME: %[[OFFSET:.+]]: index
+// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
+// STORE-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// STORE-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// STORE-ND-SAME: boundary_check = false
+// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32>
+
+// STORE-SCATTER-LABEL: @store_1D_vector(
+// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8xf32>,
+// STORE-SCATTER-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
+// STORE-SCATTER-DAG: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
+// STORE-SCATTER-DAG: %[[STEP:.+]] = vector.step
+// STORE-SCATTER-COUNT2: arith.muli {{.*}} : index
+// STORE-SCATTER-COUNT2: arith.addi {{.*}} : index
+// STORE-SCATTER-DAG: %[[BCAST:.+]] = vector.broadcast {{.*}} : index to vector<8xindex>
+// STORE-SCATTER-DAG: %[[IDX:.+]] = arith.addi %[[BCAST]], %{{.*}} : vector<8xindex>
+// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
+// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8xf32>, memref<4096xf32>, vector<8xindex>, vector<8xi1>
+}
// -----
-
-func.func @store_2D_vector(%vec: vector<8x16xf32>,
+gpu.module @xevm_module {
+gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
%source: memref<8x16x32xf32>, %offset: index) {
vector.transfer_write %vec, %source[%offset, %offset, %offset]
{in_bounds = [true, true]}
: vector<8x16xf32>, memref<8x16x32xf32>
- return
+ gpu.return
}
-// CHECK-LABEL: @store_2D_vector(
-// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
-// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
-// CHECK-SAME: %[[OFFSET:.+]]: index
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
-// CHECK-SAME: boundary_check = false
-// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// STORE-ND-LABEL: @store_2D_vector(
+// STORE-ND-SAME: %[[VEC:.+]]: vector<8x16xf32>,
+// STORE-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
+// STORE-ND-SAME: %[[OFFSET:.+]]: index
+// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
+// STORE-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// STORE-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// STORE-ND-SAME: boundary_check = false
+// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+
+// STORE-SCATTER-LABEL: @store_2D_vector(
+// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8x16xf32>,
+// STORE-SCATTER-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
+// STORE-SCATTER-SAME: %[[OFFSET:.+]]: index
+// STORE-SCATTER: %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// STORE-SCATTER-COUNT2: %[[STEP:.+]] = vector.step
+// STORE-SCATTER-COUNT2: vector.shape_cast {{.*}}
+// STORE-SCATTER-COUNT2: vector.broadcast {{.*}} : vector<8x16xindex>
+// STORE-SCATTER-DAG: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
+// STORE-SCATTER-DAG: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}} : vector<8x16xindex>
+// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
+// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1>
+}
// -----
-
-func.func @store_dynamic_source(%vec: vector<8x16xf32>,
+gpu.module @xevm_module {
+gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
%source: memref<?x?x?xf32>, %offset: index) {
vector.transfer_write %vec, %source[%offset, %offset, %offset]
{in_bounds = [true, true]}
: vector<8x16xf32>, memref<?x?x?xf32>
- return
-}
-
-// CHECK-LABEL: @store_dynamic_source(
-// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
-// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
-// CHECK-SAME: %[[OFFSET:.+]]: index
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
-// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
-// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
-// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+ gpu.return
+}
-// -----
+// STORE-ND-LABEL: @store_dynamic_source(
+// STORE-ND-SAME: %[[VEC:.+]]: vector<8x16xf32>,
+// STORE-ND-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
+// STORE-ND-SAME: %[[OFFSET:.+]]: index
+// STORE-ND-DAG: %[[C0:.+]] = arith.constant 0 : index
+// STORE-ND-DAG: %[[C1:.+]] = arith.constant 1 : index
+// STORE-ND-DAG: %[[C2:.+]] = arith.constant 2 : index
+// STORE-ND-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
+// STORE-ND-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
+// STORE-ND-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
+// STORE-ND: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// STORE-ND-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
+// STORE-ND-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
+// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+
+// STORE-SCATTER-LABEL: @store_dynamic_source(
+// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8x16xf32>,
+// STORE-SCATTER-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
+// STORE-SCATTER-DAG: %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// STORE-SCATTER-DAG: memref.extract_strided_metadata %[[SRC]] : memref<?x?x?xf32> -> memref<f32>, index, index, index, index, index, index, index
+// STORE-SCATTER-COUNT2: %[[STEP:.+]] = vector.step
+// STORE-SCATTER-COUNT2: vector.shape_cast {{.*}}
+// STORE-SCATTER-COUNT2: vector.broadcast {{.*}} : vector<8x16xindex>
+// STORE-SCATTER-DAG: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
+// STORE-SCATTER-DAG: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}} : vector<8x16xindex>
+// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
+// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, memref<?xf32>, vector<8x16xindex>, vector<8x16xi1>
+}
-func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
+// -----
+gpu.module @xevm_module {
+gpu.func @store_out_of_bounds(%vec: vector<8x16xf32>,
%source: memref<7x64xf32>, %offset: index) {
vector.transfer_write %vec, %source[%offset, %offset]
{in_bounds = [false, true]}
: vector<8x16xf32>, memref<7x64xf32>
- return
+ gpu.return
}
-// CHECK-LABEL: @store_out_of_bounds(
-// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
-// CHECK-SAME: %[[SRC:.+]]: memref<7x64xf32>,
-// CHECK-SAME: %[[OFFSET:.+]]: index
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// STORE-ND-LABEL: @store_out_of_bounds(
+// STORE-ND-SAME: %[[VEC:.+]]: vector<8x16xf32>,
+// STORE-ND-SAME: %[[SRC:.+]]: memref<7x64xf32>,
+// STORE-ND-SAME: %[[OFFSET:.+]]: index
+// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
+// STORE-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// STORE-ND-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
+// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+
+// STORE-SCATTER-LABEL: @store_out_of_bounds(
+// STORE-SCATTER: vector.transfer_write
+}
// -----
-
-func.func @no_store_transposed(%vec: vector<8x16xf32>,
+gpu.module @xevm_module {
+gpu.func @no_store_transposed(%vec: vector<8x16xf32>,
%source: memref<32x64xf32>, %offset: index) {
vector.transfer_write %vec, %source[%offset, %offset]
{permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
in_bounds = [true, true]}
: vector<8x16xf32>, memref<32x64xf32>
- return
+ gpu.return
}
-// CHECK-LABEL: @no_store_transposed(
-// CHECK: vector.transfer_write
+// STORE-ND-LABEL: @no_store_transposed(
+// STORE-ND: vector.transfer_write
+
+// STORE-SCATTER-LABEL: @no_store_transposed(
+// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8x16xf32>,
+// STORE-SCATTER-SAME: %[[SRC:.+]]: memref<32x64xf32>,
+// STORE-SCATTER-SAME: %[[OFFSET:.+]]: index
+// STORE-SCATTER: %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// STORE-SCATTER-COUNT2: %[[STEP:.+]] = vector.step
+// STORE-SCATTER-COUNT2: vector.shape_cast {{.*}}
+// STORE-SCATTER-COUNT2: vector.broadcast {{.*}} : vector<8x16xindex>
+// STORE-SCATTER-DAG: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
+// STORE-SCATTER-DAG: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}} : vector<8x16xindex>
+// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1]{{\]}} : memref<32x64xf32> into memref<2048xf32>
+// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1>
+}
// -----
+gpu.module @xevm_module {
+gpu.func @store_high_dim_vector(%vec: vector<8x16x32xf32>,
+ %source: memref<16x32x64xf32>, %offset: index) {
+ vector.transfer_write %vec, %source[%offset, %offset, %offset]
+ {in_bounds = [true, true, true]}
+ : vector<8x16x32xf32>, memref<16x32x64xf32>
+ gpu.return
+}
+
+// STORE-ND-LABEL: @store_high_dim_vector(
+// STORE-ND: vector.transfer_write
+
+// STORE-SCATTER-LABEL: @store_high_dim_vector(
+// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8x16x32xf32>,
+// STORE-SCATTER-SAME: %[[SRC:.+]]: memref<16x32x64xf32>
+// STORE-SCATTER: %[[CST:.+]] = arith.constant dense<true> : vector<8x16x32xi1>
+// STORE-SCATTER: %[[CST_0:.+]] = arith.constant dense<64> : vector<16xindex>
+// STORE-SCATTER: %[[CST_1:.+]] = arith.constant dense<2048> : vector<8xindex>
+// STORE-SCATTER: %[[C2048:.+]] = arith.constant 2048 : index
+// STORE-SCATTER: %[[C64:.+]] = arith.constant 64 : index
+// STORE-SCATTER-COUNT3: vector.step
+// STORE-SCATTER-COUNT3: vector.shape_cast
+// STORE-SCATTER-COUNT3: vector.broadcast {{.*}} : vector<8x16x32xindex>
+// STORE-SCATTER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
+// STORE-SCATTER: %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
+// STORE-SCATTER: %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
+// STORE-SCATTER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32>
+// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE]][%[[IDX]]], %[[CST]] : vector<8x16x32xf32>, memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1>
+}
-func.func @no_store_masked(%vec: vector<4xf32>,
+// -----
+gpu.module @xevm_module {
+gpu.func @no_store_masked(%vec: vector<4xf32>,
%source: memref<4xf32>, %offset: index) {
%mask = arith.constant dense<[0, 1, 0, 1]> : vector<4xi1>
vector.transfer_write %vec, %source[%offset], %mask
{in_bounds = [true]}
: vector<4xf32>, memref<4xf32>
- return
+ gpu.return
}
-// CHECK-LABEL: @no_store_masked(
-// CHECK: vector.transfer_write
+// STORE-ND-LABEL: @no_store_masked(
+// STORE-ND: vector.transfer_write
-// -----
+// STORE-SCATTER-LABEL: @no_store_masked(
+// STORE-SCATTER: vector.transfer_write
+}
-func.func @no_store_tensor(%vec: vector<8x16xf32>,
+// -----
+gpu.module @xevm_module {
+gpu.func @no_store_tensor(%vec: vector<8x16xf32>,
%source: tensor<32x64xf32>, %offset: index) -> tensor<32x64xf32> {
%0 = vector.transfer_write %vec, %source[%offset, %offset]
{in_bounds = [true, true]}
: vector<8x16xf32>, tensor<32x64xf32>
- return %0 : tensor<32x64xf32>
+ gpu.return %0 : tensor<32x64xf32>
}
-// CHECK-LABEL: @no_store_tensor(
-// CHECK: vector.transfer_write
-
-// -----
+// STORE-ND-LABEL: @no_store_tensor(
+// STORE-ND: vector.transfer_write
-func.func @no_store_high_dim_vector(%vec: vector<8x16x32xf32>,
- %source: memref<16x32x64xf32>, %offset: index) {
- vector.transfer_write %vec, %source[%offset, %offset, %offset]
- {in_bounds = [true, true, true]}
- : vector<8x16x32xf32>, memref<16x32x64xf32>
- return
+// STORE-SCATTER-LABEL: @no_store_tensor(
+// STORE-SCATTER: vector.transfer_write
}
-// CHECK-LABEL: @no_store_high_dim_vector(
-// CHECK: vector.transfer_write
-
// -----
-
-func.func @no_store_non_unit_inner_stride(%vec: vector<8xf32>,
+gpu.module @xevm_module {
+gpu.func @no_store_non_unit_inner_stride(%vec: vector<8xf32>,
%source: memref<32xf32, strided<[?], offset: ?>>, %offset: index) {
vector.transfer_write %vec, %source[%offset]
{in_bounds = [true]}
: vector<8xf32>, memref<32xf32, strided<[?], offset: ?>>
- return
+ gpu.return
}
-// CHECK-LABEL: @no_store_non_unit_inner_stride(
-// CHECK: vector.transfer_write
+// STORE-ND-LABEL: @no_store_non_unit_inner_stride(
+// STORE-ND: vector.transfer_write
-// -----
+// STORE-SCATTER-LABEL: @no_store_non_unit_inner_stride(
+// STORE-SCATTER: vector.transfer_write
+}
-func.func @no_store_unsupported_map(%vec: vector<8x16xf32>,
+// -----
+gpu.module @xevm_module {
+gpu.func @no_store_unsupported_map(%vec: vector<8x16xf32>,
%source: memref<16x32x64xf32>, %offset: index) {
vector.transfer_write %vec, %source[%offset, %offset, %offset]
{permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)>,
in_bounds = [true, true]}
: vector<8x16xf32>, memref<16x32x64xf32>
- return
+ gpu.return
}
-// CHECK-LABEL: @no_store_unsupported_map(
-// CHECK: vector.transfer_write
+// STORE-ND-LABEL: @no_store_unsupported_map(
+// STORE-ND: vector.transfer_write
-// -----
+// STORE-SCATTER-LABEL: @no_store_unsupported_map(
+// STORE-SCATTER: vector.transfer_write
+}
-func.func @no_store_out_of_bounds_1D_vector(%vec: vector<8xf32>,
+// -----
+gpu.module @xevm_module {
+gpu.func @no_store_out_of_bounds_1D_vector(%vec: vector<8xf32>,
%source: memref<8x16x32xf32>, %offset: index) {
vector.transfer_write %vec, %source[%offset, %offset, %offset]
{in_bounds = [false]}
: vector<8xf32>, memref<8x16x32xf32>
- return
+ gpu.return
}
-// CHECK-LABEL: @no_store_out_of_bounds_1D_vector(
-// CHECK: vector.transfer_write
+// STORE-ND-LABEL: @no_store_out_of_bounds_1D_vector(
+// STORE-ND: vector.transfer_write
+
+// STORE-SCATTER-LABEL: @no_store_out_of_bounds_1D_vector(
+// STORE-SCATTER: vector.transfer_write
+} \ No newline at end of file
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
new file mode 100644
index 0000000..ed664a7
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -0,0 +1,80 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+
+gpu.module @create_nd_tdesc {
+ // CHECK-LABEL: gpu.func @create_nd_tdesc
+ // CHECK-SAME: %[[ARG0:.*]]: memref<16x32xf32, 1>, %[[ARG1:.*]]: ui64,
+ // CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index
+ gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
+ %stride1: index, %stride2: index, %offset1: index, %offset2: index) kernel {
+ // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
+ // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+ // CHECK: %[[OFFSET_W:.*]] = arith.constant 0 : i32
+ // CHECK: %[[OFFSET_H:.*]] = arith.constant 0 : i32
+ // CHECK: %[[SHAPE_W:.*]] = arith.index_cast %[[ARG3]] : index to i32
+ // CHECK: %[[SHAPE_H:.*]] = arith.index_cast %[[ARG2]] : index to i32
+ // CHECK: %[[VAR6:.*]] = vector.bitcast %[[CST]] : vector<8xi32> to vector<4xi64>
+ // CHECK: %[[VAR7:.*]] = vector.insert %[[BASE_ADDR]], %[[VAR6]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[VAR8:.*]] = vector.bitcast %[[VAR7]] : vector<4xi64> to vector<8xi32>
+ // CHECK: %[[VAR9:.*]] = vector.insert %[[SHAPE_W]], %[[VAR8]] [2] : i32 into vector<8xi32>
+ // CHECK: %[[VAR10:.*]] = vector.insert %[[SHAPE_H]], %[[VAR9]] [3] : i32 into vector<8xi32>
+ // CHECK: %[[VAR11:.*]] = vector.insert %[[OFFSET_W]], %[[VAR10]] [4] : i32 into vector<8xi32>
+ // CHECK: %[[VAR12:.*]] = vector.insert %[[OFFSET_H]], %[[VAR11]] [5] : i32 into vector<8xi32>
+ %ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2]
+ : ui64 -> !xegpu.tensor_desc<8x16xf32>
+
+ // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32, 1> to memref<16x32xf32>
+ %srcce = memref.memory_space_cast %src : memref<16x32xf32, 1> to memref<16x32xf32>
+
+ // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
+ // CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
+ // CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
+ // CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
+ // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
+ // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
+ // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
+ // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
+ // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
+ // CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
+ // CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
+ // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
+ // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
+ %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
+ // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
+ // CHECK: %[[OFFSET_W3:.*]] = arith.index_cast %[[ARG7]] : index to i32
+ // CHECK: %[[OFFSET_H3:.*]] = arith.index_cast %[[ARG6]] : index to i32
+ // CHECK: %[[C32_I64_6:.*]] = arith.constant 32 : i64
+ // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C32_I64_6]] : i64 to i32
+ // CHECK: %[[C16_I64_7:.*]] = arith.constant 16 : i64
+ // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C16_I64_7]] : i64 to i32
+ // CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_2]] : index to i64
+ // CHECK: %[[VAR26:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
+ // CHECK: %[[VAR27:.*]] = vector.insert %[[BASE_ADDR3]], %[[VAR26]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[VAR28:.*]] = vector.bitcast %[[VAR27]] : vector<4xi64> to vector<8xi32>
+ // CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR28]] [2] : i32 into vector<8xi32>
+ // CHECK: %[[VAR30:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR29]] [3] : i32 into vector<8xi32>
+ // CHECK: %[[VAR31:.*]] = vector.insert %[[OFFSET_W3]], %[[VAR30]] [4] : i32 into vector<8xi32>
+ // CHECK: %[[VAR32:.*]] = vector.insert %[[OFFSET_H3]], %[[VAR31]] [5] : i32 into vector<8xi32>
+ %src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
+ %c8 = arith.constant 8 : index
+ // CHECK: %[[C16:.*]] = arith.constant 16 : index
+ %c16 = arith.constant 16 : index
+ // CHECK: %[[VAR33:.*]] = arith.index_cast %[[C8]] : index to i32
+ // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[PAYLOAD]][5] : i32 from vector<8xi32>
+ // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR33]] : i32
+ // CHECK: %[[NEW_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[PAYLOAD]] [5] : i32 into vector<8xi32>
+ // CHECK: %[[VAR37:.*]] = arith.index_cast %[[C16]] : index to i32
+ // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[NEW_PAYLOAD]][4] : i32 from vector<8xi32>
+ // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR37]] : i32
+ // CHECK: %[[FINAL_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[NEW_PAYLOAD]] [4] : i32 into vector<8xi32>
+ %updated_tdesc = xegpu.update_nd_offset %src_tdesc, [%c8, %c16] : !xegpu.tensor_desc<8x16xf32>
+ gpu.return
+ }
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
new file mode 100644
index 0000000..e6f22f0
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+
+#sg_map_a_f16 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#sg_map_b_f16 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+#sg_map_c_f32 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+
+gpu.module @load_store_check {
+ // CHECK-LABEL: func.func @dpas(
+ // CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32>
+ func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> {
+ // Loads are checked in a separate test.
+ // CHECK: %[[D:.*]] = xevm.mma %[[ARG0]], %[[ARG1]], %[[ARG2]] {shape = <m = 8, n = 16, k = 16>, types = <d = f32, a = f16, b = f16, c = f32>}
+ // CHECK-SAME: : (vector<8xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
+ %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded {a_layout = #sg_map_a_f16, b_layout = #sg_map_b_f16, c_layout = #sg_map_c_f32}
+ : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
+ return %d : vector<8xf32>
+ }
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/fence.mlir b/mlir/test/Conversion/XeGPUToXeVM/fence.mlir
new file mode 100644
index 0000000..cedfcac
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/fence.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+
+gpu.module @fence_check {
+ gpu.func @fence(%dst: memref<8x16xf32, 1>) kernel {
+ %tid_x = gpu.thread_id x
+ %tid_x_i32 = arith.index_cast %tid_x : index to i32
+ %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
+
+ // CHECK: xevm.memfence <{addrspace = #xevm.addr_space<global>, scope = #xevm.mem_scope<workgroup>}>
+ xegpu.fence memory_kind = global, fence_scope = workgroup
+ %c0 = arith.constant 0 : index
+ memref.store %tid_x_f32, %dst[%c0, %c0] : memref<8x16xf32, 1>
+ gpu.return
+ }
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
new file mode 100644
index 0000000..4c6bbf2
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+
+gpu.module @load_store_check {
+ gpu.func @load_store(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
+ %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
+ %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32>
+
+ // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
+ // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
+ // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
+ // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32>
+ // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
+ // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32>
+ // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32>
+ %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+
+ //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
+ //CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
+ //CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
+ //CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
+ //CHECK: %[[LD_TILE_W64:.*]] = arith.constant 0 : i64
+ //CHECK: %[[LD_TILE_W:.*]] = arith.trunci %[[LD_TILE_W64]] : i64 to i32
+ //CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64
+ //CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32
+ //CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1>
+ //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32
+ //CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32
+ //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]],
+ //CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]]
+ //CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+ //CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
+ //CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
+ %loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+ //CHECK: %[[LD_LOADED_F32:.*]] = vector.bitcast %[[LD_LOADED_I32]] : vector<8xi32> to vector<8xf32>
+
+ %tid_x = gpu.thread_id x
+ %tid_x_i32 = arith.index_cast %tid_x : index to i32
+ %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
+ //CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32>
+ %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>
+
+ // CHECK: %[[PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
+ // CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
+ // CHECK: %[[DESC_0:.*]] = vector.insert %[[PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32>
+ // CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32>
+ // CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32>
+ // CHECK: %[[DESC_4:.*]] = vector.insert {{.*}}, %[[DESC_3]] [4] : i32 into vector<8xi32>
+ // CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32>
+ %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+
+ //CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64>
+ //CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64>
+ //CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32>
+ //CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32>
+ //CHECK: %[[TILE_W64:.*]] = arith.constant 0 : i64
+ //CHECK: %[[TILE_W:.*]] = arith.trunci %[[TILE_W64]] : i64 to i32
+ //CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64
+ //CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32
+ //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1>
+ //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32
+ //CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32
+ //CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32>
+ //CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]],
+ //CHECK-SAME: %[[TILE_W]], %[[TILE_H]], %[[FLAT_VALUE_I32]]
+ //CHECK-SAME: <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+ //CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+ xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+ gpu.return
+ }
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
new file mode 100644
index 0000000..0f67dc2
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
@@ -0,0 +1,261 @@
+// RUN: mlir-opt %s --split-input-file -convert-xegpu-to-xevm | FileCheck %s
+
+gpu.module @test {
+// CHECK-LABEL: @load_gather_ui64_src_constant_offset
+// CHECK-SAME: %[[ARG0:.*]]: ui64
+gpu.func @load_gather_ui64_src_constant_offset(%src: ui64) {
+ // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64
+ %0 = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
+ // CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
+ %1 = arith.constant dense<1>: vector<1xi1>
+ // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+ // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
+ // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR1]], %[[VAR5]] : i64
+ %2 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex>
+ -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ // CHECK: %[[VAR7:.*]] = llvm.inttoptr %[[VAR6]] : i64 to !llvm.ptr<1>
+ // CHECK: %[[VAR8:.*]] = scf.if %[[VAR4]] -> (vector<2xf32>) {
+ // CHECK: %[[VAR9:.*]] = llvm.load %[[VAR7]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}
+ // CHECK-SAME: : !llvm.ptr<1> -> vector<2xf32>
+ // CHECK: scf.yield %[[VAR9]] : vector<2xf32>
+ // CHECK: } else {
+ // CHECK: %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
+ // CHECK: scf.yield %[[CST_1]] : vector<2xf32>
+ %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<1xi1> -> vector<2xf32>
+ gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @load_gather_memref_src_constant_offset
+// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>
+gpu.func @load_gather_memref_src_constant_offset(%src: memref<256xf32>) {
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+ // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
+ %0 = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
+ // CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
+ %1 = arith.constant dense<1>: vector<1xi1>
+ // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+ // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
+ // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
+ %2 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex>
+ -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
+ // CHECK: %[[VAR7:.*]] = scf.if %[[VAR2]] -> (f32) {
+ // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR6]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}
+ // CHECK-SAME: : !llvm.ptr<1> -> vector<1xf32>
+ // CHECK: %[[VAR9:.*]] = vector.extract %[[VAR8]][0] : f32 from vector<1xf32>
+ // CHECK: scf.yield %[[VAR9]] : f32
+ // CHECK: } else {
+ // CHECK: %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+ // CHECK: %[[VAR8:.*]] = vector.extract %[[CST_1:.*]][0] : f32 from vector<1xf32>
+ // CHECK: scf.yield %[[VAR8]] : f32
+ %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xi1> -> vector<1xf32>
+ gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @load_gather_memref_src_value_offset
+// CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>, %[[ARG1:.*]]: vector<1xindex>
+gpu.func @load_gather_memref_src_value_offset(%src: memref<256xf16>, %offset: vector<1xindex>) {
+ // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf16> -> index
+ // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
+ // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
+ %1 = arith.constant dense<1>: vector<1xi1>
+ // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64
+ // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64
+ // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
+ %2 = xegpu.create_tdesc %src, %offset : memref<256xf16>, vector<1xindex>
+ -> !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
+ // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
+ // CHECK: %[[VAR7:.*]] = scf.if %[[VAR2]] -> (vector<8xf16>) {
+ // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR6]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}
+ // CHECK-SAME: : !llvm.ptr<1> -> vector<8xf16>
+ // CHECK: scf.yield %[[VAR8]] : vector<8xf16>
+ // CHECK: } else {
+ // CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<8xf16>
+ // CHECK: scf.yield %[[CST_0]] : vector<8xf16>
+ %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<1xi1> -> vector<8xf16>
+ gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @store_scatter_ui64_src_constant_offset
+// CHECK-SAME: %[[ARG0:.*]]: ui64
+gpu.func @store_scatter_ui64_src_constant_offset(%src: ui64) {
+ // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64
+ %0 = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
+ // CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
+ %1 = arith.constant dense<1>: vector<1xi1>
+ // CHECK: %[[CST_1:.*]] = arith.constant dense<2.900000e+00> : vector<2xf32>
+ %2 = arith.constant dense<2.9>: vector<2xf32>
+ // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+ // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
+ // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR1]], %[[VAR5]] : i64
+ %3 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex>
+ -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ // CHECK: %[[VAR7:.*]] = llvm.inttoptr %[[VAR6]] : i64 to !llvm.ptr<1>
+ // CHECK: scf.if %[[VAR4]] {
+ // CHECK: llvm.store %[[CST_1]], %[[VAR7]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}
+ // CHECK-SAME: : vector<2xf32>, !llvm.ptr<1>
+ xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : vector<2xf32>, !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<1xi1>
+ gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @store_scatter_memref_src_constant_offset
+// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>
+gpu.func @store_scatter_memref_src_constant_offset(%src: memref<256xf32>) {
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+ // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
+ %0 = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
+ // CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
+ %1 = arith.constant dense<1>: vector<1xi1>
+ // CHECK: %[[CST_1:.*]] = arith.constant dense<2.900390e+00> : vector<2xf16>
+ %2 = arith.constant dense<2.9>: vector<2xf16>
+ // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64
+ // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64
+ // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
+ %3 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex>
+ -> !xegpu.tensor_desc<1x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
+ // CHECK: scf.if %[[VAR2]] {
+ // CHECK: llvm.store %[[CST_1]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}
+ // CHECK-SAME: : vector<2xf16>, !llvm.ptr<1>
+ xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : vector<2xf16>, !xegpu.tensor_desc<1x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<1xi1>
+ gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @store_scatter_memref_src_value_offset
+// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>
+gpu.func @store_scatter_memref_src_value_offset(%src: memref<256xf32>, %offset: vector<1xindex>) {
+ // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+ // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
+ // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
+ %1 = arith.constant dense<1>: vector<1xi1>
+ // CHECK: %[[CST_0:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32>
+ // CHECK: %[[VAR7:.*]] = vector.extract %[[CST_0]][0] : f32 from vector<1xf32>
+ %2 = arith.constant dense<2.9>: vector<1xf32>
+ // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+ // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
+ // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
+ %3 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex>
+ -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
+ // CHECK: scf.if %[[VAR2]] {
+ // CHECK: llvm.store %[[VAR7]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}
+ // CHECK-SAME: : f32, !llvm.ptr<1>
+ xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : vector<1xf32>, !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xi1>
+ gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @prefetch_ui64_src_constant_offset
+// CHECK-SAME: %[[ARG0:.*]]: ui64
+gpu.func @prefetch_ui64_src_constant_offset(%src: ui64) {
+ // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64
+ %0 = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+ // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
+ // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR1]], %[[VAR4]] : i64
+ %1 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex>
+ -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
+ // CHECK: xevm.prefetch %[[VAR6]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>)
+ xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @prefetch_memref_src_constant_offset
+// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>
+gpu.func @prefetch_memref_src_constant_offset(%src: memref<256xf32>) {
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+ // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
+ %0 = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+ // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
+ // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64
+ %1 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex>
+ -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1>
+ // CHECK: xevm.prefetch %[[VAR5]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>)
+ xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @prefetch_memref_src_value_offset
+// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>
+gpu.func @prefetch_memref_src_value_offset(%src: memref<256xf32>, %offset: vector<1xindex>) {
+ // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+ // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+ // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
+ // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64
+ %1 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex>
+ -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1>
+ // CHECK: xevm.prefetch %[[VAR5]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>)
+ xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ gpu.return
+}
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
new file mode 100644
index 0000000..b28a8c2
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm --split-input-file %s | FileCheck %s
+
+// This file contains tests for materalization patterns added to handle custom type conversions
+// added on top of LLVM type converter.
+
+gpu.module @materializecast {
+ // CHECK-LABEL: gpu.func @materialize_memref
+ // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
+ gpu.func @materialize_memref(%src: memref<128xf32>) kernel {
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index
+ // CHECK: %[[CASTED:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ %offset = arith.constant dense<0> : vector<1xindex>
+ %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex>
+ -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+ gpu.return
+ }
+}
+
+// -----
+gpu.module @materializecast {
+ // CHECK-LABEL: gpu.func @materialize_ui64
+ // CHECK-SAME: %[[ARG0:.*]]: ui64
+ gpu.func @materialize_ui64(%src: ui64) kernel {
+ // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
+ %offset = arith.constant dense<0> : vector<1xindex>
+ %src_tdesc = xegpu.create_tdesc %src, %offset : ui64, vector<1xindex>
+ -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+ gpu.return
+ }
+}
+
+// -----
+gpu.module @materializecast {
+ // CHECK-LABEL: gpu.func @materialize_ui32
+ // CHECK-SAME: %[[ARG0:.*]]: ui32
+ gpu.func @materialize_ui32(%src: ui32) kernel {
+ // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui32 to index
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i32
+ %offset = arith.constant dense<0> : vector<1xindex>
+ %src_tdesc = xegpu.create_tdesc %src, %offset : ui32, vector<1xindex>
+ -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+ gpu.return
+ }
+}
+
+// -----
+gpu.module @materializecast {
+ // CHECK-LABEL: gpu.func @materialize_single_index_vector
+ // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
+ gpu.func @materialize_single_index_vector(%src: memref<128xf32>) kernel {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[VAR1:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR2:.*]] = arith.index_castui %[[VAR1]] : index to i64
+ %offset = arith.constant dense<0> : vector<1xindex>
+ %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex>
+ -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+ gpu.return
+ }
+}
+
+// -----
+gpu.module @materializecast {
+ // CHECK-LABEL: gpu.func @materialize_single_elem_vector
+ // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
+ gpu.func @materialize_single_elem_vector(%src: memref<128xf32>) kernel {
+ // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
+ // CHECK: %[[VAR1:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
+ %mask = arith.constant dense<1>: vector<1xi1>
+ %offset = arith.constant dense<0> : vector<1xindex>
+ %0 = xegpu.load %src[%offset], %mask <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<8xf32>
+ gpu.return
+ }
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
new file mode 100644
index 0000000..873478a
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm -split-input-file %s | FileCheck %s
+
+gpu.module @fence_check {
+ gpu.func @fence(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
+ %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
+ %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32>
+
+ // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
+ // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
+ // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
+ // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32>
+ // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
+ // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32>
+ // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32>
+ %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32,
+ #xegpu.block_tdesc_attr<memory_space = global>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+
+ //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
+ //CHECK: %[[PREF_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
+ //CHECK: %[[PREF_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
+ //CHECK: %[[PREF_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
+ //CHECK: %[[PREF_TILE_W64:.*]] = arith.constant 0 : i64
+ //CHECK: %[[PREF_TILE_W:.*]] = arith.trunci %[[PREF_TILE_W64]] : i64 to i32
+ //CHECK: %[[PREF_TILE_H64:.*]] = arith.constant 0 : i64
+ //CHECK: %[[PREF_TILE_H:.*]] = arith.trunci %[[PREF_TILE_H64]] : i64 to i32
+ //CHECK: %[[PREF_LLVMPTR:.*]] = llvm.inttoptr %[[PREF_INTPTR]] : i64 to !llvm.ptr<1>
+ //CHECK: %[[PREF_SIZEOF_F32:.*]] = arith.constant 4 : i32
+ //CHECK: %[[PREF_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[PREF_BASE_W]], %[[PREF_SIZEOF_F32]] : i32
+ //CHECK: xevm.blockprefetch2d %[[PREF_LLVMPTR]], %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_BASE_H]],
+ //CHECK-SAME: %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_TILE_W]], %[[PREF_TILE_H]]
+ //CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+ //CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}>
+ //CHECK-SAME: : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
+ xegpu.prefetch_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>,
+ #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+
+ gpu.return
+ }
+}
+
diff --git a/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir b/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir
new file mode 100644
index 0000000..6e59414
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+
+gpu.module @update_offset {
+ // CHECK-LABEL: gpu.func @update_offset
+ // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
+ gpu.func @update_offset(%src: memref<128xf32>) kernel {
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index
+ // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+ %offset = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
+ // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+ // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
+ // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64
+ %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex>
+ -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK: %[[C4_I64_0:.*]] = arith.constant 4 : i64
+ // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR1]], %[[C4_I64_0]] : i64
+ // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR4]], %[[VAR5]] : i64
+ %new_tdesc = xegpu.update_offset %src_tdesc, %offset : !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+ , vector<1xindex>
+ gpu.return
+ }
+}
diff --git a/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
index 57afa12..8ca3dd6 100644
--- a/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
+++ b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
@@ -54,18 +54,20 @@ func.func @subview_folding_offset(%offset_i: index, %offset_j: index) {
// CHECK: func @test_expand_shape
// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
func.func @test_expand_shape(%offset_i: index, %offset_j: index) {
- // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
+ // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<4096xf16, 3>
// CHECK: %[[MEM:.*]] = memref.alloc() : memref<8192xf16>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (64, 128) : index
- // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDX]]], %[[LOCAL]][%[[C0]], %[[C0]]]
- // CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<64x64xf16, 3>
+ // CHECK: %[[IDXM:.*]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (64, 128) : index
+ // CHECK: %[[IDXL:.*]] = affine.linearize_index [%[[C0]], %[[C0]]] by (64, 64) : index
+ // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDXM]]], %[[LOCAL]][%[[IDXL]]]
+ // CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<4096xf16, 3>
- %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
+ %alloc = memref.alloc() : memref<4096xf16, #gpu_lds_addrspace>
%mem = memref.alloc() : memref<8192xf16>
- %expand = memref.expand_shape %mem [[0, 1]] output_shape [64, 128] : memref<8192xf16> into memref<64x128xf16>
+ %expand_mem = memref.expand_shape %mem [[0, 1]] output_shape [64, 128] : memref<8192xf16> into memref<64x128xf16>
+ %expand_alloc = memref.expand_shape %alloc [[0, 1]] output_shape [64, 64] : memref<4096xf16, #gpu_lds_addrspace> into memref<64x64xf16, #gpu_lds_addrspace>
%c0 = arith.constant 0 : index
- amdgpu.gather_to_lds %expand[%offset_i, %offset_j], %alloc[%c0, %c0]
+ amdgpu.gather_to_lds %expand_mem[%offset_i, %offset_j], %expand_alloc[%c0, %c0]
: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, #gpu_lds_addrspace>
func.return
}
@@ -80,15 +82,82 @@ func.func @test_collapse_shape(%offset_i: index, %offset_j: index) {
// CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
// CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[INDICES:.*]]:2 = affine.delinearize_index %[[ARG0]] into (64, 128) : index, index
- // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[INDICES]]#0, %[[INDICES]]#1], %[[LOCAL]][%[[C0]], %[[C0]]]
+ // CHECK: %[[INDICES_MEM:.*]]:2 = affine.delinearize_index %[[ARG0]] into (64, 128) : index, index
+ // CHECK: %[[INDICES_LDS:.*]]:2 = affine.delinearize_index %[[ARG1]] into (64, 64) : index, index
+ // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[INDICES_MEM]]#0, %[[INDICES_MEM]]#1], %[[LOCAL]][%[[INDICES_LDS]]#0, %[[INDICES_LDS]]#1]
// CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>
%alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
+ %collapse_alloc = memref.collapse_shape %alloc [[0, 1]] : memref<64x64xf16, #gpu_lds_addrspace> into memref<4096xf16, #gpu_lds_addrspace>
%mem = memref.alloc() : memref<64x128xf16>
- %collapse = memref.collapse_shape %mem [[0, 1]] : memref<64x128xf16> into memref<8192xf16>
+ %collapse_mem = memref.collapse_shape %mem [[0, 1]] : memref<64x128xf16> into memref<8192xf16>
%c0 = arith.constant 0 : index
- amdgpu.gather_to_lds %collapse[%offset_i], %alloc[%c0, %c0]
+ amdgpu.gather_to_lds %collapse_mem[%offset_i], %collapse_alloc[%offset_j]
+ : vector<8xf16>, memref<8192xf16>, memref<4096xf16, #gpu_lds_addrspace>
+ func.return
+}
+
+
+// -----
+
+#gpu_lds_addrspace = 3
+
+
+// CHECK: func @test_expand_shape_src_raw_buffer
+// CHECK-SAME: %[[ARG0:.*]]: memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
+func.func @test_expand_shape_src_raw_buffer(%mem : memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, %offset_i: index, %offset_j: index) {
+ // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<4096xf16, 3>
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[IDXM:.*]] = affine.linearize_index [%[[ARG1]], %[[ARG2]]] by (64, 128) : index
+ // CHECK: amdgpu.gather_to_lds %[[ARG0]][%[[IDXM]]], %[[LOCAL]][%[[C0]]]
+ // CHECK-SAME: vector<8xf16>, memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, memref<4096xf16, 3>
+
+ %alloc = memref.alloc() : memref<4096xf16, #gpu_lds_addrspace>
+ %expand_mem = memref.expand_shape %mem [[0, 1]] output_shape [64, 128] : memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>> into memref<64x128xf16, #amdgpu.address_space<fat_raw_buffer>>
+
+ %c0 = arith.constant 0 : index
+ amdgpu.gather_to_lds %expand_mem[%offset_i, %offset_j], %alloc[%c0]
+ : vector<8xf16>, memref<64x128xf16, #amdgpu.address_space<fat_raw_buffer>>, memref<4096xf16, #gpu_lds_addrspace>
+ func.return
+}
+
+// -----
+
+#gpu_lds_addrspace = 3
+
+// CHECK: func @test_expand_shape_dst_only
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
+func.func @test_expand_shape_dst_only(%offset_i: index, %offset_j: index) {
+ // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<4096xf16, 3>
+ // CHECK: %[[MEM:.*]] = memref.alloc() : memref<8192xf16>
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[IDX_LDS:.*]] = affine.linearize_index [%[[ARG1]], %[[C0]]] by (64, 64) : index
+ // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[ARG0]]], %[[LOCAL]][%[[IDX_LDS]]]
+ // CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<4096xf16, 3>
+
+ %alloc = memref.alloc() : memref<4096xf16, #gpu_lds_addrspace>
+ %mem = memref.alloc() : memref<8192xf16>
+ %expand_alloc = memref.expand_shape %alloc [[0, 1]] output_shape [64, 64] : memref<4096xf16, #gpu_lds_addrspace> into memref<64x64xf16, #gpu_lds_addrspace>
+
+ %c0 = arith.constant 0 : index
+ amdgpu.gather_to_lds %mem[%offset_i], %expand_alloc[%offset_j, %c0]
: vector<8xf16>, memref<8192xf16>, memref<64x64xf16, #gpu_lds_addrspace>
func.return
}
+
+// -----
+
+#gpu_lds_addrspace = 3
+
+// CHECK: func @test_nop
+// CHECK-SAME: %[[ARG0:.*]]: memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
+func.func @test_nop(%mem : memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, %offset_i: index, %offset_j: index) {
+ // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<4096xf16, 3>
+ // CHECK: amdgpu.gather_to_lds %[[ARG0]][%[[ARG1]]], %[[LOCAL]][%[[ARG2]]]
+ // CHECK-SAME: vector<8xf16>, memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, memref<4096xf16, 3>
+
+ %alloc = memref.alloc() : memref<4096xf16, #gpu_lds_addrspace>
+ amdgpu.gather_to_lds %mem[%offset_i], %alloc[%offset_j]
+ : vector<8xf16>, memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, memref<4096xf16, #gpu_lds_addrspace>
+ func.return
+}
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 0d2fd24..66e7dd4 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -230,3 +230,11 @@ func.func @gather_to_lds_non_lds(%idx1 : index, %mem1 : memref<32xf16>, %mem2 :
amdgpu.gather_to_lds %mem1[%idx1], %mem2[%idx1] : vector<2xf16>, memref<32xf16>, memref<32xf16>
func.return
}
+
+// -----
+
+func.func @gather_to_lds_non_lds(%idx1 : index, %mem1 : memref<32xf16>, %mem2 : memref<32xf16, strided<[?]>, #gpu.address_space<workgroup>>) {
+ // expected-error@+1 {{'amdgpu.gather_to_lds' op destination type inner most dim must be contiguous}}
+ amdgpu.gather_to_lds %mem1[%idx1], %mem2[%idx1] : vector<2xf16>, memref<32xf16>, memref<32xf16, strided<[?]>, #gpu.address_space<workgroup>>
+ func.return
+}
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index fe78b53..369e0ff 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -524,6 +524,20 @@ func.func @swizzle_bitmode(%arg0 : f32) -> f32 {
func.return %0 : f32
}
+// CHECK-LABEL: func @permlane16_swap
+func.func @permlane16_swap(%arg0 : f32) -> f32 {
+ // CHECK: amdgpu.permlane_swap
+ %0 = amdgpu.permlane_swap %arg0 16 : f32
+ func.return %0 : f32
+}
+
+// CHECK-LABEL: func @permlane32_swap
+func.func @permlane32_swap(%arg0 : f32) -> f32 {
+ // CHECK: amdgpu.permlane_swap
+ %0 = amdgpu.permlane_swap %arg0 32 : f32
+ func.return %0 : f32
+}
+
// CHECK-LABEL: func @scaled_mfma
func.func @scaled_mfma(%arg0 : f8E8M0FNU, %arg1 : vector<32xf6E2M3FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
// CHECK: amdgpu.scaled_mfma
@@ -539,13 +553,15 @@ func.func @transpose_load(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16
}
// CHECK-LABEL: func @gather_to_lds
-func.func @gather_to_lds(%idx1 : index, %idx2 : index, %mem1 : memref<32xf16>, %mem2 : memref<32x32xf16>, %smem1 : memref<32xf16, #gpu.address_space<workgroup>>, %smem2 : memref<32x32xf16, #gpu.address_space<workgroup>>) {
+func.func @gather_to_lds(%idx1 : index, %idx2 : index, %mem1 : memref<32xf16>, %mem2 : memref<32x32xf16>, %smem1 : memref<32xf16, #gpu.address_space<workgroup>>, %smem2 : memref<32x32xf16, #gpu.address_space<workgroup>>, %smem3 : memref<?x?xf16, strided<[?, 1]>, #gpu.address_space<workgroup>>) {
// CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}]
// CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}[%{{.*}}]
// CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}]
+ // CHECK: amdgpu.gather_to_lds %{{.*}}[%{{.*}}], %{{.*}}[%{{.*}}, %{{.*}}]
amdgpu.gather_to_lds %mem2[%idx1, %idx2], %smem2[%idx1, %idx2] : vector<2xf16>, memref<32x32xf16>, memref<32x32xf16, #gpu.address_space<workgroup>>
amdgpu.gather_to_lds %mem2[%idx1, %idx2], %smem1[%idx1] : vector<2xf16>, memref<32x32xf16>, memref<32xf16, #gpu.address_space<workgroup>>
amdgpu.gather_to_lds %mem1[%idx1], %smem2[%idx1, %idx2] : vector<2xf16>, memref<32xf16>, memref<32x32xf16, #gpu.address_space<workgroup>>
+ amdgpu.gather_to_lds %mem1[%idx1], %smem3[%idx1, %idx2] : vector<2xf16>, memref<32xf16>, memref<?x?xf16, strided<[?, 1]>, #gpu.address_space<workgroup>>
func.return
}
diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir
index 8febe16..5de9b3f 100644
--- a/mlir/test/Dialect/AMX/invalid.mlir
+++ b/mlir/test/Dialect/AMX/invalid.mlir
@@ -1,48 +1,158 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
-// -----
-
-func.func @rowheight() {
+func.func @tile_row_height() {
// expected-error@+1 {{'amx.tile_zero' op bad row height: 17}}
%0 = amx.tile_zero : !amx.tile<17x16xbf16>
+ return
}
// -----
-func.func @colwidth() {
+func.func @tile_col_width() {
// expected-error@+1 {{'amx.tile_zero' op bad column width: 65}}
%0 = amx.tile_zero : !amx.tile<16x65xi8>
+ return
+}
+
+// -----
+
+func.func @tile_element_type() {
+ // expected-error@+1 {{failed to verify 'elementType'}}
+ %0 = amx.tile_zero : !amx.tile<8x8xi16>
+ return
+}
+
+// -----
+
+func.func @tile_rank() {
+ // expected-error@+1 {{'amx.tile_zero' op result #0 must be tile of}}
+ %0 = amx.tile_zero : !amx.tile<32xi8>
+ return
}
// -----
-func.func @col4bytemultiple() {
+func.func @tile_col_4_byte_multiple() {
// expected-error@+1 {{'amx.tile_zero' op bad column width: 5}}
%0 = amx.tile_zero : !amx.tile<16x5xi8>
+ return
}
// -----
-func.func @memtilesize(%arg0: memref<?x?xf32>) {
+func.func @load_base_tile_size(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'amx.tile_load' op bad column width: 68}}
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into !amx.tile<16x17xf32>
+ return
}
// -----
-func.func @memindexsize(%arg0: memref<?x?xf32>) {
+func.func @store_base_tile_size(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf32>) {
+ %0 = arith.constant 0 : index
+ // expected-error@+1 {{'amx.tile_store' op bad column width: 68}}
+ amx.tile_store %arg0[%0, %0], %arg1 : memref<?x?xf32>, !amx.tile<16x17xf32>
+ return
+}
+
+// -----
+
+func.func @load_base_index_size(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'amx.tile_load' op requires 2 indices}}
%1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into !amx.tile<16x16xf32>
+ return
}
// -----
-func.func @multsize() {
+func.func @store_base_index_size(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x16xf32>) {
+ %0 = arith.constant 0 : index
+ // expected-error@+1 {{'amx.tile_store' op requires 2 indices}}
+ amx.tile_store %arg0[%0], %arg1 : memref<?x?xf32>, !amx.tile<16x16xf32>
+ return
+}
+
+// -----
+
+func.func @load_base_rank(%arg0: memref<?xf32>) {
+ %0 = arith.constant 0 : index
+ // expected-error@+1 {{'amx.tile_load' op requires at least 2D memref}}
+ %1 = amx.tile_load %arg0[%0] : memref<?xf32> into !amx.tile<16x16xf32>
+ return
+}
+
+// -----
+
+func.func @store_base_rank(%arg0: memref<?xf32>, %arg1: !amx.tile<16x16xf32>) {
+ %0 = arith.constant 0 : index
+ // expected-error@+1 {{'amx.tile_store' op requires at least 2D memref}}
+ amx.tile_store %arg0[%0], %arg1 : memref<?xf32>, !amx.tile<16x16xf32>
+ return
+}
+
+// -----
+
+func.func @load_base_non_unit_stride(%arg0: memref<?x?xf32, strided<[?, ?]>>) {
+ %0 = arith.constant 0 : index
+ // expected-error@+1 {{'amx.tile_load' op requires memref with unit innermost stride}}
+ %1 = amx.tile_load %arg0[%0, %0]
+ : memref<?x?xf32, strided<[?, ?]>> into !amx.tile<16x16xf32>
+ return
+}
+
+// -----
+
+func.func @store_base_non_unit_stride(%arg0: memref<?x?xf32, strided<[?, ?]>>,
+ %arg1: !amx.tile<16x16xf32>) {
+ %0 = arith.constant 0 : index
+ // expected-error@+1 {{'amx.tile_store' op requires memref with unit innermost stride}}
+ amx.tile_store %arg0[%0, %0], %arg1
+ : memref<?x?xf32, strided<[?, ?]>>, !amx.tile<16x16xf32>
+ return
+}
+
+// -----
+
+func.func @mulf_shape() {
%0 = amx.tile_zero : !amx.tile<8x8xbf16>
%1 = amx.tile_zero : !amx.tile<8x8xbf16>
%2 = amx.tile_zero : !amx.tile<4x4xf32>
// expected-error@+1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
%3 = amx.tile_mulf %0, %1, %2 : !amx.tile<8x8xbf16>, !amx.tile<8x8xbf16>, !amx.tile<4x4xf32>
+ return
+}
+
+// -----
+
+func.func @mulf_type_combination() {
+ %0 = amx.tile_zero : !amx.tile<8x8xbf16>
+ %1 = amx.tile_zero : !amx.tile<4x8xf16>
+ %2 = amx.tile_zero : !amx.tile<8x4xf32>
+ // expected-error@+1 {{'amx.tile_mulf' op unsupported type combination}}
+ %3 = amx.tile_mulf %0, %1, %2 : !amx.tile<8x8xbf16>, !amx.tile<4x8xf16>, !amx.tile<8x4xf32>
+ return
+}
+
+// -----
+
+func.func @muli_shape() {
+ %0 = amx.tile_zero : !amx.tile<8x8xi8>
+ %1 = amx.tile_zero : !amx.tile<8x8xi8>
+ %2 = amx.tile_zero : !amx.tile<4x4xi32>
+ // expected-error@+1 {{'amx.tile_muli' op bad mult shape: 4 x 4 x 2}}
+ %3 = amx.tile_muli %0, %1, %2 : !amx.tile<8x8xi8>, !amx.tile<8x8xi8>, !amx.tile<4x4xi32>
+ return
+}
+
+// -----
+
+func.func @muli_type_combination() {
+ %0 = amx.tile_zero : !amx.tile<8x16xi8>
+ %1 = amx.tile_zero : !amx.tile<8x16xi32>
+ %2 = amx.tile_zero : !amx.tile<2x2xi32>
+ // expected-error@+1 {{'amx.tile_muli' op operand #1 must be tile of 8-bit signless integer values}}
+ %3 = amx.tile_muli %0, %1, %2 : !amx.tile<8x16xi8>, !amx.tile<8x16xi32>, !amx.tile<2x2xi32>
+ return
}
diff --git a/mlir/test/Dialect/AMX/side-effects.mlir b/mlir/test/Dialect/AMX/side-effects.mlir
new file mode 100644
index 0000000..22c76d9
--- /dev/null
+++ b/mlir/test/Dialect/AMX/side-effects.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-opt %s -cse -convert-vector-to-llvm="enable-amx" | FileCheck %s
+
+// With inclusion of memory side-effects, it is expected CSE not to fold multiple
+// "tileload" and "tilezero".
+// CHECK-LABEL: do_not_fold_tiles(
+// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+func.func @do_not_fold_tiles(%arg0: memref<2x32x32xbf16>, %arg1: memref<2x16x32xbf16>) -> memref<16x32xf32> {
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %c16 = arith.constant 16 : index
+ %alloca = memref.alloca() : memref<16x32xf32>
+ %0 = amx.tile_zero : !amx.tile<16x16xf32>
+ %1 = amx.tile_zero : !amx.tile<16x16xf32>
+ %2:2 = scf.for %arg2 = %c0 to %c2 step %c1 iter_args(%arg3 = %0, %arg4 = %1) -> (!amx.tile<16x16xf32>, !amx.tile<16x16xf32>) {
+ %3 = amx.tile_load %arg0[%arg2, %c0, %c0] : memref<2x32x32xbf16> into !amx.tile<16x32xbf16>
+ %4 = amx.tile_load %arg0[%arg2, %c16, %c0] : memref<2x32x32xbf16> into !amx.tile<16x32xbf16>
+ %5 = amx.tile_load %arg1[%arg2, %c0, %c0] : memref<2x16x32xbf16> into !amx.tile<16x32xbf16>
+ %6 = amx.tile_load %arg1[%arg2, %c0, %c0] : memref<2x16x32xbf16> into !amx.tile<16x32xbf16>
+ %7 = amx.tile_mulf %3, %5, %arg3 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
+ %8 = amx.tile_mulf %4, %6, %arg4 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
+ scf.yield %7, %8 : !amx.tile<16x16xf32>, !amx.tile<16x16xf32>
+ }
+ amx.tile_store %alloca[%c0, %c0], %2#0 : memref<16x32xf32>, !amx.tile<16x16xf32>
+ amx.tile_store %alloca[%c0, %c16], %2#1 : memref<16x32xf32>, !amx.tile<16x16xf32>
+ return %alloca : memref<16x32xf32>
+}
diff --git a/mlir/test/Dialect/Affine/loop-permute.mlir b/mlir/test/Dialect/Affine/loop-permute.mlir
index 118165b..e38aeb5 100644
--- a/mlir/test/Dialect/Affine/loop-permute.mlir
+++ b/mlir/test/Dialect/Affine/loop-permute.mlir
@@ -4,6 +4,7 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -test-loop-permutation="permutation-map=0,2,1" | FileCheck %s --check-prefix=CHECK-021
// RUN: mlir-opt -allow-unregistered-dialect %s -test-loop-permutation="permutation-map=2,0,1" | FileCheck %s --check-prefix=CHECK-201
// RUN: mlir-opt -allow-unregistered-dialect %s -test-loop-permutation="permutation-map=2,1,0" | FileCheck %s --check-prefix=CHECK-210
+// RUN: mlir-opt -allow-unregistered-dialect %s -test-loop-permutation="permutation-map=2,1,0 check-validity=1" | FileCheck %s --check-prefix=CHECK-210-VALID
// CHECK-120-LABEL: func @permute
func.func @permute(%U0 : index, %U1 : index, %U2 : index) {
@@ -45,3 +46,34 @@ func.func @permute(%U0 : index, %U1 : index, %U2 : index) {
// CHECK-201: "foo"(%arg5, %arg3)
// CHECK-201-NEXT: "bar"(%arg4)
+
+// -----
+
+// Tests that the permutation validation check utility conservatively returns false when the
+// for loop has an iter_arg.
+
+// CHECK-210-VALID-LABEL: func @check_validity_with_iter_args
+// CHECK-210-VALID-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
+func.func @check_validity_with_iter_args(%U0 : index, %U1 : index, %U2 : index) {
+ %buf = memref.alloc() : memref<100x100xf32>
+ %cst = arith.constant 1.0 : f32
+ %c10 = arith.constant 10 : index
+ %c20 = arith.constant 20 : index
+
+ // Check that the loops are not permuted.
+ // CHECK-210-VALID: affine.for %{{.*}} = 0 to %[[ARG0]] {
+ // CHECK-210-VALID-NEXT: affine.for %{{.*}} = 0 to %[[ARG1]] {
+ // CHECK-210-VALID-NEXT: affine.for %{{.*}} = 0 to %[[ARG2]] iter_args(
+ affine.for %arg0 = 0 to %U0 {
+ affine.for %arg1 = 0 to %U1 {
+ %res = affine.for %arg2 = 0 to %U2 iter_args(%iter1 = %cst) -> (f32) {
+ %val = affine.load %buf[%arg0 + 10, %arg1 + 20] : memref<100x100xf32>
+ %newVal = arith.addf %val, %cst : f32
+ affine.store %newVal, %buf[%arg0 + 10, %arg1 + 20] : memref<100x100xf32>
+ %newVal2 = arith.addf %newVal, %iter1 : f32
+ affine.yield %iter1 : f32
+ }
+ }
+ }
+ return
+}
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 3d5a46d..ca3de3a 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -654,7 +654,7 @@ func.func @signExtendConstant() -> i16 {
// CHECK: return %[[cres]]
func.func @signExtendConstantSplat() -> vector<4xi16> {
%c-2 = arith.constant -2 : i8
- %splat = vector.splat %c-2 : vector<4xi8>
+ %splat = vector.broadcast %c-2 : i8 to vector<4xi8>
%ext = arith.extsi %splat : vector<4xi8> to vector<4xi16>
return %ext : vector<4xi16>
}
@@ -682,7 +682,7 @@ func.func @unsignedExtendConstant() -> i16 {
// CHECK: return %[[cres]]
func.func @unsignedExtendConstantSplat() -> vector<4xi16> {
%c2 = arith.constant 2 : i8
- %splat = vector.splat %c2 : vector<4xi8>
+ %splat = vector.broadcast %c2 : i8 to vector<4xi8>
%ext = arith.extui %splat : vector<4xi8> to vector<4xi16>
return %ext : vector<4xi16>
}
@@ -866,7 +866,7 @@ func.func @truncExtsiVector(%arg0: vector<2xi32>) -> vector<2xi16> {
// CHECK: return %[[cres]]
func.func @truncConstantSplat() -> vector<4xi8> {
%c-2 = arith.constant -2 : i16
- %splat = vector.splat %c-2 : vector<4xi16>
+ %splat = vector.broadcast %c-2 : i16 to vector<4xi16>
%trunc = arith.trunci %splat : vector<4xi16> to vector<4xi8>
return %trunc : vector<4xi8>
}
@@ -2334,7 +2334,7 @@ func.func @constant_FPtoUI_splat() -> vector<4xi32> {
// CHECK: %[[C0:.+]] = arith.constant dense<2> : vector<4xi32>
// CHECK: return %[[C0]]
%c0 = arith.constant 2.0 : f32
- %splat = vector.splat %c0 : vector<4xf32>
+ %splat = vector.broadcast %c0 : f32 to vector<4xf32>
%res = arith.fptoui %splat : vector<4xf32> to vector<4xi32>
return %res : vector<4xi32>
}
@@ -2374,7 +2374,7 @@ func.func @constant_FPtoSI_splat() -> vector<4xi32> {
// CHECK: %[[C0:.+]] = arith.constant dense<-2> : vector<4xi32>
// CHECK: return %[[C0]]
%c0 = arith.constant -2.0 : f32
- %splat = vector.splat %c0 : vector<4xf32>
+ %splat = vector.broadcast %c0 : f32 to vector<4xf32>
%res = arith.fptosi %splat : vector<4xf32> to vector<4xi32>
return %res : vector<4xi32>
}
@@ -2413,7 +2413,7 @@ func.func @constant_SItoFP_splat() -> vector<4xf32> {
// CHECK: %[[C0:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32>
// CHECK: return %[[C0]]
%c0 = arith.constant 2 : i32
- %splat = vector.splat %c0 : vector<4xi32>
+ %splat = vector.broadcast %c0 : i32 to vector<4xi32>
%res = arith.sitofp %splat : vector<4xi32> to vector<4xf32>
return %res : vector<4xf32>
}
@@ -2442,7 +2442,7 @@ func.func @constant_UItoFP_splat() -> vector<4xf32> {
// CHECK: %[[C0:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32>
// CHECK: return %[[C0]]
%c0 = arith.constant 2 : i32
- %splat = vector.splat %c0 : vector<4xi32>
+ %splat = vector.broadcast %c0 : i32 to vector<4xi32>
%res = arith.uitofp %splat : vector<4xi32> to vector<4xf32>
return %res : vector<4xf32>
}
@@ -3363,3 +3363,18 @@ func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>,
}
}
#-}
+
+// CHECK-LABEL: func @unreachable()
+// CHECK-NEXT: return
+// CHECK-NOT: arith
+func.func @unreachable() {
+ return
+^unreachable:
+ %c1_i64 = arith.constant 1 : i64
+ // This self referencing operation is legal in an unreachable block.
+ // Many patterns are unsafe with respect to this kind of situation,
+ // check that we don't infinite loop here.
+ %add = arith.addi %add, %c1_i64 : i64
+ cf.br ^unreachable
+}
+
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 2128d36..130782ba 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -224,6 +224,15 @@ func.func @ceil_divui(%arg0 : index) -> i1 {
func.return %7 : i1
}
+// CHECK-LABEL: func @ceil_divui_by_zero_issue_131273
+// CHECK-NEXT: return
+func.func @ceil_divui_by_zero_issue_131273() {
+ %0 = test.with_bounds {smax = 0 : i32, smin = -1 : i32, umax = 0 : i32, umin = -1 : i32} : i32
+ %c7_i32 = arith.constant 7 : i32
+ %1 = arith.ceildivui %c7_i32, %0 : i32
+ return
+}
+
// CHECK-LABEL: func @ceil_divsi
// CHECK: %[[ret:.*]] = arith.cmpi eq
// CHECK: return %[[ret]]
diff --git a/mlir/test/Dialect/EmitC/attrs.mlir b/mlir/test/Dialect/EmitC/attrs.mlir
index 11251b8..5a219c4 100644
--- a/mlir/test/Dialect/EmitC/attrs.mlir
+++ b/mlir/test/Dialect/EmitC/attrs.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt %s | FileCheck %s
// check parser
-// RUN: mlir-opt -verify-diagnostics %s | mlir-opt -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
// CHECK-LABEL: func @opaque_attrs() {
func.func @opaque_attrs() {
diff --git a/mlir/test/Dialect/EmitC/transforms.mlir b/mlir/test/Dialect/EmitC/form-expressions.mlir
index a38f396..67cd6fd 100644
--- a/mlir/test/Dialect/EmitC/transforms.mlir
+++ b/mlir/test/Dialect/EmitC/form-expressions.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt %s --form-expressions --verify-diagnostics --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -form-expressions | FileCheck %s
// CHECK-LABEL: func.func @single_expression(
// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) -> i1 {
// CHECK: %[[VAL_4:.*]] = "emitc.constant"() <{value = 42 : i32}> : () -> i32
-// CHECK: %[[VAL_5:.*]] = emitc.expression : i1 {
+// CHECK: %[[VAL_5:.*]] = emitc.expression %[[VAL_3]], %[[VAL_2]], %[[VAL_0]], %[[VAL_4]] : (i32, i32, i32, i32) -> i1 {
// CHECK: %[[VAL_6:.*]] = mul %[[VAL_0]], %[[VAL_4]] : (i32, i32) -> i32
// CHECK: %[[VAL_7:.*]] = sub %[[VAL_6]], %[[VAL_2]] : (i32, i32) -> i32
// CHECK: %[[VAL_8:.*]] = cmp lt, %[[VAL_7]], %[[VAL_3]] : (i32, i32) -> i1
@@ -22,12 +22,12 @@ func.func @single_expression(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) ->
// CHECK-LABEL: func.func @multiple_expressions(
// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) -> (i32, i32) {
-// CHECK: %[[VAL_4:.*]] = emitc.expression : i32 {
+// CHECK: %[[VAL_4:.*]] = emitc.expression %[[VAL_2]], %[[VAL_0]], %[[VAL_1]] : (i32, i32, i32) -> i32 {
// CHECK: %[[VAL_5:.*]] = mul %[[VAL_0]], %[[VAL_1]] : (i32, i32) -> i32
// CHECK: %[[VAL_6:.*]] = sub %[[VAL_5]], %[[VAL_2]] : (i32, i32) -> i32
// CHECK: yield %[[VAL_6]] : i32
// CHECK: }
-// CHECK: %[[VAL_7:.*]] = emitc.expression : i32 {
+// CHECK: %[[VAL_7:.*]] = emitc.expression %[[VAL_2]], %[[VAL_1]], %[[VAL_3]] : (i32, i32, i32) -> i32 {
// CHECK: %[[VAL_8:.*]] = add %[[VAL_1]], %[[VAL_3]] : (i32, i32) -> i32
// CHECK: %[[VAL_9:.*]] = div %[[VAL_8]], %[[VAL_2]] : (i32, i32) -> i32
// CHECK: yield %[[VAL_9]] : i32
@@ -45,12 +45,12 @@ func.func @multiple_expressions(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32)
// CHECK-LABEL: func.func @expression_with_call(
// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) -> i1 {
-// CHECK: %[[VAL_4:.*]] = emitc.expression : i32 {
+// CHECK: %[[VAL_4:.*]] = emitc.expression %[[VAL_2]], %[[VAL_0]], %[[VAL_1]] : (i32, i32, i32) -> i32 {
// CHECK: %[[VAL_5:.*]] = mul %[[VAL_0]], %[[VAL_1]] : (i32, i32) -> i32
// CHECK: %[[VAL_6:.*]] = call_opaque "foo"(%[[VAL_5]], %[[VAL_2]]) : (i32, i32) -> i32
// CHECK: yield %[[VAL_6]] : i32
// CHECK: }
-// CHECK: %[[VAL_7:.*]] = emitc.expression : i1 {
+// CHECK: %[[VAL_7:.*]] = emitc.expression %[[VAL_4]], %[[VAL_1]] : (i32, i32) -> i1 {
// CHECK: %[[VAL_8:.*]] = cmp lt, %[[VAL_4]], %[[VAL_1]] : (i32, i32) -> i1
// CHECK: yield %[[VAL_8]] : i1
// CHECK: }
@@ -66,11 +66,11 @@ func.func @expression_with_call(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32)
// CHECK-LABEL: func.func @expression_with_dereference(
// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: !emitc.ptr<i32>) -> i1 {
-// CHECK: %[[VAL_3:.*]] = emitc.expression : i32 {
+// CHECK: %[[VAL_3:.*]] = emitc.expression %[[VAL_2]] : (!emitc.ptr<i32>) -> i32 {
// CHECK: %[[VAL_4:.*]] = apply "*"(%[[VAL_2]]) : (!emitc.ptr<i32>) -> i32
// CHECK: yield %[[VAL_4]] : i32
// CHECK: }
-// CHECK: %[[VAL_5:.*]] = emitc.expression : i1 {
+// CHECK: %[[VAL_5:.*]] = emitc.expression %[[VAL_3]], %[[VAL_0]], %[[VAL_1]] : (i32, i32, i32) -> i1 {
// CHECK: %[[VAL_6:.*]] = mul %[[VAL_0]], %[[VAL_1]] : (i32, i32) -> i32
// CHECK: %[[VAL_7:.*]] = cmp lt, %[[VAL_6]], %[[VAL_3]] : (i32, i32) -> i1
// CHECK: return %[[VAL_5]] : i1
@@ -83,11 +83,10 @@ func.func @expression_with_dereference(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr
return %c : i1
}
-
// CHECK-LABEL: func.func @expression_with_address_taken(
// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: !emitc.ptr<i32>) -> i1 {
// CHECK: %[[VAL_3:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<i32>
-// CHECK: %[[VAL_4:.*]] = emitc.expression : i1 {
+// CHECK: %[[VAL_4:.*]] = emitc.expression %[[VAL_2]], %[[VAL_1]], %[[VAL_3]] : (!emitc.ptr<i32>, i32, !emitc.lvalue<i32>) -> i1 {
// CHECK: %[[VAL_5:.*]] = apply "&"(%[[VAL_3]]) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
// CHECK: %[[VAL_6:.*]] = add %[[VAL_5]], %[[VAL_1]] : (!emitc.ptr<i32>, i32) -> !emitc.ptr<i32>
// CHECK: %[[VAL_7:.*]] = cmp lt, %[[VAL_6]], %[[VAL_2]] : (!emitc.ptr<i32>, !emitc.ptr<i32>) -> i1
@@ -106,7 +105,7 @@ func.func @expression_with_address_taken(%arg0: i32, %arg1: i32, %arg2: !emitc.p
// CHECK-LABEL: func.func @no_nested_expression(
// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32) -> i1 {
-// CHECK: %[[VAL_2:.*]] = emitc.expression : i1 {
+// CHECK: %[[VAL_2:.*]] = emitc.expression %[[VAL_0]], %[[VAL_1]] : (i32, i32) -> i1 {
// CHECK: %[[VAL_3:.*]] = cmp lt, %[[VAL_0]], %[[VAL_1]] : (i32, i32) -> i1
// CHECK: yield %[[VAL_3]] : i1
// CHECK: }
@@ -114,16 +113,15 @@ func.func @expression_with_address_taken(%arg0: i32, %arg1: i32, %arg2: !emitc.p
// CHECK: }
func.func @no_nested_expression(%arg0: i32, %arg1: i32) -> i1 {
- %a = emitc.expression : i1 {
+ %a = emitc.expression %arg0, %arg1 :(i32, i32) -> i1 {
%b = emitc.cmp lt, %arg0, %arg1 :(i32, i32) -> i1
emitc.yield %b : i1
}
return %a : i1
}
-
// CHECK-LABEL: func.func @single_result_requirement
-// CHECK-NOT: emitc.expression
+// CHECK-NOT: emitc.expression
func.func @single_result_requirement() -> (i32, i32) {
%0:2 = emitc.call_opaque "foo" () : () -> (i32, i32)
@@ -135,16 +133,16 @@ func.func @single_result_requirement() -> (i32, i32) {
// CHECK-SAME: %[[VAL_1:.*]]: !emitc.ptr<i32>) -> i1 {
// CHECK: %[[VAL_2:.*]] = "emitc.constant"() <{value = 0 : i64}> : () -> i64
// CHECK: %[[VAL_3:.*]] = "emitc.variable"() <{value = #emitc.opaque<"42">}> : () -> !emitc.lvalue<i32>
-// CHECK: %[[VAL_4:.*]] = emitc.expression : i32 {
+// CHECK: %[[VAL_4:.*]] = emitc.expression %[[VAL_3]] : (!emitc.lvalue<i32>) -> i32 {
// CHECK: %[[VAL_5:.*]] = load %[[VAL_3]] : <i32>
// CHECK: yield %[[VAL_5]] : i32
// CHECK: }
// CHECK: %[[VAL_6:.*]] = emitc.subscript %[[VAL_1]]{{\[}}%[[VAL_2]]] : (!emitc.ptr<i32>, i64) -> !emitc.lvalue<i32>
-// CHECK: %[[VAL_7:.*]] = emitc.expression : i32 {
+// CHECK: %[[VAL_7:.*]] = emitc.expression %[[VAL_6]] : (!emitc.lvalue<i32>) -> i32 {
// CHECK: %[[VAL_8:.*]] = load %[[VAL_6]] : <i32>
// CHECK: yield %[[VAL_8]] : i32
// CHECK: }
-// CHECK: %[[VAL_9:.*]] = emitc.expression : i1 {
+// CHECK: %[[VAL_9:.*]] = emitc.expression %[[VAL_0]], %[[VAL_4]], %[[VAL_7]] : (i32, i32, i32) -> i1 {
// CHECK: %[[VAL_10:.*]] = add %[[VAL_4]], %[[VAL_7]] : (i32, i32) -> i32
// CHECK: %[[VAL_11:.*]] = cmp lt, %[[VAL_10]], %[[VAL_0]] : (i32, i32) -> i1
// CHECK: yield %[[VAL_11]] : i1
@@ -152,7 +150,6 @@ func.func @single_result_requirement() -> (i32, i32) {
// CHECK: return %[[VAL_9]] : i1
// CHECK: }
-
func.func @expression_with_load(%arg0: i32, %arg1: !emitc.ptr<i32>) -> i1 {
%c0 = "emitc.constant"() {value = 0 : i64} : () -> i64
%0 = "emitc.variable"() <{value = #emitc.opaque<"42">}> : () -> !emitc.lvalue<i32>
@@ -163,3 +160,26 @@ func.func @expression_with_load(%arg0: i32, %arg1: !emitc.ptr<i32>) -> i1 {
%c = emitc.cmp lt, %b, %arg0 :(i32, i32) -> i1
return %c : i1
}
+
+// CHECK-LABEL: func.func @opaque_type_expression(%arg0: i32, %arg1: !emitc.opaque<"T0">, %arg2: i32) -> i1 {
+// CHECK: %0 = "emitc.constant"() <{value = 42 : i32}> : () -> i32
+// CHECK: %1 = emitc.expression %arg1, %arg0, %0 : (!emitc.opaque<"T0">, i32, i32) -> i32 {
+// CHECK: %3 = mul %arg0, %0 : (i32, i32) -> i32
+// CHECK: %4 = sub %3, %arg1 : (i32, !emitc.opaque<"T0">) -> i32
+// CHECK: yield %4 : i32
+// CHECK: }
+// CHECK: %2 = emitc.expression %1, %arg2 : (i32, i32) -> i1 {
+// CHECK: %3 = cmp lt, %1, %arg2 : (i32, i32) -> i1
+// CHECK: yield %3 : i1
+// CHECK: }
+// CHECK: return %2 : i1
+// CHECK: }
+
+
+func.func @opaque_type_expression(%arg0: i32, %arg1: !emitc.opaque<"T0">, %arg2: i32) -> i1 {
+ %c42 = "emitc.constant"(){value = 42 : i32} : () -> i32
+ %a = emitc.mul %arg0, %c42 : (i32, i32) -> i32
+ %b = emitc.sub %a, %arg1 : (i32, !emitc.opaque<"T0">) -> i32
+ %c = emitc.cmp lt, %b, %arg2 :(i32, i32) -> i1
+ return %c : i1
+}
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 3946a36..fdfb0eb 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -290,7 +290,7 @@ func.func @test_assign_to_array(%arg1: !emitc.array<4xi32>) {
func.func @test_expression_no_yield() -> i32 {
// expected-error @+1 {{'emitc.expression' op must yield a value at termination}}
- %r = emitc.expression : i32 {
+ %r = emitc.expression : () -> i32 {
%c7 = "emitc.constant"(){value = 7 : i32} : () -> i32
}
return %r : i32
@@ -300,7 +300,7 @@ func.func @test_expression_no_yield() -> i32 {
func.func @test_expression_illegal_op(%arg0 : i1) -> i32 {
// expected-error @+1 {{'emitc.expression' op contains an unsupported operation}}
- %r = emitc.expression : i32 {
+ %r = emitc.expression : () -> i32 {
%x = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<i32>
%y = emitc.load %x : <i32>
emitc.yield %y : i32
@@ -312,7 +312,7 @@ func.func @test_expression_illegal_op(%arg0 : i1) -> i32 {
func.func @test_expression_no_use(%arg0: i32, %arg1: i32) -> i32 {
// expected-error @+1 {{'emitc.expression' op requires exactly one use for each operation}}
- %r = emitc.expression : i32 {
+ %r = emitc.expression %arg0, %arg1 : (i32, i32) -> i32 {
%a = emitc.add %arg0, %arg1 : (i32, i32) -> i32
%b = emitc.rem %arg0, %arg1 : (i32, i32) -> i32
emitc.yield %a : i32
@@ -324,7 +324,7 @@ func.func @test_expression_no_use(%arg0: i32, %arg1: i32) -> i32 {
func.func @test_expression_multiple_uses(%arg0: i32, %arg1: i32) -> i32 {
// expected-error @+1 {{'emitc.expression' op requires exactly one use for each operation}}
- %r = emitc.expression : i32 {
+ %r = emitc.expression %arg0, %arg1 : (i32, i32) -> i32 {
%a = emitc.rem %arg0, %arg1 : (i32, i32) -> i32
%b = emitc.add %a, %arg0 : (i32, i32) -> i32
%c = emitc.mul %arg1, %a : (i32, i32) -> i32
@@ -337,7 +337,7 @@ func.func @test_expression_multiple_uses(%arg0: i32, %arg1: i32) -> i32 {
func.func @test_expression_multiple_results(%arg0: i32) -> i32 {
// expected-error @+1 {{'emitc.expression' op requires exactly one result for each operation}}
- %r = emitc.expression : i32 {
+ %r = emitc.expression %arg0 : (i32) -> i32 {
%a:2 = emitc.call_opaque "bar" (%arg0) : (i32) -> (i32, i32)
emitc.yield %a : i32
}
@@ -348,7 +348,7 @@ func.func @test_expression_multiple_results(%arg0: i32) -> i32 {
emitc.func @test_expression_no_defining_op(%a : i32) {
// expected-error @+1 {{'emitc.expression' op yielded value has no defining op}}
- %res = emitc.expression : i32 {
+ %res = emitc.expression %a : (i32) -> i32 {
emitc.yield %a : i32
}
@@ -357,10 +357,21 @@ emitc.func @test_expression_no_defining_op(%a : i32) {
// -----
+emitc.func @test_expression_no_defining_op() {
+ %cond = literal "true" : i1
+ // expected-error @+1 {{'emitc.expression' op yielded value has no defining op}}
+ %res = emitc.expression %cond : (i1) -> i1 {
+ emitc.yield %cond : i1
+ }
+ return
+}
+
+// -----
+
emitc.func @test_expression_op_outside_expression() {
%cond = literal "true" : i1
- // expected-error @+1 {{'emitc.expression' op yielded value not defined within expression}}
- %res = emitc.expression : i1 {
+ %res = emitc.expression : () -> i1 {
+ // expected-error @+1 {{use of undeclared SSA value name}}
emitc.yield %cond : i1
}
return
@@ -676,3 +687,35 @@ func.func @test_verbatim(%arg0 : !emitc.ptr<i32>, %arg1 : i32) {
emitc.verbatim "{a} " args %arg0, %arg1 : !emitc.ptr<i32>, i32
return
}
+
+// -----
+
+// expected-error @+1 {{'emitc.field' op field must be nested within an emitc.class operation}}
+emitc.field @testField : !emitc.array<1xf32>
+
+// -----
+
+// expected-error @+1 {{'emitc.get_field' op must be nested within an emitc.class operation}}
+%1 = emitc.get_field @testField : !emitc.array<1xf32>
+
+// -----
+
+emitc.func @testMethod() {
+ %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+ // expected-error @+1 {{'emitc.get_field' op must be nested within an emitc.class operation}}
+ %1 = get_field @testField : !emitc.array<1xf32>
+ %2 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ return
+}
+
+// -----
+
+emitc.class @testClass {
+ emitc.func @testMethod() {
+ %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+ // expected-error @+1 {{'emitc.get_field' op field '@testField' not found in the class}}
+ %1 = get_field @testField : !emitc.array<1xf32>
+ %2 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ return
+ }
+}
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index ad40313..e890f77 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -188,11 +188,11 @@ func.func @test_assign(%arg1: f32) {
func.func @test_expression(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: f32, %arg4: f32) -> i32 {
%c7 = "emitc.constant"() {value = 7 : i32} : () -> i32
- %q = emitc.expression : i32 {
+ %q = emitc.expression %arg1, %c7 : (i32, i32) -> i32 {
%a = emitc.rem %arg1, %c7 : (i32, i32) -> i32
emitc.yield %a : i32
}
- %r = emitc.expression noinline : i32 {
+ %r = emitc.expression %arg0, %arg1, %arg2, %arg3, %arg4, %q noinline : (i32, i32, i32, f32, f32, i32) -> i32 {
%a = emitc.add %arg0, %arg1 : (i32, i32) -> i32
%b = emitc.call_opaque "bar" (%a, %arg2, %q) : (i32, i32, i32) -> (i32)
%c = emitc.mul %arg3, %arg4 : (f32, f32) -> f32
@@ -286,8 +286,11 @@ func.func @assign_global(%arg0 : i32) {
func.func @member_access(%arg0: !emitc.lvalue<!emitc.opaque<"mystruct">>, %arg1: !emitc.lvalue<!emitc.opaque<"mystruct_ptr">>, %arg2: !emitc.lvalue<!emitc.ptr<!emitc.opaque<"mystruct">>>) {
%0 = "emitc.member" (%arg0) {member = "a"} : (!emitc.lvalue<!emitc.opaque<"mystruct">>) -> !emitc.lvalue<i32>
- %1 = "emitc.member_of_ptr" (%arg1) {member = "a"} : (!emitc.lvalue<!emitc.opaque<"mystruct_ptr">>) -> !emitc.lvalue<i32>
- %2 = "emitc.member_of_ptr" (%arg2) {member = "a"} : (!emitc.lvalue<!emitc.ptr<!emitc.opaque<"mystruct">>>) -> !emitc.lvalue<i32>
+ %1 = "emitc.member" (%arg0) {member = "b"} : (!emitc.lvalue<!emitc.opaque<"mystruct">>) -> !emitc.array<2xi32>
+ %2 = "emitc.member_of_ptr" (%arg1) {member = "a"} : (!emitc.lvalue<!emitc.opaque<"mystruct_ptr">>) -> !emitc.lvalue<i32>
+ %3 = "emitc.member_of_ptr" (%arg1) {member = "b"} : (!emitc.lvalue<!emitc.opaque<"mystruct_ptr">>) -> !emitc.array<2xi32>
+ %4 = "emitc.member_of_ptr" (%arg2) {member = "a"} : (!emitc.lvalue<!emitc.ptr<!emitc.opaque<"mystruct">>>) -> !emitc.lvalue<i32>
+ %5 = "emitc.member_of_ptr" (%arg2) {member = "b"} : (!emitc.lvalue<!emitc.ptr<!emitc.opaque<"mystruct">>>) -> !emitc.array<2xi32>
return
}
@@ -310,3 +313,15 @@ func.func @switch() {
return
}
+
+emitc.class final @finalClass {
+ emitc.field @fieldName0 : !emitc.array<1xf32>
+ emitc.field @fieldName1 : !emitc.array<1xf32>
+ emitc.func @execute() {
+ %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+ %1 = get_field @fieldName0 : !emitc.array<1xf32>
+ %2 = get_field @fieldName1 : !emitc.array<1xf32>
+ %3 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ return
+ }
+}
diff --git a/mlir/test/Dialect/EmitC/types.mlir b/mlir/test/Dialect/EmitC/types.mlir
index d4dd944..ce1e03a 100644
--- a/mlir/test/Dialect/EmitC/types.mlir
+++ b/mlir/test/Dialect/EmitC/types.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt -verify-diagnostics -allow-unregistered-dialect %s | FileCheck %s
-// check parser
-// RUN: mlir-opt -verify-diagnostics -allow-unregistered-dialect %s | mlir-opt -verify-diagnostics --allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s -allow-unregistered-dialect | FileCheck %s
+// Check parser
+// RUN: mlir-opt %s -allow-unregistered-dialect | mlir-opt -allow-unregistered-dialect | FileCheck %s
// CHECK-LABEL: func @array_types(
func.func @array_types(
diff --git a/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir b/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir
new file mode 100644
index 0000000..809febd
--- /dev/null
+++ b/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt %s -wrap-emitc-func-in-class -split-input-file | FileCheck %s
+
+emitc.func @foo(%arg0 : !emitc.array<1xf32>) {
+ emitc.call_opaque "bar" (%arg0) : (!emitc.array<1xf32>) -> ()
+ emitc.return
+}
+
+// CHECK: module {
+// CHECK: emitc.class @fooClass {
+// CHECK: emitc.field @fieldName0 : !emitc.array<1xf32>
+// CHECK: emitc.func @execute() {
+// CHECK: %0 = get_field @fieldName0 : !emitc.array<1xf32>
+// CHECK: call_opaque "bar"(%0) : (!emitc.array<1xf32>) -> ()
+// CHECK: return
+// CHECK: }
+// CHECK: }
+// CHECK: }
+
+// -----
+
+module attributes { } {
+ emitc.func @model(%arg0: !emitc.array<1xf32> {emitc.name_hint = "another_feature"},
+ %arg1: !emitc.array<1xf32> {emitc.name_hint = "some_feature"},
+ %arg2: !emitc.array<1xf32> {emitc.name_hint = "output_0"}) attributes { } {
+ %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+ %1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ %2 = load %1 : <f32>
+ %3 = subscript %arg0[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ %4 = load %3 : <f32>
+ %5 = add %2, %4 : (f32, f32) -> f32
+ %6 = subscript %arg2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ assign %5 : f32 to %6 : <f32>
+ return
+ }
+}
+
+// CHECK: module {
+// CHECK: emitc.class @modelClass {
+// CHECK: emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.name_hint = "another_feature"}
+// CHECK: emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.name_hint = "some_feature"}
+// CHECK: emitc.field @fieldName2 : !emitc.array<1xf32> {emitc.name_hint = "output_0"}
+// CHECK: emitc.func @execute() {
+// CHECK: get_field @fieldName0 : !emitc.array<1xf32>
+// CHECK: get_field @fieldName1 : !emitc.array<1xf32>
+// CHECK: get_field @fieldName2 : !emitc.array<1xf32>
+// CHECK: "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+// CHECK: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK: load {{.*}} : <f32>
+// CHECK: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK: load {{.*}} : <f32>
+// CHECK: add {{.*}}, {{.*}} : (f32, f32) -> f32
+// CHECK: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK: assign {{.*}} : f32 to {{.*}} : <f32>
+// CHECK: return
+// CHECK: }
+// CHECK: }
+// CHECK: }
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
deleted file mode 100644
index 029fa78..0000000
--- a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
+++ /dev/null
@@ -1,40 +0,0 @@
-// RUN: mlir-opt --wrap-emitc-func-in-class %s | FileCheck %s
-
-module attributes { } {
- emitc.func @model(%arg0: !emitc.array<1xf32> {emitc.name_hint = "another_feature"},
- %arg1: !emitc.array<1xf32> {emitc.name_hint = "some_feature"},
- %arg2: !emitc.array<1xf32> {emitc.name_hint = "output_0"}) attributes { } {
- %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
- %1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
- %2 = load %1 : <f32>
- %3 = subscript %arg0[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
- %4 = load %3 : <f32>
- %5 = add %2, %4 : (f32, f32) -> f32
- %6 = subscript %arg2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
- assign %5 : f32 to %6 : <f32>
- return
- }
-}
-
-
-// CHECK: module {
-// CHECK-NEXT: emitc.class @modelClass {
-// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.name_hint = "another_feature"}
-// CHECK-NEXT: emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.name_hint = "some_feature"}
-// CHECK-NEXT: emitc.field @fieldName2 : !emitc.array<1xf32> {emitc.name_hint = "output_0"}
-// CHECK-NEXT: emitc.func @execute() {
-// CHECK-NEXT: get_field @fieldName0 : !emitc.array<1xf32>
-// CHECK-NEXT: get_field @fieldName1 : !emitc.array<1xf32>
-// CHECK-NEXT: get_field @fieldName2 : !emitc.array<1xf32>
-// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
-// CHECK-NEXT: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
-// CHECK-NEXT: load {{.*}} : <f32>
-// CHECK-NEXT: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
-// CHECK-NEXT: load {{.*}} : <f32>
-// CHECK-NEXT: add {{.*}}, {{.*}} : (f32, f32) -> f32
-// CHECK-NEXT: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
-// CHECK-NEXT: assign {{.*}} : f32 to {{.*}} : <f32>
-// CHECK-NEXT: return
-// CHECK-NEXT: }
-// CHECK-NEXT: }
-// CHECK-NEXT: }
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir
deleted file mode 100644
index 92ed20c..0000000
--- a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir
+++ /dev/null
@@ -1,17 +0,0 @@
-// RUN: mlir-opt --wrap-emitc-func-in-class %s | FileCheck %s
-
-emitc.func @foo(%arg0 : !emitc.array<1xf32>) {
- emitc.call_opaque "bar" (%arg0) : (!emitc.array<1xf32>) -> ()
- emitc.return
-}
-
-// CHECK: module {
-// CHECK-NEXT: emitc.class @fooClass {
-// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32>
-// CHECK-NEXT: emitc.func @execute() {
-// CHECK-NEXT: %0 = get_field @fieldName0 : !emitc.array<1xf32>
-// CHECK-NEXT: call_opaque "bar"(%0) : (!emitc.array<1xf32>) -> ()
-// CHECK-NEXT: return
-// CHECK-NEXT: }
-// CHECK-NEXT: }
-// CHECK-NEXT: }
diff --git a/mlir/test/Dialect/GPU/broadcast-speculatability.mlir b/mlir/test/Dialect/GPU/broadcast-speculatability.mlir
new file mode 100644
index 0000000..ea32d62
--- /dev/null
+++ b/mlir/test/Dialect/GPU/broadcast-speculatability.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s --loop-invariant-code-motion | FileCheck %s
+
+func.func private @side_effect(%arg0 : f32, %arg1 : f32, %arg2 : f32)
+
+// CHECK-LABEL: func @broadcast_hoisting
+// CHECK-SAME: (%[[ARG:.*]]: f32, %[[IDX:.*]]: i32, {{.*}}: index)
+func.func @broadcast_hoisting(%arg0 : f32, %arg1 : i32, %arg2 : index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+// `any_lane` and `specific_lane` can be speculated across the control flow, but
+// `first_active_lane` cannot as active lanes can change.
+// CHECK: %[[V1:.*]] = gpu.subgroup_broadcast %[[ARG]], any_lane : f32
+// CHECK: %[[V2:.*]] = gpu.subgroup_broadcast %[[ARG]], specific_lane %[[IDX]] : f32
+// CHECK: scf.for
+// CHECK: %[[V0:.*]] = gpu.subgroup_broadcast %[[ARG]], first_active_lane : f32
+// CHECK: func.call @side_effect(%[[V0]], %[[V1]], %[[V2]])
+ scf.for %i = %c0 to %arg2 step %c1 {
+ %0 = gpu.subgroup_broadcast %arg0, first_active_lane : f32
+ %1 = gpu.subgroup_broadcast %arg0, any_lane : f32
+ %2 = gpu.subgroup_broadcast %arg0, specific_lane %arg1 : f32
+ func.call @side_effect(%0, %1, %2) : (f32, f32, f32) -> ()
+ }
+ func.return
+}
diff --git a/mlir/test/Dialect/GPU/int-range-interface.mlir b/mlir/test/Dialect/GPU/int-range-interface.mlir
index 1613f83..2e92db0 100644
--- a/mlir/test/Dialect/GPU/int-range-interface.mlir
+++ b/mlir/test/Dialect/GPU/int-range-interface.mlir
@@ -329,3 +329,22 @@ module attributes {gpu.container_module} {
}
}
}
+
+// -----
+
+// CHECK-LABEL: func @broadcast
+func.func @broadcast(%idx: i32) {
+ %0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index
+ %1 = gpu.subgroup_broadcast %0, first_active_lane : index
+ %2 = gpu.subgroup_broadcast %0, any_lane : index
+ %3 = gpu.subgroup_broadcast %0, specific_lane %idx : index
+
+ // CHECK: test.reflect_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index}
+ // CHECK: test.reflect_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index}
+ // CHECK: test.reflect_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index}
+
+ %4 = test.reflect_bounds %1 : index
+ %5 = test.reflect_bounds %2 : index
+ %6 = test.reflect_bounds %3 : index
+ return
+}
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index ee1fdfa..cd889f8 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -17,6 +17,18 @@ module attributes {gpu.container_module} {
return
}
+ // CHECK-LABEL:func @launch_with_module_func_attr(%{{.*}}: index)
+ func.func @launch_with_module_func_attr(%sz : index) {
+ // CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) module(@test_module) function(@test_kernel_func)
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
+ threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz)
+ module(@test_module) function(@test_kernel_func) {
+ // CHECK: gpu.terminator
+ gpu.terminator
+ }
+ return
+ }
+
// CHECK-LABEL:func @args(%{{.*}}: index, %{{.*}}: index, %{{.*}}: f32, %{{.*}}: memref<?xf32, 1>) {
func.func @args(%blk : index, %thrd : index, %float : f32, %data : memref<?xf32,1>) {
// CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}})
@@ -114,7 +126,7 @@ module attributes {gpu.container_module} {
// CHECK-NEXT: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f32
// CHECK-NEXT: gpu.yield %{{.*}} : f32
// CHECK-NEXT: } : (f32) -> f32
- %sum2 = gpu.all_reduce %one {
+ %sum2 = gpu.all_reduce %one {
^bb(%lhs : f32, %rhs : f32):
%tmp = arith.addf %lhs, %rhs : f32
gpu.yield %tmp : f32
@@ -247,7 +259,7 @@ module attributes {gpu.container_module} {
%1 = arith.cmpi slt, %arg0, %arg0 : i32
scf.if %1 {
gpu.printf ", "
- }
+ }
gpu.return
}
@@ -530,3 +542,15 @@ func.func @warp_operand_result(%laneid: index, %v0 : vector<4xi32>) -> (vector<4
}
return %2 : vector<4xi32>
}
+
+// CHECK-LABEL: func @subgroup_broadcast
+// CHECK-SAME: (%[[ARG:.*]]: f32, %[[IDX:.*]]: i32)
+func.func @subgroup_broadcast(%arg0 : f32, %arg1 : i32) -> (f32, f32, f32) {
+ // CHECK: gpu.subgroup_broadcast %[[ARG]], first_active_lane : f32
+ %0 = gpu.subgroup_broadcast %arg0, first_active_lane : f32
+ // CHECK: gpu.subgroup_broadcast %[[ARG]], any_lane : f32
+ %1 = gpu.subgroup_broadcast %arg0, any_lane : f32
+ // CHECK: gpu.subgroup_broadcast %[[ARG]], specific_lane %[[IDX]] : f32
+ %2 = gpu.subgroup_broadcast %arg0, specific_lane %arg1 : f32
+ func.return %0, %1, %2 : f32, f32, f32
+}
diff --git a/mlir/test/Dialect/GPU/outlining.mlir b/mlir/test/Dialect/GPU/outlining.mlir
index d48fa05..e14521a 100644
--- a/mlir/test/Dialect/GPU/outlining.mlir
+++ b/mlir/test/Dialect/GPU/outlining.mlir
@@ -509,7 +509,7 @@ func.func @launch_cluster() {
// CHECK-NEXT: = memref.load %[[KERNEL_ARG1]][%[[TID]]] : memref<?xf32, 1>
// -----
-// This test tests the two optional attributes kernelModule and kernelFunc for gpu.launch
+// This test tests the two optional attributes `module` and `function` for gpu.launch
// CHECK-LABEL: func.func @testKernelAttributes()
// CHECK: gpu.launch_func @test_module::@test_kernel_func blocks in (%[[GRID_X:.*]], %[[GRID_Y:.*]], %[[GRID_Z:.*]]) threads in (%[[BLOCK_X:.*]], %[[BLOCK_Y:.*]], %[[BLOCK_Z:.*]])
// CHECK: gpu.module @test_module
@@ -523,15 +523,16 @@ func.func @testKernelAttributes() {
%bDimZ = arith.constant 8 : index
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY, %grid_z = %gDimZ)
- threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ) {
+ threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ)
+ module(@test_module) function(@test_kernel_func) {
"some_op"(%bx, %tx) : (index, index) -> ()
gpu.terminator
- } {kernelModule = @test_module, kernelFunc = @test_kernel_func}
+ }
return
}
// -----
-// This test tests the two optional attributes kernelModule and kernelFunc for gpu.launch, when kernelModule already exists.
+// This test tests the two optional attributes `module` and `function` for gpu.launch, when kernelModule already exists.
// CHECK-LABEL: gpu.module @existing_module
// CHECK: gpu.func @test_kernel_func()
@@ -556,15 +557,16 @@ func.func @testExistingModule() {
%bDimZ = arith.constant 8 : index
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY, %grid_z = %gDimZ)
- threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ) {
+ threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ)
+ module(@existing_module) function(@test_kernel_func) {
"some_op"(%bx, %tx) : (index, index) -> ()
gpu.terminator
- } {kernelModule = @existing_module, kernelFunc = @test_kernel_func}
+ }
return
}
// -----
-// This test tests the optional attribute kernelModule for gpu.launch.
+// This test tests the optional attribute `module` for gpu.launch.
// CHECK-LABEL: func.func @testKernelModuleOnly()
// CHECK: gpu.launch_func @test_module::@testKernelModuleOnly_kernel blocks in (%[[GRID_X:.*]], %[[GRID_Y:.*]], %[[GRID_Z:.*]]) threads in (%[[BLOCK_X:.*]], %[[BLOCK_Y:.*]], %[[BLOCK_Z:.*]])
// CHECK: gpu.module @test_module
@@ -578,15 +580,16 @@ func.func @testKernelModuleOnly() {
%bDimZ = arith.constant 8 : index
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY, %grid_z = %gDimZ)
- threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ) {
+ threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ)
+ module(@test_module) {
"some_op"(%bx, %tx) : (index, index) -> ()
gpu.terminator
- } {kernelModule = @test_module}
+ }
return
}
// -----
-// This test tests the optional attribute kernelFunc for gpu.launch.
+// This test tests the optional attribute `function` for gpu.launch.
// CHECK-LABEL: func.func @testKernelFuncOnly()
// CHECK: gpu.launch_func @test_kernel_func::@test_kernel_func blocks in (%[[GRID_X:.*]], %[[GRID_Y:.*]], %[[GRID_Z:.*]]) threads in (%[[BLOCK_X:.*]], %[[BLOCK_Y:.*]], %[[BLOCK_Z:.*]])
@@ -601,15 +604,16 @@ func.func @testKernelFuncOnly() {
%bDimZ = arith.constant 8 : index
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY, %grid_z = %gDimZ)
- threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ) {
+ threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ)
+ function(@test_kernel_func) {
"some_op"(%bx, %tx) : (index, index) -> ()
gpu.terminator
- } {kernelFunc = @test_kernel_func}
+ }
return
}
// -----
-// This test tests gpu.launch when optional attributes kernelModule and kernelFunc are not specified.
+// This test tests gpu.launch when optional attributes `module` and `function` are not specified.
// CHECK-LABEL: func.func @testNoAttributes()
// CHECK: gpu.launch_func @testNoAttributes_kernel::@testNoAttributes_kernel blocks in (%[[GRID_X:.*]], %[[GRID_Y:.*]], %[[GRID_Z:.*]]) threads in (%[[BLOCK_X:.*]], %[[BLOCK_Y:.*]], %[[BLOCK_Z:.*]])
@@ -630,3 +634,29 @@ func.func @testNoAttributes() {
}
return
}
+
+// -----
+
+// This test tests nested `gpu.launch`.
+
+// CHECK-LABEL: func.func @nested_launch(
+// CHECK-SAME: %[[ARG0:.*]]: index) {
+// CHECK: gpu.launch_func @nested_launch_kernel_0::@nested_launch_kernel blocks in (%[[ARG0]], %[[ARG0]], %[[ARG0]]) threads in (%[[ARG0]], %[[ARG0]], %[[ARG0]]) args(%[[ARG0]] : index)
+// CHECK: gpu.module @nested_launch_kernel
+// CHECK: gpu.func @nested_launch_kernel() kernel
+// CHECK: "some_op"
+// CHECK: gpu.module @nested_launch_kernel_0
+// CHECK: gpu.func @nested_launch_kernel(%[[VAL_0:.*]]: index) kernel
+// CHECK: gpu.launch_func @nested_launch_kernel::@nested_launch_kernel blocks in (%[[VAL_0]], %[[VAL_0]], %[[VAL_0]]) threads in (%[[VAL_0]], %[[VAL_0]], %[[VAL_0]])
+func.func @nested_launch(%sz : index) {
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
+ threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
+ gpu.launch blocks(%bx1, %by1, %bz1) in (%grid_x1 = %sz, %grid_y1 = %sz, %grid_z1 = %sz)
+ threads(%tx1, %ty1, %tz1) in (%block_x1 = %sz, %block_y1 = %sz, %block_z1 = %sz) {
+ "some_op"(%bx1, %tx1) : (index, index) -> ()
+ gpu.terminator
+ }
+ gpu.terminator
+ }
+ return
+}
diff --git a/mlir/test/Dialect/GPU/promote-shuffle-amdgpu.mlir b/mlir/test/Dialect/GPU/promote-shuffle-amdgpu.mlir
index 4293b43..747c997 100644
--- a/mlir/test/Dialect/GPU/promote-shuffle-amdgpu.mlir
+++ b/mlir/test/Dialect/GPU/promote-shuffle-amdgpu.mlir
@@ -4,7 +4,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
- transform.apply_patterns.gpu.gpu_shuffle_to_amdgpu
+ transform.apply_patterns.gpu.gpu_shuffle_to_amdgpu chipset = "gfx950"
} : !transform.any_op
transform.yield
}
@@ -21,3 +21,15 @@ func.func @gpu_shuffle_swizzle(%arg0: i32) -> (i32, i1) {
%shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : i32
func.return %shfl, %pred : i32, i1
}
+
+ // CHECK-LABEL: func @gpu_shuffle_permlane_swap
+ // CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @gpu_shuffle_permlane_swap(%arg0: i32) -> (i32, i1) {
+ // CHECK: %[[TRUE:.*]] = arith.constant true
+ // CHECK: %[[RES:.*]] = amdgpu.permlane_swap %[[ARG]] 32 : i32
+ // CHECK: return %[[RES]], %[[TRUE]] : i32, i1
+ %width = arith.constant 64 : i32
+ %offset = arith.constant 32 : i32
+ %shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : i32
+ func.return %shfl, %pred : i32, i1
+}
diff --git a/mlir/test/Dialect/LLVMIR/call-intrin.mlir b/mlir/test/Dialect/LLVMIR/call-intrin.mlir
index b8d845d..bf11e07 100644
--- a/mlir/test/Dialect/LLVMIR/call-intrin.mlir
+++ b/mlir/test/Dialect/LLVMIR/call-intrin.mlir
@@ -27,14 +27,13 @@ llvm.func @round_overloaded() -> f32 {
// CHECK: define void @lifetime_start() {
// CHECK: %1 = alloca float, i8 1, align 4
-// CHECK: call void @llvm.lifetime.start.p0(i64 4, ptr %1)
+// CHECK: call void @llvm.lifetime.start.p0(ptr %1)
// CHECK: ret void
// CHECK: }
llvm.func @lifetime_start() {
- %0 = llvm.mlir.constant(4 : i64) : i64
- %1 = llvm.mlir.constant(1 : i8) : i8
- %2 = llvm.alloca %1 x f32 : (i8) -> !llvm.ptr
- llvm.call_intrinsic "llvm.lifetime.start"(%0, %2) {} : (i64, !llvm.ptr) -> ()
+ %0 = llvm.mlir.constant(1 : i8) : i8
+ %1 = llvm.alloca %0 x f32 : (i8) -> !llvm.ptr
+ llvm.call_intrinsic "llvm.lifetime.start"(%1) {} : (!llvm.ptr) -> ()
llvm.return
}
diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir
index a168ceb..071f124 100644
--- a/mlir/test/Dialect/LLVMIR/func.mlir
+++ b/mlir/test/Dialect/LLVMIR/func.mlir
@@ -276,12 +276,6 @@ module {
llvm.return
}
- llvm.func @approx_func_fp_math_roundtrip() attributes {approx_func_fp_math = true} {
- // CHECK: @approx_func_fp_math_roundtrip
- // CHECK-SAME: attributes {approx_func_fp_math = true}
- llvm.return
- }
-
llvm.func @no_signed_zeros_fp_math_roundtrip() attributes {no_signed_zeros_fp_math = true} {
// CHECK: @no_signed_zeros_fp_math_roundtrip
// CHECK-SAME: attributes {no_signed_zeros_fp_math = true}
diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir
index 551e0c9..8e292f4 100644
--- a/mlir/test/Dialect/LLVMIR/inlining.mlir
+++ b/mlir/test/Dialect/LLVMIR/inlining.mlir
@@ -299,7 +299,7 @@ llvm.func @test_inline(%cond0 : i1, %cond1 : i1, %funcArg : f32) -> f32 {
^bb1:
// Make sure the lifetime begin intrinsic has been inserted where the call
// used to be, even though the alloca has been moved to the entry block.
- // CHECK-NEXT: llvm.intr.lifetime.start 4, %[[PTR]]
+ // CHECK-NEXT: llvm.intr.lifetime.start %[[PTR]]
%0 = llvm.call @static_alloca(%cond1) : (i1) -> f32
// CHECK: llvm.cond_br %{{.+}}, ^[[BB2:.+]], ^[[BB3:.+]]
llvm.br ^bb3(%0: f32)
@@ -307,9 +307,9 @@ llvm.func @test_inline(%cond0 : i1, %cond1 : i1, %funcArg : f32) -> f32 {
// return sites of the callee.
// CHECK: ^[[BB2]]:
// CHECK-NEXT: llvm.load
- // CHECK-NEXT: llvm.intr.lifetime.end 4, %[[PTR]]
+ // CHECK-NEXT: llvm.intr.lifetime.end %[[PTR]]
// CHECK: ^[[BB3]]:
- // CHECK-NEXT: llvm.intr.lifetime.end 4, %[[PTR]]
+ // CHECK-NEXT: llvm.intr.lifetime.end %[[PTR]]
^bb2:
llvm.br ^bb3(%funcArg: f32)
^bb3(%blockArg: f32):
@@ -334,9 +334,9 @@ llvm.func @test_inline(%cond0 : i1) {
// CHECK: "test.one_region_op"() ({
"test.one_region_op"() ({
%0 = llvm.call @static_alloca() : () -> f32
- // CHECK-NEXT: llvm.intr.lifetime.start 4, %[[ALLOCA]]
+ // CHECK-NEXT: llvm.intr.lifetime.start %[[ALLOCA]]
// CHECK-NEXT: %[[RES:.+]] = llvm.load %[[ALLOCA]]
- // CHECK-NEXT: llvm.intr.lifetime.end 4, %[[ALLOCA]]
+ // CHECK-NEXT: llvm.intr.lifetime.end %[[ALLOCA]]
// CHECK-NEXT: test.region_yield %[[RES]]
test.region_yield %0 : f32
}) : () -> ()
@@ -368,9 +368,9 @@ llvm.func @test_inline(%cond0 : i1) {
llvm.func @alloca_with_lifetime(%cond: i1) -> f32 {
%0 = llvm.mlir.constant(4 : i32) : i32
%1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr
- llvm.intr.lifetime.start 4, %1 : !llvm.ptr
+ llvm.intr.lifetime.start %1 : !llvm.ptr
%2 = llvm.load %1 : !llvm.ptr -> f32
- llvm.intr.lifetime.end 4, %1 : !llvm.ptr
+ llvm.intr.lifetime.end %1 : !llvm.ptr
%3 = llvm.fadd %2, %2 : f32
llvm.return %3 : f32
}
@@ -385,9 +385,9 @@ llvm.func @test_inline(%cond0 : i1, %cond1 : i1, %funcArg : f32) -> f32 {
^bb1:
// Make sure the original lifetime intrinsic has been preserved, rather than
// inserting a new one with a larger scope.
- // CHECK: llvm.intr.lifetime.start 4, %[[PTR]]
+ // CHECK: llvm.intr.lifetime.start %[[PTR]]
// CHECK-NEXT: llvm.load %[[PTR]]
- // CHECK-NEXT: llvm.intr.lifetime.end 4, %[[PTR]]
+ // CHECK-NEXT: llvm.intr.lifetime.end %[[PTR]]
// CHECK: llvm.fadd
// CHECK-NOT: llvm.intr.lifetime.end
%0 = llvm.call @alloca_with_lifetime(%cond1) : (i1) -> f32
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index ac17374..4394786 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1220,38 +1220,6 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<
// -----
-llvm.func @wmmald_matrix(%arg0: !llvm.ptr) {
- // expected-error@+1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}}
- %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr) -> i32
- llvm.return
-}
-
-// -----
-
-llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
- // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4}}
- %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> i32
- llvm.return
-}
-
-// -----
-
-llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
- // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is i32}}
- %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
- llvm.return
-}
-
-// -----
-
-llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
- // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}}
- %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
- llvm.return
-}
-
-// -----
-
llvm.func @caller() {
// expected-error @below {{expected function call to produce a value}}
llvm.call @callee() : () -> ()
@@ -1307,8 +1275,8 @@ func.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
// -----
func.func @mapa(%a: !llvm.ptr, %b : i32) {
- // expected-error @below {{`res` and `a` should have the same type}}
- %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<3>
+ // expected-error @below {{'nvvm.mapa' op failed to verify that Valid address-space check(or mapping) for mapa Op}}
+ %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<7>
return
}
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 56634cf..716a586 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -304,10 +304,9 @@ llvm.func @g()
// CHECK-NOT: = llvm.alloca
llvm.func amdgpu_kernelcc @addrspace_discard() {
%0 = llvm.mlir.constant(1 : i32) : i32
- %1 = llvm.mlir.constant(2 : i64) : i64
- %2 = llvm.alloca %0 x i8 {alignment = 8 : i64} : (i32) -> !llvm.ptr<5>
- %3 = llvm.addrspacecast %2 : !llvm.ptr<5> to !llvm.ptr
- llvm.intr.lifetime.start 2, %3 : !llvm.ptr
+ %1 = llvm.alloca %0 x i8 {alignment = 8 : i64} : (i32) -> !llvm.ptr<5>
+ %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr
+ llvm.intr.lifetime.start %2 : !llvm.ptr
llvm.return
}
@@ -406,9 +405,9 @@ llvm.func @unreachable_jumps_to_merge_point(%arg0: i1) -> i32 {
llvm.func @ignore_lifetime() {
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
- llvm.intr.lifetime.start 2, %1 : !llvm.ptr
+ llvm.intr.lifetime.start %1 : !llvm.ptr
llvm.store %0, %1 {alignment = 4 : i64} : i32, !llvm.ptr
- llvm.intr.lifetime.end 2, %1 : !llvm.ptr
+ llvm.intr.lifetime.end %1 : !llvm.ptr
llvm.return
}
@@ -437,9 +436,9 @@ llvm.func @ignore_discardable_tree() {
%5 = llvm.insertvalue %1, %4[1] : !llvm.struct<(i8, i16)>
%6 = llvm.alloca %0 x !llvm.struct<(i8, i16)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
%7 = llvm.getelementptr %6[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i8, i16)>
- llvm.intr.lifetime.start 2, %7 : !llvm.ptr
+ llvm.intr.lifetime.start %7 : !llvm.ptr
llvm.store %5, %6 {alignment = 2 : i64} : !llvm.struct<(i8, i16)>, !llvm.ptr
- llvm.intr.lifetime.end 2, %7 : !llvm.ptr
+ llvm.intr.lifetime.end %7 : !llvm.ptr
llvm.return
}
@@ -517,8 +516,8 @@ llvm.func @discardable_use_tree() {
%2 = llvm.alloca %0 x i8 {alignment = 8 : i64} : (i32) -> !llvm.ptr
%3 = llvm.bitcast %2 : !llvm.ptr to !llvm.ptr
%4 = llvm.bitcast %3 : !llvm.ptr to !llvm.ptr
- llvm.intr.lifetime.start 2, %3 : !llvm.ptr
- llvm.intr.lifetime.start 2, %4 : !llvm.ptr
+ llvm.intr.lifetime.start %3 : !llvm.ptr
+ llvm.intr.lifetime.start %4 : !llvm.ptr
%5 = llvm.intr.invariant.start 2, %3 : !llvm.ptr
llvm.intr.invariant.end %5, 2, %3 : !llvm.ptr
llvm.return
@@ -534,8 +533,8 @@ llvm.func @non_discardable_use_tree() {
%2 = llvm.alloca %0 x i8 {alignment = 8 : i64} : (i32) -> !llvm.ptr
%3 = llvm.bitcast %2 : !llvm.ptr to !llvm.ptr
%4 = llvm.bitcast %3 : !llvm.ptr to !llvm.ptr
- llvm.intr.lifetime.start 2, %3 : !llvm.ptr
- llvm.intr.lifetime.start 2, %4 : !llvm.ptr
+ llvm.intr.lifetime.start %3 : !llvm.ptr
+ llvm.intr.lifetime.start %4 : !llvm.ptr
llvm.call @use(%4) : (!llvm.ptr) -> i1
llvm.return
}
@@ -551,8 +550,8 @@ llvm.func @trivial_get_element_ptr() {
%2 = llvm.alloca %0 x i8 {alignment = 8 : i64} : (i32) -> !llvm.ptr
%3 = llvm.bitcast %2 : !llvm.ptr to !llvm.ptr
%4 = llvm.getelementptr %3[0] : (!llvm.ptr) -> !llvm.ptr, i8
- llvm.intr.lifetime.start 2, %3 : !llvm.ptr
- llvm.intr.lifetime.start 2, %4 : !llvm.ptr
+ llvm.intr.lifetime.start %3 : !llvm.ptr
+ llvm.intr.lifetime.start %4 : !llvm.ptr
llvm.return
}
@@ -565,8 +564,8 @@ llvm.func @nontrivial_get_element_ptr() {
// CHECK: = llvm.alloca
%2 = llvm.alloca %0 x i8 {alignment = 8 : i64} : (i32) -> !llvm.ptr
%4 = llvm.getelementptr %2[1] : (!llvm.ptr) -> !llvm.ptr, i8
- llvm.intr.lifetime.start 2, %2 : !llvm.ptr
- llvm.intr.lifetime.start 2, %4 : !llvm.ptr
+ llvm.intr.lifetime.start %2 : !llvm.ptr
+ llvm.intr.lifetime.start %4 : !llvm.ptr
llvm.return
}
@@ -580,8 +579,8 @@ llvm.func @dynamic_get_element_ptr() {
%2 = llvm.alloca %0 x i8 {alignment = 8 : i64} : (i32) -> !llvm.ptr
%3 = llvm.bitcast %2 : !llvm.ptr to !llvm.ptr
%4 = llvm.getelementptr %3[%0] : (!llvm.ptr, i32) -> !llvm.ptr, i8
- llvm.intr.lifetime.start 2, %3 : !llvm.ptr
- llvm.intr.lifetime.start 2, %4 : !llvm.ptr
+ llvm.intr.lifetime.start %3 : !llvm.ptr
+ llvm.intr.lifetime.start %4 : !llvm.ptr
llvm.return
}
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index c7fa41c..5209b3c 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -385,17 +385,6 @@ llvm.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
llvm.return
}
-// CHECK-LABEL: llvm.func @ld_matrix
-llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
- // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} : (!llvm.ptr<3>) -> i32
- %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> i32
- // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 2 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
- %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
- // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
- %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
- llvm.return
-}
-
// CHECK-LABEL: llvm.func @redux_sync
llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 {
// CHECK: nvvm.redux.sync add %{{.*}}
@@ -535,15 +524,15 @@ func.func @wgmma_wait_group_sync_aligned() {
}
func.func @griddepcontrol_wait() {
- // CHECK: nvvm.griddepcontrol.wait
- nvvm.griddepcontrol.wait
+ // CHECK: nvvm.griddepcontrol wait
+ nvvm.griddepcontrol wait
return
}
func.func @griddepcontrol_launch_dependents()
{
- // CHECK: nvvm.griddepcontrol.launch.dependents
- nvvm.griddepcontrol.launch.dependents
+ // CHECK: nvvm.griddepcontrol launch_dependents
+ nvvm.griddepcontrol launch_dependents
return
}
@@ -552,7 +541,7 @@ func.func @mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
// CHECK: nvvm.mapa %{{.*}}
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr
// CHECK: nvvm.mapa %{{.*}}
- %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
+ %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<7>
return
}
@@ -597,7 +586,7 @@ func.func @dot_accumulate_2way(%a_vec: vector<2xi16>, %b_vec: vector<4xi8>, %c:
}
// CHECK-LABEL: @prefetch
-func.func @prefetch(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>) {
+func.func @prefetch(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>, %const_ptr: !llvm.ptr<4>) {
// CHECK: nvvm.prefetch level = L1, %{{.*}}
nvvm.prefetch level = L1, %gen_ptr : !llvm.ptr<0>
// CHECK: nvvm.prefetch level = L1, %{{.*}}
@@ -610,12 +599,24 @@ func.func @prefetch(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr:
nvvm.prefetch level = L2, %local_ptr : !llvm.ptr<5>
// CHECK: nvvm.prefetch level = L2, %{{.*}}
nvvm.prefetch level = L2, %global_ptr : !llvm.ptr<1>
- // CHECK: nvvm.prefetch level = L2, %{{.*}}
- nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_last : !llvm.ptr<1>
- // CHECK: nvvm.prefetch level = L2, %{{.*}}
- nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_normal : !llvm.ptr<1>
+ // CHECK: nvvm.prefetch level = L2, evict_priority = evict_last, %{{.*}}
+ nvvm.prefetch level = L2, evict_priority = evict_last, %global_ptr :
+ !llvm.ptr<1>
+ // CHECK: nvvm.prefetch level = L2, evict_priority = evict_normal, %{{.*}}
+ nvvm.prefetch level = L2, evict_priority = evict_normal, %global_ptr : !llvm.ptr<1>
// CHECK: nvvm.prefetch level = L1 uniform, %{{.*}}
nvvm.prefetch level = L1 uniform, %gen_ptr : !llvm.ptr
+ // CHECK: nvvm.prefetch tensormap, %{{.*}}
+ nvvm.prefetch tensormap, %gen_ptr : !llvm.ptr
+ // CHECK: nvvm.prefetch tensormap, %{{.*}}
+ nvvm.prefetch tensormap, %const_ptr : !llvm.ptr<4>
+ // CHECK: nvvm.prefetch tensormap in_param_space, %{{.*}}
+ nvvm.prefetch tensormap in_param_space, %gen_ptr : !llvm.ptr
+ return
+}
+
+// CHECK-LABEL: @prefetch_tensormap
+func.func @prefetch_tensormap(%gen_ptr: !llvm.ptr, %const_ptr: !llvm.ptr<4>) {
return
}
diff --git a/mlir/test/Dialect/LLVMIR/ptr.mlir b/mlir/test/Dialect/LLVMIR/ptr.mlir
new file mode 100644
index 0000000..3c208ae
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/ptr.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt %s --verify-roundtrip
+
+// Check that LLVM ops accept ptr values.
+llvm.func @llvm_ops_with_ptr_values(%arg0: !llvm.ptr, %arg1: !llvm.struct<(!ptr.ptr<#llvm.address_space<3>>)>) {
+ %1 = llvm.load %arg0 : !llvm.ptr -> !ptr.ptr<#llvm.address_space<1>>
+ llvm.store %1, %arg0 : !ptr.ptr<#llvm.address_space<1>>, !llvm.ptr
+ llvm.store %arg1, %arg0 : !llvm.struct<(!ptr.ptr<#llvm.address_space<3>>)>, !llvm.ptr
+ llvm.return
+}
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index a2b2f84..782ef4e 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -981,6 +981,13 @@ llvm.func @rocdl.s.wait.expcnt() {
// -----
+llvm.func @rocdl.readfirstlane(%src : f32) -> f32 {
+ // CHECK-LABEL: rocdl.readfirstlane
+ // CHECK: rocdl.readfirstlane %{{.*}} : f32
+ %ret = rocdl.readfirstlane %src : f32
+ llvm.return %ret : f32
+}
+
llvm.func @rocdl.readlane(%src : f32) -> f32 {
%cst0 = llvm.mlir.constant(0 : i32) : i32
@@ -1002,6 +1009,22 @@ llvm.func @rocdl.permlanex16(%src : f32) -> f32 {
// -----
+llvm.func @rocdl.permlane16.swap(%src : i32) -> !llvm.struct<(i32, i32)> {
+ // CHECK-LABEL: rocdl.permlane16.swap
+ // CHECK: rocdl.permlane16.swap %{{.*}} %{{.*}}
+ %res = rocdl.permlane16.swap %src, %src, 0, -1 : (i32, i32) -> !llvm.struct<(i32, i32)>
+ llvm.return %res : !llvm.struct<(i32, i32)>
+}
+
+llvm.func @rocdl.permlane32.swap(%src : i32) -> !llvm.struct<(i32, i32)> {
+ // CHECK-LABEL: rocdl.permlane32.swap
+ // CHECK: rocdl.permlane32.swap %{{.*}} %{{.*}}
+ %res = rocdl.permlane32.swap %src, %src, 0, -1 : (i32, i32) -> !llvm.struct<(i32, i32)>
+ llvm.return %res : !llvm.struct<(i32, i32)>
+}
+
+// -----
+
// expected-error@below {{attribute attached to unexpected op}}
func.func private @expected_llvm_func() attributes { rocdl.kernel }
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index a0273fb..7344797 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -685,10 +685,10 @@ func.func @fastmathFlags(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: vector<2 x f
// CHECK-LABEL: @lifetime
// CHECK-SAME: %[[P:.*]]: !llvm.ptr
llvm.func @lifetime(%p: !llvm.ptr) {
- // CHECK: llvm.intr.lifetime.start 16, %[[P]]
- llvm.intr.lifetime.start 16, %p : !llvm.ptr
- // CHECK: llvm.intr.lifetime.end 16, %[[P]]
- llvm.intr.lifetime.end 16, %p : !llvm.ptr
+ // CHECK: llvm.intr.lifetime.start %[[P]]
+ llvm.intr.lifetime.start %p : !llvm.ptr
+ // CHECK: llvm.intr.lifetime.end %[[P]]
+ llvm.intr.lifetime.end %p : !llvm.ptr
llvm.return
}
diff --git a/mlir/test/Dialect/LLVMIR/sroa.mlir b/mlir/test/Dialect/LLVMIR/sroa.mlir
index fe1531d..1674bbd 100644
--- a/mlir/test/Dialect/LLVMIR/sroa.mlir
+++ b/mlir/test/Dialect/LLVMIR/sroa.mlir
@@ -177,7 +177,7 @@ llvm.func @direct_promotable_use_is_fine() -> i32 {
// CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]]
%3 = llvm.load %2 : !llvm.ptr -> i32
// This is a direct use of the slot but it can be removed because it implements PromotableOpInterface.
- llvm.intr.lifetime.start 2, %1 : !llvm.ptr
+ llvm.intr.lifetime.start %1 : !llvm.ptr
// CHECK: llvm.return %[[RES]] : i32
llvm.return %3 : i32
}
diff --git a/mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir b/mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir
index 4ba4b09..2f30e8b 100644
--- a/mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir
+++ b/mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir
@@ -20,20 +20,6 @@ func.func @block_matmul(
return %0 : tensor<64x64xf32>
}
-func.func @block_matmul_transpose_a(
- %A: tensor<128x64xf32>, %B: tensor<128x64xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
- %0 = linalg.matmul_transpose_a ins(%A, %B : tensor<128x64xf32>, tensor<128x64xf32>)
- outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
- return %0 : tensor<64x64xf32>
-}
-
-func.func @block_matmul_transpose_b(
- %A: tensor<64x128xf32>, %B: tensor<64x128xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
- %0 = linalg.matmul_transpose_b ins(%A, %B : tensor<64x128xf32>, tensor<64x128xf32>)
- outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
- return %0 : tensor<64x64xf32>
-}
-
// MMT4D-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
// MMT4D-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
// MMT4D-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
@@ -43,18 +29,6 @@ func.func @block_matmul_transpose_b(
// MMT4D-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
// MMT4D-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
// MMT4D-COUNT-1: linalg.unpack
-// MMT4D-LABEL: func @block_matmul_transpose_a
-// MMT4D-COUNT-3: linalg.pack
-// MMT4D: linalg.generic
-// MMT4D-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// MMT4D-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// MMT4D-COUNT-1: linalg.unpack
-// MMT4D-LABEL: func @block_matmul_transpose_b
-// MMT4D-COUNT-3: linalg.pack
-// MMT4D: linalg.generic
-// MMT4D-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// MMT4D-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// MMT4D-COUNT-1: linalg.unpack
// MM4D-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
// MM4D-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>
@@ -65,18 +39,6 @@ func.func @block_matmul_transpose_b(
// MM4D-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
// MM4D-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
// MM4D-COUNT-1: linalg.unpack
-// MM4D-LABEL: func @block_matmul_transpose_a
-// MM4D-COUNT-3: linalg.pack
-// MM4D: linalg.generic
-// MM4D-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// MM4D-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// MM4D-COUNT-1: linalg.unpack
-// MM4D-LABEL: func @block_matmul_transpose_b
-// MM4D-COUNT-3: linalg.pack
-// MM4D: linalg.generic
-// MM4D-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// MM4D-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// MM4D-COUNT-1: linalg.unpack
// MTM4D-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d5, d3)>
// MTM4D-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>
@@ -87,15 +49,3 @@ func.func @block_matmul_transpose_b(
// MTM4D-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
// MTM4D-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
// MTM4D-COUNT-1: linalg.unpack
-// MTM4D-LABEL: func @block_matmul_transpose_a
-// MTM4D-COUNT-3: linalg.pack
-// MTM4D: linalg.generic
-// MTM4D-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// MTM4D-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// MTM4D-COUNT-1: linalg.unpack
-// MTM4D-LABEL: func @block_matmul_transpose_b
-// MTM4D-COUNT-3: linalg.pack
-// MTM4D: linalg.generic
-// MTM4D-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// MTM4D-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// MTM4D-COUNT-1: linalg.unpack
diff --git a/mlir/test/Dialect/Linalg/block-pack-matmul.mlir b/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
index aa860db..e16af1f 100644
--- a/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
+++ b/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
@@ -197,150 +197,6 @@ func.func @block_batch_matmul(
// -----
-func.func @block_matmul_transpose_a(
- %A: tensor<128x64xf32>, %B: tensor<128x64xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
- %0 = linalg.matmul_transpose_a ins(%A, %B : tensor<128x64xf32>, tensor<128x64xf32>)
- outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
- return %0 : tensor<64x64xf32>
-}
-
-// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
-// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
-
-// CHECK-LABEL: func @block_matmul_transpose_a(
-// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<128x64xf32>, %[[B:[0-9a-z]+]]: tensor<128x64xf32>, %[[C:[0-9a-z]+]]: tensor<64x64xf32>
-// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<2x2x32x64xf32>
-// CHECK: %[[A_PACKED:.+]] = linalg.pack %[[A]]
-// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [32, 64]
-// CHECK-SAME: into %[[PACK_DST_0]] : tensor<128x64xf32> -> tensor<2x2x32x64xf32>
-// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<4x2x16x64xf32>
-// CHECK: %[[B_PACKED:.+]] = linalg.pack %[[B]]
-// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64]
-// CHECK-SAME: into %[[PACK_DST_1]] : tensor<128x64xf32> -> tensor<4x2x16x64xf32>
-// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<2x4x32x16xf32>
-// CHECK: %[[C_PACKED:.+]] = linalg.pack %[[C]]
-// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[PACK_DST_2]] : tensor<64x64xf32> -> tensor<2x4x32x16xf32>
-// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// CHECK-SAME: ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>)
-// CHECK: %[[RES_UNPACKED:.+]] = linalg.unpack %[[GEMM_RES_PACKED]]
-// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32>
-// CHECK: return %[[RES_UNPACKED]] : tensor<64x64xf32>
-
-// -----
-
-func.func @block_batch_matmul_transpose_a(
- %A: tensor<512x128x64xf32>, %B: tensor<512x128x64xf32>, %C: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
- %0 = linalg.batch_matmul_transpose_a ins(%A, %B : tensor<512x128x64xf32>, tensor<512x128x64xf32>)
- outs(%C : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
- return %0 : tensor<512x64x64xf32>
-}
-
-// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d5, d6)>
-// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
-
-// CHECK-LABEL: func @block_batch_matmul_transpose_a(
-// CHECK-SAME: %[[A:.+]]: tensor<512x128x64xf32>, %[[B:.+]]: tensor<512x128x64xf32>, %[[C:.+]]: tensor<512x64x64xf32>
-// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
-// CHECK: %[[A_PACKED:.+]] = linalg.pack %[[A]]
-// CHECK-SAME: outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [32, 64]
-// CHECK-SAME: into %[[PACK_DST_0]] : tensor<512x128x64xf32> -> tensor<512x2x2x32x64xf32>
-// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x16x64xf32>
-// CHECK: %[[B_PACKED:.+]] = linalg.pack %[[B]]
-// CHECK-SAME: outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [16, 64]
-// CHECK-SAME: into %[[PACK_DST_1]] : tensor<512x128x64xf32> -> tensor<512x4x2x16x64xf32>
-// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
-// CHECK: %[[C_PACKED:.+]] = linalg.pack %[[C]]
-// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[PACK_DST_2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
-// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// CHECK-SAME: ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>)
-// CHECK: %[[RES_UNPACKED:.+]] = linalg.unpack %[[GEMM_RES_PACKED]]
-// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
-// CHECK: return %[[RES_UNPACKED]] : tensor<512x64x64xf32>
-
-// -----
-
-func.func @block_matmul_transpose_b(
- %A: tensor<64x128xf32>, %B: tensor<64x128xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
- %0 = linalg.matmul_transpose_b ins(%A, %B : tensor<64x128xf32>, tensor<64x128xf32>)
- outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
- return %0 : tensor<64x64xf32>
-}
-
-// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
-// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
-
-// CHECK-LABEL: func @block_matmul_transpose_b(
-// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<64x128xf32>, %[[B:[0-9a-z]+]]: tensor<64x128xf32>, %[[C:[0-9a-z]+]]: tensor<64x64xf32>
-// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<2x2x32x64xf32>
-// CHECK: %[[A_PACKED:.+]] = linalg.pack %[[A]]
-// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64]
-// CHECK-SAME: into %[[PACK_DST_0]] : tensor<64x128xf32> -> tensor<2x2x32x64xf32>
-// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<4x2x16x64xf32>
-// CHECK: %[[B_PACKED:.+]] = linalg.pack %[[B]]
-// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 64]
-// CHECK-SAME: into %[[PACK_DST_1]] : tensor<64x128xf32> -> tensor<4x2x16x64xf32>
-// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<2x4x32x16xf32>
-// CHECK: %[[C_PACKED:.+]] = linalg.pack %[[C]]
-// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[PACK_DST_2]] : tensor<64x64xf32> -> tensor<2x4x32x16xf32>
-// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// CHECK-SAME: ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>)
-// CHECK: %[[RES_UNPACKED:.+]] = linalg.unpack %[[GEMM_RES_PACKED]]
-// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32>
-// CHECK: return %[[RES_UNPACKED]] : tensor<64x64xf32>
-
-// -----
-
-func.func @block_batch_matmul_transpose_b(
- %A: tensor<512x64x128xf32>, %B: tensor<512x64x128xf32>, %C: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
- %0 = linalg.batch_matmul_transpose_b ins(%A, %B : tensor<512x64x128xf32>, tensor<512x64x128xf32>)
- outs(%C : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
- return %0 : tensor<512x64x64xf32>
-}
-
-// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d5, d6)>
-// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
-
-// CHECK-LABEL: func @block_batch_matmul_transpose_b(
-// CHECK-SAME: %[[A:.+]]: tensor<512x64x128xf32>, %[[B:.+]]: tensor<512x64x128xf32>, %[[C:.+]]: tensor<512x64x64xf32>
-// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
-// CHECK: %[[A_PACKED:.+]] = linalg.pack %[[A]]
-// CHECK-SAME: outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [32, 64]
-// CHECK-SAME: into %[[PACK_DST_0]] : tensor<512x64x128xf32> -> tensor<512x2x2x32x64xf32>
-// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x16x64xf32>
-// CHECK: %[[B_PACKED:.+]] = linalg.pack %[[B]]
-// CHECK-SAME: outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 64]
-// CHECK-SAME: into %[[PACK_DST_1]] : tensor<512x64x128xf32> -> tensor<512x4x2x16x64xf32>
-// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
-// CHECK: %[[C_PACKED:.+]] = linalg.pack %[[C]]
-// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[PACK_DST_2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
-// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
-// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// CHECK-SAME: ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>)
-// CHECK: %[[RES_UNPACKED:.+]] = linalg.unpack %[[GEMM_RES_PACKED]]
-// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
-// CHECK: return %[[RES_UNPACKED]] : tensor<512x64x64xf32>
-
-// -----
-
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
index c17f20b..8627fcd 100644
--- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -34,40 +34,35 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
-// Collapsed indices.
-// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index
-// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index
-// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index
-
-// Compute input channel/convolved indices.
-// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 4)>()[%[[KINDEX]]]
-// CHECK: %[[CONVH:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 14 + s1 floordiv 12)>()[%[[MINDEX]], %[[KINDEX]]]
-// CHECK: %[[CONVW:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 mod 14 + (s1 mod 12) floordiv 4)>()[%[[MINDEX]], %[[KINDEX]]]
-
-// Extract from the input tensor.
-// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
-// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32>
-// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
-
// CHECK: IR printer: transformed
// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// Im2col maps
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
+// CHECK-DAG: #[[MAPI2C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// Matmul maps
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-// CHECK: @conv_16433136
-// CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
-// CHECK-SAME: %[[FILTER:.+]]: tensor<3x3x4x16xf32>
-// CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32>
+
+// CHECK: @conv_16433136
+// CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
+// CHECK-SAME: %[[FILTER:.+]]: tensor<3x3x4x16xf32>
+// CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32>
// CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<3x3x4x16xf32> into tensor<36x16xf32>
// CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
-// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
-// CHECK: %[[COL_TENSOR:.+]] = linalg.generic
-// CHECK-SAME: #[[MAP0]]
-// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
-// CHECK: linalg.yield %{{.+}} : f32
-// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
+// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
+
+// CHECK: %[[COL_TENSOR:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAPI2C]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[INPUT]] : tensor<1x16x16x4xf32>)
+// CHECK-SAME: outs(%[[INIT_COL_TENSOR]] : tensor<1x196x36xf32>)
+// CHECK: ^bb0(%[[IN:.+]]: f32, %out: f32):
+// CHECK: linalg.yield %[[IN]] : f32
+// CHECK: } -> tensor<1x196x36xf32>
+
+// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
// CHECK-SAME: #[[MAP1]]
// CHECK-SAME: #[[MAP2]]
// CHECK-SAME: #[[MAP3]]
@@ -180,7 +175,10 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// Im2col maps
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
+// CHECK-DAG: #[[MAPI2C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
// CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
// CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
// CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
@@ -191,9 +189,13 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3]] : tensor<8x14x14x16xf32> into tensor<8x196x16xf32>
// CHECK: %[[IT:.+]] = tensor.empty() : tensor<8x196x36xf32>
// CHECK: %[[IMG2COL:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP]]]
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAPI2C]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[INPUT]] : tensor<8x16x16x4xf32>)
// CHECK-SAME: outs(%[[IT]] : tensor<8x196x36xf32>)
+// CHECK: ^bb0(%[[IN:.+]]: f32, %out: f32):
+// CHECK: linalg.yield %[[IN]] : f32
+// CHECK: } -> tensor<8x196x36xf32>
// CHECK: %[[MATMUL:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
@@ -224,13 +226,9 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-
// Im2col maps
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 floordiv 9)>
-// CHECK-DAG: #[[MAP7:.+]] = affine_map<()[s0, s1] -> (s0 floordiv 14 + (s1 mod 9) floordiv 3)>
-// CHECK-DAG: #[[MAP8:.+]] = affine_map<()[s0, s1] -> (s0 + s1 - (s0 floordiv 14) * 14 - (s1 floordiv 3) * 3)>
-
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 9, d2 floordiv 14 + (d1 mod 9) floordiv 3, d2 mod 14 + d1 mod 3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
// CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
@@ -242,32 +240,12 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[CS_RESULT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x14x14xf32> into tensor<8x16x196xf32>
// CHECK: %[[IT:.+]] = tensor.empty() : tensor<8x36x196xf32>
// CHECK: %[[IMG2COL:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP]]]
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[INPUT]] : tensor<8x4x16x16xf32>)
// CHECK-SAME: outs(%[[IT]] : tensor<8x36x196xf32>)
-// Collapsed indices.
-// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index
-// CHECK: %[[KINDEX:.+]] = linalg.index 1 : index
-// CHECK: %[[NINDEX:.+]] = linalg.index 2 : index
-
-// Compute input channel/convolved indices.
-// CHECK: %[[ICINDEX:.+]] = affine.apply #[[MAP1]]()[%[[KINDEX]]]
-// CHECK: %[[CONVH:.+]] = affine.apply #[[MAP7]]()[%[[NINDEX]], %[[KINDEX]]]
-// CHECK: %[[CONVW:.+]] = affine.apply #[[MAP8]]()[%[[NINDEX]], %[[KINDEX]]]
-
-// Extract from the input tensor.
-// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
-// CHECK-SAME: %[[INPUT]]{{\[}}%[[BINDEX]], %[[ICINDEX]], %[[CONVH]], %[[CONVW]]] : tensor<8x4x16x16xf32>
-// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
-// CHECK: %[[MATMUL:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[LHSMAP]], #[[RHSMAP]], #[[RESMAP]]],
-// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
-// CHECK-SAME: ins(%[[CS_FILTER]], %[[IMG2COL]] : tensor<16x36xf32>, tensor<8x36x196xf32>)
-// CHECK-SAME: outs(%[[CS_RESULT]] : tensor<8x16x196xf32>)
-// CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32):
-// CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
-// CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
-// CHECK: linalg.yield %[[ADD]] : f32
+// CHECK: ^bb0(%[[IN:.+]]: f32, %out: f32):
+// CHECK: linalg.yield %[[IN]] : f32
// CHECK: } -> tensor<8x16x196xf32>
// CHECK: %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1], [2, 3]] output_shape [8, 16, 14, 14] : tensor<8x16x196xf32> into tensor<8x16x14x14xf32>
// CHECK: return %[[CS_FINAL]]
@@ -291,31 +269,19 @@ module attributes {transform.with_named_sequence} {
// CHECK: IR printer: tensor_producer
// CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic
+// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
-// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
-
-// Collapsed indices.
-// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index
-// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index
-// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index
-
-// Compute input channel/convolved indices.
-// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 4)>()[%[[KINDEX]]]
-// CHECK: %[[CONVH:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 14 + s1 floordiv 12)>()[%[[MINDEX]], %[[KINDEX]]]
-// CHECK: %[[CONVW:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 mod 14 + (s1 mod 12) floordiv 4)>()[%[[MINDEX]], %[[KINDEX]]]
-
-// Extract from the input tensor.
-// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
-// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32>
-// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
+// CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32)
+// CHECK: linalg.yield %[[IN_DATA]] : f32
// CHECK: IR printer: transformed
// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK: @conv_2d_nhwc_fhwc
// CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
// CHECK-SAME: %[[FILTER:.+]]: tensor<16x3x3x4xf32>
@@ -324,13 +290,13 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
// CHECK: %[[COL_TENSOR:.+]] = linalg.generic
-// CHECK-SAME: #[[MAP0]]
+// CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
// CHECK: linalg.yield %{{.+}} : f32
// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
-// CHECK-SAME: #[[MAP1]]
// CHECK-SAME: #[[MAP2]]
// CHECK-SAME: #[[MAP3]]
+// CHECK-SAME: #[[MAP4]]
// CHECK-SAME: ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<16x36xf32>)
// CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xf32>)
// CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)
diff --git a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
index a6552e0..cc7a546 100644
--- a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
+++ b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
@@ -108,3 +108,69 @@ func.func @cmpf(%arg0: tensor<4x?x?x8x2x?xf32>, %arg1: tensor<4x?x?x8x2x?xf32>)
return %0 : tensor<4x?x?x8x2x?xi1>
}
+// -----
+
+// Check a mix of scalar and tensor input.
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @scalar_plus_tensor
+func.func @scalar_plus_tensor(%arg0: f32, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // CHECK: %[[GEN:.*]] = linalg.generic
+ // CHECK-SAME: iterator_types = ["parallel", "parallel"]
+ // CHECK-SAME: ins(%[[S:.*]], %[[T:.*]] : f32, tensor<?x?xf32>)
+ // CHECK-SAME: outs(%[[T]] : tensor<?x?xf32>)
+ // CHECK: ^bb0(%[[SB:.*]]: f32, %[[TB:.*]]: f32, %[[OB:.*]]: f32):
+ // CHECK: "test.elementwise_mappable"(%[[SB]], %[[TB]]) : (f32, f32) -> f32
+ // CHECK: linalg.yield {{.*}} : f32
+ // CHECK: } -> tensor<?x?xf32>
+ %0 = "test.elementwise_mappable"(%arg0, %arg1)
+ : (f32, tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// -----
+// This test exercises the case where an elementwise op has two scalar-like
+// operands and one ranked tensor operand. In this example, we chain two
+// `test.elementwise_mappable` calls:
+// %0 = f(%s1, %t)
+// %1 = f(%s2, %0)
+// CHECK-DAG: #[[$SC2:[A-Za-z0-9_]+]] = affine_map<(d0, d1) -> ()>
+// CHECK-DAG: #[[$ID2:[A-Za-z0-9_]+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @scalar_tensor_scalar
+func.func @scalar_tensor_scalar(%s1: f32, %t: tensor<?x?xf32>, %s2: f32) -> tensor<?x?xf32> {
+ // First generic.
+ // CHECK: %[[GEN0:.*]] = linalg.generic
+ // CHECK-SAME: indexing_maps = [#[[$SC2]], #[[$ID2]], #[[$ID2]]]
+ // CHECK-SAME: iterator_types = ["parallel", "parallel"]
+ // CHECK-SAME: ins(%[[S1:[^,]+]], %[[T0:[^)]*]] : f32, tensor<?x?xf32>)
+ // CHECK-SAME: outs(%[[T0]] : tensor<?x?xf32>)
+ // CHECK: ^bb0(%[[S1E:.*]]: f32, %[[T0E:.*]]: f32, %[[O0E:.*]]: f32):
+ // CHECK: %[[APPLY0:.*]] = "test.elementwise_mappable"(%[[S1E]], %[[T0E]]) : (f32, f32) -> f32
+ // CHECK: linalg.yield %[[APPLY0]] : f32
+ // CHECK: } -> tensor<?x?xf32>
+
+ // Second generic.
+ // CHECK: %[[GEN1:.*]] = linalg.generic
+ // CHECK-SAME: indexing_maps = [#[[$SC2]], #[[$ID2]], #[[$ID2]]]
+ // CHECK-SAME: iterator_types = ["parallel", "parallel"]
+ // CHECK-SAME: ins(%[[S2:[^,]+]], %[[GEN0]] : f32, tensor<?x?xf32>)
+ // CHECK-SAME: outs(%[[GEN0]] : tensor<?x?xf32>)
+ // CHECK: ^bb0(%[[S2E:.*]]: f32, %[[G0E:.*]]: f32, %[[O1E:.*]]: f32):
+ // CHECK: %[[APPLY1:.*]] = "test.elementwise_mappable"(%[[S2E]], %[[G0E]]) : (f32, f32) -> f32
+ // CHECK: linalg.yield %[[APPLY1]] : f32
+ // CHECK: } -> tensor<?x?xf32>
+ // CHECK: return %[[GEN1]] : tensor<?x?xf32>
+ %0 = "test.elementwise_mappable"(%s1, %t)
+ : (f32, tensor<?x?xf32>) -> tensor<?x?xf32>
+ %1 = "test.elementwise_mappable"(%s2, %0)
+ : (f32, tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+// ----
+// CHECK-LABEL: func @negative_scalar_only_eltwise
+// CHECK-NOT: linalg
+func.func @negative_scalar_only_eltwise(%a: f32, %b: f32) -> f32 {
+ %0 = arith.addf %a, %b : f32
+ return %0 : f32
+}
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index cc26fa4..0e42027 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1447,3 +1447,116 @@ func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %ar
// CHECK: %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]]
// CHECK-SAME: into %[[ARG1]]
// CHECK: return %[[UNPACK2]] : tensor<?x64xf32>
+
+// -----
+
+module {
+ func.func @push_extract_through_generic(%arg0: tensor<128x7x128xf32>, %arg1: tensor<?x5x3x128xf32>, %arg2: tensor<?x5x128xbf16>, %arg3: index) -> tensor<?x5x128xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[0, 0, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor<?x5x3x128xf32>) outs(%arg2 : tensor<?x5x128xbf16>) {
+ ^bb0(%in: f32, %in_0: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ linalg.yield %1 : bf16
+ } -> tensor<?x5x128xbf16>
+ return %0 : tensor<?x5x128xbf16>
+ }
+}
+
+// CHECK-LABEL: func.func @push_extract_through_generic
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]
+// CHECK: %[[POISON:.+]] = ub.poison : f32
+// CHECK: %[[PADDED:.+]] = tensor.pad %arg1
+// CHECK: tensor.yield %[[POISON]] : f32
+// CHECK: } : tensor<?x5x3x128xf32> to tensor<?x5x3x128xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<128x5x128xbf16>
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[PADDED]]
+// CHECK-SAME: outs(%[[EMPTY]]
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %3[%[[ARG3]], 0, 0] [%[[ARG3]], 5, 128] [1, 1, 1] : tensor<128x5x128xbf16> to tensor<?x5x128xbf16>
+// CHECK: return %[[EXTRACT]]
+
+// -----
+
+func.func @nopush_extract_through_generic_nodimexpr1(%arg0: tensor<128x7x128xf32>, %arg1: tensor<?x5x3x128xf32>, %arg2: tensor<?x5x128xbf16>, %arg3: index) -> tensor<?x5x128xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[0, %arg3, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor<?x5x3x128xf32>) outs(%arg2 : tensor<?x5x128xbf16>) {
+ ^bb0(%in: f32, %in_0: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ linalg.yield %1 : bf16
+ } -> tensor<?x5x128xbf16>
+ return %0 : tensor<?x5x128xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr1
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK: return %[[GENERIC]]
+
+// -----
+
+func.func @nopush_extract_through_generic_nodimexpr2(%arg0: tensor<128x?x128xf32>, %arg1: tensor<128x5x3x128xf32>, %arg2: tensor<128x?x128xbf16>, %arg3: index) -> tensor<128x?x128xbf16> {
+ %extracted_slice = tensor.extract_slice %arg1[0, %arg3, 0, 0] [128, %arg3, 3, 128] [1, 1, 1, 1] : tensor<128x5x3x128xf32> to tensor<128x?x3x128xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %extracted_slice : tensor<128x?x128xf32>, tensor<128x?x3x128xf32>) outs(%arg2 : tensor<128x?x128xbf16>) {
+ ^bb0(%in: f32, %in_0: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ linalg.yield %1 : bf16
+ } -> tensor<128x?x128xbf16>
+ return %0 : tensor<128x?x128xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr2
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK: return %[[GENERIC]]
+
+// -----
+
+func.func @push_redcutionextract_through_generic_withoutsused_2(%arg0: tensor<128x128xf32>, %arg1: tensor<?xbf16>, %arg2: index) -> tensor<?xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<?x?xf32>) outs(%arg1 : tensor<?xbf16>) {
+ ^bb0(%in: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ %2 = arith.addf %1, %out : bf16
+ linalg.yield %2 : bf16
+ } -> tensor<?xbf16>
+ return %0 : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @push_redcutionextract_through_generic_withoutsused_2
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK: %[[POISON_BF16:.+]] = ub.poison : bf16
+// CHECK: %[[POISON_F32:.+]] = ub.poison : f32
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], %[[ARG2]]] [%[[ARG2]], %[[ARG2]]] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+// CHECK: %[[PADDED:.+]] = tensor.pad %[[EXTRACT]]
+// CHECK: tensor.yield %[[POISON_F32]] : f32
+// CHECK: } : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[APPLY2:.+]] = affine.apply #map()[%[[ARG2]]]
+// CHECK: %[[PADDED1:.+]] = tensor.pad %[[ARG1]] low[%[[ARG2]]] high[%[[APPLY2]]]
+// CHECK: tensor.yield %[[POISON_BF16]] : bf16
+// CHECK: } : tensor<?xbf16> to tensor<?xbf16>
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[PADDED]]
+// CHECK-SAME: outs(%[[PADDED1]]
+// CHECK: %[[EXTRACT1:.+]] = tensor.extract_slice %[[GENERIC]][%[[ARG2]]] [%[[ARG2]]] [1] : tensor<?xbf16> to tensor<?xbf16>
+// CHECK: return %[[EXTRACT1]]
+
+
+// -----
+
+func.func @nopush_rankreducingextract(%arg0: tensor<128x128x128xf32>, %arg1: tensor<?xbf16>, %arg2: index) -> tensor<?xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[0, %arg2, %arg2] [1, %arg2, %arg2] [1, 1, 1] : tensor<128x128x128xf32> to tensor<?x?xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<?x?xf32>) outs(%arg1 : tensor<?xbf16>) {
+ ^bb0(%in: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ %2 = arith.addf %1, %out : bf16
+ linalg.yield %2 : bf16
+ } -> tensor<?xbf16>
+ return %0 : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_rankreducingextract
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK: return %[[GENERIC]]
diff --git a/mlir/test/Dialect/Linalg/decompose-unpack.mlir b/mlir/test/Dialect/Linalg/decompose-unpack.mlir
index e173d55..a53dde8 100644
--- a/mlir/test/Dialect/Linalg/decompose-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-unpack.mlir
@@ -203,3 +203,20 @@ func.func @unpack_with_non_trailing_dimensions_in_inner_dims(%arg0: tensor<1x1x1
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4xf32>) permutation = [1, 0]
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %transposed into %[[DEST]][0, 0, 0] [1, 1, 4] [1, 1, 1] : tensor<1x4xf32> into tensor<1x1x4xf32>
// CHECK: return %[[INSERT]]
+
+// -----
+
+/// Note "126", which is a non-unit tile-outer-dim. This is not supported.
+
+func.func @negative_non_unit_tiled_outer_dim(%src: tensor<1x126x1x1x8xf32>, %dest: tensor<1x1x1x1001xf32>) -> tensor<1x1x1x1001xf32> {
+ %unpack = linalg.unpack %src
+ outer_dims_perm = [0, 3, 2, 1]
+ inner_dims_pos = [3]
+ inner_tiles = [8]
+ into %dest : tensor<1x126x1x1x8xf32>
+ -> tensor<1x1x1x1001xf32>
+
+ return %unpack : tensor<1x1x1x1001xf32>
+}
+// CHECK-LABEL: @negative_non_unit_tiled_outer_dim(
+// CHECK: linalg.unpack
diff --git a/mlir/test/Dialect/Linalg/elementwise/named-to-elementwise.mlir b/mlir/test/Dialect/Linalg/elementwise/named-to-elementwise.mlir
new file mode 100644
index 0000000..2332b28
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/elementwise/named-to-elementwise.mlir
@@ -0,0 +1,56 @@
+// RUN: mlir-opt %s -linalg-morph-ops=named-to-category -split-input-file | FileCheck %s
+
+// CHECK: @exp(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME: ins(%[[A]] : tensor<16x8xf32>)
+// CHECK-SAME: outs(%[[B]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %exp : tensor<16x8xf32>
+}
+
+// ----
+
+// CHECK: @add(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_kind<add>
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>)
+// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @add(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C : tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %add = linalg.add ins(%A, %B : tensor<16x8xf32>, tensor<16x8xf32>) outs(%C : tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %add : tensor<16x8xf32>
+}
+
+// ----
+
+// CHECK: @sub(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_kind<sub>
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>)
+// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>)
+//
+func.func @sub(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C : tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %sub = linalg.sub ins(%A, %B : tensor<16x8xf32>, tensor<16x8xf32>) outs(%C : tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %sub : tensor<16x8xf32>
+}
+
+// ----
+
+// CHECK: @ternary_select(%[[A:.+]]: tensor<4x8x16xi1>, %[[B:.+]]: tensor<4x8x16xf32>, %[[C:.+]]: tensor<4x8x16xf32>)
+// CHECK: %[[E:.+]] = tensor.empty() : tensor<4x8x16xf32>
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_kind<select>
+// CHECK-SAME: ins(%[[A]], %[[B]], %[[C]] : tensor<4x8x16xi1>, tensor<4x8x16xf32>, tensor<4x8x16xf32>)
+// CHECK-SAME: outs(%[[E]] : tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+//
+func.func @ternary_select(%A: tensor<4x8x16xi1>, %B: tensor<4x8x16xf32>, %C: tensor<4x8x16xf32>)
+ -> tensor<4x8x16xf32> {
+ %empty = tensor.empty() : tensor<4x8x16xf32>
+ %select = linalg.select
+ ins(%A, %B, %C : tensor<4x8x16xi1>, tensor<4x8x16xf32>, tensor<4x8x16xf32>)
+ outs(%empty: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+ return %select : tensor<4x8x16xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir b/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
index d8e92e4..e90247d 100644
--- a/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
+++ b/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir
@@ -158,36 +158,6 @@ module attributes {transform.with_named_sequence} {
// -----
!type = tensor<2048x2048xf32>
-func.func @fold_add_on_transposed_matmuls(%arg0: !type, %arg1: !type) -> !type {
- %0 = arith.constant dense<1.111111e+00> : !type
- %cst = arith.constant 0.000000e+00 : f32
- %1 = tensor.empty() : !type
- %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
- %3 = linalg.matmul_transpose_a ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
- %4 = linalg.matmul_transpose_b ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
- %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
- return %5 : !type
-}
-
-// CHECK-LABEL: func.func @fold_add_on_transposed_matmuls
-// CHECK: %[[ACC:.+]] = linalg.matmul_transpose_a
-// CHECK-NEXT: %[[RES:.+]] = linalg.matmul_transpose_b ins({{.+}}) outs(%[[ACC]]
-// CHECK-NOT: linalg.add
-// CHECK-NEXT: return %[[RES]]
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns to %func {
- transform.apply_patterns.linalg.fold_add_into_dest
- } : !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-!type = tensor<2048x2048xf32>
func.func @expect_no_fold_of_add_as_dominated_op_is_not_a_contraction(%arg0: !type, %arg1: !type) -> !type {
%0 = arith.constant dense<1.111111e+00> : !type
%cst = arith.constant 0.000000e+00 : f32
diff --git a/mlir/test/Dialect/Linalg/linalg-morph-category-ops.mlir b/mlir/test/Dialect/Linalg/linalg-morph-category-ops.mlir
new file mode 100644
index 0000000..2469023
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/linalg-morph-category-ops.mlir
@@ -0,0 +1,15 @@
+// Forward path `named -> category -> generic`
+// RUN: mlir-opt %s -linalg-morph-ops=named-to-category | FileCheck %s --check-prefix=NAMED_TO_CATEGORY
+
+// RUN: mlir-opt %s -linalg-morph-ops=named-to-category | \
+// RUN: mlir-opt -linalg-morph-ops=category-to-generic | FileCheck %s --check-prefix=CATEGORY_TO_GENERIC
+
+func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %exp : tensor<16x8xf32>
+}
+// NAMED_TO_CATEGORY: linalg.elementwise
+// NAMED_TO_CATEGORY-NOT: linalg.exp
+
+// CATEGORY_TO_GENERIC: linalg.generic
+// CATEGORY_TO_GENERIC-NOT: linalg.elementwise
diff --git a/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir b/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
new file mode 100644
index 0000000..bdd29b9
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic | FileCheck %s --check-prefix=NAMED_TO_GENERIC
+// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic | mlir-opt -linalg-morph-ops=generic-to-named | \
+// RUN: FileCheck %s --check-prefix=ROUND_TRIP
+
+func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %exp : tensor<16x8xf32>
+}
+
+// NAMED_TO_GENERIC: linalg.generic
+// NAMED_TO_GENERIC-NOT: linalg.exp
+
+// ROUND_TRIP: linalg.exp
+// ROUND_TRIP-NOT: linalg.generic
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 412f40d..a93e979 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1222,17 +1222,6 @@ func.func @batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32
// -----
-// CHECK-LABEL: func @matmul_transpose_a
-// CHECK: linalg.matmul_transpose_a
-// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5x7xf32>)
-// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
-func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
- linalg.matmul_transpose_a ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
- return
-}
-
-// -----
-
// CHECK-LABEL: func @matmul_transpose_a_explicit
// CHECK: linalg.matmul
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5x7xf32>)
@@ -1478,17 +1467,6 @@ func.func @matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5xf3
// -----
-// CHECK-LABEL: func @matmul_transpose_b
-// CHECK: linalg.matmul_transpose_b
-// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<7x5xf32>)
-// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
-func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
- linalg.matmul_transpose_b ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>)
- return
-}
-
-// -----
-
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
@@ -1806,28 +1784,6 @@ func.func @bcast_A_transpose_B(%A: memref<3x5xf32>, %B: memref<2x7x5xf32>, %C: m
// -----
-// CHECK-LABEL: func @batchmatmul_transpose_a
-// CHECK: linalg.batch_matmul_transpose_a
-// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>)
-// CHECK-SAME: outs(%{{.+}} : memref<2x3x7xf32>)
-func.func @batchmatmul_transpose_a(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
- linalg.batch_matmul_transpose_a ins(%arg0, %arg1 : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>)
- return
-}
-
-// -----
-
-// CHECK-LABEL: func @batchmatmul_transpose_b
-// CHECK: linalg.batch_matmul_transpose_b
-// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x3x5xf32>, memref<2x7x5xf32>)
-// CHECK-SAME: outs(%{{.+}} : memref<2x3x7xf32>)
-func.func @batchmatmul_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) {
- linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<2x3x7xf32>)
- return
-}
-
-// -----
-
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
index 43bddb0..704576d 100644
--- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
+++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
@@ -92,38 +92,6 @@ func.func @singleton_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?x
// -----
-func.func @singleton_batchmatmul_transpose_a(%arg0: memref<1x5x3xf32>, %arg1: memref<1x5x7xf32>, %arg2: memref<1x3x7xf32>) {
- // CHECK-LABEL: @singleton_batchmatmul_transpose_a
- // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x5x3xf32>
- // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x5x7xf32>
- // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32>
- // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: linalg.matmul_transpose_a ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<5x3xf32>, memref<5x7xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>)
- // CHECK-NEXT: return
- linalg.batch_matmul_transpose_a ins(%arg0, %arg1 : memref<1x5x3xf32>, memref<1x5x7xf32>) outs(%arg2: memref<1x3x7xf32>)
- return
-}
-
-// -----
-
-func.func @singleton_batchmatmul_transpose_b(%arg0: memref<1x3x5xf32>, %arg1: memref<1x7x5xf32>, %arg2: memref<1x3x7xf32>) {
- // CHECK-LABEL: @singleton_batchmatmul_transpose_b
- // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<1x3x5xf32>
- // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<1x7x5xf32>
- // CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: memref<1x3x7xf32>
- // CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = memref.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = memref.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = memref.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]]
- // CHECK-NEXT: linalg.matmul_transpose_b ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[COLLAPSED_INIT]] : memref<3x7xf32>)
- // CHECK-NEXT: return
- linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : memref<1x3x5xf32>, memref<1x7x5xf32>) outs(%arg2: memref<1x3x7xf32>)
- return
-}
-
-// -----
-
func.func @matmul_to_matvec_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK-LABEL: @matmul_to_matvec_tensor
// CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<?x?xf32>
@@ -226,59 +194,6 @@ func.func @matvec_to_dot_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?xf32>, %a
// -----
-func.func @matmul_transpose_a_to_vecmat(%arg0: tensor<256x1xf32>, %arg1: tensor<256x512xf32>, %arg2: tensor<1x512xf32>) -> tensor<1x512xf32> {
- // CHECK-LABEL: @matmul_transpose_a_to_vecmat
- // CHECK: collapse_shape {{.*}} into tensor<256xf32>
- // CHECK: collapse_shape {{.*}} into tensor<512xf32>
- // CHECK: linalg.vecmat
- // CHECK: expand_shape {{.*}} into tensor<1x512xf32>
- %0 = linalg.matmul_transpose_a ins(%arg0, %arg1: tensor<256x1xf32>, tensor<256x512xf32>) outs(%arg2: tensor<1x512xf32>) -> tensor<1x512xf32>
- return %0 : tensor<1x512xf32>
-}
-
-// -----
-
-func.func @batch_matmul_transpose_a_to_batch_vecmat(%arg0: tensor<64x256x1xf32>, %arg1: tensor<64x256x512xf32>, %arg2: tensor<64x1x512xf32>) -> tensor<64x1x512xf32> {
- // CHECK-LABEL: @batch_matmul_transpose_a_to_batch_vecmat
- // CHECK: collapse_shape {{.*}} into tensor<64x256xf32>
- // CHECK: collapse_shape {{.*}} into tensor<64x512xf32>
- // CHECK: linalg.batch_vecmat
- // CHECK: expand_shape {{.*}} into tensor<64x1x512xf32>
- %0 = linalg.batch_matmul_transpose_a ins(%arg0, %arg1: tensor<64x256x1xf32>, tensor<64x256x512xf32>) outs(%arg2: tensor<64x1x512xf32>) -> tensor<64x1x512xf32>
- return %0 : tensor<64x1x512xf32>
-}
-
-// -----
-
-func.func @matmul_transpose_b_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<1x?xf32>, %arg2: memref<?x1xf32>) {
- // CHECK-LABEL: @matmul_transpose_b_to_matvec
- // CHECK: linalg.matvec
- linalg.matmul_transpose_b ins(%arg0, %arg1: memref<?x?xf32>, memref<1x?xf32>) outs(%arg2: memref<?x1xf32>)
- return
-}
-
-// -----
-
-func.func @batchmatmul_transpose_b_to_batchmatvec_tensor(%arg0: tensor<64x128x256xf32>, %arg1: tensor<64x1x256xf32>, %arg2: tensor<64x128x1xf32>) -> tensor<64x128x1xf32> {
- // CHECK: collapse_shape {{.*}} into tensor<64x256xf32>
- // CHECK: collapse_shape {{.*}} into tensor<64x128xf32>
- // CHECK: linalg.batch_matvec
- // CHECK: expand_shape {{.*}} into tensor<64x128x1xf32>
- %0 = linalg.batch_matmul_transpose_b ins(%arg0, %arg1: tensor<64x128x256xf32>, tensor<64x1x256xf32>) outs(%arg2: tensor<64x128x1xf32>) -> tensor<64x128x1xf32>
- return %0 : tensor<64x128x1xf32>
-}
-
-// -----
-
-func.func @batchmatmul_transpose_b_to_to_dot(%arg0: tensor<1x1x?xf32>, %arg1: tensor<1x1x?xf32>, %arg2: tensor<1x1x1xf32>) -> tensor<1x1x1xf32> {
- // CHECK-LABEL: @batchmatmul_transpose_b_to_to_dot
- // CHECK: linalg.dot
- %0 = linalg.batch_matmul_transpose_b ins(%arg0, %arg1: tensor<1x1x?xf32>, tensor<1x1x?xf32>) outs(%arg2: tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
- return %0 : tensor<1x1x1xf32>
-}
-
-// -----
-
func.func @nonsingleton_batch_matmul(%arg0 : tensor<2x?x?xf32>, %arg1 : tensor<2x?x?xf32>, %arg2: tensor<2x?x?xf32>) -> tensor<2x?x?xf32> {
// CHECK-LABEL: @nonsingleton_batch_matmul
// CHECK-NOT: collapse_shape
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 4edbc6e..563013d 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -436,6 +436,34 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
// CHECK-SAME: outs
// CHECK-SAME: dimensions = [1]
+
+// -----
+
+
+func.func @reduce_not_short_form_compatible(%input: tensor<16x32x64xf32>,
+ %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
+ %reduce = linalg.reduce
+ ins(%input:tensor<16x32x64xf32>)
+ outs(%init:tensor<16x64xf32>)
+ dimensions = [1]
+ (%in1: f32, %in2: f32) {
+ %0 = arith.addf %in1, %in2: f32
+ linalg.yield %in1: f32
+ }
+ func.return %reduce : tensor<16x64xf32>
+}
+
+// CHECK-LABEL: func @reduce_not_short_form_compatible
+// CHECK-SAME: %[[INPUT:.*]]: tensor<16x32x64xf32>
+// CHECK-SAME: %[[INIT:.*]]: tensor<16x64xf32>
+// CHECK-NOT: linalg.reduce { arith.addf } ins(%[[INPUT]] : tensor<16x32x64xf32>
+// CHECK: linalg.reduce ins(%[[INPUT]] : tensor<16x32x64xf32>) outs(%[[INIT]] : tensor<16x64xf32>)
+// CHECK-SAME: dimensions = [1]
+// CHECK: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
+// CHECK-NEXT: %[[ADD_RESULT:.*]] = arith.addf %[[IN1]], %[[IN2]] : f32
+// CHECK-NEXT: linalg.yield %[[IN1]] : f32
+// CHECK-NEXT: }
+
// -----
func.func @reduce_memref(%input: memref<16x32x64xf32>,
@@ -592,6 +620,27 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
// -----
+func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x32xf32>, %init: tensor<1x32xf32>) -> tensor<1x32xf32> {
+ %mapped = linalg.map ins(%lhs, %rhs : tensor<1x32xf32>, tensor<1x32xf32>) outs(%init : tensor<1x32xf32>)
+ (%in_1: f32, %in_2: f32) {
+ %1 = arith.maximumf %in_1, %in_2 : f32
+ linalg.yield %in_1 : f32
+ }
+ func.return %mapped : tensor<1x32xf32>
+}
+
+// CHECK-LABEL: func @map_not_short_form_compatible
+// CHECK-SAME: %[[LHS:.*]]: tensor<1x32xf32>, %[[RHS:.*]]: tensor<1x32xf32>, %[[INIT:.*]]: tensor<1x32xf32>
+// CHECK-NOT: linalg.map { arith.maximumf } ins(%[[LHS]] : tensor<1x32xf32>
+// CHECK: linalg.map ins(%[[LHS]], %[[RHS]] : tensor<1x32xf32>, tensor<1x32xf32>)
+// CHECK-SAME: outs(%[[INIT]] : tensor<1x32xf32>)
+// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
+// CHECK-NEXT: %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
+// CHECK-NEXT: linalg.yield %[[IN1]] : f32
+// CHECK-NEXT: }
+
+// -----
+
func.func @reduce_arith_with_attr(%input: tensor<16x32x64xf32>,
%init: tensor<16x64xf32>) -> tensor<16x64xf32> {
%reduce = linalg.reduce
diff --git a/mlir/test/Dialect/Linalg/namedop_conversion.mlir b/mlir/test/Dialect/Linalg/simplify-depthwise-conv.mlir
index 4f2f272..70e68e7 100644
--- a/mlir/test/Dialect/Linalg/namedop_conversion.mlir
+++ b/mlir/test/Dialect/Linalg/simplify-depthwise-conv.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-named-op-conversion -split-input-file | FileCheck %s
+// RUN: mlir-opt %s --simplify-depthwise-conv -split-input-file | FileCheck %s
// CHECK-LABEL: @depthwise_conv
func.func @depthwise_conv(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x1xf32>, %arg2: tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32> {
diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
index 778d5bb..1b0bade 100644
--- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
@@ -504,7 +504,7 @@ func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.matmul_transpose_b"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%c10 = transform.param.constant 10 : i64 -> !transform.param<i64>
%c20 = transform.param.constant 20 : i64 -> !transform.param<i64>
%sz = transform.merge_handles %c10, %c20 : !transform.param<i64>
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
index f741876..9a3dcf0 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
@@ -14,11 +14,11 @@ module attributes {transform.with_named_sequence} {
: (!transform.any_op) -> !transform.any_op
// Tile to 5 then pad to 8
- %fill_l1, %loops_l1 = transform.structured.tile_using_for %fill tile_sizes [5]
+ %fill_l1, %loops_l1 = transform.structured.tile_using_for %fill tile_sizes [5]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%fill_padded, %_ = transform.structured.pad_tiling_interface %fill_l1 to padding_sizes [8] {
- padding_values=[0.0 : f32, 0.0 : f32]
+ padding_values= [#ub.poison, 0.0 : f32]
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
@@ -33,9 +33,9 @@ func.func @pad_lhs(
-> tensor<24x25xf32>
{
// CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>)
- // CHECK: tensor.pad %{{.*}}
+ // CHECK: tensor.pad %{{.*}}
// CHECK: : tensor<?x12xf32> to tensor<8x12xf32>
- // CHECK: tensor.pad %{{.*}}
+ // CHECK: tensor.pad %{{.*}}
// CHECK: : tensor<?x25xf32> to tensor<8x25xf32>
// CHECK: linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<8x12xf32>, tensor<12x25xf32>) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
// CHECK: tensor.extract_slice %{{.*}}[0, 0] [%{{.*}}, 25] [1, 1]
@@ -92,7 +92,7 @@ module {
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [8, 0, 14] {
padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.yield
+ transform.yield
}
}
}
@@ -147,7 +147,7 @@ module {
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [8, 0, 14] {
padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.yield
+ transform.yield
}
}
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index f91eb9c..51bf4a2 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -465,14 +465,14 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[RHS:.*]] = tensor.pad
// CHECK: scf.for
// CHECK-DAG: tensor.extract_slice %[[LHS]][0, %{{.*}}] [%{{.*}}, 32]
-// CHECK-DAG: tensor.extract_slice %[[RHS]][0, %{{.*}}] [%{{.*}}, 32]
+// CHECK-DAG: tensor.extract_slice %[[RHS]][%{{.*}}, 0] [32, %{{.*}}]
func.func @dyn_pad_tiling(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.matmul_transpose_b"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%padded, %pad, %copy = transform.structured.pad %0 pad_to_multiple_of [32] use_prescribed_tensor_shapes {padding_dimensions = [2], padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
%tiled_linalg_op, %loops = transform.structured.tile_using_for %padded tile_sizes [0, 0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%1 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir
index f64953b..bd4c655 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir
@@ -30,66 +30,6 @@ module attributes {transform.with_named_sequence} {
// -----
-#map = affine_map<(d0, d1, d2) -> (d2, d0)>
-#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
- linalg.generic
- {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
- ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2 : memref<3x7xf32>) {
- ^bb0(%in: f32, %in_0: f32, %out: f32):
- %0 = arith.mulf %in, %in_0 : f32
- %1 = arith.addf %out, %0 : f32
- linalg.yield %1 : f32
- }
- return
-}
-
-// CHECK-LABEL: @matmul_transpose_a
-// CHECK-SAME: %[[ARG0:.+]]: memref<5x3xf32>, %[[ARG1:.+]]: memref<5x7xf32>, %[[ARG2:.+]]: memref<3x7xf32>) {
-// CHECK-NOT: linalg.generic
-// CHECK: linalg.matmul_transpose_a ins(%[[ARG0]], %[[ARG1]] : memref<5x3xf32>, memref<5x7xf32>) outs(%[[ARG2]] : memref<3x7xf32>)
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
- %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-#map = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @matmul_transpose_b(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = linalg.generic
- {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
- ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
- ^bb0(%in: f32, %in_0: f32, %out: f32):
- %1 = arith.mulf %in, %in_0 : f32
- %2 = arith.addf %out, %1 : f32
- linalg.yield %2 : f32
- } -> tensor<?x?xf32>
- return %0 : tensor<?x?xf32>
-}
-
-// CHECK-LABEL: @matmul_transpose_b
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK-NOT: linalg.generic
-// CHECK: linalg.matmul_transpose_b ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
- %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
@@ -117,32 +57,3 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
-
-// -----
-#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
-#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-func.func @batch_matmul_transpose_b(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
- %0 = linalg.generic
- {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
- ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%arg2 : tensor<?x?x?xf32>) {
- ^bb0(%in: f32, %in_0: f32, %out: f32):
- %1 = arith.mulf %in, %in_0 : f32
- %2 = arith.addf %out, %1 : f32
- linalg.yield %2 : f32
- } -> tensor<?x?x?xf32>
- return %0 : tensor<?x?x?xf32>
-}
-
-// CHECK-LABEL: @batch_matmul_transpose_b
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>, %[[ARG2:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-// CHECK-NOT: linalg.generic
-// CHECK: linalg.batch_matmul_transpose_b ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[ARG2]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
- %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
diff --git a/mlir/test/Dialect/Linalg/transpose-matmul.mlir b/mlir/test/Dialect/Linalg/transpose-matmul.mlir
index d2b7e9f..4ee87fb 100644
--- a/mlir/test/Dialect/Linalg/transpose-matmul.mlir
+++ b/mlir/test/Dialect/Linalg/transpose-matmul.mlir
@@ -1,6 +1,20 @@
// RUN: mlir-opt -transform-preload-library='transform-library-paths=%p/transpose-matmul-a.mlir' -transform-interpreter -split-input-file %s | FileCheck %s --check-prefixes=CHECK,TRANSPOSE-A
// RUN: mlir-opt -transform-preload-library='transform-library-paths=%p/transpose-matmul-b.mlir' -transform-interpreter -split-input-file %s | FileCheck %s --check-prefixes=CHECK,TRANSPOSE-B
+// TRANSPOSE-A-DAG: #[[$MA:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// TRANSPOSE-A-DAG: #[[$MB:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// TRANSPOSE-A-DAG: #[[$MC:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// TRANSPOSE-A-DAG: #[[$BMA:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+// TRANSPOSE-A-DAG: #[[$BMB:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// TRANSPOSE-A-DAG: #[[$BMC:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// TRANSPOSE-B-DAG: #[[$MA:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// TRANSPOSE-B-DAG: #[[$MB:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// TRANSPOSE-B-DAG: #[[$MC:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// TRANSPOSE-B-DAG: #[[$BMA:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// TRANSPOSE-B-DAG: #[[$BMB:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// TRANSPOSE-B-DAG: #[[$BMC:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
// CHECK-LABEL: func.func @matmul_static(
// CHECK-SAME: %[[A:.*]]: tensor<16x8xf32>,
// CHECK-SAME: %[[B:.*]]: tensor<8x16xf32>) -> tensor<16x16xf32> {
@@ -9,10 +23,10 @@
// CHECK: %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<16x16xf32>) -> tensor<16x16xf32>
// TRANSPOSE-A: %[[A_TRANSP_INIT:.*]] = tensor.empty() : tensor<8x16xf32>
// TRANSPOSE-A: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<16x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x16xf32>) permutation = [1, 0]
-// TRANSPOSE-A: %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x16xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// TRANSPOSE-A: %[[C:.*]] = linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]] ins(%[[A_TRANSP]], %[[B]] : tensor<8x16xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
// TRANSPOSE-B: %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<16x8xf32>
// TRANSPOSE-B: %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<16x8xf32>) permutation = [1, 0]
-// TRANSPOSE-B: %[[C:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<16x8xf32>, tensor<16x8xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// TRANSPOSE-B: %[[C:.*]] = linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]] ins(%[[A]], %[[B_TRANSP]] : tensor<16x8xf32>, tensor<16x8xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
// CHECK: return %[[C]] : tensor<16x16xf32>
// CHECK: }
func.func @matmul_static(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<16x16xf32>) {
@@ -38,11 +52,11 @@ func.func @matmul_static(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<
// TRANSPOSE-A: %[[A_DIM1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?xf32>
// TRANSPOSE-A: %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM1]], %[[A_DIM0]]) : tensor<?x?xf32>
// TRANSPOSE-A: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x?xf32>) outs(%[[A_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0]
-// TRANSPOSE-A: %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// TRANSPOSE-A: %[[C:.*]] = linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]] ins(%[[A_TRANSP]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// TRANSPOSE-B: %[[B_DIM0:.*]] = tensor.dim %[[B]], %[[C0]] : tensor<?x?xf32>
// TRANSPOSE-B: %[[B_TRANSP_INIT:.*]] = tensor.empty(%[[B_DIM1]], %[[B_DIM0]]) : tensor<?x?xf32>
// TRANSPOSE-B: %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<?x?xf32>) outs(%[[B_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0]
-// TRANSPOSE-B: %[[C:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// TRANSPOSE-B: %[[C:.*]] = linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]] ins(%[[A]], %[[B_TRANSP]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: return %[[C]] : tensor<?x?xf32>
// CHECK: }
func.func @matmul_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
@@ -69,10 +83,10 @@ func.func @matmul_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> (tensor<?
// CHECK: %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x16xf32>) -> tensor<?x16xf32>
// TRANSPOSE-A: %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<8x?xf32>
// TRANSPOSE-A: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x?xf32>) permutation = [1, 0]
-// TRANSPOSE-A: %[[B0:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x?xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
+// TRANSPOSE-A: %[[B0:.*]] = linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]] ins(%[[A_TRANSP]], %[[B]] : tensor<8x?xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
// TRANSPOSE-B: %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<16x8xf32>
// TRANSPOSE-B: %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<16x8xf32>) permutation = [1, 0]
-// TRANSPOSE-B: %[[B0:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<?x8xf32>, tensor<16x8xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
+// TRANSPOSE-B: %[[B0:.*]] = linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]] ins(%[[A]], %[[B_TRANSP]] : tensor<?x8xf32>, tensor<16x8xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
// CHECK: return %[[B0]] : tensor<?x16xf32>
// CHECK: }
func.func @matmul_mixed(%A: tensor<?x8xf32>, %B: tensor<8x16xf32>) -> (tensor<?x16xf32>) {
@@ -96,10 +110,10 @@ func.func @matmul_mixed(%A: tensor<?x8xf32>, %B: tensor<8x16xf32>) -> (tensor<?x
// CHECK: %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
// TRANSPOSE-A: %[[A_TRANSP_INIT:.*]] = tensor.empty() : tensor<2x8x16xf32>
// TRANSPOSE-A: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<2x16x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<2x8x16xf32>) permutation = [0, 2, 1]
-// TRANSPOSE-A: %[[C:.*]] = linalg.batch_matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<2x8x16xf32>, tensor<2x8x16xf32>) outs(%[[C_ZERO]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+// TRANSPOSE-A: %[[C:.*]] = linalg.batch_matmul indexing_maps = [#[[$BMA]], #[[$BMB]], #[[$BMC]]] ins(%[[A_TRANSP]], %[[B]] : tensor<2x8x16xf32>, tensor<2x8x16xf32>) outs(%[[C_ZERO]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
// TRANSPOSE-B: %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<2x16x8xf32>
// TRANSPOSE-B: %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<2x8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<2x16x8xf32>) permutation = [0, 2, 1]
-// TRANSPOSE-B: %[[C:.*]] = linalg.batch_matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<2x16x8xf32>, tensor<2x16x8xf32>) outs(%[[C_ZERO]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+// TRANSPOSE-B: %[[C:.*]] = linalg.batch_matmul indexing_maps = [#[[$BMA]], #[[$BMB]], #[[$BMC]]] ins(%[[A]], %[[B_TRANSP]] : tensor<2x16x8xf32>, tensor<2x16x8xf32>) outs(%[[C_ZERO]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
// CHECK: return %[[C]] : tensor<2x16x16xf32>
// CHECK: }
func.func @batch_matmul_static(%A: tensor<2x16x8xf32>, %B: tensor<2x8x16xf32>) -> (tensor<2x16x16xf32>) {
@@ -127,12 +141,12 @@ func.func @batch_matmul_static(%A: tensor<2x16x8xf32>, %B: tensor<2x8x16xf32>) -
// TRANSPOSE-A: %[[A_DIM2:.*]] = tensor.dim %[[A]], %[[C2]] : tensor<?x?x?xf32>
// TRANSPOSE-A: %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM0]], %[[A_DIM2]], %[[A_DIM1]]) : tensor<?x?x?xf32>
// TRANSPOSE-A: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x?x?xf32>) outs(%[[A_TRANSP_INIT]] : tensor<?x?x?xf32>) permutation = [0, 2, 1]
-// TRANSPOSE-A: %[[C:.*]] = linalg.batch_matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// TRANSPOSE-A: %[[C:.*]] = linalg.batch_matmul indexing_maps = [#[[$BMA]], #[[$BMB]], #[[$BMC]]] ins(%[[A_TRANSP]], %[[B]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// TRANSPOSE-B: %[[B_DIM0:.*]] = tensor.dim %[[B]], %[[C0]] : tensor<?x?x?xf32>
// TRANSPOSE-B: %[[B_DIM1:.*]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?x?xf32>
// TRANSPOSE-B: %[[B_TRANSP_INIT:.*]] = tensor.empty(%[[B_DIM0]], %[[B_DIM2]], %[[B_DIM1]]) : tensor<?x?x?xf32>
// TRANSPOSE-B: %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<?x?x?xf32>) outs(%[[B_TRANSP_INIT]] : tensor<?x?x?xf32>) permutation = [0, 2, 1]
-// TRANSPOSE-B: %[[C:.*]] = linalg.batch_matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// TRANSPOSE-B: %[[C:.*]] = linalg.batch_matmul indexing_maps = [#[[$BMA]], #[[$BMB]], #[[$BMC]]] ins(%[[A]], %[[B_TRANSP]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: return %[[C]] : tensor<?x?x?xf32>
// CHECK: }
func.func @batch_matmul_dynamic(%A: tensor<?x?x?xf32>, %B: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>) {
@@ -161,10 +175,10 @@ func.func @batch_matmul_dynamic(%A: tensor<?x?x?xf32>, %B: tensor<?x?x?xf32>) ->
// CHECK: %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
// TRANSPOSE-A: %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM1]]) : tensor<2x8x?xf32>
// TRANSPOSE-A: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<2x?x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<2x8x?xf32>) permutation = [0, 2, 1]
-// TRANSPOSE-A: %[[B0:.*]] = linalg.batch_matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<2x8x?xf32>, tensor<2x8x16xf32>) outs(%[[C_ZERO]] : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
+// TRANSPOSE-A: %[[B0:.*]] = linalg.batch_matmul indexing_maps = [#[[$BMA]], #[[$BMB]], #[[$BMC]]] ins(%[[A_TRANSP]], %[[B]] : tensor<2x8x?xf32>, tensor<2x8x16xf32>) outs(%[[C_ZERO]] : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
// TRANSPOSE-B: %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<2x16x8xf32>
// TRANSPOSE-B: %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<2x8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<2x16x8xf32>) permutation = [0, 2, 1]
-// TRANSPOSE-B: %[[B0:.*]] = linalg.batch_matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<2x?x8xf32>, tensor<2x16x8xf32>) outs(%[[C_ZERO]] : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
+// TRANSPOSE-B: %[[B0:.*]] = linalg.batch_matmul indexing_maps = [#[[$BMA]], #[[$BMB]], #[[$BMC]]] ins(%[[A]], %[[B_TRANSP]] : tensor<2x?x8xf32>, tensor<2x16x8xf32>) outs(%[[C_ZERO]] : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
// CHECK: return %[[B0]] : tensor<2x?x16xf32>
// CHECK: }
func.func @batch_matmul_mixed(%A: tensor<2x?x8xf32>, %B: tensor<2x8x16xf32>) -> (tensor<2x?x16xf32>) {
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
index 4eeae4c..25cbceb 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
@@ -61,6 +61,83 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: @float_mixed_precision_matmul
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT: arith.extf
+// CHECK: vector.contract {{.*}} : vector<1584x1584xbf16>, vector<1584x1584xbf16> into vector<1584x1584xf32>
+func.func @float_mixed_precision_matmul(%A: memref<1584x1584xbf16>, %B: memref<1584x1584xbf16>, %C: memref<1584x1584xf32>) {
+ linalg.matmul ins(%A, %B: memref<1584x1584xbf16>, memref<1584x1584xbf16>)
+ outs(%C: memref<1584x1584xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_type_extensions_into_contract } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @vectorization_test_2
+func.func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
+ %C: memref<8x32xf32>) {
+ // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
+ // CHECK: vector.multi_reduction <add>, %{{.*}}, {{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32>
+ linalg.matmul
+ ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>)
+ outs(%C: memref<8x32xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { disable_multi_reduction_to_contract_patterns } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @matmul_tensors
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>,
+// CHECK-SAME: %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32>
+func.func @matmul_tensors(
+ %arg0: tensor<8x4xf32>, %arg1: tensor<4x12xf32>, %arg2: tensor<8x12xf32>)
+ -> tensor<8x12xf32> {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x12x4xf32>
+ // CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<8x12x4xf32>
+ // CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32>
+ //
+ // linalg matmul lowers gets expanded to a 3D reduction, canonicalization later
+ // convert it to a 2D contract.
+ // CHECK: %[[MUL:.*]] = arith.mulf %[[V0]], %[[V1]] : vector<8x12x4xf32>
+ // CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[V2]] [2] : vector<8x12x4xf32> to vector<8x12xf32>
+ // CHECK: %[[W:.*]] = vector.transfer_write %[[R]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32>
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>)
+ outs(%arg2: tensor<8x12xf32>)
+ -> tensor<8x12xf32>
+ // CHECK: return %[[W]] : tensor<8x12xf32>
+ return %0 : tensor<8x12xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: contraction_batch_matmul
func.func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584x1584xf32>
@@ -115,6 +192,265 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: @float_mixed_precision_matmul_as_contract
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT: arith.extf
+// CHECK: vector.contract {{.*}} : vector<24x12xbf16>, vector<12x25xbf16> into vector<24x25xf32>
+// CHECK: vector.transfer_write
+func.func @float_mixed_precision_matmul_as_contract(%A: tensor<24x12xbf16>,
+ %B: tensor<12x25xbf16>,
+ %C: tensor<24x25xf32>) -> tensor<24x25xf32> {
+ %0 = linalg.contract
+ indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>]
+ ins(%A, %B : tensor<24x12xbf16>, tensor<12x25xbf16>)
+ outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32>
+ func.return %0 : tensor<24x25xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_type_extensions_into_contract } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_fill
+func.func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
+ // CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
+ // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
+ linalg.fill ins(%arg0 : f32) outs(%A : memref<8x16xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_fill
+func.func @test_vectorize_fill_0d(%A : memref<f32>, %arg0 : f32) {
+ // CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[val:.*]]: f32)
+ // CHECK: %[[VEC:.*]] = vector.broadcast %[[val]] : f32 to vector<f32>
+ // CHECK: vector.transfer_write %[[VEC]], %[[M]][] : vector<f32>, memref<f32>
+ linalg.fill ins(%arg0 : f32) outs(%A : memref<f32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_copy
+func.func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
+ // CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32>
+ // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
+ memref.copy %A, %B : memref<8x16xf32> to memref<8x16xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_copy_0d
+func.func @test_vectorize_copy_0d(%A : memref<f32>, %B : memref<f32>) {
+ // CHECK-SAME: (%[[A:.*]]: memref<f32>, %[[B:.*]]: memref<f32>)
+ // CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<f32>
+ // CHECK: %[[val:.*]] = vector.extract %[[V]][] : f32 from vector<f32>
+ // CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<f32>
+ // CHECK: vector.transfer_write %[[VV]], %[[B]][] : vector<f32>, memref<f32>
+ memref.copy %A, %B : memref<f32> to memref<f32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_copy_complex
+// CHECK-NOT: vector<
+func.func @test_vectorize_copy_complex(%A : memref<8x16xcomplex<f32>>, %B : memref<8x16xcomplex<f32>>) {
+ memref.copy %A, %B : memref<8x16xcomplex<f32>> to memref<8x16xcomplex<f32>>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Input identical as the test in vectorization.mlir. Output is different -
+// vector sizes are inferred (rather than user-specified) and hence _no_
+// masking was used.
+
+func.func @test_vectorize_pack(%arg0: tensor<32x8x16xf32>, %arg1: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> {
+ %pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x8x16xf32> -> tensor<4x1x32x16x2xf32>
+ return %pack : tensor<4x1x32x16x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @test_vectorize_pack(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x8x16xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> {
+// CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_2]] {in_bounds = [true, true, true]} : tensor<32x8x16xf32>, vector<32x8x16xf32>
+// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
+// CHECK: %[[VAL_6:.*]] = vector.transpose %[[VAL_5]], [1, 3, 0, 4, 2] : vector<32x4x2x1x16xf32> to vector<4x1x32x16x2xf32>
+// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<4x1x32x16x2xf32>
+// CHECK: %[[VAL_8:.*]] = vector.transfer_write %[[VAL_6]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true, true, true, true]} : vector<4x1x32x16x2xf32>, tensor<4x1x32x16x2xf32>
+// CHECK: return %[[VAL_8]] : tensor<4x1x32x16x2xf32>
+
+// -----
+
+func.func @test_vectorize_padded_pack(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
+ %pad = arith.constant 0.000000e+00 : f32
+ %pack = linalg.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32>
+ return %pack : tensor<32x4x1x16x2xf32>
+}
+
+// CHECK-LABEL: func.func @test_vectorize_padded_pack(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x7x15xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
+// CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_2]] {in_bounds = [true, false, false]} : tensor<32x7x15xf32>, vector<32x8x16xf32>
+// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
+// CHECK: %[[VAL_6:.*]] = vector.transpose %[[VAL_5]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
+// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<32x4x1x16x2xf32>
+// CHECK: %[[VAL_8:.*]] = vector.transfer_write %[[VAL_6]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
+// CHECK: return %[[VAL_8]] : tensor<32x4x1x16x2xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @vectorize_map(%arg0: memref<64xf32>,
+ %arg1: memref<64xf32>, %arg2: memref<64xf32>) {
+ linalg.map ins(%arg0, %arg1 : memref<64xf32>, memref<64xf32>)
+ outs(%arg2 : memref<64xf32>)
+ (%in: f32, %in_0: f32) {
+ %0 = arith.addf %in, %in_0 : f32
+ linalg.yield %0 : f32
+ }
+ return
+}
+// CHECK-LABEL: func @vectorize_map
+// CHECK: %[[LHS:.*]] = vector.transfer_read
+// CHECK-NEXT: %[[RHS:.*]] = vector.transfer_read
+// CHECK-NEXT: arith.addf %[[LHS]], %[[RHS]] : vector<64xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.map"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @vectorize_transpose(%arg0: memref<16x32x64xf32>,
+ %arg1: memref<32x64x16xf32>) {
+ linalg.transpose ins(%arg0 : memref<16x32x64xf32>)
+ outs(%arg1 : memref<32x64x16xf32>) permutation = [1, 2, 0]
+ return
+}
+// CHECK-LABEL: func @vectorize_transpose
+// CHECK: vector.transpose
+// CHECK-SAME: [1, 2, 0] : vector<16x32x64xf32> to vector<32x64x16xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.transpose"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @vectorize_reduce(%arg0: memref<16x32x64xf32>,
+ %arg1: memref<16x64xf32>) {
+ linalg.reduce ins(%arg0 : memref<16x32x64xf32>)
+ outs(%arg1 : memref<16x64xf32>) dimensions = [1]
+ (%in: f32, %init: f32) {
+ %0 = arith.addf %in, %init : f32
+ linalg.yield %0 : f32
+ }
+ return
+}
+// CHECK-LABEL: func @vectorize_reduce
+// CHECK: vector.multi_reduction <add>
+// CHECK-SAME: : vector<16x32x64xf32> to vector<16x64xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
#matmul_trait = {
indexing_maps = [
affine_map<(m, n, k) -> (m, k)>,
@@ -306,27 +642,6 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: func @vectorization_test_2
-func.func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
- %C: memref<8x32xf32>) {
- // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
- // CHECK: vector.multi_reduction <add>, %{{.*}}, {{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32>
- linalg.matmul
- ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>)
- outs(%C: memref<8x32xf32>)
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 { disable_multi_reduction_to_contract_patterns } : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
// CHECK-LABEL: func @test_vectorize_scalar_input
func.func @test_vectorize_scalar_input(%A : memref<8x16xf32>, %arg0 : f32) {
@@ -427,104 +742,6 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: func @test_vectorize_fill
-func.func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
- // CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
- // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
- linalg.fill ins(%arg0 : f32) outs(%A : memref<8x16xf32>)
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: func @test_vectorize_fill
-func.func @test_vectorize_fill_0d(%A : memref<f32>, %arg0 : f32) {
- // CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[val:.*]]: f32)
- // CHECK: %[[VEC:.*]] = vector.broadcast %[[val]] : f32 to vector<f32>
- // CHECK: vector.transfer_write %[[VEC]], %[[M]][] : vector<f32>, memref<f32>
- linalg.fill ins(%arg0 : f32) outs(%A : memref<f32>)
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: func @test_vectorize_copy
-func.func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
- // CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32>
- // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
- memref.copy %A, %B : memref<8x16xf32> to memref<8x16xf32>
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: func @test_vectorize_copy_0d
-func.func @test_vectorize_copy_0d(%A : memref<f32>, %B : memref<f32>) {
- // CHECK-SAME: (%[[A:.*]]: memref<f32>, %[[B:.*]]: memref<f32>)
- // CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<f32>
- // CHECK: %[[val:.*]] = vector.extract %[[V]][] : f32 from vector<f32>
- // CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<f32>
- // CHECK: vector.transfer_write %[[VV]], %[[B]][] : vector<f32>, memref<f32>
- memref.copy %A, %B : memref<f32> to memref<f32>
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: func @test_vectorize_copy_complex
-// CHECK-NOT: vector<
-func.func @test_vectorize_copy_complex(%A : memref<8x16xcomplex<f32>>, %B : memref<8x16xcomplex<f32>>) {
- memref.copy %A, %B : memref<8x16xcomplex<f32>> to memref<8x16xcomplex<f32>>
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
// CHECK-LABEL: func @test_vectorize_trailing_index
// CHECK-SAME: (%[[ARG0:.*]]: memref<1x2x4x8xindex>)
func.func @test_vectorize_trailing_index(%arg0: memref<1x2x4x8xindex>) {
@@ -855,40 +1072,6 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: func @matmul_tensors
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>,
-// CHECK-SAME: %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32>
-func.func @matmul_tensors(
- %arg0: tensor<8x4xf32>, %arg1: tensor<4x12xf32>, %arg2: tensor<8x12xf32>)
- -> tensor<8x12xf32> {
- // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x12x4xf32>
- // CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<8x12x4xf32>
- // CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32>
- //
- // linalg matmul lowers gets expanded to a 3D reduction, canonicalization later
- // convert it to a 2D contract.
- // CHECK: %[[MUL:.*]] = arith.mulf %[[V0]], %[[V1]] : vector<8x12x4xf32>
- // CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[V2]] [2] : vector<8x12x4xf32> to vector<8x12xf32>
- // CHECK: %[[W:.*]] = vector.transfer_write %[[R]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32>
- %0 = linalg.matmul ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>)
- outs(%arg2: tensor<8x12xf32>)
- -> tensor<8x12xf32>
- // CHECK: return %[[W]] : tensor<8x12xf32>
- return %0 : tensor<8x12xf32>
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
// CHECK-LABEL: func @sum_exp
func.func @sum_exp(%input: tensor<4x16x8xf32>, %output: tensor<4x16xf32>)
-> tensor<4x16xf32>
@@ -914,7 +1097,6 @@ func.func @sum_exp(%input: tensor<4x16x8xf32>, %output: tensor<4x16xf32>)
return %0 : tensor<4x16xf32>
}
-
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
@@ -993,7 +1175,6 @@ func.func @red_maximumf_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
return %red : tensor<4xf32>
}
-
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
@@ -1428,78 +1609,6 @@ module attributes {transform.with_named_sequence} {
// -----
-func.func @vectorize_map(%arg0: memref<64xf32>,
- %arg1: memref<64xf32>, %arg2: memref<64xf32>) {
- linalg.map ins(%arg0, %arg1 : memref<64xf32>, memref<64xf32>)
- outs(%arg2 : memref<64xf32>)
- (%in: f32, %in_0: f32) {
- %0 = arith.addf %in, %in_0 : f32
- linalg.yield %0 : f32
- }
- return
-}
-// CHECK-LABEL: func @vectorize_map
-// CHECK: %[[LHS:.*]] = vector.transfer_read
-// CHECK-NEXT: %[[RHS:.*]] = vector.transfer_read
-// CHECK-NEXT: arith.addf %[[LHS]], %[[RHS]] : vector<64xf32>
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.map"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-func.func @vectorize_transpose(%arg0: memref<16x32x64xf32>,
- %arg1: memref<32x64x16xf32>) {
- linalg.transpose ins(%arg0 : memref<16x32x64xf32>)
- outs(%arg1 : memref<32x64x16xf32>) permutation = [1, 2, 0]
- return
-}
-// CHECK-LABEL: func @vectorize_transpose
-// CHECK: vector.transpose
-// CHECK-SAME: [1, 2, 0] : vector<16x32x64xf32> to vector<32x64x16xf32>
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.transpose"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-func.func @vectorize_reduce(%arg0: memref<16x32x64xf32>,
- %arg1: memref<16x64xf32>) {
- linalg.reduce ins(%arg0 : memref<16x32x64xf32>)
- outs(%arg1 : memref<16x64xf32>) dimensions = [1]
- (%in: f32, %init: f32) {
- %0 = arith.addf %in, %init : f32
- linalg.yield %0 : f32
- }
- return
-}
-// CHECK-LABEL: func @vectorize_reduce
-// CHECK: vector.multi_reduction <add>
-// CHECK-SAME: : vector<16x32x64xf32> to vector<16x64xf32>
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
// This is a regression test. This IR cannot be vectorized, but
// structured.vectorize_children_and_apply_patterns should nevertheless succeed.
@@ -1715,65 +1824,77 @@ module attributes {transform.with_named_sequence} {
// -----
-// Input identical as the test in vectorization.mlir. Output is different -
-// vector sizes are inferred (rather than user-specified) and hence _no_
-// masking was used.
-
-func.func @test_vectorize_pack(%arg0: tensor<32x8x16xf32>, %arg1: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> {
- %pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x8x16xf32> -> tensor<4x1x32x16x2xf32>
- return %pack : tensor<4x1x32x16x2xf32>
+// CHECK-LABEL: func @float_mixed_precision_matmul_as_generic
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT: arith.extf
+// CHECK: vector.contract {{.*}} : vector<8x16xbf16>, vector<16x32xbf16> into vector<8x32xf32>
+// CHECK: vector.transfer_write
+func.func @float_mixed_precision_matmul_as_generic(%A: memref<8x16xbf16>, %B: memref<16x32xbf16>,
+ %C: memref<8x32xf32>) {
+ linalg.generic {
+ indexing_maps = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"]
+ }
+ ins(%A, %B : memref<8x16xbf16>, memref<16x32xbf16>)
+ outs(%C : memref<8x32xf32>) {
+ ^bb(%in: bf16, %in_0: bf16, %c: f32) :
+ %a = arith.extf %in : bf16 to f32
+ %b = arith.extf %in_0 : bf16 to f32
+ %d = arith.mulf %a, %b: f32
+ %e = arith.addf %c, %d: f32
+ linalg.yield %e : f32
+ }
+ return
}
module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_type_extensions_into_contract } : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
-// CHECK-LABEL: func.func @test_vectorize_pack(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x8x16xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> {
-// CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_2]] {in_bounds = [true, true, true]} : tensor<32x8x16xf32>, vector<32x8x16xf32>
-// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
-// CHECK: %[[VAL_6:.*]] = vector.transpose %[[VAL_5]], [1, 3, 0, 4, 2] : vector<32x4x2x1x16xf32> to vector<4x1x32x16x2xf32>
-// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<4x1x32x16x2xf32>
-// CHECK: %[[VAL_8:.*]] = vector.transfer_write %[[VAL_6]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true, true, true, true]} : vector<4x1x32x16x2xf32>, tensor<4x1x32x16x2xf32>
-// CHECK: return %[[VAL_8]] : tensor<4x1x32x16x2xf32>
-
// -----
-// Input identical as the test in vectorization.mlir. Output is different -
-// vector sizes are inferred (rather than user-specified) and hence _no_
-// masking was used.
-
-func.func @test_vectorize_padded_pack(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
- %pad = arith.constant 0.000000e+00 : f32
- %pack = linalg.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32>
- return %pack : tensor<32x4x1x16x2xf32>
+// CHECK-LABEL: func @integer_mixed_precision_matmul_as_generic
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT: arith.extsi
+// CHECK: vector.contract {{.*}} : vector<8x16xi8>, vector<16x32xi8> into vector<8x32xi32>
+// CHECK: vector.transfer_write
+func.func @integer_mixed_precision_matmul_as_generic(%A: memref<8x16xi8>, %B: memref<16x32xi8>,
+ %C: memref<8x32xi32>) {
+ linalg.generic {
+ indexing_maps = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"]
+ }
+ ins(%A, %B : memref<8x16xi8>, memref<16x32xi8>)
+ outs(%C : memref<8x32xi32>) {
+ ^bb(%in: i8, %in_0: i8, %c: i32) :
+ %a = arith.extsi %in : i8 to i32
+ %b = arith.extsi %in_0 : i8 to i32
+ %d = arith.muli %a, %b: i32
+ %e = arith.addi %c, %d: i32
+ linalg.yield %e : i32
+ }
+ return
}
-// CHECK-LABEL: func.func @test_vectorize_padded_pack(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x7x15xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
-// CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_2]] {in_bounds = [true, false, false]} : tensor<32x7x15xf32>, vector<32x8x16xf32>
-// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
-// CHECK: %[[VAL_6:.*]] = vector.transpose %[[VAL_5]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
-// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<32x4x1x16x2xf32>
-// CHECK: %[[VAL_8:.*]] = vector.transfer_write %[[VAL_6]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
-// CHECK: return %[[VAL_8]] : tensor<32x4x1x16x2xf32>
-
module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_type_extensions_into_contract } : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
+
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 095810f..01eb210 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -880,22 +880,22 @@ func.func @mmt4d_scalable(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>,
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>,
// CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>,
// CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
-// CHECK: %[[VAL_0:.*]] = arith.constant 16 : index
-// CHECK: %[[VAL_1:.*]] = arith.constant 16 : index
-// CHECK: %[[VAL_2:.*]] = arith.constant 16 : index
+// CHECK: %[[C16_M:.*]] = arith.constant 16 : index
+// CHECK: %[[C16_N:.*]] = arith.constant 16 : index
+// CHECK: %[[C16_K:.*]] = arith.constant 16 : index
// CHECK: %[[C8:.*]] = arith.constant 8 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[DIM_2:.*]] = memref.dim %[[B]], %[[C2]] : memref<16x16x?x1xf32>
-// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
-// CHECK: %[[MASK_1:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x[4]x1xi1>
+// CHECK: %[[MASK_1:.*]] = vector.create_mask %[[C16_N]], %[[C16_K]], %[[DIM_2]], %[[C1]] : vector<16x16x[4]x1xi1>
// CHECK: %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> } : vector<16x16x[4]x1xi1> -> vector<16x16x16x8x[4]x1xf32>
-// CHECK: %[[MASK_2:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[C8]], %[[DIM_2]] : vector<16x16x8x[4]xi1>
-// CHECK: %[[VAL_15:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> } : vector<16x16x8x[4]xi1> -> vector<16x16x8x[4]xf32>
-// CHECK: %[[VAL_16:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
-// CHECK: %[[MASK_3:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[C8]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x16x8x[4]x1xi1>
-// CHECK: %[[VAL_18:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[VAL_16]], %[[VAL_15]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> } : vector<16x16x16x8x[4]x1xi1> -> vector<16x16x8x[4]xf32>
-// CHECK: vector.mask %[[MASK_2]] { vector.transfer_write %[[VAL_18]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> } : vector<16x16x8x[4]xi1>
+// CHECK: %[[MASK_2:.*]] = vector.create_mask %[[C16_M]], %[[C16_N]], %[[C8]], %[[DIM_2]] : vector<16x16x8x[4]xi1>
+// CHECK: %[[VEC_C:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> } : vector<16x16x8x[4]xi1> -> vector<16x16x8x[4]xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
+// CHECK: %[[MASK_3:.*]] = vector.create_mask %[[C16_M]], %[[C16_N]], %[[C16_K]], %[[C8]], %[[DIM_2]], %[[C1]] : vector<16x16x16x8x[4]x1xi1>
+// CHECK: %[[RED:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> } : vector<16x16x16x8x[4]x1xi1> -> vector<16x16x8x[4]xf32>
+// CHECK: vector.mask %[[MASK_2]] { vector.transfer_write %[[RED]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> } : vector<16x16x8x[4]xi1>
module attributes {transform.with_named_sequence} {
@@ -920,10 +920,10 @@ func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x1
// CHECK-NOT: mask
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32>
-// CHECK: %[[VAL_13:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
-// CHECK: %[[VAL_14:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
-// CHECK: %[[VAL_15:.*]] = vector.multi_reduction <add>, %[[VAL_14]], %[[VAL_13]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
-// CHECK: vector.transfer_write %[[VAL_15]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>
+// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
+// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
+// CHECK: vector.transfer_write %[[RED]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -936,6 +936,100 @@ module attributes {transform.with_named_sequence} {
// -----
///----------------------------------------------------------------------------------------
+/// Tests for linalg.batch_mmt4d
+///----------------------------------------------------------------------------------------
+
+func.func @batch_mmt4d(%A: memref<2x16x16x8x1xf32>, %B: memref<2x16x16x8x1xf32>, %C_in: memref<2x16x16x8x8xf32>) {
+ linalg.batch_mmt4d ins(%A, %B: memref<2x16x16x8x1xf32>, memref<2x16x16x8x1xf32>)
+ outs(%C_in: memref<2x16x16x8x8xf32>)
+ return
+}
+
+// CHECK-LABEL: func.func @batch_mmt4d(
+// CHECK-SAME: %[[A:.*]]: memref<2x16x16x8x1xf32>, %[[B:.*]]: memref<2x16x16x8x1xf32>, %[[C:.*]]: memref<2x16x16x8x8xf32>) {
+// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x8x1xf32>
+// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x8x1xf32>
+// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<2x16x16x8x8xf32>, vector<2x16x16x8x8xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x8x1xf32>
+// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [3, 6] : vector<2x16x16x16x8x8x1xf32> to vector<2x16x16x8x8xf32>
+// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<2x16x16x8x8xf32>, memref<2x16x16x8x8xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %batch_mmt4d = transform.structured.match ops{["linalg.batch_mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %batch_mmt4d : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @batch_mmt4d_scalable(%A: memref<2x16x16x8x1xf32>, %B: memref<2x16x16x?x1xf32>, %C_in: memref<2x16x16x8x?xf32>) {
+ linalg.batch_mmt4d ins(%A, %B: memref<2x16x16x8x1xf32>, memref<2x16x16x?x1xf32>)
+ outs(%C_in: memref<2x16x16x8x?xf32>)
+ return
+}
+// CHECK-LABEL: func.func @batch_mmt4d_scalable(
+// CHECK-SAME: %[[A:.*]]: memref<2x16x16x8x1xf32>,
+// CHECK-SAME: %[[B:.*]]: memref<2x16x16x?x1xf32>,
+// CHECK-SAME: %[[C_IN:.*]]: memref<2x16x16x8x?xf32>) {
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C16_M:.*]] = arith.constant 16 : index
+// CHECK: %[[C16_N:.*]] = arith.constant 16 : index
+// CHECK: %[[C16_K:.*]] = arith.constant 16 : index
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[C3:.*]] = arith.constant 3 : index
+// CHECK: %[[DIM_N_IN:.*]] = memref.dim %[[B]], %[[C3]] : memref<2x16x16x?x1xf32>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
+// CHECK: %[[MASK_1:.*]] = vector.create_mask %[[C2]], %[[C16_N]], %[[C16_K]], %[[DIM_N_IN]], %[[C1]] : vector<2x16x16x[4]x1xi1>
+// CHECK: %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32> } : vector<2x16x16x[4]x1xi1> -> vector<2x16x16x16x8x[4]x1xf32>
+// CHECK: %[[MASK_2:.*]] = vector.create_mask %[[C2]], %[[C16_M]], %[[C16_N]], %[[C8]], %[[DIM_N_IN]] : vector<2x16x16x8x[4]xi1>
+// CHECK: %[[VEC_C:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32> } : vector<2x16x16x8x[4]xi1> -> vector<2x16x16x8x[4]xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x[4]x1xf32>
+// CHECK: %[[MASK_3:.*]] = vector.create_mask %[[C2]], %[[C16_M]], %[[C16_N]], %[[C16_K]], %[[C8]], %[[DIM_N_IN]], %[[C1]] : vector<2x16x16x16x8x[4]x1xi1>
+// CHECK: %[[RED:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [3, 6] : vector<2x16x16x16x8x[4]x1xf32> to vector<2x16x16x8x[4]xf32> } : vector<2x16x16x16x8x[4]x1xi1> -> vector<2x16x16x8x[4]xf32>
+// CHECK: vector.mask %[[MASK_2]] { vector.transfer_write %[[RED]], %[[C_IN]]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32> } : vector<2x16x16x8x[4]xi1>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %batch_mmt4d = transform.structured.match ops{["linalg.batch_mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %batch_mmt4d vector_sizes [2, 16, 16, 16, 8, [4], 1] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @batch_mmt4d_scalable_with_assume(%A: memref<2x16x16x8x1xf32>, %B: memref<2x16x16x?x1xf32>, %C_in: memref<2x16x16x8x?xf32>) {
+ linalg.batch_mmt4d ins(%A, %B: memref<2x16x16x8x1xf32>, memref<2x16x16x?x1xf32>)
+ outs(%C_in: memref<2x16x16x8x?xf32>)
+ return
+}
+// CHECK-LABEL: func.func @batch_mmt4d_scalable_with_assume(
+// CHECK-SAME: %[[A:.*]]: memref<2x16x16x8x1xf32>,
+// CHECK-SAME: %[[B:.*]]: memref<2x16x16x?x1xf32>,
+// CHECK-SAME: %[[C_IN:.*]]: memref<2x16x16x8x?xf32>) {
+// CHECK-NOT: mask
+// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
+// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
+// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x[4]x1xf32>
+// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [3, 6] : vector<2x16x16x16x8x[4]x1xf32> to vector<2x16x16x8x[4]xf32>
+// CHECK: vector.transfer_write %[[RED]], %[[C_IN]]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %batch_mmt4d = transform.structured.match ops{["linalg.batch_mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %batch_mmt4d vector_sizes [2, 16, 16, 16, 8, [4], 1] {assume_dynamic_dims_match_vec_sizes} : !transform.any_op
+ transform.yield
+ }
+}
+
+
+// -----
+
+///----------------------------------------------------------------------------------------
/// Tests for linalg.unpack
///----------------------------------------------------------------------------------------
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 1420aca..615c607 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -1,7 +1,9 @@
-// RUN: mlir-opt %s --split-input-file -test-expand-math | FileCheck %s
+// RUN: mlir-opt %s --split-input-file -math-expand-ops | FileCheck %s
+// RUN: mlir-opt %s --split-input-file -math-expand-ops=ops=tanh,tan | FileCheck %s --check-prefix=CHECK-FILTER
// CHECK-LABEL: func @tanh
func.func @tanh(%arg: f32) -> f32 {
+ // CHECK-FILTER-NOT: math.tanh
%res = math.tanh %arg : f32
return %res : f32
}
@@ -27,6 +29,7 @@ func.func @tanh(%arg: f32) -> f32 {
// CHECK-LABEL: func @vector_tanh
func.func @vector_tanh(%arg: vector<4xf32>) -> vector<4xf32> {
// CHECK-NOT: math.tanh
+ // CHECK-FILTER-NOT: math.tanh
%res = math.tanh %arg : vector<4xf32>
return %res : vector<4xf32>
}
@@ -35,6 +38,7 @@ func.func @vector_tanh(%arg: vector<4xf32>) -> vector<4xf32> {
// CHECK-LABEL: func @tan
func.func @tan(%arg: f32) -> f32 {
+ // CHECK-FILTER-NOT: math.tan
%res = math.tan %arg : f32
return %res : f32
}
@@ -49,6 +53,7 @@ func.func @tan(%arg: f32) -> f32 {
// CHECK-LABEL: func @vector_tan
func.func @vector_tan(%arg: vector<4xf32>) -> vector<4xf32> {
+ // CHECK-FILTER-NOT: math.tan
%res = math.tan %arg : vector<4xf32>
return %res : vector<4xf32>
}
@@ -58,6 +63,7 @@ func.func @vector_tan(%arg: vector<4xf32>) -> vector<4xf32> {
// -----
func.func @ctlz(%arg: i32) -> i32 {
+ // CHECK-FILTER: math.ctlz
%res = math.ctlz %arg : i32
return %res : i32
}
@@ -112,6 +118,7 @@ func.func @ctlz(%arg: i32) -> i32 {
// -----
func.func @ctlz_vector(%arg: vector<4xi32>) -> vector<4xi32> {
+ // CHECK-FILTER: math.ctlz
%res = math.ctlz %arg : vector<4xi32>
return %res : vector<4xi32>
}
@@ -145,6 +152,7 @@ func.func @ceilf_func(%a: f64) -> f64 {
// CHECK-NEXT: [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
// CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]]
// CHECK-NEXT: return [[ADDF]]
+ // CHECK-FILTER: math.ceil
%ret = math.ceil %a : f64
return %ret : f64
}
@@ -158,6 +166,7 @@ func.func @exp2f_func(%a: f64) -> f64 {
// CHECK: [[MULF:%.+]] = arith.mulf [[ARG0]], [[CST]]
// CHECK: [[EXP:%.+]] = math.exp [[MULF]]
// CHECK: return [[EXP]]
+ // CHECK-FILTER: math.exp2
%ret = math.exp2 %a : f64
return %ret : f64
}
@@ -813,3 +822,27 @@ func.func @unranked_rsqrt_op(%arg: tensor<*xf32>) -> tensor<*xf32>{
%a = math.rsqrt %arg : tensor<*xf32>
return %a: tensor<*xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @clampf_scalar_op
+// CHECK-SAME: (%[[ARG:.*]]: f16, %[[MIN:.*]]: f16, %[[MAX:.*]]: f16)
+// CHECK: %[[V0:.*]] = arith.minimumf %[[ARG]], %[[MIN]] : f16
+// CHECK: %[[V1:.*]] = arith.maximumf %[[V0]], %[[MAX]] : f16
+// CHECK: return %[[V1]] : f16
+
+func.func @clampf_scalar_op(%arg: f16, %min: f16, %max: f16) -> f16 {
+ %a = math.clampf %arg to [%min, %max] : f16
+ return %a: f16
+}
+
+// CHECK-LABEL: func.func @clampf_vector_op
+// CHECK-SAME: (%[[ARG:.*]]: vector<3x4xf32>, %[[MIN:.*]]: vector<3x4xf32>, %[[MAX:.*]]: vector<3x4xf32>)
+// CHECK: %[[V0:.*]] = arith.minimumf %[[ARG]], %[[MIN]] fastmath<fast> : vector<3x4xf32>
+// CHECK: %[[V1:.*]] = arith.maximumf %[[V0]], %[[MAX]] fastmath<fast> : vector<3x4xf32>
+// CHECK: return %[[V1]] : vector<3x4xf32>
+
+func.func @clampf_vector_op(%arg: vector<3x4xf32>, %min: vector<3x4xf32>, %max: vector<3x4xf32>) -> vector<3x4xf32>{
+ %a = math.clampf %arg to [%min, %max] fastmath<fast> : vector<3x4xf32>
+ return %a: vector<3x4xf32>
+}
diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir
index 8feaded..cb10fc4 100644
--- a/mlir/test/Dialect/Math/ops.mlir
+++ b/mlir/test/Dialect/Math/ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --verify-roundtrip | FileCheck %s
// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
// CHECK-LABEL: func @atan(
@@ -337,3 +337,16 @@ func.func @fpclassify(%f: f32, %d: f64, %v: vector<4xf32>, %t: tensor<4x?xf32>)
math.isnormal %t : tensor<4x?xf32>
return
}
+
+// CHECK-LABEL: func @clampf(
+func.func @clampf(%av: vector<3x4xf32>, %mv: vector<3x4xf32>, %Mv: vector<3x4xf32>,
+ %as: f32, %ms: f32, %Ms: f32,
+ %at: tensor<?xf80>, %mt: tensor<?xf80>, %Mt: tensor<?xf80>) {
+ // CHECK: math.clampf %{{.*}} to [%{{.*}}, %{{.*}}] fastmath<fast> : vector<3x4xf32>
+ %rv = math.clampf %av to [%mv, %Mv] fastmath<fast> : vector<3x4xf32>
+ // CHECK: math.clampf %{{.*}} to [%{{.*}}, %{{.*}}] : f32
+ %rs = math.clampf %as to [%ms, %Ms] fastmath<none> : f32
+ // CHECK: math.clampf %{{.*}} to [%{{.*}}, %{{.*}}] : tensor<?xf80>
+ %rt = math.clampf %at to [%mt, %Mt] : tensor<?xf80>
+ return
+}
diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index 2b64fa4..f735e3f 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -378,3 +378,14 @@ func.func @check_matrixC_dim(%arg0: vector<4x4xf16>, %arg1: vector<2x2xf16>, %ar
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x4xf16>, vector<2x2xf16>, vector<4xf16>) -> vector<2x2xf16>
return %d : vector<2x2xf16>
}
+
+// -----
+
+!desc = !nvgpu.tensormap.descriptor<tensor = memref<32x8xi8,3>, swizzle=none, l2promo = none, oob = zero, interleave = none>
+!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
+func.func @tma_last_dim_bytes(%desc: !desc, %buffer: memref<32x8xi8,3>, %mbarrier: !mbarrier) {
+ %c0 = arith.constant 0 : index
+ // expected-error @+1 {{the bytes in the last dimension of the tensor map must be a multiple of 16}}
+ nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer : !desc, !mbarrier -> memref<32x8xi8,3>
+ return
+}
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 7bb6cf4..5a3bbaf 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -1954,6 +1954,70 @@ acc.reduction.recipe @reduction_add_memref_i32 : memref<i32> reduction_operator
// CHECK-LABEL: acc.reduction.recipe @reduction_add_memref_i32
// CHECK: memref.alloca
+// -----
+
+// Test reduction recipe with destroy region using dynamic memory allocation
+acc.reduction.recipe @reduction_add_with_destroy : memref<?xf32> reduction_operator<add> init {
+^bb0(%arg0: memref<?xf32>):
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %size = memref.dim %arg0, %c0 : memref<?xf32>
+ %alloc = memref.alloc(%size) : memref<?xf32>
+ %c1 = arith.constant 1 : index
+ scf.for %i = %c0 to %size step %c1 {
+ memref.store %cst, %alloc[%i] : memref<?xf32>
+ }
+ acc.yield %alloc : memref<?xf32>
+} combiner {
+^bb0(%arg0: memref<?xf32>, %arg1: memref<?xf32>):
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %size = memref.dim %arg0, %c0 : memref<?xf32>
+ scf.for %i = %c0 to %size step %c1 {
+ %val0 = memref.load %arg0[%i] : memref<?xf32>
+ %val1 = memref.load %arg1[%i] : memref<?xf32>
+ %sum = arith.addf %val0, %val1 : f32
+ memref.store %sum, %arg0[%i] : memref<?xf32>
+ }
+ acc.yield %arg0 : memref<?xf32>
+} destroy {
+^bb0(%arg0: memref<?xf32>):
+ // destroy region to deallocate dynamically allocated memory
+ memref.dealloc %arg0 : memref<?xf32>
+ acc.yield
+}
+
+// CHECK-LABEL: acc.reduction.recipe @reduction_add_with_destroy : memref<?xf32> reduction_operator <add> init {
+// CHECK: ^bb0(%[[ARG:.*]]: memref<?xf32>):
+// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[SIZE:.*]] = memref.dim %[[ARG]], %[[C0]] : memref<?xf32>
+// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[SIZE]]) : memref<?xf32>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[SIZE]] step %[[C1]] {
+// CHECK: memref.store %[[CST]], %[[ALLOC]][%[[I]]] : memref<?xf32>
+// CHECK: }
+// CHECK: acc.yield %[[ALLOC]] : memref<?xf32>
+// CHECK: } combiner {
+// CHECK: ^bb0(%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: memref<?xf32>):
+// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK: %[[C1_1:.*]] = arith.constant 1 : index
+// CHECK: %[[SIZE_1:.*]] = memref.dim %[[ARG0]], %[[C0_1]] : memref<?xf32>
+// CHECK: scf.for %[[I_1:.*]] = %[[C0_1]] to %[[SIZE_1]] step %[[C1_1]] {
+// CHECK: %{{.*}} = memref.load %[[ARG0]][%[[I_1]]] : memref<?xf32>
+// CHECK: %{{.*}} = memref.load %[[ARG1]][%[[I_1]]] : memref<?xf32>
+// CHECK: %[[SUM:.*]] = arith.addf %{{.*}}, %{{.*}} : f32
+// CHECK: memref.store %[[SUM]], %[[ARG0]][%[[I_1]]] : memref<?xf32>
+// CHECK: }
+// CHECK: acc.yield %[[ARG0]] : memref<?xf32>
+// CHECK: } destroy {
+// CHECK: ^bb0(%[[ARG_DESTROY:.*]]: memref<?xf32>):
+// CHECK: memref.dealloc %[[ARG_DESTROY]] : memref<?xf32>
+// CHECK: acc.yield
+// CHECK: }
+
+// -----
+
acc.private.recipe @privatization_memref_i32 : memref<i32> init {
^bb0(%arg0: memref<i32>):
%alloca = memref.alloca() : memref<i32>
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 5088f2d..986c384 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -3017,3 +3017,110 @@ func.func @invalid_allocate_allocator(%arg0 : memref<i32>) -> () {
return
}
+
+// -----
+func.func @invalid_workdistribute_empty_region() -> () {
+ omp.teams {
+ // expected-error @below {{region cannot be empty}}
+ omp.workdistribute {
+ }
+ omp.terminator
+ }
+ return
+}
+
+// -----
+func.func @invalid_workdistribute_no_terminator() -> () {
+ omp.teams {
+ // expected-error @below {{region must be terminated with omp.terminator}}
+ omp.workdistribute {
+ %c0 = arith.constant 0 : i32
+ }
+ omp.terminator
+ }
+ return
+}
+
+// -----
+func.func @invalid_workdistribute_wrong_terminator() -> () {
+ omp.teams {
+ // expected-error @below {{region must be terminated with omp.terminator}}
+ omp.workdistribute {
+ %c0 = arith.constant 0 : i32
+ func.return
+ }
+ omp.terminator
+ }
+ return
+}
+
+// -----
+func.func @invalid_workdistribute_multiple_terminators() -> () {
+ omp.teams {
+ // expected-error @below {{region must have exactly one terminator}}
+ omp.workdistribute {
+ %cond = arith.constant true
+ cf.cond_br %cond, ^bb1, ^bb2
+ ^bb1:
+ omp.terminator
+ ^bb2:
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
+
+// -----
+func.func @invalid_workdistribute_with_barrier() -> () {
+ omp.teams {
+ // expected-error @below {{explicit barriers are not allowed in workdistribute region}}
+ omp.workdistribute {
+ %c0 = arith.constant 0 : i32
+ omp.barrier
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
+
+// -----
+func.func @invalid_workdistribute_nested_parallel() -> () {
+ omp.teams {
+ // expected-error @below {{nested parallel constructs not allowed in workdistribute}}
+ omp.workdistribute {
+ omp.parallel {
+ omp.terminator
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
+
+// -----
+// Test: nested teams not allowed in workdistribute
+func.func @invalid_workdistribute_nested_teams() -> () {
+ omp.teams {
+ // expected-error @below {{nested teams constructs not allowed in workdistribute}}
+ omp.workdistribute {
+ omp.teams {
+ omp.terminator
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
+
+// -----
+func.func @invalid_workdistribute() -> () {
+// expected-error @below {{workdistribute must be nested under teams}}
+ omp.workdistribute {
+ omp.terminator
+ }
+ return
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 8c846cd..3c2e0a3 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -3238,3 +3238,15 @@ func.func @omp_allocate_dir(%arg0 : memref<i32>, %arg1 : memref<i32>) -> () {
return
}
+// CHECK-LABEL: func.func @omp_workdistribute
+func.func @omp_workdistribute() {
+ // CHECK: omp.teams
+ omp.teams {
+ // CHECK: omp.workdistribute
+ omp.workdistribute {
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir
index 19fd715..0c34ae4 100644
--- a/mlir/test/Dialect/Ptr/invalid.mlir
+++ b/mlir/test/Dialect/Ptr/invalid.mlir
@@ -14,3 +14,43 @@ func.func @invalid_to_ptr(%v: !ptr.ptr<#ptr.generic_space>) {
%r = ptr.to_ptr %v : !ptr.ptr<#ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
return
}
+
+// -----
+
+func.func @invalid_load_alignment(%arg0: !ptr.ptr<#ptr.generic_space>) -> i64 {
+ // expected-error@+1 {{alignment must be a power of 2}}
+ %r = ptr.load %arg0 alignment = 3 : !ptr.ptr<#ptr.generic_space> -> i64
+ return %r : i64
+}
+
+// -----
+
+func.func @invalid_store_alignment(%arg0: !ptr.ptr<#ptr.generic_space>, %arg1: i64) {
+ // expected-error@+1 {{alignment must be a power of 2}}
+ ptr.store %arg1, %arg0 alignment = 3 : i64, !ptr.ptr<#ptr.generic_space>
+ return
+}
+
+// -----
+
+func.func @store_const(%arg0: !ptr.ptr<#test.const_memory_space>, %arg1: i64) {
+ // expected-error@+1 {{memory space is read-only}}
+ ptr.store %arg1, %arg0 atomic monotonic alignment = 8 : i64, !ptr.ptr<#test.const_memory_space>
+ return
+}
+
+// -----
+
+func.func @llvm_load(%arg0: !ptr.ptr<#llvm.address_space<1>>) -> (memref<f32>) {
+ // expected-error@+1 {{type must be LLVM type with size, but got 'memref<f32>'}}
+ %0 = ptr.load %arg0 : !ptr.ptr<#llvm.address_space<1>> -> memref<f32>
+ return %0 : memref<f32>
+}
+
+// -----
+
+func.func @llvm_store(%arg0: !ptr.ptr<#llvm.address_space<1>>, %arg1: memref<f32>) {
+ // expected-error@+1 {{type must be LLVM type with size, but got 'memref<f32>'}}
+ ptr.store %arg1, %arg0 : memref<f32>, !ptr.ptr<#llvm.address_space<1>>
+ return
+}
diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir
index eed3272..3f3ad05 100644
--- a/mlir/test/Dialect/Ptr/ops.mlir
+++ b/mlir/test/Dialect/Ptr/ops.mlir
@@ -1,14 +1,7 @@
-// RUN: mlir-opt %s --verify-roundtrip | FileCheck %s
+// RUN: mlir-opt %s --verify-roundtrip
/// Check op assembly.
-// CHECK-LABEL: @ptr_add_type_offset
func.func @ptr_add_type_offset(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.generic_space> {
- // CHECK: ptr.type_offset f32 : index
- // CHECK-NEXT: ptr.ptr_add %{{.*}}, %{{.*}} : <#ptr.generic_space>, index
- // CHECK-NEXT: ptr.ptr_add %{{.*}}, %{{.*}} : <#ptr.generic_space>, index
- // CHECK-NEXT: ptr.ptr_add nusw %{{.*}}, %{{.*}} : <#ptr.generic_space>, index
- // CHECK-NEXT: ptr.ptr_add nuw %{{.*}}, %{{.*}} : <#ptr.generic_space>, index
- // CHECK-NEXT: ptr.ptr_add inbounds %{{.*}}, %{{.*}} : <#ptr.generic_space>, index
%off = ptr.type_offset f32 : index
%res = ptr.ptr_add %ptr, %off : !ptr.ptr<#ptr.generic_space>, index
%res0 = ptr.ptr_add none %ptr, %off : !ptr.ptr<#ptr.generic_space>, index
@@ -19,7 +12,6 @@ func.func @ptr_add_type_offset(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#
}
/// Check cast ops assembly.
-// CHECK-LABEL: @cast_ops
func.func @cast_ops(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
%ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
%mda = ptr.get_metadata %mr : memref<f32, #ptr.generic_space>
@@ -27,3 +19,40 @@ func.func @cast_ops(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.ge
%mr0 = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
return %res : memref<f32, #ptr.generic_space>
}
+
+/// Check load ops assembly.
+func.func @load_ops(%arg0: !ptr.ptr<#ptr.generic_space>) -> (f32, f32, f32, f32, f32, i64, i32) {
+ %0 = ptr.load %arg0 : !ptr.ptr<#ptr.generic_space> -> f32
+ %1 = ptr.load volatile %arg0 : !ptr.ptr<#ptr.generic_space> -> f32
+ %2 = ptr.load %arg0 nontemporal : !ptr.ptr<#ptr.generic_space> -> f32
+ %3 = ptr.load %arg0 invariant : !ptr.ptr<#ptr.generic_space> -> f32
+ %4 = ptr.load %arg0 invariant_group : !ptr.ptr<#ptr.generic_space> -> f32
+ %5 = ptr.load %arg0 atomic monotonic alignment = 8 : !ptr.ptr<#ptr.generic_space> -> i64
+ %6 = ptr.load volatile %arg0 atomic syncscope("workgroup") acquire nontemporal alignment = 4 : !ptr.ptr<#ptr.generic_space> -> i32
+ return %0, %1, %2, %3, %4, %5, %6 : f32, f32, f32, f32, f32, i64, i32
+}
+
+/// Check store ops assembly.
+func.func @store_ops(%arg0: !ptr.ptr<#ptr.generic_space>, %arg1: f32, %arg2: i64, %arg3: i32) {
+ ptr.store %arg1, %arg0 : f32, !ptr.ptr<#ptr.generic_space>
+ ptr.store volatile %arg1, %arg0 : f32, !ptr.ptr<#ptr.generic_space>
+ ptr.store %arg1, %arg0 nontemporal : f32, !ptr.ptr<#ptr.generic_space>
+ ptr.store %arg1, %arg0 invariant_group : f32, !ptr.ptr<#ptr.generic_space>
+ ptr.store %arg2, %arg0 atomic monotonic alignment = 8 : i64, !ptr.ptr<#ptr.generic_space>
+ ptr.store volatile %arg3, %arg0 atomic syncscope("workgroup") release nontemporal alignment = 4 : i32, !ptr.ptr<#ptr.generic_space>
+ return
+}
+
+/// Test load operations with llvm.address_space memory space
+func.func @llvm_load(%arg0: !ptr.ptr<#llvm.address_space<1>>) -> (f32, i32) {
+ %0 = ptr.load %arg0 : !ptr.ptr<#llvm.address_space<1>> -> f32
+ %1 = ptr.load volatile %arg0 atomic acquire alignment = 4 : !ptr.ptr<#llvm.address_space<1>> -> i32
+ return %0, %1 : f32, i32
+}
+
+/// Test store operations with llvm.address_space memory space
+func.func @llvm_store(%arg0: !ptr.ptr<#llvm.address_space<2>>, %arg1: f32, %arg2: i64) {
+ ptr.store %arg1, %arg0 : f32, !ptr.ptr<#llvm.address_space<2>>
+ ptr.store %arg2, %arg0 atomic release alignment = 8 : i64, !ptr.ptr<#llvm.address_space<2>>
+ return
+}
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 308cf150..2752c49 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1479,7 +1479,7 @@ func.func @execute_region_no_inline() {
// CHECK-NEXT: scf.execute_region
// CHECK-NEXT: %[[VAL:.*]] = "test.val"() : () -> i64
// CHECK-NEXT: scf.yield %[[VAL]] : i64
-// CHECK-NEXT: }
+// CHECK-NOT: no_inline
// -----
@@ -1912,3 +1912,16 @@ func.func @index_switch_fold_no_res() {
// CHECK-LABEL: func.func @index_switch_fold_no_res()
// CHECK-NEXT: "test.op"() : () -> ()
+
+// -----
+
+// CHECK-LABEL: func @scf_for_all_step_size_0()
+// CHECK: scf.forall (%{{.*}}) = (0) to (1) step (0)
+func.func @scf_for_all_step_size_0() {
+ %x = arith.constant 0 : index
+ scf.forall (%i, %j) = (0, 4) to (1, 5) step (%x, 8) {
+ vector.print %x : index
+ scf.forall.in_parallel {}
+ }
+ return
+}
diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir
index 7f457ef..5930a1d 100644
--- a/mlir/test/Dialect/SCF/ops.mlir
+++ b/mlir/test/Dialect/SCF/ops.mlir
@@ -28,14 +28,14 @@ func.func @std_for(%arg0 : index, %arg1 : index, %arg2 : index) {
func.func @std_for_i32(%arg0 : i32, %arg1 : i32, %arg2 : i32) {
scf.for %i0 = %arg0 to %arg1 step %arg2 : i32 {
- scf.for %i1 = %arg0 to %arg1 step %arg2 : i32 {
+ scf.for unsigned %i1 = %arg0 to %arg1 step %arg2 : i32 {
}
}
return
}
// CHECK-LABEL: func @std_for_i32(
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} : i32 {
-// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} : i32 {
+// CHECK-NEXT: scf.for unsigned %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} : i32 {
func.func @scf_for_i64_iter(%arg1: i64, %arg2: i64) {
%c1_i64 = arith.constant 1 : i64
diff --git a/mlir/test/Dialect/SPIRV/IR/invalid.mlir b/mlir/test/Dialect/SPIRV/IR/invalid.mlir
new file mode 100644
index 0000000..e010074
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/invalid.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt --split-input-file --verify-diagnostics %s
+
+//===----------------------------------------------------------------------===//
+// spirv.LoadOp
+//===----------------------------------------------------------------------===//
+
+func.func @aligned_load_non_positive() -> () {
+ %0 = spirv.Variable : !spirv.ptr<f32, Function>
+ // expected-error@below {{'spirv.Load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ %1 = spirv.Load "Function" %0 ["Aligned", 0] : f32
+ return
+}
+
+// -----
+
+func.func @aligned_load_non_power_of_two() -> () {
+ %0 = spirv.Variable : !spirv.ptr<f32, Function>
+ // expected-error@below {{'spirv.Load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ %1 = spirv.Load "Function" %0 ["Aligned", 3] : f32
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.StoreOp
+//===----------------------------------------------------------------------===//
+
+func.func @aligned_store_non_positive(%arg0 : f32) -> () {
+ %0 = spirv.Variable : !spirv.ptr<f32, Function>
+ // expected-error@below {{'spirv.Store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ spirv.Store "Function" %0, %arg0 ["Aligned", 0] : f32
+ return
+}
+
+// -----
+
+func.func @aligned_store_non_power_of_two(%arg0 : f32) -> () {
+ %0 = spirv.Variable : !spirv.ptr<f32, Function>
+ // expected-error@below {{'spirv.Store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ spirv.Store "Function" %0, %arg0 ["Aligned", 3] : f32
+ return
+}
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 6398161..600c4c7 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -598,7 +598,7 @@ func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
%border = tosa.const_shape { values = dense<[1, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
// CHECK: profiles: [ [pro_int, pro_fp] ]
// CHECK: extensions: [ [int16, bf16] ]
- %1 = tosa.resize %arg0, %scale, %offset, %border {mode = "BILINEAR"} : (tensor<1x32x32x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x64x64x8xf32>
+ %1 = tosa.resize %arg0, %scale, %offset, %border {mode = BILINEAR} : (tensor<1x32x32x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x64x64x8xf32>
return %1 : tensor<1x64x64x8xf32>
}
@@ -618,7 +618,7 @@ func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439
%output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8>
// CHECK: tosa.rescale profiles: [ [pro_int] ]
// CHECK: tosa.rescale extensions: [ [int16] ]
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", scale32 = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = SINGLE_ROUND, scale32 = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 5150ee3..fd2a3f1 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -315,8 +315,8 @@ func.func @disjoint_clamp_twice_is_not_single_clamp(%arg0: tensor<4xi8>) -> tens
// CHECK-LABEL: @clamp_twice_with_nan_propagate_is_single_clamp
func.func @clamp_twice_with_nan_propagate_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
// CHECK: tosa.clamp %arg0 {max_val = 2 : i8, min_val = -2 : i8}
- %0 = tosa.clamp %arg0 {max_val = 4 : i8, min_val = -2 : i8, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8>
- %1 = tosa.clamp %0 {max_val = 2 : i8, min_val = -4 : i8, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8>
+ %0 = tosa.clamp %arg0 {max_val = 4 : i8, min_val = -2 : i8, nan_mode = PROPAGATE} : (tensor<4xi8>) -> tensor<4xi8>
+ %1 = tosa.clamp %0 {max_val = 2 : i8, min_val = -4 : i8, nan_mode = PROPAGATE} : (tensor<4xi8>) -> tensor<4xi8>
return %1 : tensor<4xi8>
}
@@ -324,9 +324,9 @@ func.func @clamp_twice_with_nan_propagate_is_single_clamp(%arg0: tensor<4xi8>) -
// CHECK-LABEL: @clamp_twice_with_nan_ignore_is_single_clamp
func.func @clamp_twice_with_nan_ignore_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
- // CHECK: tosa.clamp %arg0 {max_val = 2 : i8, min_val = -2 : i8, nan_mode = "IGNORE"}
- %0 = tosa.clamp %arg0 {max_val = 4 : i8, min_val = -2 : i8, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
- %1 = tosa.clamp %0 {max_val = 2 : i8, min_val = -4 : i8, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
+ // CHECK: tosa.clamp %arg0 {max_val = 2 : i8, min_val = -2 : i8, nan_mode = IGNORE}
+ %0 = tosa.clamp %arg0 {max_val = 4 : i8, min_val = -2 : i8, nan_mode = IGNORE} : (tensor<4xi8>) -> tensor<4xi8>
+ %1 = tosa.clamp %0 {max_val = 2 : i8, min_val = -4 : i8, nan_mode = IGNORE} : (tensor<4xi8>) -> tensor<4xi8>
return %1 : tensor<4xi8>
}
@@ -334,9 +334,9 @@ func.func @clamp_twice_with_nan_ignore_is_single_clamp(%arg0: tensor<4xi8>) -> t
// CHECK-LABEL: @clamp_twice_with_nan_ignore_propagate_is_single_clamp
func.func @clamp_twice_with_nan_ignore_propagate_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
- // CHECK: tosa.clamp %arg0 {max_val = 2 : i8, min_val = -2 : i8, nan_mode = "IGNORE"}
- %0 = tosa.clamp %arg0 {max_val = 4 : i8, min_val = -2 : i8, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
- %1 = tosa.clamp %0 {max_val = 2 : i8, min_val = -4 : i8, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8>
+ // CHECK: tosa.clamp %arg0 {max_val = 2 : i8, min_val = -2 : i8, nan_mode = IGNORE}
+ %0 = tosa.clamp %arg0 {max_val = 4 : i8, min_val = -2 : i8, nan_mode = IGNORE} : (tensor<4xi8>) -> tensor<4xi8>
+ %1 = tosa.clamp %0 {max_val = 2 : i8, min_val = -4 : i8, nan_mode = PROPAGATE} : (tensor<4xi8>) -> tensor<4xi8>
return %1 : tensor<4xi8>
}
@@ -345,9 +345,9 @@ func.func @clamp_twice_with_nan_ignore_propagate_is_single_clamp(%arg0: tensor<4
// CHECK: @clamp_twice_with_nan_propagate_ignore_is_not_single_clamp(%[[INPUT:.*]]: tensor<4xi8>)
func.func @clamp_twice_with_nan_propagate_ignore_is_not_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
// CHECK: %[[CLAMP_1:.*]] = tosa.clamp %[[INPUT]] {max_val = 4 : i8, min_val = -2 : i8} : (tensor<4xi8>) -> tensor<4xi8>
- // CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_val = 2 : i8, min_val = -4 : i8, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
- %0 = tosa.clamp %arg0 {max_val = 4 : i8, min_val = -2 : i8, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8>
- %1 = tosa.clamp %0 {max_val = 2 : i8, min_val = -4 : i8, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
+ // CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_val = 2 : i8, min_val = -4 : i8, nan_mode = IGNORE} : (tensor<4xi8>) -> tensor<4xi8>
+ %0 = tosa.clamp %arg0 {max_val = 4 : i8, min_val = -2 : i8, nan_mode = PROPAGATE} : (tensor<4xi8>) -> tensor<4xi8>
+ %1 = tosa.clamp %0 {max_val = 2 : i8, min_val = -4 : i8, nan_mode = IGNORE} : (tensor<4xi8>) -> tensor<4xi8>
return %1 : tensor<4xi8>
}
@@ -565,6 +565,33 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso
// -----
+// CHECK-LABEL: @mul_zero_dynamic_nofold
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x17xf32>) -> tensor<?x17xf32> {
+// CHECK: %[[ZERO:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
+// CHECK: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+// CHECK: %[[MUL:.*]] = tosa.mul %[[ARG0]], %[[ZERO]], %[[SHIFT]] : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
+// CHECK: return %[[MUL]]
+func.func @mul_zero_dynamic_nofold(%arg0: tensor<?x17xf32>) -> tensor<?x17xf32> {
+ %0 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
+ %1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %2 = tosa.mul %arg0, %0, %1 : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
+ return %2 : tensor<?x17xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @mul_one_dynamic_fold
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x17xf32>) -> tensor<?x17xf32> {
+// CHECK: return %[[ARG0]]
+func.func @mul_one_dynamic_fold(%arg0: tensor<?x17xf32>) -> tensor<?x17xf32> {
+ %0 = "tosa.const"() <{values = dense<1.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
+ %1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %2 = tosa.mul %arg0, %0, %1 : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
+ return %2 : tensor<?x17xf32>
+}
+
+// -----
+
// CHECK-LABEL: @select_same_value
func.func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
%0 = tosa.select %arg0, %arg1, %arg1 : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
@@ -913,7 +940,7 @@ func.func @fold_resize_nearest(%arg0 : tensor<1x15x13x1xi8>) -> tensor<1x15x13x1
%scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
- %resize = tosa.resize %arg0, %scale, %offset, %border {mode = "NEAREST_NEIGHBOR"} : (tensor<1x15x13x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x15x13x1xi8>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<1x15x13x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x15x13x1xi8>
return %resize : tensor<1x15x13x1xi8>
}
@@ -925,7 +952,7 @@ func.func @fold_resize_bilinear(%arg0 : tensor<1x15x13x1xi8>) -> tensor<1x15x13x
%scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
- %resize = tosa.resize %arg0, %scale, %offset, %border {mode = "BILINEAR"} : (tensor<1x15x13x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x15x13x1xi8>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<1x15x13x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x15x13x1xi8>
return %resize : tensor<1x15x13x1xi8>
}
@@ -1169,7 +1196,7 @@ func.func @reshape_quant_nofold() -> tensor<1x1x1x1xi32> {
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
%input_zp = "tosa.const"() {values = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
- %2 = tosa.rescale %1, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "DOUBLE_ROUND", scale32 = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1x1x1x1xi32>
+ %2 = tosa.rescale %1, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = DOUBLE_ROUND, scale32 = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1x1x1x1xi32>
return %2 : tensor<1x1x1x1xi32>
}
@@ -1277,10 +1304,9 @@ func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> {
// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<8.620690e-03> : tensor<3x600x1200xf32>}> : () -> tensor<3x600x1200xf32>
// CHECK: return %[[VAL_0]] : tensor<3x600x1200xf32>
// CHECK: }
- %0 = "tosa.const"(){ values = dense<116.0>: tensor<f32> }: () -> tensor<f32>
- %1 = "tosa.cast"(%0) : (tensor<f32>) -> tensor<3x600x1200xf32>
- %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32>
- return %2 : tensor<3x600x1200xf32>
+ %0 = "tosa.const"(){ values = dense<116.0>: tensor<3x600x1200xf32> }: () -> tensor<3x600x1200xf32>
+ %1 = "tosa.reciprocal"(%0): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32>
+ return %1 : tensor<3x600x1200xf32>
}
// -----
@@ -1288,10 +1314,9 @@ func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> {
// CHECK-LABEL: @do_not_fold_reciprocal_int
func.func nested @do_not_fold_reciprocal_int() -> tensor<3x600x1200xi32> {
// CHECK: tosa.reciprocal
- %0 = "tosa.const"(){ values = dense<11>: tensor<i32> }: () -> tensor<i32>
- %1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<3x600x1200xi32>
- %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32>
- return %2 : tensor<3x600x1200xi32>
+ %0 = "tosa.const"(){ values = dense<11>: tensor<3x600x1200xi32> }: () -> tensor<3x600x1200xi32>
+ %1 = "tosa.reciprocal"(%0): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32>
+ return %1 : tensor<3x600x1200xi32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
index e23ce430..aaf8371 100644
--- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir
+++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
@@ -31,7 +31,7 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) ->
func.func @test_rescale_non_const_multiplier(%arg0: tensor<13x21x3xi32>, %multiplier: tensor<1xi32>) -> tensor<13x21x3xi32> {
%zps = "tosa.const"() {values = dense<0> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = SINGLE_ROUND, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
@@ -40,7 +40,7 @@ func.func @test_rescale_non_const_multiplier(%arg0: tensor<13x21x3xi32>, %multip
func.func @test_rescale_non_const_shift(%arg0: tensor<13x21x3xi32>, %shift: tensor<1xi8>) -> tensor<13x21x3xi32> {
%zps = "tosa.const"() {values = dense<0> : tensor<1xi32> } : () -> tensor<1xi32>
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
- %0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = SINGLE_ROUND, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
@@ -50,7 +50,7 @@ func.func @test_rescale_non_const_input_zp(%arg0: tensor<13x21x3xi32>, %input_zp
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
%output_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = SINGLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
@@ -60,7 +60,7 @@ func.func @test_rescale_non_const_output_zp(%arg0: tensor<13x21x3xi32>, %output_
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
%input_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = SINGLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index fad1bec..290773b 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -8,7 +8,7 @@ func.func @test_resize_large_image_size(%arg0: tensor<1x16384x16384x8xf32>) -> t
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
// expected-error@+1 {{'tosa.resize' op expect input/output height/width dims to be < 16384, got [OH, OW, IH, IW] = 32767, 32767, 16384, 16384}}
- %1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x16384x16384x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x32767x32767x8xf32>
+ %1 = tosa.resize %arg0, %scale, %offset, %border { mode = BILINEAR } : (tensor<1x16384x16384x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x32767x32767x8xf32>
return %1 : tensor<1x32767x32767x8xf32>
}
@@ -20,7 +20,7 @@ func.func @test_resize_invalid_scale_numerator(%arg0: tensor<1x9x9x8xf32>) -> te
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
// expected-error@+1 {{'tosa.resize' op expect all scale numerator values to be <= (1 << 11), got scale_y_n=2049, scale_x_n=1}}
- %1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x9x9x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
+ %1 = tosa.resize %arg0, %scale, %offset, %border { mode = BILINEAR } : (tensor<1x9x9x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
return %1 : tensor<?x?x?x?xf32>
}
@@ -32,7 +32,7 @@ func.func @test_resize_invalid_downscale(%arg0: tensor<1x37x37x8xf32>) -> tensor
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
// expected-error@+1 {{'tosa.resize' op expect a downscale ratio larger than 1/16, got y=1/18, x=1/18}}
- %1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x37x37x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
+ %1 = tosa.resize %arg0, %scale, %offset, %border { mode = BILINEAR } : (tensor<1x37x37x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
return %1 : tensor<?x?x?x?xf32>
}
@@ -44,7 +44,7 @@ func.func @test_resize_invalid_offset_y(%arg0: tensor<1x8x8x8xf32>) -> tensor<?x
%offset = tosa.const_shape { values = dense<[17, 0]> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
// expected-error@+1 {{'tosa.resize' op expect offsetY / scaleYNumerator to be in range [-1, 16), got 17/1}}
- %1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
+ %1 = tosa.resize %arg0, %scale, %offset, %border { mode = BILINEAR } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
return %1 : tensor<?x?x?x?xf32>
}
@@ -56,7 +56,7 @@ func.func @test_resize_invalid_offset_x(%arg0: tensor<1x8x8x8xf32>) -> tensor<?x
%offset = tosa.const_shape { values = dense<[0, -2]> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
// expected-error@+1 {{'tosa.resize' op expect offsetX / scaleXNumerator to be in range [-1, 16), got -2/1}}
- %1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
+ %1 = tosa.resize %arg0, %scale, %offset, %border { mode = BILINEAR } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
return %1 : tensor<?x?x?x?xf32>
}
@@ -68,7 +68,7 @@ func.func @test_resize_invalid_boarder_y(%arg0: tensor<1x8x8x8xf32>) -> tensor<?
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<[-17, 0]> : tensor<2xindex> } : () -> !tosa.shape<2>
// expected-error@+1 {{'tosa.resize' op expect borderY / scaleYNumerator to be in range [-16, 1), got -17/1}}
- %1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
+ %1 = tosa.resize %arg0, %scale, %offset, %border { mode = BILINEAR } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
return %1 : tensor<?x?x?x?xf32>
}
@@ -80,7 +80,7 @@ func.func @test_resize_invalid_boarder_x(%arg0: tensor<1x8x8x8xf32>) -> tensor<?
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<[0, 2]> : tensor<2xindex> } : () -> !tosa.shape<2>
// expected-error@+1 {{'tosa.resize' op expect borderX / scaleXNumerator to be in range [-16, 1), got 2/1}}
- %1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
+ %1 = tosa.resize %arg0, %scale, %offset, %border { mode = BILINEAR } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
return %1 : tensor<?x?x?x?xf32>
}
@@ -138,7 +138,7 @@ func.func @test_error_scale32_with_i48(%arg0: tensor<1xi48>) -> tensor<1xi8> {
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.rescale' op scale32 is not allowed with 48-bit input}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi48>, tensor<1xi32>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = SINGLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi48>, tensor<1xi32>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
return %0 : tensor<1xi8>
}
@@ -150,7 +150,7 @@ func.func @test_error_input_output_unsigned(%arg0: tensor<1xi8>) -> tensor<1xi16
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
// expected-error@+1 {{'tosa.rescale' op input and output cannot be both unsigned}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = true} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<1xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = SINGLE_ROUND, per_channel = false, input_unsigned = true, output_unsigned = true} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<1xi16>
return %0 : tensor<1xi16>
}
@@ -162,7 +162,7 @@ func.func @test_error_i32_output_unsigned_input(%arg0: tensor<1xi8>) -> tensor<1
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
// expected-error@+1 {{'tosa.rescale' op i32 output type is not allowed with unsigned input}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1xi32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = SINGLE_ROUND, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1xi32>
return %0 : tensor<1xi32>
}
@@ -174,7 +174,7 @@ func.func @test_error_i32_input_unsigned_output(%arg0: tensor<1xi32>) -> tensor<
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.rescale' op i32 input type is not allowed with unsigned output}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi32>, tensor<1xi16>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = SINGLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi32>, tensor<1xi16>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi8>
return %0 : tensor<1xi8>
}
@@ -186,7 +186,7 @@ func.func @test_error_i48_input_unsigned_output(%arg0: tensor<1xi48>) -> tensor<
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.rescale' op i48 input type is not allowed with unsigned output}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = SINGLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
return %0 : tensor<1xi8>
}
@@ -198,7 +198,7 @@ func.func @test_error_i48_input_unsigned_output(%arg0: tensor<1xi48>) -> tensor<
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.rescale' op i48 input type cannot be unsigned}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = SINGLE_ROUND, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
return %0 : tensor<1xi8>
}
@@ -210,7 +210,7 @@ func.func @test_error_i32_input_unsigned_output(%arg0: tensor<1xi32>) -> tensor<
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.rescale' op i32 input type cannot be unsigned}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi32>, tensor<1xi16>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = SINGLE_ROUND, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi32>, tensor<1xi16>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi8>
return %0 : tensor<1xi8>
}
@@ -222,7 +222,7 @@ func.func @test_error_i32_unsigned_output(%arg0: tensor<1xi8>) -> tensor<1xi32>
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
// expected-error@+1 {{'tosa.rescale' op i32 output type cannot be unsigned}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1xi32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = SINGLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1xi32>
return %0 : tensor<1xi32>
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 3bccb32..41c3243 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -6,6 +6,15 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment"
+
+func.func @test_cast(%arg0: tensor<i1>) -> tensor<5xi32> {
+ // expected-error@+1{{'tosa.cast' op requires the same shape for all operands and results}}
+ %1 = "tosa.cast"(%arg0) : (tensor<i1>) -> tensor<5xi32>
+ return %1 : tensor<5xi32>
+}
+
+// -----
+
func.func @test_const() -> tensor<1xf32> {
// expected-error@+1{{'tosa.const' op expected same attr/result element types}}
%0 = "tosa.const"() {values = dense<1> : tensor<1xi32>} : () -> tensor<1xf32>
@@ -1182,7 +1191,7 @@ func.func @test_resize_invalid_scale_values(%arg0: tensor<1x8x8x8xf32>) -> tenso
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
// expected-error@+1 {{'tosa.resize' op expect all scale values to be > 0, got 2, 0, -1, 2}}
- %1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
+ %1 = tosa.resize %arg0, %scale, %offset, %border { mode = BILINEAR } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
return %1 : tensor<?x?x?x?xf32>
}
@@ -1194,7 +1203,7 @@ func.func @test_resize_invalid_wholly_divisible_height(%arg0: tensor<1x8x8x8xf32
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
// expected-error@+1 {{'tosa.resize' op expected (input_height - 1) * scale_y_n - offset_y + border_y to be wholly divisible by scale_y_d, got ((8 - 1) * 1 - 0 + 0) / 3}}
- %1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x8x8x8xf32>
+ %1 = tosa.resize %arg0, %scale, %offset, %border { mode = BILINEAR } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x8x8x8xf32>
return %1 : tensor<1x8x8x8xf32>
}
@@ -1206,7 +1215,7 @@ func.func @test_resize_invalid_output_height(%arg0: tensor<1x8x8x8xf32>) -> tens
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
// expected-error@+1 {{'tosa.resize' op calculated output height did not match expected: calculated=15, expected=9}}
- %1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x9x8x8xf32>
+ %1 = tosa.resize %arg0, %scale, %offset, %border { mode = BILINEAR } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x9x8x8xf32>
return %1 : tensor<1x9x8x8xf32>
}
@@ -1218,7 +1227,7 @@ func.func @test_resize_invalid_wholly_divisible_width(%arg0: tensor<1x8x8x8xf32>
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
// expected-error@+1 {{'tosa.resize' op expected (input_width - 1) * scale_x_n - offset_x + border_x to be wholly divisible by scale_x_d, got ((8 - 1) * 1 - 0 + 0) / 3}}
- %1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x8x8x8xf32>
+ %1 = tosa.resize %arg0, %scale, %offset, %border { mode = BILINEAR } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x8x8x8xf32>
return %1 : tensor<1x8x8x8xf32>
}
@@ -1230,7 +1239,7 @@ func.func @test_resize_invalid_output_width(%arg0: tensor<1x8x8x8xf32>) -> tenso
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
// expected-error@+1 {{'tosa.resize' op calculated output width did not match expected: calculated=15, expected=9}}
- %1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x8x9x8xf32>
+ %1 = tosa.resize %arg0, %scale, %offset, %border { mode = BILINEAR } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x8x9x8xf32>
return %1 : tensor<1x8x9x8xf32>
}
@@ -1242,7 +1251,7 @@ func.func @broadcast_resize_nearest_f32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
// expected-error@+1 {{'tosa.resize' op calculated output width did not match expected: calculated=1, expected=5}}
- %resize = tosa.resize %arg0, %scale, %offset, %border {mode = "NEAREST_NEIGHBOR"} : (tensor<3x1x1x7xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x5x7xf32>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<3x1x1x7xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x5x7xf32>
return %resize : tensor<3x1x5x7xf32>
}
@@ -1255,7 +1264,7 @@ func.func @broadcast_resize_bilinear_i8(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
// expected-error@+1 {{'tosa.resize' op calculated output height did not match expected: calculated=1, expected=4}}
- %resize = tosa.resize %arg0, %scale, %offset, %border {mode = "BILINEAR"} : (tensor<3x1x1x7xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x4x5x7xi32>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = BILINEAR} : (tensor<3x1x1x7xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x4x5x7xi32>
return %resize : tensor<3x4x5x7xi32>
}
@@ -1472,7 +1481,7 @@ func.func @test_rescale_invalid_input_type(%arg0: tensor<13x21x3xf32>) -> tensor
%input_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
%output_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
// expected-error@+1 {{'tosa.rescale' op expect input to have integer element type, got 'f32'}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xf32>, tensor<1xi32>, tensor<1xi8>, tensor<1xf32>, tensor<1xf32>) -> tensor<13x21x3xi32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = SINGLE_ROUND, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xf32>, tensor<1xi32>, tensor<1xi8>, tensor<1xf32>, tensor<1xf32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
@@ -1484,7 +1493,7 @@ func.func @test_rescale_invalid_output_type(%arg0: tensor<13x21x3xi32>) -> tenso
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
// expected-error@+1 {{'tosa.rescale' op expect output to have integer element type, got 'f32'}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xf32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = SINGLE_ROUND, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -1496,7 +1505,7 @@ func.func @test_rescale_invalid_multiplier_type(%arg0: tensor<13x21x3xi32>) -> t
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
// expected-error@+1 {{'tosa.rescale' op operand #1 must be 1D tensor of 16-bit signless integer or 32-bit signless integer values, but got 'tensor<1xi48>'}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi48>, tensor<1xi16>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xf32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = SINGLE_ROUND, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi48>, tensor<1xi16>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -1508,7 +1517,7 @@ func.func @test_rescale_invalid_shift_type(%arg0: tensor<13x21x3xi32>) -> tensor
%input_zp = "tosa.const"() {values = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
// expected-error@+1 {{'tosa.rescale' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<1xi16>'}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi16>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xf32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = SINGLE_ROUND, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi16>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -1520,7 +1529,7 @@ func.func @test_rescale_invalid_input_zp_i32(%arg0: tensor<13x21x3xi32>) -> tens
%input_zp = "tosa.const"() {values = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
// expected-error@+1 {{'tosa.rescale' op expect input_zp of 0, got 1}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = SINGLE_ROUND, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
@@ -1532,7 +1541,7 @@ func.func @test_rescale_invalid_input_zp_s16(%arg0: tensor<13x21x3xi16>) -> tens
%input_zp = "tosa.const"() {values = dense<1> : tensor<1xi16>} : () -> tensor<1xi16>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
// expected-error@+1 {{'tosa.rescale' op expect input_zp of 0, got 1}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = SINGLE_ROUND, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
@@ -1544,7 +1553,7 @@ func.func @test_rescale_invalid_input_zp_u16(%arg0: tensor<13x21x3xi16>) -> tens
%input_zp = "tosa.const"() {values = dense<1> : tensor<1xi16>} : () -> tensor<1xi16>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
// expected-error@+1 {{'tosa.rescale' op expect input_zp of 0 or 32768 for unsigned int16 input, got 1}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = SINGLE_ROUND, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
@@ -1557,7 +1566,7 @@ func.func @test_rescale_invalid_output_zp_i32(%arg0: tensor<13x21x3xi32>) -> ten
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
%output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
// expected-error@+1 {{'tosa.rescale' op expect output_zp of 0, got -1}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = SINGLE_ROUND, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
@@ -1569,7 +1578,7 @@ func.func @test_rescale_invalid_output_zp_s16(%arg0: tensor<13x21x3xi16>) -> ten
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
%output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi16>} : () -> tensor<1xi16>
// expected-error@+1 {{'tosa.rescale' op expect output_zp of 0, got -1}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = SINGLE_ROUND, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
@@ -1581,7 +1590,7 @@ func.func @test_rescale_invalid_output_zp_u16(%arg0: tensor<13x21x3xi16>) -> ten
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
%output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi16>} : () -> tensor<1xi16>
// expected-error@+1 {{'tosa.rescale' op expect output_zp of 0 or 32768 for unsigned int16 output, got 65535}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = SINGLE_ROUND, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
@@ -1593,7 +1602,7 @@ func.func @test_rescale_invalid_multiplier_i16(%arg0: tensor<13x21x3xi16>) -> te
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
// expected-error@+1 {{'tosa.rescale' op expect i32 element type for multiplier for scale32=true, got 'i16'}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi16>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = SINGLE_ROUND, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi16>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
@@ -1605,7 +1614,7 @@ func.func @test_rescale_invalid_multiplier_i32(%arg0: tensor<13x21x3xi16>) -> te
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
// expected-error@+1 {{'tosa.rescale' op expect i16 element type for multiplier for scale32=false, got 'i32'}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = false, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = SINGLE_ROUND, per_channel = false, scale32 = false, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
@@ -1617,7 +1626,7 @@ func.func @test_rescale_invalid_multiplier_rank(%arg0: tensor<13x21x3xi16>) -> t
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
// expected-error@+1 {{'tosa.rescale' op operand #1 must be 1D tensor of 16-bit signless integer or 32-bit signless integer values, but got 'tensor<1x1xi32>'}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1x1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = SINGLE_ROUND, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1x1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
@@ -1629,7 +1638,7 @@ func.func @test_rescale_invalid_shift_rank(%arg0: tensor<13x21x3xi16>) -> tensor
%input_zp = "tosa.const"() {values = dense<1> : tensor<1xi16>} : () -> tensor<1xi16>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
// expected-error@+1 {{'tosa.rescale' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1x1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = SINGLE_ROUND, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1x1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
@@ -1641,7 +1650,7 @@ func.func @test_rescale_invalid_perchannel_multiplier_shape(%arg0: tensor<13x21x
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
// expected-error@+1 {{'tosa.rescale' op expect shape of { 3 } for multiplier input, got { 1 }}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = SINGLE_ROUND, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
@@ -1653,7 +1662,7 @@ func.func @test_rescale_invalid_non_perchannel_multiplier_shape(%arg0: tensor<13
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
// expected-error@+1 {{'tosa.rescale' op expect shape of { 1 } for multiplier input, got { 3 }}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<3xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = SINGLE_ROUND, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<3xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
@@ -1665,7 +1674,7 @@ func.func @test_rescale_invalid_perchannel_shift_shape(%arg0: tensor<13x21x3xi16
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
// expected-error@+1 {{'tosa.rescale' op expect shape of { 3 } for shift input, got { 1 }}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<3xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = SINGLE_ROUND, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<3xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
@@ -1677,7 +1686,7 @@ func.func @test_rescale_invalid_non_perchannel_shift_shape(%arg0: tensor<13x21x3
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
// expected-error@+1 {{'tosa.rescale' op expect shape of { 1 } for shift input, got { 3 }}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = SINGLE_ROUND, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
@@ -1689,7 +1698,7 @@ func.func @test_error_double_round_without_scale32(%arg0: tensor<1xi8>) -> tenso
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
// expected-error@+1 {{'tosa.rescale' op DOUBLE_ROUND is only allowed with scale32=true}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "DOUBLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<1xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = DOUBLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<1xi16>
return %0 : tensor<1xi16>
}
@@ -2010,7 +2019,7 @@ func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui8>) -> (tensor<1x1xi8
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.rescale' op is not profile-aligned: element type 'ui8' is not legal}}
- %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
+ %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = SINGLE_ROUND, scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
return %r : tensor<1x1xi8>
}
@@ -2023,7 +2032,7 @@ func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.rescale' op is not profile-aligned: element type 'ui8' is not legal}}
- %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
+ %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = SINGLE_ROUND, scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
return %r : tensor<1x1xui8>
}
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 3154f54..3138ce2 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -255,7 +255,7 @@ func.func @test_resize(%arg0: tensor<1x32x32x8xbf16>) -> tensor<1x64x64x8xbf16>
%offset = tosa.const_shape { values = dense<[-1, -1]> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<[1, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
// expected-error@+1 {{'tosa.resize' op illegal: requires [bf16] but not enabled in target}}
- %1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x32x32x8xbf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x64x64x8xbf16>
+ %1 = tosa.resize %arg0, %scale, %offset, %border { mode = BILINEAR } : (tensor<1x32x32x8xbf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x64x64x8xbf16>
return %1 : tensor<1x64x64x8xbf16>
}
@@ -377,7 +377,7 @@ func.func @test_single_round_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// CHECK tosa.rescale
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = SINGLE_ROUND, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
return %0 : tensor<13x21x3xi8>
}
@@ -389,7 +389,7 @@ func.func @test_double_round_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.rescale' op failed attribute check: rounding_mode = DOUBLE_ROUND requires extension [doubleround]}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "DOUBLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = DOUBLE_ROUND, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
return %0 : tensor<13x21x3xi8>
}
@@ -401,7 +401,7 @@ func.func @test_inexact_round_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.rescale' op failed attribute check: rounding_mode = INEXACT_ROUND requires extension [inexactround]}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "INEXACT_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = INEXACT_ROUND, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3xi8>
return %0 : tensor<13x21x3xi8>
}
@@ -420,7 +420,7 @@ func.func @test_rescale_non_const_multiplier(%arg0: tensor<13x21x3xi32>, %multip
%zps = "tosa.const"() {values = dense<0> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.rescale' op expected compile time resolvable constant, but got variable value for operand #1}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = SINGLE_ROUND, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
@@ -430,7 +430,7 @@ func.func @test_rescale_non_const_shift(%arg0: tensor<13x21x3xi32>, %shift: tens
%zps = "tosa.const"() {values = dense<0> : tensor<1xi32> } : () -> tensor<1xi32>
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
// expected-error@+1 {{'tosa.rescale' op expected compile time resolvable constant, but got variable value for operand #2}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %zps, %zps {rounding_mode = SINGLE_ROUND, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
@@ -510,7 +510,7 @@ func.func @test_rescale_non_const_input_zp(%arg0: tensor<13x21x3xi32>, %input_zp
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%output_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
// expected-error@+1 {{'tosa.rescale' op expected compile time resolvable constant, but got variable value for operand #3}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = SINGLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 0184d2b..a693a66 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -437,7 +437,7 @@ func.func @test_rescale_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi8>) -> tenso
%input_zp = "tosa.const"() {values = dense<127> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.rescale' op failed level check: operand rank(shape) <= MAX_RANK}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1x1x1x13x21x3xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = SINGLE_ROUND, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1x1x1x13x21x3xi8>
return %0 : tensor<1x1x1x1x13x21x3xi8>
}
@@ -965,7 +965,7 @@ func.func @test_resize_scale_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x7970x64
%offset = tosa.const_shape { values = dense<[-1, -1]> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<[1, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
// expected-error@+1 {{'tosa.resize' op failed level check: scale_y_n/scale_y_d <= MAX_SCALE}}
- %1 = tosa.resize %arg0, %scale, %offset, %border {mode = "BILINEAR"} :
+ %1 = tosa.resize %arg0, %scale, %offset, %border {mode = BILINEAR} :
(tensor<1x32x32x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x7970x64x8xf32>
return %1 : tensor<1x7970x64x8xf32>
}
@@ -977,7 +977,7 @@ func.func @test_resize_scale_x(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x7970
%offset = tosa.const_shape { values = dense<[-1, -1]> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<[1, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
// expected-error@+1 {{'tosa.resize' op failed level check: scale_x_n/scale_x_d <= MAX_SCALE}}
- %1 = tosa.resize %arg0, %scale, %offset, %border {mode = "BILINEAR"} :
+ %1 = tosa.resize %arg0, %scale, %offset, %border {mode = BILINEAR} :
(tensor<1x32x32x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x64x7970x8xf32>
return %1 : tensor<1x64x7970x8xf32>
}
@@ -1009,7 +1009,7 @@ func.func @test_resize_tensor_size_invalid(%arg0: tensor<1x23178x23178x1xf32>) {
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
// expected-error@+1 {{'tosa.resize' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
- %0 = tosa.resize %arg0, %scale, %offset, %border {mode = "NEAREST_NEIGHBOR"} : (tensor<1x23178x23178x1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
+ %0 = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<1x23178x23178x1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
return
}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 30361a8..bee0eb1 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -105,7 +105,7 @@ func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8>
%shift = "tosa.const"() {values = dense<[37, 36, 37]> : tensor<3xi8>} : () -> tensor<3xi8>
%rescale_input_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
%rescale_output_zp = "tosa.const"() <{values = dense<27> : tensor<1xi8>}> : () -> tensor<1xi8>
- %3 = tosa.rescale %2, %multiplier, %shift, %rescale_input_zp, %rescale_output_zp {rounding_mode = "DOUBLE_ROUND", scale32 = true, per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x3xi32>, tensor<3xi32>, tensor<3xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x1x1x3xi8>
+ %3 = tosa.rescale %2, %multiplier, %shift, %rescale_input_zp, %rescale_output_zp {rounding_mode = DOUBLE_ROUND, scale32 = true, per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x3xi32>, tensor<3xi32>, tensor<3xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x1x1x3xi8>
return %3 : tensor<1x1x1x3xi8>
}
@@ -240,14 +240,21 @@ func.func @test_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
// -----
// CHECK-LABEL: clamp_propagate
func.func @test_clamp_propagate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
- %0 = tosa.clamp %arg0 {min_val = 0.0 : f32, max_val = 1.0: f32, nan_mode = "PROPAGATE"} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ %0 = tosa.clamp %arg0 {min_val = 0.0 : f32, max_val = 1.0: f32, nan_mode = PROPAGATE} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
// -----
// CHECK-LABEL: clamp_ignore
func.func @test_clamp_ignore(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
- %0 = tosa.clamp %arg0 {min_val = 0.0 : f32, max_val = 1.0: f32, nan_mode = "IGNORE"} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ %0 = tosa.clamp %arg0 {min_val = 0.0 : f32, max_val = 1.0: f32, nan_mode = IGNORE} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: clamp_ignore_enum
+func.func @test_clamp_ignore_enum(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = tosa.clamp %arg0 {min_val = 0.0 : f32, max_val = 1.0: f32, nan_mode = #tosa.nan_mode<IGNORE>} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -371,14 +378,28 @@ func.func @test_logical_xor(%arg0: tensor<13x1x3xi1>, %arg1: tensor<13x21x3xi1>)
}
// -----
-// CHECK-LABEL: maximum
+// CHECK-LABEL: test_max
func.func @test_max(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {
- %0 = tosa.maximum %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32>
+ %0 = tosa.maximum %arg0, %arg1 {} : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_max_ignore
+func.func @test_max_ignore(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {
+ %0 = tosa.maximum %arg0, %arg1 { nan_mode = IGNORE } : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
// -----
-// CHECK-LABEL: minimum
+// CHECK-LABEL: test_max_propagate
+func.func @test_max_propagate(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {
+ %0 = tosa.maximum %arg0, %arg1 { nan_mode = PROPAGATE } : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_min
func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.minimum %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<1x21x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
@@ -750,12 +771,12 @@ func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tens
}
// -----
-// CHECK-LABEL: resize
+// CHECK-LABEL: test_resize
func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
%scale = tosa.const_shape { values = dense<[4, 2, 4, 2]> : tensor<4xindex> } : () -> !tosa.shape<4>
%offset = tosa.const_shape { values = dense<[-1, -1]> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<[1, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
- %1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x32x32x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x64x64x8xf32>
+ %1 = tosa.resize %arg0, %scale, %offset, %border { mode = BILINEAR } : (tensor<1x32x32x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x64x64x8xf32>
return %1 : tensor<1x64x64x8xf32>
}
@@ -765,7 +786,7 @@ func.func @test_resize_unranked_output(%arg0: tensor<1x32x32x8xf32>) -> tensor<*
%scale = tosa.const_shape { values = dense<[4, 2, 4, 2]> : tensor<4xindex> } : () -> !tosa.shape<4>
%offset = tosa.const_shape { values = dense<[-1, -1]> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<[1, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
- %1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x32x32x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+ %1 = tosa.resize %arg0, %scale, %offset, %border { mode = BILINEAR } : (tensor<1x32x32x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
return %1 : tensor<*xf32>
}
@@ -775,12 +796,22 @@ func.func @test_resize_unranked_input(%arg0: tensor<*xf32>) -> tensor<1x64x64x8x
%scale = tosa.const_shape { values = dense<[4, 2, 4, 2]> : tensor<4xindex> } : () -> !tosa.shape<4>
%offset = tosa.const_shape { values = dense<[-1, -1]> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<[1, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
- %1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<*xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x64x64x8xf32>
+ %1 = tosa.resize %arg0, %scale, %offset, %border { mode = BILINEAR } : (tensor<*xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x64x64x8xf32>
return %1 : tensor<1x64x64x8xf32>
}
// -----
-// CHECK-LABEL: cast
+// CHECK-LABEL: test_resize_enum
+func.func @test_resize_enum(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
+ %scale = tosa.const_shape { values = dense<[4, 2, 4, 2]> : tensor<4xindex> } : () -> !tosa.shape<4>
+ %offset = tosa.const_shape { values = dense<[-1, -1]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %border = tosa.const_shape { values = dense<[1, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %1 = tosa.resize %arg0, %scale, %offset, %border { mode = #tosa.resize_mode<BILINEAR> } : (tensor<1x32x32x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x64x64x8xf32>
+ return %1 : tensor<1x64x64x8xf32>
+}
+
+// -----
+// CHECK-LABEL: test_cast1
func.func @test_cast1(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
%0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
@@ -807,7 +838,7 @@ func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
%input_zp = "tosa.const"() <{values = dense<127> : tensor<1xi8>}> : () -> tensor<1xi8>
%output_zp = "tosa.const"() <{values = dense<-1> : tensor<1xi8>}> : () -> tensor<1xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", scale32 = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = SINGLE_ROUND, scale32 = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
}
@@ -818,7 +849,18 @@ func.func @test_rescale_i16_zp32768(%arg0 : tensor<2xi8>) -> tensor<2xi16> {
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8>} : () -> tensor<1xi8>
%input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<32768> : tensor<1xi16>} : () -> tensor<1xi16>
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<2xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = SINGLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<2xi16>
+ return %0 : tensor<2xi16>
+}
+
+// -----
+// CHECK-LABEL: test_rescale_i16_rounding_mode
+func.func @test_rescale_i16_rounding_mode(%arg0 : tensor<2xi8>) -> tensor<2xi16> {
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16>} : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<15> : tensor<1xi8>} : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() {values = dense<32768> : tensor<1xi16>} : () -> tensor<1xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = #tosa.rounding_mode<SINGLE_ROUND>, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<2xi16>
return %0 : tensor<2xi16>
}
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index fad4859..58a73d6 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -322,6 +322,6 @@ func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
%offset = tosa.const_shape { values = dense<[-1, -1]> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<[1, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
// expected-error@+1 {{'tosa.resize' op illegal: requires [pro_fp] but not enabled in target}}
- %1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x32x32x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x64x64x8xf32>
+ %1 = tosa.resize %arg0, %scale, %offset, %border { mode = BILINEAR } : (tensor<1x32x32x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x64x64x8xf32>
return %1 : tensor<1x64x64x8xf32>
}
diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
index 9438179..a5784b3 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
@@ -254,7 +254,7 @@ func.func @test_resize(%arg0: tensor<1x32x32x8xi8>) -> tensor<1x64x64x8xi32> {
%offset = tosa.const_shape { values = dense<[-1, -1]> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<[1, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
// expected-error@+1 {{'tosa.resize' op illegal: requires [pro_int] but not enabled in target}}
- %1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x32x32x8xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x64x64x8xi32>
+ %1 = tosa.resize %arg0, %scale, %offset, %border { mode = BILINEAR } : (tensor<1x32x32x8xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x64x64x8xi32>
return %1 : tensor<1x64x64x8xi32>
}
@@ -293,6 +293,6 @@ func.func @test_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi32> {
%input_zp = "tosa.const"() {values = dense<127> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
// expected-error@+1 {{'tosa.rescale' op illegal: requires [pro_int] but not enabled in target}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", scale32 = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<13x21x3xi32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = SINGLE_ROUND, scale32 = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
diff --git a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
index e957bdd..a64f69a 100644
--- a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
@@ -9,8 +9,8 @@ func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
- // CHECK: %[[RESCALE:.*]] = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
- %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
+ // CHECK: %[[RESCALE:.*]] = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = SINGLE_ROUND, scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
+ %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = SINGLE_ROUND, scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
// CHECK: return %[[RESCALE]] : tensor<1x1xi8>
return %r : tensor<1x1xui8>
}
@@ -24,8 +24,8 @@ func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui16>) -> (tensor<1x1xi
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
%3 = "tosa.const"() <{values = dense<32768> : tensor<1xi16>}> : () -> tensor<1xi16>
- // CHECK: %[[RESCALE:.*]] = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x1xi8>
- %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x1xi8>
+ // CHECK: %[[RESCALE:.*]] = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = SINGLE_ROUND, scale32 = true} : (tensor<1x1xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x1xi8>
+ %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = SINGLE_ROUND, scale32 = true} : (tensor<1x1xui16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x1xi8>
// CHECK: return %[[RESCALE]] : tensor<1x1xi8>
return %r : tensor<1x1xi8>
}
@@ -54,7 +54,7 @@ func.func @test_no_change(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
// CHECK-LABEL: test_regions
// CHECK: %arg0: tensor<i8>, %arg1: tensor<i8>
func.func @test_regions(%arg0: tensor<ui8>, %arg1: tensor<ui8>, %arg2: tensor<i1>) -> tensor<ui8> {
- // CHECK: tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<i8>, tensor<i8>) -> tensor<i8>
+ // CHECK: tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<i8>, tensor<i8>) -> tensor<i8>
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
^bb0(%arg3: tensor<ui8>, %arg4: tensor<ui8>):
// CHECK: %1 = tosa.add %arg0, %arg1 : (tensor<i8>, tensor<i8>) -> tensor<i8>
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 7b8fc24..80f06f1 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -106,7 +106,7 @@ func.func @test_unary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<2xi8>) -> () {
%shift = "tosa.const"() {values = dense<[14, 15]> : tensor<2xi8>} : () -> tensor<2xi8>
%input_zp = "tosa.const"() {values = dense<43> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<52> : tensor<1xi8>} : () -> tensor<1xi8>
- %6 = tosa.rescale %arg1, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = true, input_unsigned = true, output_unsigned = true} : (tensor<2xi8>, tensor<2xi16>, tensor<2xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
+ %6 = tosa.rescale %arg1, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = SINGLE_ROUND, per_channel = true, input_unsigned = true, output_unsigned = true} : (tensor<2xi8>, tensor<2xi16>, tensor<2xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
// CHECK: tosa.identity %arg0 : (tensor<4xi32>) -> tensor<4xi32>
%7 = tosa.identity %arg0 : (tensor<4xi32>) -> tensor<?xi32>
@@ -1071,7 +1071,7 @@ func.func @resize_int_horizontal(%arg0: tensor<1x15x13x1xi8>) {
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
// CHECK: -> tensor<1x23x179x1xi8>
- %0 = tosa.resize %arg0, %scale, %offset, %border {mode = "NEAREST_NEIGHBOR"} : (tensor<1x15x13x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xi8>
+ %0 = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<1x15x13x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xi8>
return
}
@@ -1083,7 +1083,7 @@ func.func @resize_int_vertical(%arg0: tensor<1x49x42x1xi16>) {
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
// CHECK: -> tensor<1x112x220x1xi16>
- %0 = tosa.resize %arg0, %scale, %offset, %border {mode = "NEAREST_NEIGHBOR"} : (tensor<1x49x42x1xi16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xi16>
+ %0 = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<1x49x42x1xi16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xi16>
return
}
@@ -1095,7 +1095,7 @@ func.func @resize_int_power_of_two_upscale(%arg0: tensor<1x23x19x1xi8>) {
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
// CHECK: -> tensor<1x353x289x1xi32>
- %0 = tosa.resize %arg0, %scale, %offset, %border {mode = "BILINEAR"} : (tensor<1x23x19x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xi32>
+ %0 = tosa.resize %arg0, %scale, %offset, %border {mode = BILINEAR} : (tensor<1x23x19x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xi32>
return
}
@@ -1107,7 +1107,7 @@ func.func @resize_int_power_of_two_upscale_offsetted(%arg0: tensor<1x41x26x1xi16
%offset = tosa.const_shape { values = dense<[-7, -7]> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<[7, 7]> : tensor<2xindex> } : () -> !tosa.shape<2>
// CHECK: -> tensor<1x328x208x1xi48>
- %0 = tosa.resize %arg0, %scale, %offset, %border {mode = "BILINEAR"} : (tensor<1x41x26x1xi16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xi48>
+ %0 = tosa.resize %arg0, %scale, %offset, %border {mode = BILINEAR} : (tensor<1x41x26x1xi16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xi48>
return
}
@@ -1118,7 +1118,7 @@ func.func @resize_fp_horizontal(%arg0: tensor<1x50x48x1xf32>) {
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
// CHECK: -> tensor<1x106x85x1xf32>
- %0 = tosa.resize %arg0, %scale, %offset, %border {mode = "BILINEAR"} : (tensor<1x50x48x1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
+ %0 = tosa.resize %arg0, %scale, %offset, %border {mode = BILINEAR} : (tensor<1x50x48x1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
return
}
@@ -1129,7 +1129,7 @@ func.func @resize_fp_vertical(%arg0: tensor<1x50x48x1xf32>) {
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
// CHECK: -> tensor<1x128x13x1xf32>
- %0 = tosa.resize %arg0, %scale, %offset, %border {mode = "NEAREST_NEIGHBOR"} : (tensor<1x50x48x1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
+ %0 = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<1x50x48x1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
return
}
@@ -1141,7 +1141,7 @@ func.func @resize_fp_power_of_two_upscale(%arg0: tensor<1x23x23x1xf32>) {
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
// CHECK: -> tensor<1x89x89x1xf32>
- %0 = tosa.resize %arg0, %scale, %offset, %border {mode = "BILINEAR"} : (tensor<1x23x23x1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
+ %0 = tosa.resize %arg0, %scale, %offset, %border {mode = BILINEAR} : (tensor<1x23x23x1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
return
}
@@ -1153,7 +1153,7 @@ func.func @resize_fp_power_of_two_upscale_offsetted(%arg0: tensor<1x50x48x1xf32>
%offset = tosa.const_shape { values = dense<[-31, -31]> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<[31, 31]> : tensor<2xindex> } : () -> !tosa.shape<2>
// CHECK: -> tensor<1x1600x1536x1xf32>
- %0 = tosa.resize %arg0, %scale, %offset, %border {mode = "NEAREST_NEIGHBOR"} : (tensor<1x50x48x1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
+ %0 = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<1x50x48x1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
return
}
@@ -1165,7 +1165,7 @@ func.func @resize_negative_output_dim(%arg0: tensor<1x3x1x1xi8>) {
%offset = tosa.const_shape { values = dense<[6, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<[-15, 0]> : tensor<2xindex> } : () -> !tosa.shape<2>
// expected-error@+1 {{calculated output height and width must be non-negative, got height = -5, width = 0}}
- %0 = tosa.resize %arg0, %scale, %offset, %border {mode = "NEAREST_NEIGHBOR"} : (tensor<1x3x1x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xi8>
+ %0 = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<1x3x1x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xi8>
return
}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
index cab1420..88ec027 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
@@ -14,7 +14,7 @@ func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui8>) -> (tensor<1x1xi8
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
- %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
+ %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = SINGLE_ROUND, scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
return %r : tensor<1x1xi8>
}
@@ -26,6 +26,6 @@ func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
- %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
+ %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = SINGLE_ROUND, scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
return %r : tensor<1x1xui8>
}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 2a937b0..f58ddb1 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -367,7 +367,7 @@ func.func @test_error_scalar_input_with_per_channel(%arg0: tensor<i8>) -> tensor
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
// expected-error@+1 {{'tosa.rescale' op requires input to be at least rank 1 when per_channel is true, but got rank 0}}
- %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<i8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<i16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = SINGLE_ROUND, per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<i8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<i16>
return %0 : tensor<i16>
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index f86fb38..e7381e0 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -625,40 +625,40 @@ func.func @insert_extract_transpose_2d(
// -----
// CHECK-LABEL: insert_extract_chain
-// CHECK-SAME: %[[V234:[a-zA-Z0-9]*]]: vector<2x3x4xf32>
+// CHECK-SAME: %[[V334:[a-zA-Z0-9]*]]: vector<3x3x4xf32>
// CHECK-SAME: %[[V34:[a-zA-Z0-9]*]]: vector<3x4xf32>
// CHECK-SAME: %[[V4:[a-zA-Z0-9]*]]: vector<4xf32>
-func.func @insert_extract_chain(%v234: vector<2x3x4xf32>, %v34: vector<3x4xf32>, %v4: vector<4xf32>)
+func.func @insert_extract_chain(%v334: vector<3x3x4xf32>, %v34: vector<3x4xf32>, %v4: vector<4xf32>)
-> (vector<4xf32>, vector<4xf32>, vector<3x4xf32>, vector<3x4xf32>) {
// CHECK-NEXT: %[[A34:.*]] = vector.insert
- %A34 = vector.insert %v34, %v234[0]: vector<3x4xf32> into vector<2x3x4xf32>
+ %A34 = vector.insert %v34, %v334[0]: vector<3x4xf32> into vector<3x3x4xf32>
// CHECK-NEXT: %[[B34:.*]] = vector.insert
- %B34 = vector.insert %v34, %A34[1]: vector<3x4xf32> into vector<2x3x4xf32>
+ %B34 = vector.insert %v34, %A34[1]: vector<3x4xf32> into vector<3x3x4xf32>
// CHECK-NEXT: %[[A4:.*]] = vector.insert
- %A4 = vector.insert %v4, %B34[1, 0]: vector<4xf32> into vector<2x3x4xf32>
+ %A4 = vector.insert %v4, %B34[1, 0]: vector<4xf32> into vector<3x3x4xf32>
// CHECK-NEXT: %[[B4:.*]] = vector.insert
- %B4 = vector.insert %v4, %A4[1, 1]: vector<4xf32> into vector<2x3x4xf32>
+ %B4 = vector.insert %v4, %A4[1, 1]: vector<4xf32> into vector<3x3x4xf32>
// Case 2.a. [1, 1] == insertpos ([1, 1])
// Match %A4 insertionpos and fold to its source(i.e. %V4).
- %r0 = vector.extract %B4[1, 1]: vector<4xf32> from vector<2x3x4xf32>
+ %r0 = vector.extract %B4[1, 1]: vector<4xf32> from vector<3x3x4xf32>
// Case 3.a. insertpos ([1]) is a prefix of [1, 0].
// Traverse %B34 to its source(i.e. %V34@[*0*]).
// CHECK-NEXT: %[[R1:.*]] = vector.extract %[[V34]][0]
- %r1 = vector.extract %B34[1, 0]: vector<4xf32> from vector<2x3x4xf32>
+ %r1 = vector.extract %B34[1, 0]: vector<4xf32> from vector<3x3x4xf32>
// Case 4. [1] is a prefix of insertpos ([1, 1]).
// Cannot traverse %B4.
// CHECK-NEXT: %[[R2:.*]] = vector.extract %[[B4]][1]
- %r2 = vector.extract %B4[1]: vector<3x4xf32> from vector<2x3x4xf32>
+ %r2 = vector.extract %B4[1]: vector<3x4xf32> from vector<3x3x4xf32>
// Case 5. [0] is disjoint from insertpos ([1, 1]).
// Traverse %B4 to its dest(i.e. %A4@[0]).
// Traverse %A4 to its dest(i.e. %B34@[0]).
// Traverse %B34 to its dest(i.e. %A34@[0]).
// Match %A34 insertionpos and fold to its source(i.e. %V34).
- %r3 = vector.extract %B4[0]: vector<3x4xf32> from vector<2x3x4xf32>
+ %r3 = vector.extract %B4[0]: vector<3x4xf32> from vector<3x3x4xf32>
// CHECK: return %[[V4]], %[[R1]], %[[R2]], %[[V34]]
return %r0, %r1, %r2, %r3:
@@ -1057,8 +1057,8 @@ func.func @insert_fold_same_rank(%v: vector<2x2xf32>) -> vector<2x2xf32> {
// CHECK-LABEL: func @insert_no_fold_scalar_to_0d(
// CHECK-SAME: %[[v:.*]]: vector<f32>)
-// CHECK: %[[extract:.*]] = vector.insert %{{.*}}, %[[v]] [] : f32 into vector<f32>
-// CHECK: return %[[extract]]
+// CHECK: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<f32>
+// CHECK: return %[[cst]]
func.func @insert_no_fold_scalar_to_0d(%v: vector<f32>) -> vector<f32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = vector.insert %cst, %v [] : f32 into vector<f32>
@@ -1168,6 +1168,106 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)
// -----
+// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim
+// CHECK-NOT: vector.shape_cast
+// CHECK: vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32>
+func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim(%arg0 : vector<2xf32>) -> vector<32x2xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2xf32> to vector<1x2xf32>
+ %1 = vector.broadcast %0 : vector<1x2xf32> to vector<32x2xf32>
+ return %1 : vector<32x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim2(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2x1xf32> {
+// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<2x1xf32> to vector<32x2x1xf32>
+// CHECK: return %[[VAL_0]] : vector<32x2x1xf32>
+// CHECK: }
+func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim2(%arg0 : vector<2x1xf32>) -> vector<32x2x1xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2x1xf32>
+ %1 = vector.broadcast %0 : vector<1x2x1xf32> to vector<32x2x1xf32>
+ return %1 : vector<32x2x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim3(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2x4xf32> {
+// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<2x1xf32> to vector<32x2x4xf32>
+// CHECK: return %[[VAL_0]] : vector<32x2x4xf32>
+// CHECK: }
+func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim3(%arg0 : vector<2x1xf32>) -> vector<32x2x4xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2x1xf32>
+ %1 = vector.broadcast %0 : vector<1x2x1xf32> to vector<32x2x4xf32>
+ return %1 : vector<32x2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_remove_leading_dim(
+// CHECK-SAME: %[[ARG0:.*]]: vector<1x2xf32>) -> vector<32x2xf32> {
+// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<1x2xf32> to vector<32x2xf32>
+// CHECK: return %[[VAL_0]] : vector<32x2xf32>
+// CHECK: }
+func.func @canonicalize_shapecast_broadcast_to_broadcast_remove_leading_dim(%arg0 : vector<1x2xf32>) -> vector<32x2xf32> {
+ %0 = vector.shape_cast %arg0 : vector<1x2xf32> to vector<2xf32>
+ %1 = vector.broadcast %0 : vector<2xf32> to vector<32x2xf32>
+ return %1 : vector<32x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_shape
+// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32>
+// CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
+func.func @negative_canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
+ %0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32>
+ %1 = vector.broadcast %0 : vector<4x16xf32> to vector<2x4x16xf32>
+ return %1 : vector<2x4x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims
+// CHECK: vector.shape_cast {{.+}} : vector<2x1xf32> to vector<1x2xf32>
+// CHECK: vector.broadcast {{.+}} : vector<1x2xf32> to vector<2x2xf32>
+func.func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims(%arg0 : vector<2x1xf32>) -> vector<2x2xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2xf32>
+ %1 = vector.broadcast %0 : vector<1x2xf32> to vector<2x2xf32>
+ return %1 : vector<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_append_dim(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2xf32>) -> vector<2x4xf32> {
+// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2xf32> to vector<2x1xf32>
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1xf32> to vector<2x4xf32>
+// CHECK: return %[[VAL_1]] : vector<2x4xf32>
+// CHECK: }
+func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_append_dim(%arg0 : vector<2xf32>) -> vector<2x4xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2xf32> to vector<2x1xf32>
+ %1 = vector.broadcast %0 : vector<2x1xf32> to vector<2x4xf32>
+ return %1 : vector<2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_remove_trailing_dim(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2xf32> {
+// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2x1xf32> to vector<2xf32>
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2xf32> to vector<32x2xf32>
+// CHECK: return %[[VAL_1]] : vector<32x2xf32>
+// CHECK: }
+func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_remove_trailing_dim(%arg0 : vector<2x1xf32>) -> vector<32x2xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<2xf32>
+ %1 = vector.broadcast %0 : vector<2xf32> to vector<32x2xf32>
+ return %1 : vector<32x2xf32>
+}
+
+// -----
+
// CHECK-LABEL: fold_vector_transfer_masks
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -2569,6 +2669,112 @@ func.func @insert_2d_constant() -> (vector<2x3xi32>, vector<2x3xi32>, vector<2x3
// -----
+// +---------------------------------------------------------------------------
+// Tests for InsertChainFullyInitialized .
+// +---------------------------------------------------------------------------
+// This pattern should fire when all vector elements are overwritten by vector.insert
+// at static positions, replacing the insert chain with vector.from_elements.
+// CHECK-LABEL: func.func @fully_insert_scalar_to_vector(
+// CHECK-SAME: %[[DEST:.+]]: vector<2xi64>, %[[SRC1:.+]]: i64, %[[SRC2:.+]]: i64)
+// CHECK: %[[RES:.+]] = vector.from_elements %[[SRC1]], %[[SRC2]] : vector<2xi64>
+// CHECK-NEXT: return %[[RES]]
+func.func @fully_insert_scalar_to_vector(%dest : vector<2xi64>, %src1 : i64, %src2 : i64) -> vector<2xi64> {
+ %v1 = vector.insert %src1, %dest[0] : i64 into vector<2xi64>
+ %v2 = vector.insert %src2, %v1[1] : i64 into vector<2xi64>
+ return %v2 : vector<2xi64>
+}
+
+// -----
+
+// Same as the above test, but with vector insertions.
+// CHECK-LABEL: func.func @fully_insert_vector_to_vector(
+// CHECK-SAME: %[[DEST:.+]]: vector<2x2xi64>, %[[SRC1:.+]]: vector<2xi64>, %[[SRC2:.+]]: vector<2xi64>)
+// CHECK: %[[ELE1:.+]]:2 = vector.to_elements %[[SRC1]] : vector<2xi64>
+// CHECK: %[[ELE2:.+]]:2 = vector.to_elements %[[SRC2]] : vector<2xi64>
+// CHECK: %[[RES:.+]] = vector.from_elements %[[ELE1]]#0, %[[ELE1]]#1, %[[ELE2]]#0, %[[ELE2]]#1 : vector<2x2xi64>
+// CHECK-NEXT: return %[[RES]]
+func.func @fully_insert_vector_to_vector(%dest : vector<2x2xi64>, %src1 : vector<2xi64>, %src2 : vector<2xi64>) -> vector<2x2xi64> {
+ %v1 = vector.insert %src1, %dest[0] : vector<2xi64> into vector<2x2xi64>
+ %v2 = vector.insert %src2, %v1[1] : vector<2xi64> into vector<2x2xi64>
+ return %v2 : vector<2x2xi64>
+}
+
+// -----
+
+// Test InsertChainFullyInitialized pattern with overlapping insertions.
+// 1. The first op inserts %src2 at [0,0].
+// 2. The second op inserts %src1 at [0,0], [0,1], overwriting %src2 at [0,0].
+// 3. The third op inserts %src1 at [1,0], [1,1].
+// 4. The fourth op inserts %src2 at [1,1], overwriting the existing value at [1,1].
+// CHECK-LABEL: func.func @fully_insert_to_vector_overlap_1(
+// CHECK-SAME: %[[DEST:.+]]: vector<2x2xi64>, %[[SRC1:.+]]: vector<2xi64>, %[[SRC2:.+]]: i64)
+// CHECK: %[[ELE1:.+]]:2 = vector.to_elements %[[SRC1]] : vector<2xi64>
+// CHECK: %[[ELE2:.+]]:2 = vector.to_elements %[[SRC1]] : vector<2xi64>
+// CHECK: %[[RES:.+]] = vector.from_elements %[[ELE1]]#0, %[[ELE1]]#1, %[[ELE2]]#0, %[[SRC2]] : vector<2x2xi64>
+// CHECK-NEXT: return %[[RES]]
+func.func @fully_insert_to_vector_overlap_1(%dest : vector<2x2xi64>, %src1 : vector<2xi64>, %src2 : i64) -> vector<2x2xi64> {
+ %v0 = vector.insert %src2, %dest[0, 0] : i64 into vector<2x2xi64>
+ %v1 = vector.insert %src1, %v0[0] : vector<2xi64> into vector<2x2xi64>
+ %v2 = vector.insert %src1, %v1[1] : vector<2xi64> into vector<2x2xi64>
+ %v3 = vector.insert %src2, %v2[1, 1] : i64 into vector<2x2xi64>
+ return %v3 : vector<2x2xi64>
+}
+
+// -----
+
+// Test InsertChainFullyInitialized pattern with overlapping insertions.
+// The vector inserted at last should overwrite the previously inserted scalars.
+// CHECK-LABEL: func.func @fully_insert_to_vector_overlap_2(
+// CHECK-SAME: %[[DEST:.+]]: vector<2x2xi64>, %[[SRC1:.+]]: i64, %[[SRC2:.+]]: i64, %[[SRC3:.+]]: vector<2xi64>, %[[SRC4:.+]]: vector<2xi64>)
+// CHECK: %[[ELE1:.+]]:2 = vector.to_elements %[[SRC3]] : vector<2xi64>
+// CHECK: %[[ELE2:.+]]:2 = vector.to_elements %[[SRC4]] : vector<2xi64>
+// CHECK: %[[RES:.+]] = vector.from_elements %[[ELE1]]#0, %[[ELE1]]#1, %[[ELE2]]#0, %[[ELE2]]#1 : vector<2x2xi64>
+// CHECK-NEXT: return %[[RES]]
+func.func @fully_insert_to_vector_overlap_2(%dest : vector<2x2xi64>, %src1 : i64, %src2 : i64, %src3 : vector<2xi64>, %src4 : vector<2xi64>) -> vector<2x2xi64> {
+ %v0 = vector.insert %src1, %dest[0, 0] : i64 into vector<2x2xi64>
+ %v1 = vector.insert %src2, %v0[0, 1] : i64 into vector<2x2xi64>
+ %v2 = vector.insert %src3, %v1[0] : vector<2xi64> into vector<2x2xi64>
+ %v3 = vector.insert %src4, %v2[1] : vector<2xi64> into vector<2x2xi64>
+ return %v3 : vector<2x2xi64>
+}
+
+// -----
+
+// Negative test for InsertChainFullyInitialized pattern when only some elements are overwritten.
+// CHECK-LABEL: func.func @negative_partially_insert_vector_to_vector(
+// CHECK-SAME: %[[DEST:.+]]: vector<2x2xi64>, %[[SRC1:.+]]: vector<2xi64>, %[[SRC2:.+]]: i64)
+// CHECK: %[[V0:.+]] = vector.insert %[[SRC1]], %[[DEST]] [0] : vector<2xi64> into vector<2x2xi64>
+// CHECK: %[[V1:.+]] = vector.insert %[[SRC2]], %[[V0]] [0, 0] : i64 into vector<2x2xi64>
+// CHECK: return %[[V1]]
+func.func @negative_partially_insert_vector_to_vector(%dest : vector<2x2xi64>, %src1 : vector<2xi64>, %src2 : i64) -> vector<2x2xi64> {
+ %v1 = vector.insert %src1, %dest[0] : vector<2xi64> into vector<2x2xi64>
+ %v2 = vector.insert %src2, %v1[0, 0] : i64 into vector<2x2xi64>
+ return %v2 : vector<2x2xi64>
+}
+
+// -----
+
+// Negative test when intermediate results have more than one user.
+// CHECK-LABEL: func.func @negative_intermediate_insert_multiple_users(
+// CHECK-SAME: %[[DEST:.+]]: vector<3xi64>, %[[SRC1:.+]]: i64, %[[SRC2:.+]]: i64, %[[SRC3:.+]]: i64, %[[SRC4:.+]]: i64)
+// CHECK: %[[V0:.+]] = vector.insert %[[SRC1]], %[[DEST]] [0] : i64 into vector<3xi64>
+// CHECK: %[[V1:.+]] = vector.insert %[[SRC2]], %[[V0]] [1] : i64 into vector<3xi64>
+// CHECK: %[[V2:.+]] = vector.insert %[[SRC3]], %[[V1]] [2] : i64 into vector<3xi64>
+// CHECK: %[[V3:.+]] = vector.insert %[[SRC4]], %[[V1]] [2] : i64 into vector<3xi64>
+func.func @negative_intermediate_insert_multiple_users(%dest : vector<3xi64>, %src1 : i64, %src2 : i64, %src3 : i64, %src4 : i64) -> (vector<3xi64>, vector<3xi64>) {
+ %v1 = vector.insert %src1, %dest[0] : i64 into vector<3xi64>
+ %v2 = vector.insert %src2, %v1[1] : i64 into vector<3xi64>
+ %v3_0 = vector.insert %src3, %v2[2] : i64 into vector<3xi64>
+ %v3_1 = vector.insert %src4, %v2[2] : i64 into vector<3xi64>
+ return %v3_0, %v3_1 : vector<3xi64>, vector<3xi64>
+}
+
+// +---------------------------------------------------------------------------
+// End of Tests For InsertChainFullyInitialized.
+// +---------------------------------------------------------------------------
+
+// -----
+
// CHECK-LABEL: func.func @insert_2d_splat_constant
// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<0> : vector<2x3xi32>
// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[99, 0, 0\], \[0, 0, 0\]\]}}> : vector<2x3xi32>
@@ -3520,3 +3726,17 @@ func.func @no_fold_insert_use_chain_mismatch_static_position(%arg : vector<4xf32
%v_1 = vector.insert %val, %v_0[1] : f32 into vector<4xf32>
return %v_1 : vector<4xf32>
}
+
+// -----
+
+llvm.mlir.global constant @my_symbol() : i32
+
+// CHECK-LABEL: func @from_address_of_regression
+// CHECK: %[[a:.*]] = llvm.mlir.addressof @my_symbol
+// CHECK: %[[b:.*]] = vector.broadcast %[[a]] : !llvm.ptr to vector<1x!llvm.ptr>
+// CHECK: return %[[b]]
+func.func @from_address_of_regression() -> vector<1x!llvm.ptr> {
+ %a = llvm.mlir.addressof @my_symbol : !llvm.ptr
+ %b = vector.from_elements %a : vector<1x!llvm.ptr>
+ return %b : vector<1x!llvm.ptr>
+}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c21de56..6ee70fd 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1305,6 +1305,26 @@ func.func @store_memref_index_mismatch(%base : memref<?xf32>, %value : vector<16
// -----
+//===----------------------------------------------------------------------===//
+// vector.maskedload
+//===----------------------------------------------------------------------===//
+
+func.func @maskedload_nonpositive_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %pass: vector<1xi32>, %index: index) {
+ // expected-error@below {{'vector.maskedload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ %val = vector.maskedload %base[%index], %mask, %pass { alignment = 0 } : memref<4xi32>, vector<32xi1>, vector<1xi32> into vector<1xi32>
+ return
+}
+
+// -----
+
+func.func @maskedload_non_power_of_2_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %pass: vector<1xi32>, %index: index) {
+ // expected-error@below {{'vector.maskedload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ %val = vector.maskedload %base[%index], %mask, %pass { alignment = 3 } : memref<4xi32>, vector<32xi1>, vector<1xi32> into vector<1xi32>
+ return
+}
+
+// -----
+
func.func @maskedload_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass: vector<16xf32>) {
%c0 = arith.constant 0 : index
// expected-error@+1 {{'vector.maskedload' op base and result element type should match}}
@@ -1336,6 +1356,26 @@ func.func @maskedload_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>
// -----
+//===----------------------------------------------------------------------===//
+// vector.maskedstore
+//===----------------------------------------------------------------------===//
+
+func.func @maskedstore_nonpositive_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %value: vector<1xi32>, %index: index) {
+ // expected-error@below {{'vector.maskedstore' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ vector.maskedstore %base[%index], %mask, %value { alignment = 0 } : memref<4xi32>, vector<32xi1>, vector<1xi32> into vector<1xi32>
+ return
+}
+
+// -----
+
+func.func @maskedstore_non_power_of_2_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %value: vector<1xi32>, %index: index) {
+ // expected-error@below {{'vector.maskedstore' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ vector.maskedstore %base[%index], %mask, %value { alignment = 3 } : memref<4xi32>, vector<32xi1>, vector<1xi32> into vector<1xi32>
+ return
+}
+
+// -----
+
func.func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
%c0 = arith.constant 0 : index
// expected-error@+1 {{'vector.maskedstore' op base and valueToStore element type should match}}
@@ -1430,6 +1470,24 @@ func.func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector
// -----
+func.func @gather_nonpositive_alignment(%base: memref<16xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0 : index) {
+ // expected-error@+2 {{'vector.gather' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
+ { alignment = 0 } : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
+func.func @gather_non_power_of_two_alignment(%base: memref<16xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0 : index) {
+ // expected-error@+2 {{'vector.gather' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
+ { alignment = 3 } : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
@@ -1491,6 +1549,24 @@ func.func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi
// -----
+func.func @scatter_nonpositive_alignment(%base: memref<?xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
+ // expected-error@+1 {{'vector.scatter' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ vector.scatter %base[%c0][%indices], %mask, %value { alignment = 0 }
+ : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+}
+
+// -----
+
+func.func @scatter_non_power_of_2_alignment(%base: memref<?xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
+ // expected-error@+1 {{'vector.scatter' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ vector.scatter %base[%c0][%indices], %mask, %value { alignment = 3 }
+ : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+}
+
+// -----
+
func.func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
// expected-error@+1 {{'vector.expandload' op base and result element type should match}}
@@ -1531,6 +1607,20 @@ func.func @expand_memref_mismatch(%base: memref<?x?xf32>, %mask: vector<16xi1>,
// -----
+func.func @expand_nonpositive_alignment(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0: index) {
+ // expected-error@+1 {{'vector.expandload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ %0 = vector.expandload %base[%c0], %mask, %pass_thru { alignment = 0 } : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
+func.func @expand_non_power_of_2_alignment(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0: index) {
+ // expected-error@+1 {{'vector.expandload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ %0 = vector.expandload %base[%c0], %mask, %pass_thru { alignment = 3 } : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
func.func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
%c0 = arith.constant 0 : index
// expected-error@+1 {{'vector.compressstore' op base and valueToStore element type should match}}
@@ -1563,6 +1653,20 @@ func.func @compress_memref_mismatch(%base: memref<?x?xf32>, %mask: vector<16xi1>
// -----
+func.func @compress_nonpositive_alignment(%base: memref<?xf32>, %mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
+ // expected-error @below {{'vector.compressstore' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ vector.compressstore %base[%c0], %mask, %value { alignment = 0 } : memref<?xf32>, vector<16xi1>, vector<16xf32>
+}
+
+// -----
+
+func.func @compress_non_power_of_2_alignment(%base: memref<?xf32>, %mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
+ // expected-error @below {{'vector.compressstore' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ vector.compressstore %base[%c0], %mask, %value { alignment = 3 } : memref<?xf32>, vector<16xi1>, vector<16xf32>
+}
+
+// -----
+
func.func @scan_reduction_dim_constraint(%arg0: vector<2x3xi32>, %arg1: vector<3xi32>) -> vector<3xi32> {
// expected-error@+1 {{'vector.scan' op reduction dimension 5 has to be less than 2}}
%0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 5} :
@@ -1912,10 +2016,17 @@ func.func @vector_load(%src : memref<?xi8>) {
// -----
-func.func @invalid_load_alignment(%memref: memref<4xi32>) {
- %c0 = arith.constant 0 : index
+func.func @load_nonpositive_alignment(%memref: memref<4xi32>, %c0: index) {
+ // expected-error @below {{'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ %val = vector.load %memref[%c0] { alignment = 0 } : memref<4xi32>, vector<4xi32>
+ return
+}
+
+// -----
+
+func.func @load_non_pow_of_2_alignment(%memref: memref<4xi32>, %c0: index) {
// expected-error @below {{'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
- %val = vector.load %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
+ %val = vector.load %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32>
return
}
@@ -1934,8 +2045,15 @@ func.func @vector_store(%dest : memref<?xi8>, %vec : vector<16x16xi8>) {
// -----
-func.func @invalid_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>) {
- %c0 = arith.constant 0 : index
+func.func @store_nonpositive_alignment(%memref: memref<4xi32>, %val: vector<4xi32>, %c0: index) {
+ // expected-error @below {{'vector.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ vector.store %val, %memref[%c0] { alignment = 0 } : memref<4xi32>, vector<4xi32>
+ return
+}
+
+// -----
+
+func.func @store_non_pow_of_2_alignment(%memref: memref<4xi32>, %val: vector<4xi32>, %c0: index) {
// expected-error @below {{'vector.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
vector.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32>
return
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 2e630bf..5e8bfd0 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -524,3 +524,17 @@ func.func @linearize_vector_store_scalable(%arg0: memref<2x8xf32>, %arg1: vector
vector.store %arg1, %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x[4]xf32>
return
}
+
+// -----
+
+// Test pattern LinearizeVectorFromElements.
+
+// CHECK-LABEL: test_vector_from_elements
+// CHECK-SAME: %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: f32, %[[ARG_2:.*]]: f32, %[[ARG_3:.*]]: f32
+func.func @test_vector_from_elements(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x2xf32> {
+ // CHECK: %[[FROM_ELEMENTS:.*]] = vector.from_elements %[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]] : vector<4xf32>
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %[[FROM_ELEMENTS]] : vector<4xf32> to vector<2x2xf32>
+ // CHECK: return %[[CAST]] : vector<2x2xf32>
+ %1 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
+ return %1 : vector<2x2xf32>
+}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 625ffc1..550e52a 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -149,7 +149,7 @@ func.func @vector_transfer_ops_tensor(%arg0: tensor<?x?xf32>,
}
// CHECK-LABEL: @vector_broadcast
-func.func @vector_broadcast(%a: f32, %b: vector<f32>, %c: vector<16xf32>, %d: vector<1x16xf32>, %e: vector<8x1xf32>, %f: vector<8x1x!llvm.ptr<1>>) {
+func.func @vector_broadcast(%a: f32, %b: vector<f32>, %c: vector<16xf32>, %d: vector<1x16xf32>, %e: vector<8x1xf32>, %f: vector<8x1x!llvm.ptr<1>>, %g: !llvm.ptr<1>) {
// CHECK: vector.broadcast %{{.*}} : f32 to vector<f32>
%0 = vector.broadcast %a : f32 to vector<f32>
// CHECK: vector.broadcast %{{.*}} : vector<f32> to vector<4xf32>
@@ -164,6 +164,8 @@ func.func @vector_broadcast(%a: f32, %b: vector<f32>, %c: vector<16xf32>, %d: ve
%5 = vector.broadcast %e : vector<8x1xf32> to vector<8x16xf32>
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1x!llvm.ptr<1>> to vector<8x16x!llvm.ptr<1>>
%6 = vector.broadcast %f : vector<8x1x!llvm.ptr<1>> to vector<8x16x!llvm.ptr<1>>
+ // CHECK-NEXT: vector.broadcast %{{.*}} : !llvm.ptr<1> to vector<8x16x!llvm.ptr<1>>
+ %7 = vector.broadcast %g : !llvm.ptr<1> to vector<8x16x!llvm.ptr<1>>
return
}
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index 4b38db7..524a4f4 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -121,6 +121,13 @@ func.func @arith_to_outerproduct_trans_rhs_f32(%lhs: vector<16xf32>, %rhs: vecto
return %mul: vector<8x16xf32>
}
+// See https://github.com/llvm/llvm-project/pull/152957
+// CHECK-LABEL: func.func @negative_non_vector_type
+func.func @negative_non_vector_type(%lhs: f32, %rhs: f32) -> f32 {
+ %mul = arith.mulf %lhs, %rhs : f32
+ return %mul: f32
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
diff --git a/mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir
new file mode 100644
index 0000000..8fac608
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir
@@ -0,0 +1,45 @@
+// RUN: mlir-opt %s -test-unroll-vector-from-elements | FileCheck %s --check-prefix=CHECK-UNROLL
+
+//===----------------------------------------------------------------------===//
+// Test UnrollFromElements.
+//===----------------------------------------------------------------------===//
+
+// CHECK-UNROLL-LABEL: @unroll_from_elements_2d
+// CHECK-UNROLL-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
+// CHECK-UNROLL-NEXT: %[[UNDEF_RES:.*]] = ub.poison : vector<2x2xf32>
+// CHECK-UNROLL-NEXT: %[[VEC_0:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
+// CHECK-UNROLL-NEXT: %[[RES_0:.*]] = vector.insert %[[VEC_0]], %[[UNDEF_RES]] [0] : vector<2xf32> into vector<2x2xf32>
+// CHECK-UNROLL-NEXT: %[[VEC_1:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32>
+// CHECK-UNROLL-NEXT: %[[RES_1:.*]] = vector.insert %[[VEC_1]], %[[RES_0]] [1] : vector<2xf32> into vector<2x2xf32>
+// CHECK-UNROLL-NEXT: return %[[RES_1]] : vector<2x2xf32>
+func.func @unroll_from_elements_2d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x2xf32> {
+ %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
+ return %0 : vector<2x2xf32>
+}
+
+// CHECK-UNROLL-LABEL: @unroll_from_elements_3d
+// CHECK-UNROLL-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
+// CHECK-UNROLL-NEXT: %[[UNDEF_RES:.*]] = ub.poison : vector<2x1x2xf32>
+// CHECK-UNROLL-NEXT: %[[UNDEF_RANK_2:.*]] = ub.poison : vector<1x2xf32>
+// CHECK-UNROLL-NEXT: %[[VEC_0:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
+// CHECK-UNROLL-NEXT: %[[RANK_2_0:.*]] = vector.insert %[[VEC_0]], %[[UNDEF_RANK_2]] [0] : vector<2xf32> into vector<1x2xf32>
+// CHECK-UNROLL-NEXT: %[[RES_0:.*]] = vector.insert %[[RANK_2_0]], %[[UNDEF_RES]] [0] : vector<1x2xf32> into vector<2x1x2xf32>
+// CHECK-UNROLL-NEXT: %[[VEC_1:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32>
+// CHECK-UNROLL-NEXT: %[[RANK_2_1:.*]] = vector.insert %[[VEC_1]], %[[UNDEF_RANK_2]] [0] : vector<2xf32> into vector<1x2xf32>
+// CHECK-UNROLL-NEXT: %[[RES_1:.*]] = vector.insert %[[RANK_2_1]], %[[RES_0]] [1] : vector<1x2xf32> into vector<2x1x2xf32>
+// CHECK-UNROLL-NEXT: return %[[RES_1]] : vector<2x1x2xf32>
+func.func @unroll_from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> {
+ %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
+ return %0 : vector<2x1x2xf32>
+}
+
+// 1-D vector.from_elements should not be unrolled.
+
+// CHECK-UNROLL-LABEL: @negative_unroll_from_elements_1d
+// CHECK-UNROLL-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32)
+// CHECK-UNROLL-NEXT: %[[RES:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
+// CHECK-UNROLL-NEXT: return %[[RES]] : vector<2xf32>
+func.func @negative_unroll_from_elements_1d(%arg0: f32, %arg1: f32) -> vector<2xf32> {
+ %0 = vector.from_elements %arg0, %arg1 : vector<2xf32>
+ return %0 : vector<2xf32>
+}
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 5be267c..0e1bad6 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -81,7 +81,7 @@ func.func @gather_memref_1d_i32_index(%base: memref<?xf32>, %v: vector<2xi32>, %
// CHECK-SAME: %[[PASS:.*]]: vector<2x[3]xf32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x[3]xf32>
+// CHECK: %[[INIT:.*]] = ub.poison : vector<2x[3]xf32>
// CHECK: %[[IDXVEC0:.*]] = vector.extract %[[IDXVEC]][0] : vector<[3]xindex> from vector<2x[3]xindex>
// CHECK: %[[MASK0:.*]] = vector.extract %[[MASK]][0] : vector<[3]xi1> from vector<2x[3]xi1>
// CHECK: %[[PASS0:.*]] = vector.extract %[[PASS]][0] : vector<[3]xf32> from vector<2x[3]xf32>
@@ -198,7 +198,7 @@ func.func @gather_memref_non_unit_stride_read_more_than_1_element(%base: memref<
// CANON-NOT: scf.if
// CANON: tensor.extract
// CANON: tensor.extract
-// CANON: [[FINAL:%.+]] = vector.insert %{{.+}}, %{{.+}} [1] : f32 into vector<2xf32>
+// CANON: [[FINAL:%.+]] = vector.from_elements %{{.+}}, %{{.+}} : vector<2xf32>
// CANON-NEXT: return [[FINAL]] : vector<2xf32>
func.func @gather_tensor_1d_all_set(%base: tensor<?xf32>, %v: vector<2xindex>, %pass_thru: vector<2xf32>) -> vector<2xf32> {
%mask = arith.constant dense <true> : vector<2xi1>
diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
index 0bf38ba..3b51e6b 100644
--- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
+++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -test-vector-reduction-to-contract-patterns -split-input-file | FileCheck %s
-// TODO: Seperate tests for vector.multi_reduction -> vector.contract and
+// TODO: Separate tests for vector.multi_reduction -> vector.contract and
// * pre-op + vector.contract -> vector.contract,
// * vector.contract + post-op -> vector.contract.
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index ef881ba..577b06d 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -40,7 +40,7 @@ func.func @broadcast_scalar_with_bcast_scalable(%arg1: index, %arg2: index) -> v
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
// CHECK: return %[[BCAST]] : vector<1x4xindex>
func.func @broadcast_scalar_with_bcast_and_splat(%arg1: index, %arg2: index) -> vector<1x4xindex> {
- %0 = vector.splat %arg1 : vector<1x4xindex>
+ %0 = vector.broadcast %arg1 : index to vector<1x4xindex>
%1 = vector.broadcast %arg2 : index to vector<1x4xindex>
%2 = arith.addi %0, %1 : vector<1x4xindex>
return %2 : vector<1x4xindex>
@@ -53,7 +53,7 @@ func.func @broadcast_scalar_with_bcast_and_splat(%arg1: index, %arg2: index) ->
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x[4]xindex>
// CHECK: return %[[BCAST]] : vector<1x[4]xindex>
func.func @broadcast_scalar_with_bcast_and_splat_scalable(%arg1: index, %arg2: index) -> vector<1x[4]xindex> {
- %0 = vector.splat %arg1 : vector<1x[4]xindex>
+ %0 = vector.broadcast %arg1 : index to vector<1x[4]xindex>
%1 = vector.broadcast %arg2 : index to vector<1x[4]xindex>
%2 = arith.addi %0, %1 : vector<1x[4]xindex>
return %2 : vector<1x[4]xindex>
@@ -94,12 +94,12 @@ func.func @broadcast_vector_scalable(%arg1: vector<[4]xf32>, %arg2: vector<[4]xf
// CHECK-LABEL: func.func @broadcast_scalar_and_vec(
// CHECK-SAME: %[[ARG1:.*]]: index,
// CHECK-SAME: %[[ARG2:.*]]: vector<4xindex>) -> vector<1x4xindex> {
-// CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG1]] : vector<1x4xindex>
+// CHECK: %[[SPLAT:.*]] = vector.broadcast %[[ARG1]] : index to vector<1x4xindex>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<4xindex> to vector<1x4xindex>
// CHECK: %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x4xindex>
// CHECK: return %[[ADD]] : vector<1x4xindex>
func.func @broadcast_scalar_and_vec(%arg1: index, %arg2: vector<4xindex>) -> vector<1x4xindex> {
- %0 = vector.splat %arg1 : vector<1x4xindex>
+ %0 = vector.broadcast %arg1 : index to vector<1x4xindex>
%1 = vector.broadcast %arg2 : vector<4xindex> to vector<1x4xindex>
%2 = arith.addi %0, %1 : vector<1x4xindex>
return %2 : vector<1x4xindex>
@@ -108,12 +108,12 @@ func.func @broadcast_scalar_and_vec(%arg1: index, %arg2: vector<4xindex>) -> vec
// CHECK-LABEL: func.func @broadcast_scalar_and_vec_scalable(
// CHECK-SAME: %[[ARG1:.*]]: index,
// CHECK-SAME: %[[ARG2:.*]]: vector<[4]xindex>) -> vector<1x[4]xindex> {
-// CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG1]] : vector<1x[4]xindex>
+// CHECK: %[[SPLAT:.*]] = vector.broadcast %[[ARG1]] : index to vector<1x[4]xindex>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<[4]xindex> to vector<1x[4]xindex>
// CHECK: %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x[4]xindex>
// CHECK: return %[[ADD]] : vector<1x[4]xindex>
func.func @broadcast_scalar_and_vec_scalable(%arg1: index, %arg2: vector<[4]xindex>) -> vector<1x[4]xindex> {
- %0 = vector.splat %arg1 : vector<1x[4]xindex>
+ %0 = vector.broadcast %arg1 : index to vector<1x[4]xindex>
%1 = vector.broadcast %arg2 : vector<[4]xindex> to vector<1x[4]xindex>
%2 = arith.addi %0, %1 : vector<1x[4]xindex>
return %2 : vector<1x[4]xindex>
@@ -787,7 +787,7 @@ func.func @negative_extract_load_scalable(%arg0: memref<?xf32>, %arg1: index) ->
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
func.func @store_splat(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
// CHECK: memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>
- %0 = vector.splat %arg2 : vector<1xf32>
+ %0 = vector.broadcast %arg2 : f32 to vector<1xf32>
vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
return
}
@@ -813,9 +813,9 @@ func.func @store_broadcast_1d_to_2d(%arg0: memref<?x?xf32>, %arg1: index, %arg2:
// CHECK-LABEL: @negative_store_scalable
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
func.func @negative_store_scalable(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
-// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<[1]xf32>
+// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG2]] : f32 to vector<[1]xf32>
// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
- %0 = vector.splat %arg2 : vector<[1]xf32>
+ %0 = vector.broadcast %arg2 : f32 to vector<[1]xf32>
vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<[1]xf32>
return
}
@@ -823,9 +823,9 @@ func.func @negative_store_scalable(%arg0: memref<?xf32>, %arg1: index, %arg2: f3
// CHECK-LABEL: @negative_store_memref_of_vec
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xvector<1xf32>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
func.func @negative_store_memref_of_vec(%arg0: memref<?xvector<1xf32>>, %arg1: index, %arg2: f32) {
-// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
+// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG2]] : f32 to vector<1xf32>
// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xvector<1xf32>>, vector<1xf32>
- %0 = vector.splat %arg2 : vector<1xf32>
+ %0 = vector.broadcast %arg2 : f32 to vector<1xf32>
vector.store %0, %arg0[%arg1] : memref<?xvector<1xf32>>, vector<1xf32>
return
}
@@ -833,9 +833,9 @@ func.func @negative_store_memref_of_vec(%arg0: memref<?xvector<1xf32>>, %arg1: i
// CHECK-LABEL: @negative_store_more_than_one_element
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
func.func @negative_store_more_than_one_element(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
-// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<4xf32>
+// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG2]] : f32 to vector<4xf32>
// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
- %0 = vector.splat %arg2 : vector<4xf32>
+ %0 = vector.broadcast %arg2 : f32 to vector<4xf32>
vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<4xf32>
return
}
@@ -843,10 +843,10 @@ func.func @negative_store_more_than_one_element(%arg0: memref<?xf32>, %arg1: ind
// CHECK-LABEL: @negative_store_no_single_use
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
func.func @negative_store_no_single_use(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) -> vector<1xf32> {
-// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
+// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG2]] : f32 to vector<1xf32>
// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<1xf32>
// CHECK: return %[[RES:.*]] : vector<1xf32>
- %0 = vector.splat %arg2 : vector<1xf32>
+ %0 = vector.broadcast %arg2 : f32 to vector<1xf32>
vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
return %0 : vector<1xf32>
}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 1b54d54..45afbff 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -285,19 +285,19 @@ func.func @transfer_read_permutations(%mem_0 : memref<?x?xf32>, %mem_1 : memref<
%c0 = arith.constant 0 : index
// CHECK: %[[MASK0:.*]] = vector.broadcast %{{.*}} : i1 to vector<14x7xi1>
- %mask0 = vector.splat %m : vector<14x7xi1>
+ %mask0 = vector.broadcast %m : i1 to vector<14x7xi1>
%0 = vector.transfer_read %mem_1[%c0, %c0, %c0, %c0], %cst, %mask0 {in_bounds = [true, false, true, true], permutation_map = #map0} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
// CHECK: vector.transfer_read {{.*}} %[[MASK0]] {in_bounds = [false, true, true, true], permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<14x7x8x16xf32>
// CHECK: vector.transpose %{{.*}}, [1, 0, 2, 3] : vector<14x7x8x16xf32> to vector<7x14x8x16xf32>
// CHECK: %[[MASK1:.*]] = vector.broadcast %{{.*}} : i1 to vector<16x14xi1>
- %mask1 = vector.splat %m : vector<16x14xi1>
+ %mask1 = vector.broadcast %m : i1 to vector<16x14xi1>
%1 = vector.transfer_read %mem_1[%c0, %c0, %c0, %c0], %cst, %mask1 {in_bounds = [true, false, true, false], permutation_map = #map1} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
// CHECK: vector.transfer_read {{.*}} %[[MASK1]] {in_bounds = [false, false, true, true], permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<16x14x7x8xf32>
// CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32>
// CHECK: %[[MASK3:.*]] = vector.broadcast %{{.*}} : i1 to vector<14x7xi1>
- %mask2 = vector.splat %m : vector<14x7xi1>
+ %mask2 = vector.broadcast %m : i1 to vector<14x7xi1>
%2 = vector.transfer_read %mem_1[%c0, %c0, %c0, %c0], %cst, %mask2 {in_bounds = [true, false, true, true], permutation_map = #map2} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
// CHECK: vector.transfer_read {{.*}} %[[MASK3]] {in_bounds = [false, true, true], permutation_map = #[[$MAP1]]} : memref<?x?x?x?xf32>, vector<14x16x7xf32>
// CHECK: vector.broadcast %{{.*}} : vector<14x16x7xf32> to vector<8x14x16x7xf32>
@@ -337,7 +337,7 @@ func.func @transfer_write_permutations_tensor_masked(
%c0 = arith.constant 0 : index
// CHECK: %[[MASK:.*]] = vector.broadcast %[[M]] : i1 to vector<16x14x7x8xi1>
- %mask0 = vector.splat %m : vector<16x14x7x8xi1>
+ %mask0 = vector.broadcast %m : i1 to vector<16x14x7x8xi1>
%res = vector.transfer_write %vec, %dst[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, tensor<?x?x?x?xf32>
// CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [3, 1, 0, 2] : vector<7x14x8x16xf32> to vector<16x14x7x8xf32>
// CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[DST]][%c0, %c0, %c0, %c0], %[[MASK]] {in_bounds = [true, false, true, false]} : vector<16x14x7x8xf32>, tensor<?x?x?x?xf32>
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index ae8fce7..8750582 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1299,7 +1299,7 @@ func.func @vector_insert_1d_broadcast(%laneid: index, %pos: index) -> (vector<96
// CHECK-PROP: %[[VEC:.*]] = "some_def"
// CHECK-PROP: %[[VAL:.*]] = "another_def"
// CHECK-PROP: gpu.yield %[[VEC]], %[[VAL]]
-// CHECK-PROP: vector.insert %[[W]]#1, %[[W]]#0 [] : f32 into vector<f32>
+// CHECK-PROP: vector.broadcast %[[W]]#1 : f32 to vector<f32>
func.func @vector_insert_0d(%laneid: index) -> (vector<f32>) {
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<f32>) {
%0 = "some_def"() : () -> (vector<f32>)
@@ -1803,3 +1803,56 @@ func.func @warp_propagate_nd_write(%laneid: index, %dest: memref<4x1024xf32>) {
// CHECK-DIST-AND-PROP: %[[IDS:.+]]:2 = affine.delinearize_index %{{.*}} into (4, 8) : index, index
// CHECK-DIST-AND-PROP: %[[INNER_ID:.+]] = affine.apply #map()[%[[IDS]]#1]
// CHECK-DIST-AND-PROP: vector.transfer_write %[[W]], %{{.*}}[%[[IDS]]#0, %[[INNER_ID]]] {{.*}} : vector<1x128xf32>
+
+// -----
+func.func @warp_propagate_duplicated_operands_in_yield(%laneid: index) {
+ %r:3 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>, vector<1xf32>, vector<1xf32>) {
+ %0 = "some_def"() : () -> (vector<32xf32>)
+ %1 = "some_other_def"() : () -> (vector<32xf32>)
+ %2 = math.exp %1 : vector<32xf32>
+ gpu.yield %2, %0, %0 : vector<32xf32>, vector<32xf32>, vector<32xf32>
+ }
+ "some_use"(%r#0) : (vector<1xf32>) -> ()
+ return
+}
+
+// CHECK-PROP-LABEL : func.func @warp_propagate_duplicated_operands_in_yield(
+// CHECK-PROP : %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) {
+// CHECK-PROP : %{{.*}} = "some_def"() : () -> vector<32xf32>
+// CHECK-PROP : %[[T3:.*]] = "some_other_def"() : () -> vector<32xf32>
+// CHECK-PROP : gpu.yield %[[T3]] : vector<32xf32>
+// CHECK-PROP : }
+// CHECK-PROP : %[T1:.*] = math.exp %[[W]] : vector<1xf32>
+// CHECK-PROP : "some_use"(%[[T1]]) : (vector<1xf32>) -> ()
+
+// -----
+
+func.func @warp_step_distribute(%buffer: memref<128xindex>) {
+ %laneid = gpu.lane_id
+ %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xindex>) {
+ %seq = vector.step : vector<32xindex>
+ gpu.yield %seq : vector<32xindex>
+ }
+ vector.transfer_write %r, %buffer[%laneid] : vector<1xindex>, memref<128xindex>
+ return
+}
+
+// CHECK-PROP-LABEL: func.func @warp_step_distribute(
+// CHECK-PROP: %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK-PROP: %[[LANE_ID_VEC:.*]] = vector.broadcast %[[LANE_ID]] : index to vector<1xindex>
+// CHECK-PROP: vector.transfer_write %[[LANE_ID_VEC]], %{{.*}} : vector<1xindex>, memref<128xindex>
+
+// -----
+
+func.func @negative_warp_step_more_than_warp_size(%laneid: index, %buffer: memref<128xindex>) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xindex>) {
+ %seq = vector.step : vector<64xindex>
+ gpu.yield %seq : vector<64xindex>
+ }
+ vector.transfer_write %r, %buffer[%laneid] : vector<2xindex>, memref<128xindex>
+ return
+}
+
+// CHECK-PROP-LABEL: @negative_warp_step_more_than_warp_size
+// CHECK-PROP-NOT: vector.broadcast
+// CHECK-PROP: vector.step : vector<64xindex>
diff --git a/mlir/test/Dialect/WasmSSA/custom_parser/if.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/if.mlir
new file mode 100644
index 0000000..01068cb
--- /dev/null
+++ b/mlir/test/Dialect/WasmSSA/custom_parser/if.mlir
@@ -0,0 +1,53 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+// CHECK-LABEL: wasmssa.func nested @func_0(
+// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32
+// CHECK: wasmssa.if %[[VAL_0]] : {
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 5.000000e-01 : f32
+// CHECK: wasmssa.block_return %[[VAL_1]] : f32
+// CHECK: } "else "{
+// CHECK: %[[VAL_2:.*]] = wasmssa.const 2.500000e-01 : f32
+// CHECK: wasmssa.block_return %[[VAL_2]] : f32
+// CHECK: }> ^bb1
+// CHECK: ^bb1(%[[VAL_3:.*]]: f32):
+// CHECK: wasmssa.return %[[VAL_3]] : f32
+wasmssa.func nested @func_0(%arg0 : !wasmssa<local ref to i32>) -> i32 {
+ %cond = wasmssa.local_get %arg0 : ref to i32
+ wasmssa.if %cond : {
+ %c0 = wasmssa.const 0.5 : f32
+ wasmssa.block_return %c0 : f32
+ } else {
+ %c1 = wasmssa.const 0.25 : f32
+ wasmssa.block_return %c1 : f32
+ } >^bb1
+ ^bb1(%retVal: f32):
+ wasmssa.return %retVal : f32
+}
+
+// CHECK-LABEL: wasmssa.func nested @func_1(
+// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.local of type i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.const 0 : i64
+// CHECK: wasmssa.if %[[VAL_0]] : {
+// CHECK: %[[VAL_3:.*]] = wasmssa.const 1 : i32
+// CHECK: wasmssa.local_set %[[VAL_1]] : ref to i32 to %[[VAL_3]] : i32
+// CHECK: wasmssa.block_return
+// CHECK: } > ^bb1
+// CHECK: ^bb1:
+// CHECK: %[[VAL_4:.*]] = wasmssa.local_get %[[VAL_1]] : ref to i32
+// CHECK: wasmssa.return %[[VAL_4]] : i32
+wasmssa.func nested @func_1(%arg0 : !wasmssa<local ref to i32>) -> i32 {
+ %cond = wasmssa.local_get %arg0 : ref to i32
+ %var = wasmssa.local of type i32
+ %zero = wasmssa.const 0
+ wasmssa.if %cond : {
+ %c1 = wasmssa.const 1 : i32
+ wasmssa.local_set %var : ref to i32 to %c1 : i32
+ wasmssa.block_return
+ } >^bb1
+ ^bb1:
+ %res = wasmssa.local_get %var : ref to i32
+ wasmssa.return %res : i32
+}
diff --git a/mlir/test/Dialect/WasmSSA/custom_parser/memory.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/memory.mlir
new file mode 100644
index 0000000..47551db
--- /dev/null
+++ b/mlir/test/Dialect/WasmSSA/custom_parser/memory.mlir
@@ -0,0 +1,7 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+// CHECK: wasmssa.memory @mem0 public !wasmssa<limit[0: 65536]>
+wasmssa.memory @mem0 public !wasmssa<limit[0:65536]>
+
+// CHECK: wasmssa.memory @mem1 nested !wasmssa<limit[512:]>
+wasmssa.memory @mem1 !wasmssa<limit[512:]>
diff --git a/mlir/test/Dialect/WasmSSA/custom_parser/table.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/table.mlir
new file mode 100644
index 0000000..5a874f4
--- /dev/null
+++ b/mlir/test/Dialect/WasmSSA/custom_parser/table.mlir
@@ -0,0 +1,7 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+// CHECK: wasmssa.table @tab0 public !wasmssa<tabletype !wasmssa.externref [0: 65536]>
+wasmssa.table @tab0 public !wasmssa<tabletype !wasmssa.externref [0:65536]>
+
+// CHECK: wasmssa.table @tab1 nested !wasmssa<tabletype !wasmssa.funcref [348:]>
+wasmssa.table @tab1 !wasmssa<tabletype !wasmssa.funcref [348:]>
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index dff3ffa..228ef69d 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -52,14 +52,14 @@ func.func @create_nd_tdesc_7(%src: memref<128x128xf32>) {
// -----
func.func @create_nd_tdesc_8(%src: ui64) {
- // expected-error@+1 {{'xegpu.create_nd_tdesc' op Expecting strides and shape to be present for integer source}}
+ // expected-error@+1 {{'xegpu.create_nd_tdesc' op expecting strides and shape to be present for integer source}}
%1 = xegpu.create_nd_tdesc %src : ui64-> !xegpu.tensor_desc<128x128xf32>
return
}
// -----
func.func @create_nd_tdesc_9(%src: ui64) {
- // expected-error@+1 {{expected mixed offsets rank to match mixed sizes rank}}
+ // expected-error@+1 {{expecting strides and shape to be present for integer source}}
%1 = xegpu.create_nd_tdesc %src[0, 0] : ui64-> !xegpu.tensor_desc<128x128xf32>
return
}
@@ -149,7 +149,7 @@ func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
}
// -----
-func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
+func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
%3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%5 = xegpu.load_nd %3[0, 0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
// expected-error@+1 {{Mismatched ranks between offsets and tensor descriptor}}
@@ -387,12 +387,45 @@ func.func @load_gather_vc_3(%src: ui64) {
// -----
func.func @prefetch_offset_wi_1(%src: memref<4x4xf32>) {
%offsets = arith.constant dense<[0]> : vector<1xindex>
- // expected-error@+1 {{Expecting the source is a 1D memref or pointer}}
+ // expected-error@+1 {{op operand #0 must be TensorDesc describing regions of interested data}}
xegpu.prefetch %src[%offsets]: memref<4x4xf32>, vector<1xindex>
return
}
// -----
+func.func @prefetch_offset_wi_2(%src: memref<16xf32>) {
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ %1 = xegpu.create_tdesc %src, %offsets : memref<16xf32>, vector<1xindex>
+ -> !xegpu.tensor_desc<1x3xf32, #xegpu.scatter_tdesc_attr<chunk_size = 3>>
+ // expected-error@+1 {{offsets not allowed}}
+ xegpu.prefetch %1[%offsets]: !xegpu.tensor_desc<1x3xf32, #xegpu.scatter_tdesc_attr<chunk_size = 3>>, vector<1xindex>
+ return
+}
+
+// -----
+func.func @prefetch_offset_wi_3(%src: memref<16xf32>) {
+ // expected-error@+1 {{Expects offsets}}
+ xegpu.prefetch %src: memref<16xf32>
+ return
+}
+
+// -----
+func.func @prefetch_offset_wi_4(%src: memref<16xf32>) {
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ // expected-error@+1 {{offset_align_byte only allowed with integer source.}}
+ xegpu.prefetch %src[%offsets] <{offset_align_byte = 4}>: memref<16xf32>, vector<1xindex>
+ return
+}
+
+// -----
+func.func @prefetch_offset_wi_5(%src: i64) {
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ // expected-error@+1 {{offset_align_byte is required with integer source.}}
+ xegpu.prefetch %src[%offsets] : i64, vector<1xindex>
+ return
+}
+
+// -----
func.func @load_gather_offset_sg(%src: memref<?xf16>) {
%offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%mask = arith.constant dense<1>: vector<8xi1>
@@ -408,7 +441,7 @@ func.func @load_gather_offset_wi(%src: ui64) {
%mask = arith.constant dense<1>: vector<1xi1>
%offsets = arith.constant dense<[0]> : vector<1xindex>
// expected-error@+1 {{value elements must match chunk size}}
- %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf32>
+ %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<3xf32>
return
}
@@ -417,7 +450,7 @@ func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) {
%val = arith.constant dense<2.9>: vector<4xf16>
%offsets = arith.constant dense<[0]> : vector<1xindex>
%mask = arith.constant dense<1>: vector<1xi1>
- // expected-error@+1 {{value elements must match chunk size}}
+ // expected-error@+1 {{Mask should match value except the chunk size dim}}
xegpu.store %val, %src[%offsets], %mask
: vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1>
return
@@ -428,18 +461,56 @@ func.func @store_scatter_offset_wi_2(%src: memref<4x4xf16>) {
%val = arith.constant dense<2.9>: vector<4xf16>
%offsets = arith.constant dense<[0]> : vector<1xindex>
%mask = arith.constant dense<1>: vector<1xi1>
- // expected-error@+1 {{Expecting the dest is a 1D memref or pointer}}
- xegpu.store %val, %src[%offsets], %mask
+ // expected-error@+1 {{op operand #1 must be TensorDesc describing regions of interested data}}
+ xegpu.store %val, %src[%offsets], %mask
: vector<4xf16>, memref<4x4xf16>, vector<1xindex>, vector<1xi1>
return
}
// -----
+func.func @store_scatter_offset_wi_3(%src: memref<16xf16>) {
+ %val = arith.constant dense<2.9>: vector<1xf16>
+ %mask = arith.constant dense<1>: vector<1xi1>
+ // expected-error@+1 {{Expects offsets}}
+ xegpu.store %val, %src, %mask
+ : vector<1xf16>, memref<16xf16>, vector<1xi1>
+ return
+}
+
+// -----
+func.func @store_scatter_offset_wi_4(%src: !xegpu.tensor_desc<1x1xf32, #xegpu.scatter_tdesc_attr<>>) {
+ %val = arith.constant dense<2.9>: vector<1xf16>
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ %mask = arith.constant dense<1>: vector<1xi1>
+ // expected-error@+1 {{offsets not allowed}}
+ xegpu.store %val, %src[%offsets], %mask
+ : vector<1xf16>, !xegpu.tensor_desc<1x1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xindex>, vector<1xi1>
+ return
+}
+
+// -----
+func.func @load_gather_offset_wi_4(%src: !xegpu.tensor_desc<1x2xf16, #xegpu.scatter_tdesc_attr<>>) {
+ %mask = arith.constant dense<1>: vector<1xi1>
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ // expected-error@+1 {{offsets not allowed}}
+ %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : !xegpu.tensor_desc<1x2xf16, #xegpu.scatter_tdesc_attr<>>, vector<1xindex>, vector<1xi1> -> vector<2xf16>
+ return
+}
+
+// -----
+func.func @load_gather_offset_wi_3(%src: ui64) {
+ %mask = arith.constant dense<1>: vector<1xi1>
+ // expected-error@+1 {{Expects offsets}}
+ %2 = xegpu.load %src, %mask <{chunk_size = 2}> : ui64, vector<1xi1> -> vector<2xf16>
+ return
+}
+
+// -----
func.func @load_gather_offset_wi_2(%src: ui64) {
%mask = arith.constant dense<1>: vector<1xi1>
%offsets = arith.constant dense<[0]> : vector<1xindex>
// expected-error@+1 {{value elements must match chunk size}}
- %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf16>
+ %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<3xf16>
return
}
@@ -447,7 +518,7 @@ func.func @load_gather_offset_wi_2(%src: ui64) {
func.func @load_gather_offset_wi_1(%src: memref<4x4xf32>) {
%mask = arith.constant dense<1>: vector<1xi1>
%offsets = arith.constant dense<[0]> : vector<1xindex>
- // expected-error@+1 {{Expecting the source is a 1D memref or pointer}}
+ // expected-error@+1 {{op operand #0 must be TensorDesc describing regions of interested data}}
%2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : memref<4x4xf32>, vector<1xindex>, vector<1xi1> -> vector<2xf32>
return
}
@@ -743,3 +814,108 @@ func.func @tensor_desc_invalid_sg_data(%src: ui64, %offsets: vector<16xindex>) {
#xegpu.layout<lane_layout = [8, 1], lane_data = [1, 2], order = [0, 1, 2]>>
return
}
+
+// -----
+#l = #xegpu.layout<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>
+// expected-error@+1 {{repeated dim (2) in slice attribute}}
+#s = #xegpu.slice<#l, dims = [2, 2]>
+func.func @slice_attr_repeat_dim() {
+ %offsets = arith.constant {layout_result_0 = #s} dense<0.8> : vector<16x8xindex>
+ return
+}
+
+// -----
+#l = #xegpu.layout<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>
+// expected-error@+1 {{invalid dim (3) in slice attribute}}
+#s = #xegpu.slice<#l, dims = [3]>
+func.func @slice_attr_repeat_dim() {
+ %offsets = arith.constant {layout_result_0 = #s} dense<0.8> : vector<16x8xindex>
+ return
+}
+
+// -----
+func.func @create_mem_desc_non_slm() {
+ %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 1>
+ // expected-error@+1 {{operand #0 must be statically shaped memref of 8-bit signless integer values for shared memory}}
+ %mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 1> -> !xegpu.mem_desc<16x64xf16>
+ return
+}
+
+// -----
+func.func @create_mem_desc_mismatch_sizes() {
+ %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 3>
+ // expected-error@+1 {{failed to verify that all of {source, mem_desc} have same size in bits}}
+ %mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 3> -> !xegpu.mem_desc<16x32xf16>
+ return
+}
+
+// -----
+func.func @load_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ // expected-error@+1 {{failed to verify that all of {mem_desc, res} have same element type}}
+ %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<8x16xf32>
+ return
+}
+
+// -----
+func.func @load_mem_desc_invalid_result_size(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ // expected-error@+1 {{result shape must not exceed mem_desc shape}}
+ %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<32x16xf16>
+ return
+}
+
+// -----
+func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) {
+ // expected-error@+1 {{mem_desc must be 2D}}
+ %data = xegpu.load_matrix %arg0[16]: !xegpu.mem_desc<64xf16> -> vector<16xf16>
+ return
+}
+
+// -----
+func.func @store_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf32>) {
+ // expected-error@+1 {{failed to verify that all of {mem_desc, data} have same element type}}
+ xegpu.store_matrix %arg1, %arg0[8, 8] : vector<16x16xf32>, !xegpu.mem_desc<16x64xf16>
+ return
+}
+
+// -----
+func.func @store_mem_desc_invalid_data_size(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<32x32xf16>) {
+ // expected-error@+1 {{data shape must not exceed mem_desc shape}}
+ xegpu.store_matrix %arg1, %arg0[8, 8] : vector<32x32xf16>, !xegpu.mem_desc<16x64xf16>
+ return
+}
+
+// -----
+func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: vector<32xf16>) {
+ // expected-error@+1 {{mem_desc must be 2D.}}
+ xegpu.store_matrix %arg1, %arg0[32] : vector<32xf16>, !xegpu.mem_desc<64xf16>
+ return
+}
+
+// -----
+func.func @mem_desc_subview_size_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ // expected-error@+1 {{result shape must not exceed source shape}}
+ %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<32x16xf16>
+ return
+}
+
+// -----
+func.func @mem_desc_subview_layout_mismatch(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride =[1, 16]>>) {
+ // expected-error@+1 {{result must inherit the source strides}}
+ %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride =[1, 16]>> -> !xegpu.mem_desc<8x16xf16>
+ return
+}
+
+// -----
+func.func @mem_desc_subview_element_type_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ // expected-error@+1 {{failed to verify that all of {src, res} have same element type}}
+ %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf32, #xegpu.mem_layout<stride =[64, 1]>>
+ return
+}
+
+// -----
+func.func @mem_desc_subview_rank_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ // expected-error@+1 {{result rank must not exceed source rank}}
+ %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<4x8x16xf16>
+ return
+}
+
diff --git a/mlir/test/Dialect/XeGPU/layout.mlir b/mlir/test/Dialect/XeGPU/layout.mlir
index 017dacc..e4b4e22 100644
--- a/mlir/test/Dialect/XeGPU/layout.mlir
+++ b/mlir/test/Dialect/XeGPU/layout.mlir
@@ -50,4 +50,27 @@ gpu.func @convert_layout_wg(%a: vector<32x64xf16>) {
gpu.return
}
+gpu.func @slice_attr() {
+ //CHECK: arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>, dims = [2]>} dense<8> : vector<16x8xindex>
+ %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>, dims = [2]>} dense<8> : vector<16x8xindex>
+ gpu.return
+}
+
+gpu.func @nested_slice_attr() {
+ //CHECK: arith.constant {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>, dims = [2]>, dims = [1]>} dense<8> : vector<16xindex>
+ %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>, dims = [2]>, dims = [1]>} dense<8> : vector<16xindex>
+ gpu.return
+}
+
+gpu.func @softmax_dim_0(%arg0: vector<256x128xf32>) -> vector<256x128xf32> {
+ %cst = arith.constant dense<0.000000e+00> : vector<128xf32>
+ %0 = math.exp %arg0 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xf32>
+ //CHECK: vector.multi_reduction <add>, {{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} [0] : vector<256x128xf32> to vector<128xf32>
+ %1 = vector.multi_reduction <add>, %0, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} [0] : vector<256x128xf32> to vector<128xf32>
+ //CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<128xf32> to vector<256x128xf32>
+ %2 = vector.broadcast %1 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<128xf32> to vector<256x128xf32>
+ %3 = arith.divf %0, %2 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xf32>
+ gpu.return %3 : vector<256x128xf32>
+}
+
}
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 6be2371..bb37902 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -62,28 +62,28 @@ gpu.func @create_nd_tdesc_7(%src: memref<8x24x32x48x64xf32>) {
}
-// CHECK: gpu.func @test_create_nd_tdesc_7(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index, %[[arg5:.*]]: memref<24x32xf32>)
+// CHECK: gpu.func @test_create_nd_tdesc_7(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index, %[[arg5:.*]]: memref<24x32xf32>)
gpu.func @test_create_nd_tdesc_7(%src: ui64, %w : index, %h : index, %x : index, %y : index, %src2: memref<24x32xf32>) {
//CHECK: %[[C:.*]] = arith.constant 1 : index
%c1 = arith.constant 1 : index
-
- // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg5]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg5]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
%3 = xegpu.create_nd_tdesc %src2 : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-
+
gpu.return
}
-// CHECK: gpu.func @test_create_nd_tdesc_8(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index)
+// CHECK: gpu.func @test_create_nd_tdesc_8(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index)
gpu.func @test_create_nd_tdesc_8(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
-
- %c1 = arith.constant 1 : index
- // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0], shape : [%arg2, %arg1], strides : [%arg1, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0, shape : [%arg2, %arg1], strides : [%arg1, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
%2 = xegpu.create_nd_tdesc %src, shape : [%h, %w], strides : [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
-
+
gpu.return
}
-// CHECK-LABEL: func @test_create_nd_tdesc_9({{.*}})
+// CHECK-LABEL: func @test_create_nd_tdesc_9({{.*}})
gpu.func @test_create_nd_tdesc_9(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
@@ -94,10 +94,10 @@ gpu.func @test_create_nd_tdesc_9(%src: memref<?x?xf16>, %w : index, %h : index,
gpu.return
}
-// CHECK-LABEL: func @test_create_nd_tdesc_10({{.*}})
-gpu.func @test_create_nd_tdesc_10(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
+// CHECK-LABEL: func @test_create_nd_tdesc_10({{.*}})
+gpu.func @test_create_nd_tdesc_10(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
%c1 = arith.constant 1 : index
- // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0], shape : [%arg2, %arg1], strides : [%arg1, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0, shape : [%arg2, %arg1], strides : [%arg1, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
%2 = xegpu.create_nd_tdesc %src, shape:[%h, %w], strides:[%w, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
gpu.return
@@ -123,7 +123,7 @@ gpu.func @prefetch_nd_2(%src: memref<48x64xf16>) {
// CHECK: gpu.func @prefetch_nd_offset_1(%[[arg0:.*]]: memref<48x64xf16>, %arg1: index, %arg2: index) {
gpu.func @prefetch_nd_offset_1(%src: memref<48x64xf16>, %x : index, %y : index) {
- // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<48x64xf16> -> !xegpu.tensor_desc<8x16xf16>
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]] : memref<48x64xf16> -> !xegpu.tensor_desc<8x16xf16>
%1 = xegpu.create_nd_tdesc %src : memref<48x64xf16> -> !xegpu.tensor_desc<8x16xf16>
// CHECK: xegpu.prefetch_nd %[[R0]][%arg1, %arg2] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf16>
xegpu.prefetch_nd %1[%x, %y] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<8x16xf16>
@@ -271,7 +271,7 @@ gpu.func @subgroup_load_nd_8(%src: memref<24x32xf32>) {
// CHECK: func @subgroup_load_nd_offset_1(%[[arg0:.*]]: memref<24x32xf32>, %arg1: index, %arg2: index) {
gpu.func @subgroup_load_nd_offset_1(%src: memref<24x32xf32>, %x : index, %y : index) {
- // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0 : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
%1 = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
// CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]][%arg1, %arg2] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
%2 = xegpu.load_nd %1[%x, %y] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
@@ -290,7 +290,7 @@ gpu.func @simt_load_nd_8(%src: memref<24x32xf32>) {
// CHECK: func @simt_load_nd_offset_1(%[[arg0:.*]]: memref<24x32xf32>) {
gpu.func @simt_load_nd_offset_1(%src: memref<24x32xf32>) {
- // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0 : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
%1 = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
// CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8xf32>
%2 = xegpu.load_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8xf32>
@@ -323,7 +323,7 @@ gpu.func @simt_store_nd(%src: memref<24x32xf16>) {
gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>, %x : index) {
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
%1 = arith.constant dense<1.0>: vector<32xf16>
- // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
%2 = xegpu.create_nd_tdesc %dst : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
// CHECK: xegpu.store_nd %[[C]], %[[R0]][%arg1] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16>
xegpu.store_nd %1, %2[%x] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<32xf16>, !xegpu.tensor_desc<32xf16>
@@ -356,7 +356,7 @@ gpu.func @simt_store_nd_2(%src: memref<24x32xf16>) {
gpu.func @simt_store_nd_offset_1(%src: memref<24x32xf16>) {
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
%1 = arith.constant dense<1.0>: vector<2xf16>
- // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0 : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
%2 = xegpu.create_nd_tdesc %src : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
// CHECK: xegpu.store_nd %[[C]], %[[R0]][0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf16>, !xegpu.tensor_desc<32xf16>
xegpu.store_nd %1, %2[0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<2xf16>, !xegpu.tensor_desc<32xf16>
@@ -508,6 +508,34 @@ gpu.func @simt_load_3(%src: ui64) {
gpu.return
}
+// CHECK: gpu.func @simt_load_4(%[[arg0:.*]]: memref<256xf16>, %[[arg1:.*]]: vector<1xindex>, %[[arg2:.*]]: vector<1xi1>) {
+gpu.func @simt_load_4(%arg0: memref<256xf16>, %arg1: vector<1xindex>, %arg2: vector<1xi1>) {
+ // CHECK: %0 = xegpu.load %[[arg0]][%[[arg1]]], %[[arg2]] <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+ %0 = xegpu.load %arg0[%arg1], %arg2 <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+ gpu.return
+}
+
+// CHECK: gpu.func @simt_load_5(%[[arg0:.*]]: memref<256xf16>, %[[arg1:.*]]: vector<1xindex>, %[[arg2:.*]]: vector<1xi1>) {
+gpu.func @simt_load_5(%arg0: memref<256xf16>, %arg1: vector<1xindex>, %arg2: vector<1xi1>) {
+ // CHECK: %0 = xegpu.load %[[arg0]][%[[arg1]]], %[[arg2]] : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+ %0 = xegpu.load %arg0[%arg1], %arg2 : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+ gpu.return
+}
+
+// CHECK: gpu.func @simt_load_6(%[[arg0:.*]]: memref<256xf16>, %[[arg1:.*]]: index, %[[arg2:.*]]: i1) {
+gpu.func @simt_load_6(%arg0: memref<256xf16>, %arg1: index, %arg2: i1) {
+ // CHECK: %0 = xegpu.load %[[arg0]][%[[arg1]]], %[[arg2]] <{chunk_size = 8 : i64}> : memref<256xf16>, index, i1 -> vector<8xf16>
+ %0 = xegpu.load %arg0[%arg1], %arg2 <{chunk_size = 8 : i64}> : memref<256xf16>, index, i1 -> vector<8xf16>
+ gpu.return
+}
+
+// CHECK: gpu.func @simt_load_7(%[[arg0:.*]]: memref<256xf16>, %[[arg1:.*]]: index, %[[arg2:.*]]: i1) {
+gpu.func @simt_load_7(%arg0: memref<256xf16>, %arg1: index, %arg2: i1) {
+ // CHECK: %0 = xegpu.load %[[arg0]][%[[arg1]]], %[[arg2]] : memref<256xf16>, index, i1 -> f16
+ %0 = xegpu.load %arg0[%arg1], %arg2 : memref<256xf16>, index, i1 -> f16
+ gpu.return
+}
+
// CHECK: gpu.func @subgroup_load_4(%[[arg0:.*]]: ui64) {
gpu.func @subgroup_load_4(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<{{.*}}> : vector<2x4xindex>
@@ -621,6 +649,34 @@ gpu.func @simt_store_3(%src: ui64) {
gpu.return
}
+// CHECK: gpu.func @simt_store_4(%[[arg0:.*]]: vector<8xf16>, %[[arg1:.*]]: memref<256xf16>, %[[arg2:.*]]: vector<1xindex>, %[[arg3:.*]]: vector<1xi1>) {
+gpu.func @simt_store_4(%arg0: vector<8xf16>, %arg1: memref<256xf16>, %arg2: vector<1xindex>, %arg3: vector<1xi1>) {
+ // CHECK: xegpu.store %[[arg0]], %[[arg1]][%[[arg2]]], %[[arg3]] <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+ xegpu.store %arg0, %arg1[%arg2], %arg3 <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+ gpu.return
+}
+
+// CHECK: gpu.func @simt_store_5(%[[arg0:.*]]: vector<8xf16>, %[[arg1:.*]]: memref<256xf16>, %[[arg2:.*]]: index, %[[arg3:.*]]: i1) {
+gpu.func @simt_store_5(%arg0: vector<8xf16>, %arg1: memref<256xf16>, %arg2: index, %arg3: i1) {
+ // CHECK: xegpu.store %[[arg0]], %[[arg1]][%[[arg2]]], %[[arg3]] <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, index, i1
+ xegpu.store %arg0, %arg1[%arg2], %arg3 <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, index, i1
+ gpu.return
+}
+
+// CHECK: gpu.func @simt_store_6(%[[arg0:.*]]: vector<1xf16>, %[[arg1:.*]]: memref<256xf16>, %[[arg2:.*]]: vector<1xindex>, %[[arg3:.*]]: vector<1xi1>) {
+gpu.func @simt_store_6(%arg0: vector<1xf16>, %arg1: memref<256xf16>, %arg2: vector<1xindex>, %arg3: vector<1xi1>) {
+ // CHECK: xegpu.store %[[arg0]], %[[arg1]][%[[arg2]]], %[[arg3]] : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+ xegpu.store %arg0, %arg1[%arg2], %arg3 : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+ gpu.return
+}
+
+// CHECK: gpu.func @simt_store_7(%[[arg0:.*]]: f16, %[[arg1:.*]]: memref<256xf16>, %[[arg2:.*]]: index, %[[arg3:.*]]: i1) {
+gpu.func @simt_store_7(%arg0: f16, %arg1: memref<256xf16>, %arg2: index, %arg3: i1) {
+ // CHECK: xegpu.store %[[arg0]], %[[arg1]][%[[arg2]]], %[[arg3]] : f16, memref<256xf16>, index, i1
+ xegpu.store %arg0, %arg1[%arg2], %arg3 : f16, memref<256xf16>, index, i1
+ gpu.return
+}
+
// CHECK: gpu.func @subgroup_store_4(%[[arg0:.*]]: ui64) {
gpu.func @subgroup_store_4(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<{{.*}}> : vector<2x4xindex>
@@ -662,8 +718,8 @@ gpu.func @prefetch(%src: ui64) {
gpu.func @prefetch_offset(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
- // CHECK: xegpu.prefetch %[[arg0]][%cst] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : ui64, vector<4xindex>
- xegpu.prefetch %src[%0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: ui64, vector<4xindex>
+ // CHECK: xegpu.prefetch %[[arg0]][%cst] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, offset_align_byte = 2 : i64}> : ui64, vector<4xindex>
+ xegpu.prefetch %src[%0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, offset_align_byte = 2}>: ui64, vector<4xindex>
gpu.return
}
@@ -751,4 +807,72 @@ gpu.func @fence() {
gpu.return
}
+// CHECK-LABEL: gpu.func @create_mem_desc({{.*}}) {
+gpu.func @create_mem_desc() {
+ //CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<2048xi8, 3>
+ //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[alloc]] : memref<2048xi8, 3> -> !xegpu.mem_desc<16x64xf16>
+ %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 3>
+ %mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 3> -> !xegpu.mem_desc<16x64xf16>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @create_mem_desc_with_stride({{.*}}) {
+gpu.func @create_mem_desc_with_stride() {
+ //CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<2048xi8, 3>
+ //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[alloc]] : memref<2048xi8, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 3>
+ %mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ gpu.return
+}
+
+// CHECK: gpu.func @load_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
+gpu.func @load_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16>
+ %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16>
+ gpu.return
+}
+
+// CHECK: gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
+gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
+ // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8x16xf16>
+ %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8x16xf16>
+ gpu.return
+}
+
+
+// CHECK: gpu.func @store_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>)
+gpu.func @store_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) {
+ // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+ xegpu.store_matrix %arg1, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+ gpu.return
+}
+
+// CHECK: gpu.func @store_mem_desc_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, [[ARG1:%.+]]: vector<16x16xf16>)
+gpu.func @store_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<16x16xf16>) {
+ // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][0, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ xegpu.store_matrix %arg1, %arg0[0, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ gpu.return
+}
+
+// CHECK: gpu.func @mem_desc_subview([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
+gpu.func @mem_desc_subview(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>>
+ %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>>
+ gpu.return
+}
+
+// CHECK: gpu.func @mem_desc_subview_lower_rank([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
+gpu.func @mem_desc_subview_lower_rank(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout<stride = [64, 1]>>
+ %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout<stride = [64, 1]>>
+ gpu.return
+}
+
+// CHECK: gpu.func @mem_desc_subview_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
+gpu.func @mem_desc_subview_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
+ //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ gpu.return
+}
+
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir b/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
new file mode 100644
index 0000000..547c735
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt --test-xegpu-layout-interface --cse -split-input-file %s | FileCheck %s
+
+//CHECk: #map = affine_map<()[s0] -> (s0 floordiv 8)>
+gpu.module @test {
+ gpu.func @slice_attr() -> vector<128xindex> {
+ //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
+ //CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
+ //CHECK: [[c32:%.+]] = arith.constant 32 : index
+ //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
+ //CHECK: [[c0:%.+]] = arith.constant 0 : index
+ //CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
+ //CHECK: [[c128:%.+]] = arith.constant 128 : index
+ //CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
+ //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
+ //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
+ //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
+ %step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex>
+ gpu.return %step : vector<128xindex>
+ }
+
+ gpu.func @nested_slice_attr() -> vector<128xindex> {
+ //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
+ //CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
+ //CHECK: [[c32:%.+]] = arith.constant 32 : index
+ //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
+ //CHECK: [[c0:%.+]] = arith.constant 0 : index
+ //CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
+ //CHECK: [[c128:%.+]] = arith.constant 128 : index
+ //CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
+ //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
+ //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
+ //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
+ %0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 1], sg_data = [32, 32, 1]>, dims = [2]>, dims = [1]>} : vector<128xindex>
+ gpu.return %0 : vector<128xindex>
+ }
+
+} \ No newline at end of file
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index 628a485..e5cc65e 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -1,5 +1,8 @@
// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
+#map = affine_map<()[s0] -> (s0 floordiv 4)>
+#map1 = affine_map<()[s0] -> (s0 mod 4)>
+
gpu.module @test_round_robin_assignment {
// CHECK-LABEL: create_nd_tdesc
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -12,6 +15,30 @@ gpu.module @test_round_robin_assignment {
gpu.return
}
+ // CHECK-LABEL: create_nd_tdesc_with_shared_data
+ // CHECK-SAME: [[ARG_0:%.*]]: memref<256x128xf32>
+ gpu.func @create_nd_tdesc_with_shared_data(%src: memref<256x128xf32>) {
+ //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
+ //CHECK: [[IdY:%.+]] = affine.apply #map()[[[sgId]]]
+ //CHECK: [[IdX:%.+]] = affine.apply #map1()[[[sgId]]]
+ //CHECK: [[C16:%.+]] = arith.constant 16 : index
+ //CHECK: [[LY:%.+]] = index.mul [[IdY]], [[C16]]
+ //CHECK: [[C64:%.+]] = arith.constant 64 : index
+ //CHECK: [[LX:%.+]] = index.mul [[IdX]], [[C64]]
+ //CHECK: [[C0:%.+]] = arith.constant 0 : index
+ //CHECK: [[C0_1:%.+]] = arith.constant 0 : index
+ //CHECK: [[ADDY:%.+]] = arith.addi [[LY]], [[C0]] : index
+ //CHECK: [[ADDX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
+ //CHECK: [[C128:%.+]] = arith.constant 128 : index
+ //CHECK: [[offY:%.+]] = index.remu [[ADDY]], [[C128]]
+ //CHECK: [[C64_2:%.+]] = arith.constant 64 : index
+ //CHECK: [[offX:%.+]] = index.remu [[ADDX]], [[C64_2]]
+ //CHECK: xegpu.create_nd_tdesc [[ARG_0]][[[offY]], [[offX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>>
+ gpu.return
+ }
+
// CHECK-LABEL: load_nd_tdesc
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
new file mode 100644
index 0000000..6ff7a94
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -0,0 +1,85 @@
+// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
+
+gpu.module @test_distribution {
+ // CHECK-LABEL: create_nd_tdesc_no_offset
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @create_nd_tdesc_no_offset(%src: memref<256x128xf32>) {
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-NOT: xegpu.create_nd_tdesc
+ %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+ }
+
+ // CHECK-LABEL: load_nd_tdesc_with_offset
+ gpu.func @load_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
+ // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}]
+ // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-SAME-COUNT-4: -> vector<16x16xf32>
+ // CHECK-NOT: xegpu.load_nd
+ %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc[0, 0]
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
+ gpu.return
+ }
+
+ // CHECK-LABEL: store_nd_with_offset
+ gpu.func @store_nd_with_offset(%src: memref<256x128xf32>) {
+ // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}]
+ // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-NOT: xegpu.store_nd
+ %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc[0, 0]
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
+ xegpu.store_nd %load, %tdesc[0, 0]
+ : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+ }
+
+ // CHECK-LABEL: prefetch_nd_tdesc_with_offset
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @prefetch_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
+ // CHECK-COUNT-4: xegpu.prefetch_nd {{%.*}}[{{%.*}}, {{%.*}}]
+ // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-NOT: xegpu.prefetch_nd
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ xegpu.prefetch_nd %tdesc[0, 0]
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+ }
+
+ // CHECK-LABEL: dpas
+ // CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf16>, %[[ARG_1:.*]]: memref<128x256xf16>)
+ gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) {
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf16>
+ // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-NOT: xegpu.create_nd_tdesc
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf16>
+ // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>>
+ // CHECK-NOT: xegpu.create_nd_tdesc
+ // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
+ // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ // CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
+ // CHECK-NOT: xegpu.dpas
+ %tdesc_a = xegpu.create_nd_tdesc %a : memref<256x128xf16>
+ -> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ %load_a = xegpu.load_nd %tdesc_a[0, 0]
+ : !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf16>
+ %tdesc_b = xegpu.create_nd_tdesc %b : memref<128x256xf16>
+ -> !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
+ %load_b = xegpu.load_nd %tdesc_b[0, 0]
+ : !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
+ -> vector<128x256xf16>
+ %dpas = xegpu.dpas %load_a, %load_b
+ {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<256x128xf16>, vector<128x256xf16> -> vector<256x256xf32>
+ gpu.return
+ }
+}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
new file mode 100644
index 0000000..afb2bf8
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -0,0 +1,368 @@
+// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
+
+//CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)>
+//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
+gpu.module @test_distribution {
+ // CHECK-LABEL: create_nd_tdesc_no_offset
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @create_nd_tdesc_no_offset(%src: memref<256x128xf32>) {
+ // CHECK: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+ }
+
+ // CHECK-LABEL: create_nd_tdesc_with_ptr
+ // CHECK-SAME: %[[ARG_0:.*]]: ui64
+ gpu.func @create_nd_tdesc_with_ptr(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
+ // CHECK: xegpu.create_nd_tdesc %[[ARG_0]], shape : [{{.*}}, {{.*}}], strides : [{{.*}}, {{.*}}] : ui64
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %c1 = arith.constant 1 : index
+ %tdesc = xegpu.create_nd_tdesc %src, shape:[%h, %w], strides: [%w, %c1] : ui64
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+ }
+
+ // CHECK-LABEL: load_nd_tdesc_with_offset
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @load_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
+ //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
+ //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
+ //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
+ //CHECK: %[[LOAD:.*]] = xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32>
+ %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc[0, 0]
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
+ gpu.return
+ }
+
+ // CHECK-LABEL: store_nd_with_offsets
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @store_nd_with_offsets(%src: memref<256x128xf32>) {
+ //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
+ //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
+ //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
+ //CHECK: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}] : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc[0, 0]
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
+ xegpu.store_nd %load, %tdesc[0, 0]
+ : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+}
+
+ // CHECK-LABEL: prefetch_nd_tdesc_with_offset
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @prefetch_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
+ //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
+ //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
+ //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
+ //CHECK: xegpu.prefetch_nd %{{.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %cst0 = arith.constant 0 : index
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ xegpu.prefetch_nd %tdesc[%cst0, %cst0]
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+ }
+
+ // CHECK-LABEL: dpas
+ gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
+ // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32>
+ %tdesc_a = xegpu.create_nd_tdesc %a : memref<128x128xf16>
+ -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
+ %load_a = xegpu.load_nd %tdesc_a[0, 0]
+ : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<128x128xf16>
+ %tdesc_b = xegpu.create_nd_tdesc %b : memref<128x128xf16>
+ -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
+ %load_b = xegpu.load_nd %tdesc_b[0, 0]
+ : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
+ -> vector<128x128xf16>
+ %dpas = xegpu.dpas %load_a, %load_b
+ {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
+ gpu.return
+ }
+
+ // CHECK-LABEL: dpas_no_sg_data
+ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
+ // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
+ %tdesc_a = xegpu.create_nd_tdesc %a : memref<128x128xf16>
+ -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
+ order = [1, 0]>>
+ %load_a = xegpu.load_nd %tdesc_a[0, 0]
+ : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
+ order = [1, 0]>>
+ -> vector<128x128xf16>
+ %tdesc_b = xegpu.create_nd_tdesc %b : memref<128x128xf16>
+ -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
+ order = [1, 0]>>
+ %load_b = xegpu.load_nd %tdesc_b[0, 0]
+ : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
+ order = [1, 0]>>
+ -> vector<128x128xf16>
+ %dpas = xegpu.dpas %load_a, %load_b
+ {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>}
+ : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
+ gpu.return
+ }
+
+ // CHECK-LABEL: dpas_with_no_create_nd_desc
+ gpu.func @dpas_with_no_create_nd_desc(%a: vector<256x128xf32>, %b: vector<128x256xf32>) {
+ // CHECK-NOT: vector<32x32xf32>
+ %dpas = xegpu.dpas %a, %b
+ {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32>
+ gpu.return
+ }
+
+ // CHECK-LABEL: broadcast_dim1
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x1xf32>
+ gpu.func @broadcast_dim1(%src: memref<256x1xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x1xf32>
+ -> !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc[0, 0]
+ : !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
+ -> vector<256x1xf32>
+ // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<32x1xf32> to vector<32x32xf32>
+ %broadcast = vector.broadcast %load
+ {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 32], lane_layout = [8, 1], lane_data = [1, 1]>}
+ : vector<256x1xf32> to vector<256x32xf32>
+ gpu.return
+ }
+
+ // CHECK-LABEL: broadcast_dim0
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<1x128xf32>
+ gpu.func @broadcast_dim0(%src: memref<1x128xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src : memref<1x128xf32>
+ -> !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc[0, 0]
+ : !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<1x128xf32>
+ // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<1x32xf32> to vector<32x32xf32>
+ %broadcast = vector.broadcast %load
+ {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<1x128xf32> to vector<32x128xf32>
+ gpu.return
+ }
+
+ // CHECK-LABEL: gemm_with_load_store_offset
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<1024x1024xf16>, %[[ARG_1:.*]]: memref<1024x1024xf16>, %[[ARG_2:.*]]: memref<1024x1024xf32>
+ gpu.func @gemm_with_load_store_offset(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
+ //CHECK: [[c0:%.+]] = arith.constant 0 : index
+ //CHECK: [[c128:%.+]] = arith.constant 128 : index
+ //CHECK: [[c1024:%.+]] = arith.constant 1024 : index
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ %c1024 = arith.constant 1024 : index
+ %block_id_x = gpu.block_id x
+ %block_id_y = gpu.block_id y
+ %0 = arith.muli %block_id_x, %c128 : index
+ %1 = arith.muli %block_id_y, %c128 : index
+ %2 = xegpu.create_nd_tdesc %arg2 : memref<1024x1024xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+ // CHECK: [[DESC_A:%.+]] = xegpu.create_nd_tdesc %[[ARG_0]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x128xf16>
+ // CHECK: [[DESC_B:%.+]] = xegpu.create_nd_tdesc %[[ARG_1]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x16xf16>
+ %3 = xegpu.create_nd_tdesc %arg0 : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>
+ %4 = xegpu.create_nd_tdesc %arg1 : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>
+ // load_nd with offset
+ %5 = xegpu.load_nd %2[%0, %1] : !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>> -> vector<128x128xf32>
+ %6 = xegpu.load_nd %3[%0, %c0] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>> -> vector<128x128xf16>
+ %7 = xegpu.load_nd %4[%c0, %1] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>> -> vector<128x128xf16>
+ // scf.for loop
+ // CHECK: [[scf:%.+]]:3 = scf.for [[arg3:%.+]] = [[c0]] to [[c1024]] step [[c128]]
+ // CHECK-SAME: iter_args([[arg4:%.+]] = {{.*}}, [[arg5:%.+]] = {{.*}}, [[arg6:%.+]] = {{.*}}) ->
+ // CHECK-SAME: (vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32>)
+ // CHECK: [[c:%.+]] = xegpu.dpas [[arg4]], [[arg5]], [[arg6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32>
+ // CHECK: [[a:%.+]] = xegpu.load_nd [[DESC_A]][{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16>
+ // CHECK: [[b:%.+]] = xegpu.load_nd [[DESC_B]][{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16>
+ // CHECK: scf.yield [[a]], [[b]], [[c]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32>
+ %8:3 = scf.for %arg3 = %c0 to %c1024 step %c128 iter_args(%arg4 = %6, %arg5 = %7, %arg6 = %5)
+ -> (vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32>) {
+ // load_nd with offset inside loop
+ %9 = xegpu.dpas %arg4, %arg5, %arg6 {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}
+ : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32>
+ %10 = xegpu.load_nd %3[%arg3, %c0] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>> -> vector<128x128xf16>
+ %11 = xegpu.load_nd %4[%c0, %arg3] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>> -> vector<128x128xf16>
+ scf.yield %10, %11, %9 : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32>
+ }
+ // store_nd with offset
+ xegpu.store_nd %8#2, %2[%0, %1] : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+ gpu.return
+ }
+
+ // CHECK-LABEL: @subgroup_id_range
+ gpu.func @subgroup_id_range(%src: memref<256x128xf32>, %src1: memref<128x256xf32>, %src2: memref<128x64xf32>) {
+ %sg_id = gpu.subgroup_id : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c31 = arith.constant 31 : index
+ %c3 = arith.constant 3 : index
+ %cond1 = arith.cmpi sge, %sg_id, %c0 : index
+ %cond2 = arith.cmpi slt, %sg_id, %c1 : index
+ %cond = arith.andi %cond1, %cond2 : i1
+ scf.if %cond {
+ // CHECK-NOT: index.sub
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc[0, 0]
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
+ } {sg_id_range = #xegpu.range<[0, 32]>}
+ %cond3 = arith.cmpi sge, %sg_id, %c2 : index
+ %cond4 = arith.cmpi slt, %sg_id, %c31 : index
+ %cond5 = arith.andi %cond3, %cond4 : i1
+ scf.if %cond5 {
+ // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]]
+ %tdesc = xegpu.create_nd_tdesc %src2 : memref<128x64xf32>
+ -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc[0, 0]
+ : !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+ -> vector<128x64xf32>
+ %exp = math.exp %load {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
+ }{sg_id_range = #xegpu.range<[2, 18]>}
+ gpu.return
+ }
+
+ // CHECK-LABEL: @subgroup_id_range_nested_if
+ gpu.func @subgroup_id_range_nested_if(%src: memref<256x128xf32>, %src1: memref<128x64xf32>) {
+ %sg_id = gpu.subgroup_id : index
+ %c1 = arith.constant 1 : i1
+ %c3 = arith.constant 3 : index
+ %c32 = arith.constant 32 : index
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc[0, 0]
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
+ %cond1 = arith.cmpi sge, %sg_id, %c3 : index
+ %cond2 = arith.cmpi slt, %sg_id, %c32 : index
+ %cond = arith.andi %cond1, %cond2 : i1
+ scf.if %c1 {
+ scf.if %cond {
+ // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
+ // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C3]]
+ %td = xegpu.create_nd_tdesc %src1 : memref<128x64xf32>
+ -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+ %ld = xegpu.load_nd %td[0, 0]
+ : !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+ -> vector<128x64xf32>
+ %exp = math.exp %ld {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
+ }
+ } {sg_id_range = #xegpu.range<[3, 19]>}
+ gpu.return
+ }
+
+ // CHECK-LABEL: @load_gather
+ // CHECK-SAME: %[[ARG0:.*]]: memref<?xf16>
+ gpu.func @load_gather(%src : memref<?xf16>) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<32x4xindex>
+ // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<32x4xi1>
+ // CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}>
+ // CHECK-SAME: : memref<?xf16>, vector<32x4xindex>, vector<32x4xi1> -> vector<32x4xf16>
+ %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>} dense<0> : vector<256x16xindex>
+ %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>} dense<1> : vector<256x16xi1>
+ %load = xegpu.load %src[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>, l1_hint = #xegpu.cache_hint<cached>}
+ : memref<?xf16>, vector<256x16xindex>, vector<256x16xi1> -> vector<256x16xf16>
+ gpu.return
+ }
+
+ // CHECK-LABEL: @store_scatter
+ // CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>
+ gpu.func @store_scatter(%dest : memref<256xf16>) {
+ // CHECK: %[[VAL:.*]] = arith.constant dense<2.550000e+01> : vector<8xf16>
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
+ // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
+ // CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}>
+ // CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1>
+ %val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<25.5> : vector<256xf16>
+ %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
+ %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
+ xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout = #xegpu.layout<sg_layout = [32], sg_data = [8]>, l1_hint = #xegpu.cache_hint<cached>}
+ : vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1>
+ gpu.return
+ }
+
+ // CHECK-LABEL: @load_with_non_unit_chunk_size
+ // CHECK-SAME: %[[ARG0:.*]]: memref<?xf16>
+ gpu.func @load_with_non_unit_chunk_size(%src : memref<?xf16>) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
+ // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
+ // CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 4 : i64, l1_hint = #xegpu.cache_hint<cached>}>
+ // CHECK-SAME: : memref<?xf16>, vector<8xindex>, vector<8xi1> -> vector<8x4xf16>
+ %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
+ %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
+ %load = xegpu.load %src[%offset], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 4]>, l1_hint = #xegpu.cache_hint<cached>}
+ : memref<?xf16>, vector<256xindex>, vector<256xi1> -> vector<256x4xf16>
+ gpu.return
+ }
+
+ // CHECK-LABEL: distribute_load_matrix
+ // CHECK-SAME: [[arg0:%.+]]: memref<32768xi8, 3>
+ gpu.func @distribute_load_matrix(%arg0: memref<32768xi8, 3>) {
+ //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
+ //CHECK: [[sgid:%.+]] = gpu.subgroup_id : index
+ //CHECK: [[c2:%.+]] = arith.constant 2 : index
+ //CHECK: [[c4:%.+]] = arith.constant 4 : index
+ //CHECK: [[c4_0:%.+]] = arith.constant 4 : index
+ //CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]]
+ //CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]]
+ //CHECK: [[c32:%.+]] = arith.constant 32 : index
+ //CHECK: [[l_off_y:%.+]] = index.mul [[id_y]], [[c32]]
+ //CHECK: [[c32_1:%.+]] = arith.constant 32 : index
+ //CHECK: [[l_off_x:%.+]] = index.mul [[id_x]], [[c32_1]]
+ //CHECK: [[c0:%.+]] = arith.constant 0 : index
+ //CHECK: [[c0_1:%.+]] = arith.constant 0 : index
+ //CHECK: [[l_off_y_0:%.+]] = arith.addi [[l_off_y]], [[c0]] : index
+ //CHECK: [[l_off_x_0:%.+]] = arith.addi [[l_off_x]], [[c0_1]] : index
+ //CHECK: [[c64:%.+]] = arith.constant 64 : index
+ //CHECK: [[off_y:%.+]] = index.remu [[l_off_y_0]], [[c64]]
+ //CHECK: [[c128:%.+]] = arith.constant 128 : index
+ //CHECK: [[off_x:%.+]] = index.remu [[l_off_x_0]], [[c128]]
+ //CHECK: xegpu.load_matrix [[mdesc]][[[off_y]], [[off_x]]] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32>, index, index -> vector<32x32xf32>
+ %0 = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
+ %1 = xegpu.load_matrix %0[0, 0] <{layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32], lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32> -> vector<64x128xf32>
+ gpu.return
+ }
+
+ //CHECK-LABEL: distribute_store_matrix
+ //CHECK-SAME: [[arg0:%.+]]: memref<32768xi8, 3>
+ gpu.func @distribute_store_matrix(%arg0 : memref<32768xi8, 3>) {
+ //CHECK: [[cst:%.+]] = arith.constant dense<1.000000e+00> : vector<32x32xf32>
+ //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
+ //CHECK: [[sgid:%.+]] = gpu.subgroup_id : index
+ //CHECK: [[c2:%.+]] = arith.constant 2 : index
+ //CHECK: [[c4:%.+]] = arith.constant 4 : index
+ //CHECK: [[c4_0:%.+]] = arith.constant 4 : index
+ //CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]]
+ //CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]]
+ //CHECK: [[c32:%.+]] = arith.constant 32 : index
+ //CHECK: [[l_off_y_0:%.+]] = index.mul [[id_y]], [[c32]]
+ //CHECK: [[c32_1:%.+]] = arith.constant 32 : index
+ //CHECK: [[l_off_x_0:%.+]] = index.mul [[id_x]], [[c32_1]]
+ //CHECK: [[c0:%.+]] = arith.constant 0 : index
+ //CHECK: [[c0_2:%.+]] = arith.constant 0 : index
+ //CHECK: [[l_off_y:%.+]] = arith.addi [[l_off_y_0]], [[c0]] : index
+ //CHECK: [[l_off_x:%.+]] = arith.addi [[l_off_x_0]], [[c0_2]] : index
+ //CHECK: [[c64:%.+]] = arith.constant 64 : index
+ //CHECK: [[off_y:%.+]] = index.remu [[l_off_y]], [[c64]]
+ //CHECK: [[c128:%.+]] = arith.constant 128 : index
+ //CHECK: [[off_x:%.+]] = index.remu [[l_off_x]], [[c128]]
+ //CHECK: xegpu.store_matrix [[cst]], [[mdesc]][[[off_y]], [[off_x]]] : vector<32x32xf32>, !xegpu.mem_desc<64x128xf32>, index, index
+ %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} dense<1.0> : vector<64x128xf32>
+ %mdesc = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
+ xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32>
+ gpu.return
+ }
+}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index d4b0037..f4a49da 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -4,34 +4,26 @@
//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
gpu.module @test_1_1_assignment {
// CHECK-LABEL: create_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ // CHECK-SAME: [[ARG_0:%.*]]: memref<256x128xf32>
gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
- // CHECK: %[[SGID:.*]] = gpu.subgroup_id
- // CHECK: %[[C8:.*]] = arith.constant 8 : index
- // CHECK: %[[C32:.*]] = arith.constant 32 : index
- // CHECK: %[[C4:.*]] = arith.constant 4 : index
- // CHECK: %[[C32_0:.*]] = arith.constant 32 : index
- // CHECK: %[[C4_1:.*]] = arith.constant 4 : index
- // CHECK: %[[DIV:.*]] = affine.apply #map()[%[[SGID]]]
- // CHECK: %[[REM:.*]] = affine.apply #map1()[%[[SGID]]]
- // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C32]]
- // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C32_0]]
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[C256:.*]] = arith.constant 256 : index
- // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C256]]
- // CHECK: %[[C0_2:.*]] = arith.constant 0 : index
- // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0_2]]
- // CHECK: %[[C0_3:.*]] = arith.constant 0 : index
- // CHECK: %[[C128:.*]] = arith.constant 128 : index
- // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C128]]
- // CHECK: %[[C0_4:.*]] = arith.constant 0 : index
- // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_4]]
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<256x128xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- // CHECK: gpu.return
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
- -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
- gpu.return
+ //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
+ //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
+ //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
+ //CHECK: [[C32:%.+]] = arith.constant 32 : index
+ //CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
+ //CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32]]
+ //CHECK: [[C0:%.+]] = arith.constant 0 : index
+ //CHECK: [[C0_1:%.+]] = arith.constant 0 : index
+ //CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
+ //CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
+ //CHECK: [[C256:%.+]] = arith.constant 256 : index
+ //CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]]
+ //CHECK: [[C128:%.+]] = arith.constant 128 : index
+ //CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]]
+ //CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][[[Y]], [[X]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
}
// CHECK-LABEL: load_nd_tdesc
@@ -347,7 +339,7 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
// CHECK-LABEL: @subgroup_id_range_nested_if
gpu.func @subgroup_id_range_nested_if(%src: memref<256x128xf32>, %src1: memref<128x64xf32>) {
%sg_id = gpu.subgroup_id : index
- %c1 = arith.constant 1 : i1
+ %c1 = arith.constant 1 : i1
%c3 = arith.constant 3 : index
%c32 = arith.constant 32 : index
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
@@ -373,4 +365,11 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
} {sg_id_range = #xegpu.range<[3, 19]>}
gpu.return
}
+
+ // CHECK-LABEL: distribute_constant
+ gpu.func @distribute_constant() {
+ // CHECK: arith.constant dense<1.000000e+00> : vector<32x32xf32>
+ %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} dense<1.0> : vector<256x128xf32>
+ gpu.return
+ }
}
diff --git a/mlir/test/Dialect/common_folders.mlir b/mlir/test/Dialect/common_folders.mlir
new file mode 100644
index 0000000..92598b4
--- /dev/null
+++ b/mlir/test/Dialect/common_folders.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s --test-fold-type-converting-op --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @test_fold_unary_op_f32_to_si32(
+func.func @test_fold_unary_op_f32_to_si32() -> tensor<4x2xsi32> {
+ // CHECK-NEXT: %[[POSITIVE_ONE:.*]] = arith.constant dense<1> : tensor<4x2xsi32>
+ // CHECK-NEXT: return %[[POSITIVE_ONE]] : tensor<4x2xsi32>
+ %operand = arith.constant dense<5.1> : tensor<4x2xf32>
+ %sign = test.sign %operand : (tensor<4x2xf32>) -> tensor<4x2xsi32>
+ return %sign : tensor<4x2xsi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_fold_binary_op_f32_to_i1(
+func.func @test_fold_binary_op_f32_to_i1() -> tensor<8xi1> {
+ // CHECK-NEXT: %[[FALSE:.*]] = arith.constant dense<false> : tensor<8xi1>
+ // CHECK-NEXT: return %[[FALSE]] : tensor<8xi1>
+ %lhs = arith.constant dense<5.1> : tensor<8xf32>
+ %rhs = arith.constant dense<4.2> : tensor<8xf32>
+ %less_than = test.less_than %lhs, %rhs : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xi1>
+ return %less_than : tensor<8xi1>
+}
diff --git a/mlir/test/Examples/standalone/test.toy b/mlir/test/Examples/standalone/test.toy
index c91b3cf..e99bab5 100644
--- a/mlir/test/Examples/standalone/test.toy
+++ b/mlir/test/Examples/standalone/test.toy
@@ -2,10 +2,13 @@
# RUN: -DCMAKE_CXX_COMPILER=%host_cxx -DCMAKE_C_COMPILER=%host_cc \
# RUN: -DLLVM_ENABLE_LIBCXX=%enable_libcxx -DMLIR_DIR=%mlir_cmake_dir \
# RUN: -DLLVM_USE_LINKER=%llvm_use_linker \
-# RUN: -DPython3_EXECUTABLE=%python
-# RUN: "%cmake_exe" --build . --target check-standalone | tee %t | FileCheck %s
+# RUN: -DPython3_EXECUTABLE=%python \
+# RUN: -DPython_EXECUTABLE=%python
+# RUN: "%cmake_exe" --build . --target check-standalone | tee %t
+# RUN: FileCheck --input-file=%t %s
# Note: The number of checked tests is not important. The command will fail
# if any fail.
# CHECK: Passed
+# CHECK-NOT: Failed
# UNSUPPORTED: target={{.*(windows|android).*}}
diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
index 02f7e60..c306341 100644
--- a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
+++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
@@ -40,12 +40,12 @@ func.func @move_before(%cond : i1) {
}
// Check that the driver handles rewriter.moveAfter. In this case, we expect
-// the moved op to be visited only once since walk uses `make_early_inc_range`.
+// the moved op to be visited twice.
// CHECK-LABEL: func.func @move_after(
// CHECK: scf.if
// CHECK: }
// CHECK: "test.move_after_parent_op"
-// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
// CHECK: return
func.func @move_after(%cond : i1) {
scf.if %cond {
@@ -119,3 +119,23 @@ func.func @erase_nested_block() -> i32 {
}): () -> (i32)
return %a : i32
}
+
+
+// CHECK-LABEL: func.func @unreachable_replace_with_new_op
+// CHECK: "test.new_op"
+// CHECK: "test.replace_with_new_op"
+// CHECK-SAME: unreachable
+// CHECK: "test.new_op"
+func.func @unreachable_replace_with_new_op() {
+ "test.br"()[^bb1] : () -> ()
+^bb1:
+ %a = "test.replace_with_new_op"() : () -> (i32)
+ "test.br"()[^end] : () -> () // Test jumping over the unreachable block is visited as well.
+^unreachable:
+ %b = "test.replace_with_new_op"() {test.unreachable} : () -> (i32)
+ return
+^end:
+ %c = "test.replace_with_new_op"() : () -> (i32)
+ return
+}
+
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
index 06a6e22..9d04357 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
@@ -9,7 +9,12 @@
// RUN: FileCheck %s
func.func @matmul_transpose_a(%A : tensor<?x?xf32>, %B : tensor<?x?xf32>, %C : tensor<?x?xf32>) {
- %res = linalg.matmul_transpose_a ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+ %res = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>]
+ ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
%xf = tensor.cast %res : tensor<?x?xf32> to tensor<*xf32>
call @printMemrefF32(%xf) : (tensor<*xf32>) -> ()
@@ -56,7 +61,7 @@ func.func @main() {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module : !transform.any_op {transform.readonly}) {
- %matmul_transpose_a = transform.structured.match ops{["linalg.matmul_transpose_a"]} in %module
+ %matmul_transpose_a = transform.structured.match ops{["linalg.matmul"]} in %module
: (!transform.any_op) -> !transform.any_op
// Step 1: Tile for size [4] x [4], which corresponds to SVLs x SVLs, where
diff --git a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
index 8f74976..25a338d 100644
--- a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
@@ -6,6 +6,15 @@
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -expand-strided-metadata \
+// RUN: -test-cf-assert \
+// RUN: -convert-to-llvm="allow-pattern-rollback=0" \
+// RUN: -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
func.func @main() {
// This buffer is properly aligned. There should be no error.
// CHECK-NOT: ^ memref is not aligned to 8
diff --git a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
index 26c731c..4c6a48d 100644
--- a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
@@ -5,6 +5,14 @@
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -test-cf-assert \
+// RUN: -convert-to-llvm="allow-pattern-rollback=0" \
+// RUN: -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
func.func @store_dynamic(%memref: memref<?xf32>, %index: index) {
%cst = arith.constant 1.0 : f32
memref.atomic_rmw addf %cst, %memref[%index] : (f32, memref<?xf32>) -> f32
diff --git a/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir
index 8b6308e..1ac1030 100644
--- a/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir
@@ -1,11 +1,20 @@
// RUN: mlir-opt %s -generate-runtime-verification \
-// RUN: -test-cf-assert \
// RUN: -expand-strided-metadata \
+// RUN: -test-cf-assert \
// RUN: -convert-to-llvm | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -expand-strided-metadata \
+// RUN: -test-cf-assert \
+// RUN: -convert-to-llvm="allow-pattern-rollback=0" \
+// RUN: -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
func.func @cast_to_static_dim(%m: memref<?xf32>) -> memref<10xf32> {
%0 = memref.cast %m : memref<?xf32> to memref<10xf32>
return %0 : memref<10xf32>
diff --git a/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir
index 95b9db2..be9417b 100644
--- a/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir
@@ -6,6 +6,15 @@
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -expand-strided-metadata \
+// RUN: -test-cf-assert \
+// RUN: -convert-to-llvm="allow-pattern-rollback=0" \
+// RUN: -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
// Put memref.copy in a function, otherwise the memref.cast may fold.
func.func @memcpy_helper(%src: memref<?xf32>, %dest: memref<?xf32>) {
memref.copy %src, %dest : memref<?xf32> to memref<?xf32>
diff --git a/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir
index 2e3f271..ef4af62 100644
--- a/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir
@@ -6,6 +6,15 @@
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -expand-strided-metadata \
+// RUN: -test-cf-assert \
+// RUN: -convert-to-llvm="allow-pattern-rollback=0" \
+// RUN: -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
func.func @main() {
%c4 = arith.constant 4 : index
%alloca = memref.alloca() : memref<1xf32>
diff --git a/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir
index b87e5bd..2e42648 100644
--- a/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir
@@ -1,12 +1,20 @@
// RUN: mlir-opt %s -generate-runtime-verification \
-// RUN: -test-cf-assert \
// RUN: -expand-strided-metadata \
-// RUN: -lower-affine \
+// RUN: -test-cf-assert \
// RUN: -convert-to-llvm | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -expand-strided-metadata \
+// RUN: -test-cf-assert \
+// RUN: -convert-to-llvm="allow-pattern-rollback=0" \
+// RUN: -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
func.func @load(%memref: memref<1xf32>, %index: index) {
memref.load %memref[%index] : memref<1xf32>
return
diff --git a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
index 12253fa..dd000c6 100644
--- a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
@@ -5,6 +5,14 @@
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -test-cf-assert \
+// RUN: -convert-to-llvm="allow-pattern-rollback=0" \
+// RUN: -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
func.func @store_dynamic(%memref: memref<?xf32>, %index: index) {
%cst = arith.constant 1.0 : f32
memref.store %cst, %memref[%index] : memref<?xf32>
diff --git a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
index ec7e408..9fbe5bc 100644
--- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
@@ -1,12 +1,22 @@
// RUN: mlir-opt %s -generate-runtime-verification \
-// RUN: -test-cf-assert \
// RUN: -expand-strided-metadata \
// RUN: -lower-affine \
+// RUN: -test-cf-assert \
// RUN: -convert-to-llvm | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -expand-strided-metadata \
+// RUN: -lower-affine \
+// RUN: -test-cf-assert \
+// RUN: -convert-to-llvm="allow-pattern-rollback=0" \
+// RUN: -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
func.func @subview(%memref: memref<1xf32>, %offset: index) {
memref.subview %memref[%offset] [1] [1] :
memref<1xf32> to
diff --git a/mlir/test/Integration/Dialect/Tensor/cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/cast-runtime-verification.mlir
index e4aab32..f37a6d6 100644
--- a/mlir/test/Integration/Dialect/Tensor/cast-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Tensor/cast-runtime-verification.mlir
@@ -8,6 +8,17 @@
// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -one-shot-bufferize="bufferize-function-boundaries" \
+// RUN: -buffer-deallocation-pipeline=private-function-dynamic-ownership \
+// RUN: -test-cf-assert \
+// RUN: -convert-scf-to-cf \
+// RUN: -convert-to-llvm="allow-pattern-rollback=0" \
+// RUN: -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
func.func private @cast_to_static_dim(%t: tensor<?xf32>) -> tensor<10xf32> {
%0 = tensor.cast %t : tensor<?xf32> to tensor<10xf32>
return %0 : tensor<10xf32>
diff --git a/mlir/test/Integration/Dialect/Tensor/dim-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/dim-runtime-verification.mlir
index c6d8f698..e9e5c04 100644
--- a/mlir/test/Integration/Dialect/Tensor/dim-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Tensor/dim-runtime-verification.mlir
@@ -1,10 +1,20 @@
// RUN: mlir-opt %s -generate-runtime-verification \
-// RUN: -one-shot-bufferize \
-// RUN: -buffer-deallocation-pipeline \
+// RUN: -one-shot-bufferize="bufferize-function-boundaries" \
+// RUN: -buffer-deallocation-pipeline=private-function-dynamic-ownership \
// RUN: -test-cf-assert \
// RUN: -convert-to-llvm | \
// RUN: mlir-runner -e main -entry-point-result=void \
-// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -one-shot-bufferize="bufferize-function-boundaries" \
+// RUN: -buffer-deallocation-pipeline=private-function-dynamic-ownership \
+// RUN: -test-cf-assert \
+// RUN: -convert-to-llvm="allow-pattern-rollback=0" \
+// RUN: -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
func.func @main() {
diff --git a/mlir/test/Integration/Dialect/Tensor/extract-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/extract-runtime-verification.mlir
index 8e3cab7..73fcec4 100644
--- a/mlir/test/Integration/Dialect/Tensor/extract-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Tensor/extract-runtime-verification.mlir
@@ -8,6 +8,17 @@
// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -one-shot-bufferize="bufferize-function-boundaries" \
+// RUN: -buffer-deallocation-pipeline=private-function-dynamic-ownership \
+// RUN: -test-cf-assert \
+// RUN: -convert-scf-to-cf \
+// RUN: -convert-to-llvm="allow-pattern-rollback=0" \
+// RUN: -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
func.func @extract(%tensor: tensor<1xf32>, %index: index) {
tensor.extract %tensor[%index] : tensor<1xf32>
return
diff --git a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir
index 28f9be0..341a59e 100644
--- a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir
@@ -8,6 +8,17 @@
// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -one-shot-bufferize="bufferize-function-boundaries" \
+// RUN: -buffer-deallocation-pipeline=private-function-dynamic-ownership \
+// RUN: -test-cf-assert \
+// RUN: -convert-scf-to-cf \
+// RUN: -convert-to-llvm="allow-pattern-rollback=0" \
+// RUN: -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
func.func @extract_slice(%tensor: tensor<1xf32>, %offset: index) {
tensor.extract_slice %tensor[%offset] [1] [1] : tensor<1xf32> to tensor<1xf32>
return
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/outerproduct-f32.mlir
index 0ee0166..219367a 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/outerproduct-f32.mlir
@@ -46,7 +46,7 @@ func.func @test_outerproduct_with_accumulator_4x4xf32() {
%c0 = arith.constant 0 : index
%f10 = arith.constant 10.0 : f32
- %acc = vector.splat %f10 : vector<[4]x[4]xf32>
+ %acc = vector.broadcast %f10 : f32 to vector<[4]x[4]xf32>
%vector_i32 = llvm.intr.stepvector : vector<[4]xi32>
%vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
%tile = vector.outerproduct %vector, %vector, %acc : vector<[4]xf32>, vector<[4]xf32>
@@ -103,7 +103,7 @@ func.func @test_masked_outerproduct_with_accumulator_4x4xf32() {
%ones = arith.constant dense<1> : vector<[4]xi32>
%f10 = arith.constant 10.0 : f32
- %acc = vector.splat %f10 : vector<[4]x[4]xf32>
+ %acc = vector.broadcast %f10 : f32 to vector<[4]x[4]xf32>
%step_vector = llvm.intr.stepvector : vector<[4]xi32>
%vector_i32 = arith.addi %step_vector, %ones : vector<[4]xi32>
%vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/outerproduct-f64.mlir
index 8e81210..059f24a 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/outerproduct-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/outerproduct-f64.mlir
@@ -52,7 +52,7 @@ func.func @test_outerproduct_with_accumulator_2x2xf64() {
%ones = arith.constant dense<1> : vector<[2]xi32>
%f10 = arith.constant 10.0 : f64
- %acc = vector.splat %f10 : vector<[2]x[2]xf64>
+ %acc = vector.broadcast %f10 : f64 to vector<[2]x[2]xf64>
%step_vector = llvm.intr.stepvector : vector<[2]xi32>
%vector_i32 = arith.addi %step_vector, %ones : vector<[2]xi32>
%vector = arith.sitofp %vector_i32 : vector<[2]xi32> to vector<[2]xf64>
@@ -108,7 +108,7 @@ func.func @test_masked_outerproduct_with_accumulator_2x2xf64() {
%ones = arith.constant dense<1> : vector<[2]xi32>
%f10 = arith.constant 10.0 : f64
- %acc = vector.splat %f10 : vector<[2]x[2]xf64>
+ %acc = vector.broadcast %f10 : f64 to vector<[2]x[2]xf64>
%step_vector = llvm.intr.stepvector : vector<[2]xi32>
%vector_i32 = arith.addi %step_vector, %ones : vector<[2]xi32>
%vector = arith.sitofp %vector_i32 : vector<[2]xi32> to vector<[2]xf64>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/transfer-write-2d.mlir
index c3bf379..bf6900c 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/transfer-write-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/transfer-write-2d.mlir
@@ -10,7 +10,7 @@
// Vector store.
func.func @transfer_write_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%c0 = arith.constant 0.0 : f32
- %zero = vector.splat %c0 : vector<[4]x[4]xf32>
+ %zero = vector.broadcast %c0 : f32 to vector<[4]x[4]xf32>
vector.transfer_write %zero, %A[%base1, %base2] {in_bounds=[true, true]} :
vector<[4]x[4]xf32>, memref<?x?xf32>
return
@@ -22,7 +22,7 @@ func.func @transfer_write_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: i
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
- %zero = vector.splat %c0 : vector<[4]x[4]xf32>
+ %zero = vector.broadcast %c0 : f32 to vector<[4]x[4]xf32>
vector.transfer_write %zero, %A[%base1, %base2], %mask {in_bounds=[true, true]} :
vector<[4]x[4]xf32>, memref<?x?xf32>
return
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction.mlir
index c990432..192f291 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction.mlir
@@ -106,7 +106,7 @@ func.func @matvec_i32() {
// val = (123 * 314) * 4 * vscale
// so ...
%vscale = vector.vscale
- %vscale_v = vector.splat %vscale : vector<3xindex>
+ %vscale_v = vector.broadcast %vscale : index to vector<3xindex>
%vscale_i32 = arith.index_cast %vscale_v : vector<3xindex> to vector<3xi32>
%mv1_div = arith.divui %mv1, %vscale_i32 : vector<3xi32>
// ... val / vscale = 123 * 314 * 4 = 154488
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/scalable-interleave.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/scalable-interleave.mlir
index d3b1fa4..2d8180a 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/scalable-interleave.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/scalable-interleave.mlir
@@ -7,8 +7,8 @@
func.func @entry() {
%f1 = arith.constant 1.0 : f32
%f2 = arith.constant 2.0 : f32
- %v1 = vector.splat %f1 : vector<[4]xf32>
- %v2 = vector.splat %f2 : vector<[4]xf32>
+ %v1 = vector.broadcast %f1 : f32 to vector<[4]xf32>
+ %v2 = vector.broadcast %f2 : f32 to vector<[4]xf32>
vector.print %v1 : vector<[4]xf32>
vector.print %v2 : vector<[4]xf32>
//
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/interleave.mlir b/mlir/test/Integration/Dialect/Vector/CPU/interleave.mlir
index f812c25..740c742 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/interleave.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/interleave.mlir
@@ -6,8 +6,8 @@
func.func @entry() {
%f1 = arith.constant 1.0 : f32
%f2 = arith.constant 2.0 : f32
- %v1 = vector.splat %f1 : vector<2x4xf32>
- %v2 = vector.splat %f2 : vector<2x4xf32>
+ %v1 = vector.broadcast %f1 : f32 to vector<2x4xf32>
+ %v2 = vector.broadcast %f2 : f32 to vector<2x4xf32>
vector.print %v1 : vector<2x4xf32>
vector.print %v2 : vector<2x4xf32>
//
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/outerproduct-f32.mlir
index f7e2229..e25795a 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/outerproduct-f32.mlir
@@ -14,9 +14,9 @@
!vector_type_R = vector<7xf32>
func.func @vector_outerproduct_splat_8x8(%fa: f32, %fb: f32, %fc: f32) -> !vector_type_C {
- %a = vector.splat %fa: !vector_type_A
- %b = vector.splat %fb: !vector_type_B
- %c = vector.splat %fc: !vector_type_C
+ %a = vector.broadcast %fa: f32 to !vector_type_A
+ %b = vector.broadcast %fb: f32 to !vector_type_B
+ %c = vector.broadcast %fc: f32 to !vector_type_C
%d = vector.outerproduct %a, %b, %c : !vector_type_A, !vector_type_B
return %d: !vector_type_C
}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/outerproduct-i64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/outerproduct-i64.mlir
index a19dfa1..0675102 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/outerproduct-i64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/outerproduct-i64.mlir
@@ -14,9 +14,9 @@
!vector_type_R = vector<7xi64>
func.func @vector_outerproduct_splat_8x8(%ia: i64, %ib: i64, %ic: i64) -> !vector_type_C {
- %a = vector.splat %ia: !vector_type_A
- %b = vector.splat %ib: !vector_type_B
- %c = vector.splat %ic: !vector_type_C
+ %a = vector.broadcast %ia: i64 to !vector_type_A
+ %b = vector.broadcast %ib: i64 to !vector_type_B
+ %c = vector.broadcast %ic: i64 to !vector_type_C
%d = vector.outerproduct %a, %b, %c : !vector_type_A, !vector_type_B
return %d: !vector_type_C
}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir
index 639eed4..895b881 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir
@@ -137,7 +137,7 @@ func.func @transfer_read_1d_mask_in_bounds(
// Non-contiguous, strided store.
func.func @transfer_write_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
%fn1 = arith.constant -1.0 : f32
- %vf0 = vector.splat %fn1 : vector<7xf32>
+ %vf0 = vector.broadcast %fn1 : f32 to vector<7xf32>
vector.transfer_write %vf0, %A[%base1, %base2]
{permutation_map = affine_map<(d0, d1) -> (d0)>}
: vector<7xf32>, memref<?x?xf32>
@@ -147,7 +147,7 @@ func.func @transfer_write_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : inde
// Non-contiguous, strided store.
func.func @transfer_write_1d_mask(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
%fn1 = arith.constant -2.0 : f32
- %vf0 = vector.splat %fn1 : vector<7xf32>
+ %vf0 = vector.broadcast %fn1 : f32 to vector<7xf32>
%mask = arith.constant dense<[1, 0, 1, 0, 1, 1, 1]> : vector<7xi1>
vector.transfer_write %vf0, %A[%base1, %base2], %mask
{permutation_map = affine_map<(d0, d1) -> (d0)>}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-2d.mlir
index 009c137..80dff9d 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-2d.mlir
@@ -100,7 +100,7 @@ func.func @transfer_read_2d_broadcast(
// Vector store.
func.func @transfer_write_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%fn1 = arith.constant -1.0 : f32
- %vf0 = vector.splat %fn1 : vector<1x4xf32>
+ %vf0 = vector.broadcast %fn1 : f32 to vector<1x4xf32>
vector.transfer_write %vf0, %A[%base1, %base2]
{permutation_map = affine_map<(d0, d1) -> (d0, d1)>} :
vector<1x4xf32>, memref<?x?xf32>
@@ -111,7 +111,7 @@ func.func @transfer_write_2d(%A : memref<?x?xf32>, %base1: index, %base2: index)
func.func @transfer_write_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%fn1 = arith.constant -2.0 : f32
%mask = arith.constant dense<[[1, 0, 1, 0]]> : vector<1x4xi1>
- %vf0 = vector.splat %fn1 : vector<1x4xf32>
+ %vf0 = vector.broadcast %fn1 : f32 to vector<1x4xf32>
vector.transfer_write %vf0, %A[%base1, %base2], %mask
{permutation_map = affine_map<(d0, d1) -> (d0, d1)>} :
vector<1x4xf32>, memref<?x?xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-3d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-3d.mlir
index d41d9c9..93e6a12 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-3d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-3d.mlir
@@ -62,7 +62,7 @@ func.func @transfer_read_3d_transposed(%A : memref<?x?x?x?xf32>,
func.func @transfer_write_3d(%A : memref<?x?x?x?xf32>,
%o: index, %a: index, %b: index, %c: index) {
%fn1 = arith.constant -1.0 : f32
- %vf0 = vector.splat %fn1 : vector<2x9x3xf32>
+ %vf0 = vector.broadcast %fn1 : f32 to vector<2x9x3xf32>
vector.transfer_write %vf0, %A[%o, %a, %b, %c]
: vector<2x9x3xf32>, memref<?x?x?x?xf32>
return
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/transfer-read.mlir b/mlir/test/Integration/Dialect/Vector/CPU/transfer-read.mlir
index d1a2790..18084e3 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/transfer-read.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/transfer-read.mlir
@@ -45,7 +45,7 @@ func.func @transfer_read_mask_inbounds_4(%A : memref<?xf32>, %base: index) {
func.func @transfer_write_1d(%A : memref<?xf32>, %base: index) {
%f0 = arith.constant 0.0 : f32
- %vf0 = vector.splat %f0 : vector<4xf32>
+ %vf0 = vector.broadcast %f0 : f32 to vector<4xf32>
vector.transfer_write %vf0, %A[%base]
{permutation_map = affine_map<(d0) -> (d0)>} :
vector<4xf32>, memref<?xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/transfer-write.mlir b/mlir/test/Integration/Dialect/Vector/CPU/transfer-write.mlir
index def7081..2251738 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/transfer-write.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/transfer-write.mlir
@@ -5,7 +5,7 @@
func.func @transfer_write16_inbounds_1d(%A : memref<?xf32>, %base: index) {
%f = arith.constant 16.0 : f32
- %v = vector.splat %f : vector<16xf32>
+ %v = vector.broadcast %f : f32 to vector<16xf32>
vector.transfer_write %v, %A[%base]
{permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]}
: vector<16xf32>, memref<?xf32>
@@ -14,7 +14,7 @@ func.func @transfer_write16_inbounds_1d(%A : memref<?xf32>, %base: index) {
func.func @transfer_write13_1d(%A : memref<?xf32>, %base: index) {
%f = arith.constant 13.0 : f32
- %v = vector.splat %f : vector<13xf32>
+ %v = vector.broadcast %f : f32 to vector<13xf32>
vector.transfer_write %v, %A[%base]
{permutation_map = affine_map<(d0) -> (d0)>}
: vector<13xf32>, memref<?xf32>
@@ -23,7 +23,7 @@ func.func @transfer_write13_1d(%A : memref<?xf32>, %base: index) {
func.func @transfer_write17_1d(%A : memref<?xf32>, %base: index) {
%f = arith.constant 17.0 : f32
- %v = vector.splat %f : vector<17xf32>
+ %v = vector.broadcast %f : f32 to vector<17xf32>
vector.transfer_write %v, %A[%base]
{permutation_map = affine_map<(d0) -> (d0)>}
: vector<17xf32>, memref<?xf32>
@@ -42,7 +42,7 @@ func.func @transfer_read_1d(%A : memref<?xf32>) -> vector<32xf32> {
func.func @transfer_write_inbounds_3d(%A : memref<4x4x4xf32>) {
%c0 = arith.constant 0: index
%f = arith.constant 0.0 : f32
- %v0 = vector.splat %f : vector<2x3x4xf32>
+ %v0 = vector.broadcast %f : f32 to vector<2x3x4xf32>
%f1 = arith.constant 1.0 : f32
%f2 = arith.constant 2.0 : f32
%f3 = arith.constant 3.0 : f32
diff --git a/mlir/test/Integration/Dialect/XeVM/GPU/lit.local.cfg b/mlir/test/Integration/Dialect/XeVM/GPU/lit.local.cfg
new file mode 100644
index 0000000..d0d51c6
--- /dev/null
+++ b/mlir/test/Integration/Dialect/XeVM/GPU/lit.local.cfg
@@ -0,0 +1,4 @@
+if not config.run_xevm_tests:
+ config.unsupported = True
+if not config.enable_levelzero_runner:
+ config.unsupported = True
diff --git a/mlir/test/Integration/Dialect/XeVM/GPU/xevm_block_dpas.mlir b/mlir/test/Integration/Dialect/XeVM/GPU/xevm_block_dpas.mlir
new file mode 100644
index 0000000..0bd3d3f
--- /dev/null
+++ b/mlir/test/Integration/Dialect/XeVM/GPU/xevm_block_dpas.mlir
@@ -0,0 +1,146 @@
+// RUN: mlir-opt %s \
+// RUN: | mlir-opt -pass-pipeline='builtin.module(cse,func.func(gpu-async-region),xevm-attach-target,gpu.module(convert-gpu-to-llvm-spv{use-64bit-index=true},convert-xevm-to-llvm,cse))' \
+// RUN: | mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \
+// RUN: | mlir-opt -gpu-to-llvm -reconcile-unrealized-casts -cse -gpu-module-to-binary \
+// RUN: | mlir-runner \
+// RUN: --shared-libs=%mlir_levelzero_runtime \
+// RUN: --shared-libs=%mlir_runner_utils \
+// RUN: --shared-libs=%mlir_c_runner_utils \
+// RUN: --entry-point-result=void \
+// RUN: | FileCheck %s
+
+module @gemm attributes {gpu.container_module} {
+ gpu.module @kernel {
+ // - Sets of `matrix_mad` intrinsics can differ based on device's *minimal* supported sub-group size.
+ // The *minimum supported* sub-group size should be used to call `matrix_mad` intrinsics.
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
+
+ gpu.func @block_dpas(%a: !llvm.ptr<1>, %b: !llvm.ptr<1>, %c: !llvm.ptr<1>) kernel {
+ %base_width_a = arith.constant 32 : i32
+ %base_height_a = arith.constant 8 : i32
+ %base_pitch_a = arith.constant 32 : i32
+ %x = arith.constant 0 : i32
+ %y = arith.constant 0 : i32
+ %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
+ <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32,
+ transpose=false, pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
+
+ %base_width_b = arith.constant 32 : i32
+ %base_height_b = arith.constant 16 : i32
+ %base_pitch_b = arith.constant 32 : i32
+ %loaded_b1 = xevm.blockload2d %b, %base_width_b, %base_height_b, %base_pitch_b, %x, %y
+ <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=16 : i32, v_blocks=1 : i32,
+ transpose=false, pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16>
+ %loaded_b_casted = vector.bitcast %loaded_b1 : vector<16xi16> to vector<8xi32>
+
+ %base_width_c = arith.constant 64 : i32
+ %base_height_c = arith.constant 8 : i32
+ %base_pitch_c = arith.constant 64 : i32
+ %loaded_c = xevm.blockload2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y
+ <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32,
+ transpose=false, pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
+
+ %loaded_c_casted = vector.bitcast %loaded_c : vector<8xi32> to vector<8xf32>
+ %c_result = xevm.mma %loaded_a, %loaded_b_casted, %loaded_c_casted
+ {shape=<m=8, n=16, k=16>, types=<d=f32, a=f16, b=f16, c=f32>}
+ : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
+ %c_result_casted = vector.bitcast %c_result : vector<8xf32> to vector<8xi32>
+
+ xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted
+ <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32}>
+ : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+ gpu.return
+ }
+ }
+
+ func.func @test(%a : memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} {
+ %c1 = arith.constant 1 : index
+ %c16 = arith.constant 16 : index
+
+ %memref_a = gpu.alloc() : memref<8x16xf16>
+ gpu.memcpy %memref_a, %a : memref<8x16xf16>, memref<8x16xf16>
+ %a_ptr_as_idx = memref.extract_aligned_pointer_as_index %memref_a : memref<8x16xf16> -> index
+ %a_ptr_as_i64 = arith.index_cast %a_ptr_as_idx : index to i64
+ %a_ptr = llvm.inttoptr %a_ptr_as_i64 : i64 to !llvm.ptr
+ %a_ptr_casted = llvm.addrspacecast %a_ptr : !llvm.ptr to !llvm.ptr<1>
+
+ %memref_b = gpu.alloc() : memref<16x16xf16>
+ gpu.memcpy %memref_b, %b : memref<16x16xf16>, memref<16x16xf16>
+ %b_ptr_as_idx = memref.extract_aligned_pointer_as_index %memref_b : memref<16x16xf16> -> index
+ %b_ptr_as_i64 = arith.index_cast %b_ptr_as_idx : index to i64
+ %b_ptr = llvm.inttoptr %b_ptr_as_i64 : i64 to !llvm.ptr
+ %b_ptr_casted = llvm.addrspacecast %b_ptr : !llvm.ptr to !llvm.ptr<1>
+
+ %memref_c = gpu.alloc() : memref<8x16xf32>
+ gpu.memcpy %memref_c, %c : memref<8x16xf32>, memref<8x16xf32>
+ %c_ptr_as_idx = memref.extract_aligned_pointer_as_index %memref_c : memref<8x16xf32> -> index
+ %c_ptr_as_i64 = arith.index_cast %c_ptr_as_idx : index to i64
+ %c_ptr = llvm.inttoptr %c_ptr_as_i64 : i64 to !llvm.ptr
+ %c_ptr_casted = llvm.addrspacecast %c_ptr : !llvm.ptr to !llvm.ptr<1>
+
+ gpu.launch_func @kernel::@block_dpas blocks in (%c1, %c1, %c1) threads in (%c16, %c1, %c1)
+ args(%a_ptr_casted : !llvm.ptr<1>, %b_ptr_casted : !llvm.ptr<1>, %c_ptr_casted : !llvm.ptr<1>)
+ gpu.dealloc %memref_a : memref<8x16xf16>
+ gpu.dealloc %memref_b : memref<16x16xf16>
+ %res = memref.alloc() : memref<8x16xf32>
+ gpu.memcpy %res, %memref_c : memref<8x16xf32>, memref<8x16xf32>
+ gpu.dealloc %memref_c : memref<8x16xf32>
+ return %res : memref<8x16xf32>
+ }
+
+ func.func @main() attributes {llvm.emit_c_interface} {
+ %A = memref.alloc() : memref<8x16xf16>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c8 = arith.constant 8 : index
+ %c16 = arith.constant 16 : index
+
+ scf.for %i = %c0 to %c8 step %c1 {
+ scf.for %j = %c0 to %c16 step %c1 {
+ %row_idx = arith.index_cast %i : index to i32
+ %row = arith.sitofp %row_idx : i32 to f16
+ memref.store %row, %A[%i, %j] : memref<8x16xf16>
+ }
+ }
+ %B = memref.alloc() : memref<16x16xf16>
+ scf.for %i = %c0 to %c16 step %c1 {
+ scf.for %j = %c0 to %c16 step %c1 {
+ %col_idx = arith.index_cast %j : index to i32
+ %col = arith.sitofp %col_idx : i32 to f16
+ memref.store %col, %B[%i, %j] : memref<16x16xf16>
+ }
+ }
+
+ %C = memref.alloc() : memref<8x16xf32>
+ %c0_f16 = arith.constant 0.0 : f32
+ scf.for %i = %c0 to %c8 step %c1 {
+ scf.for %j = %c0 to %c16 step %c1 {
+ memref.store %c0_f16, %C[%i, %j] : memref<8x16xf32>
+ }
+ }
+
+ %C_res = call @test(%A, %B, %C) : (memref<8x16xf16>, memref<16x16xf16>, memref<8x16xf32>) -> memref<8x16xf32>
+ %C_cast = memref.cast %C_res : memref<8x16xf32> to memref<*xf32>
+ %A_cast = memref.cast %A : memref<8x16xf16> to memref<*xf16>
+ call @printMemrefF32(%C_cast) : (memref<*xf32>) -> ()
+
+ // CHECK: Unranked Memref base@ = 0x{{[0-9a-f]+}}
+ // CHECK-NEXT: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ // CHECK-NEXT: [0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]
+ // CHECK-NEXT: [0, 32, 64, 96, 128, 160, 192, 224, 256, 288, 320, 352, 384, 416, 448, 480]
+ // CHECK-NEXT: [0, 48, 96, 144, 192, 240, 288, 336, 384, 432, 480, 528, 576, 624, 672, 720]
+ // CHECK-NEXT: [0, 64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960]
+ // CHECK-NEXT: [0, 80, 160, 240, 320, 400, 480, 560, 640, 720, 800, 880, 960, 1040, 1120, 1200]
+ // CHECK-NEXT: [0, 96, 192, 288, 384, 480, 576, 672, 768, 864, 960, 1056, 1152, 1248, 1344, 1440]
+ // CHECK-NEXT: [0, 112, 224, 336, 448, 560, 672, 784, 896, 1008, 1120, 1232, 1344, 1456, 1568, 1680]
+
+ memref.dealloc %A : memref<8x16xf16>
+ memref.dealloc %B : memref<16x16xf16>
+ memref.dealloc %C : memref<8x16xf32>
+ memref.dealloc %C_res : memref<8x16xf32>
+ return
+ }
+ func.func private @printMemrefF16(%ptr : memref<*xf16>) attributes { llvm.emit_c_interface }
+ func.func private @printMemrefF32(%ptr : memref<*xf32>) attributes { llvm.emit_c_interface }
+
+}
diff --git a/mlir/test/Integration/Dialect/XeVM/GPU/xevm_block_load_store.mlir b/mlir/test/Integration/Dialect/XeVM/GPU/xevm_block_load_store.mlir
new file mode 100644
index 0000000..cea05b8
--- /dev/null
+++ b/mlir/test/Integration/Dialect/XeVM/GPU/xevm_block_load_store.mlir
@@ -0,0 +1,109 @@
+// RUN: mlir-opt %s \
+// RUN: | mlir-opt -pass-pipeline='builtin.module(cse,func.func(gpu-async-region),xevm-attach-target,gpu.module(convert-gpu-to-llvm-spv{use-64bit-index=true},convert-xevm-to-llvm,cse))' \
+// RUN: | mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \
+// RUN: | mlir-opt -gpu-to-llvm -reconcile-unrealized-casts -cse -gpu-module-to-binary \
+// RUN: | mlir-runner \
+// RUN: --shared-libs=%mlir_levelzero_runtime \
+// RUN: --shared-libs=%mlir_runner_utils \
+// RUN: --shared-libs=%mlir_c_runner_utils \
+// RUN: --entry-point-result=void \
+// RUN: | FileCheck %s
+
+module @gemm attributes {gpu.container_module} {
+
+ gpu.module @kernel {
+ // - `cl_intel_subgroups` block load/store intrinsics operate at the *maximum* sub-group size,
+ // regardless of the active sub-group size. Make sure `clGetKernelSubGroupInfo` meets your expectations.
+ // - The attribute `intel_reqd_sub_group_size` establishes the maximum sub-group size for a kernel.
+ //
+ // Note: launching 16 threads without explicit `intel_reqd_sub_group_size = 16` may still use
+ // the default sub-group size of 32.
+ //
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_required_subgroup_size.html
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html
+
+ gpu.func @block_load_store(%src: !llvm.ptr<1>, %dst: !llvm.ptr<1>) kernel {
+ %base_width = arith.constant 64 : i32 // bytewidth of the block
+ %base_height = arith.constant 8 : i32 // number of rows
+ %base_pitch = arith.constant 64 : i32 // bytewidth of the base row
+ %x = arith.constant 0 : i32
+ %y = arith.constant 0 : i32
+ // If `intel_reqd_sub_group_size = 16` is not set, the default (32) is used and this `blockload2d`
+ // would only load 4 elements into vector<8xi32>
+ %loaded = xevm.blockload2d %src, %base_width, %base_height, %base_pitch, %x, %y
+ <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32,
+ transpose=false, pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
+ %loaded_f32 = vector.bitcast %loaded : vector<8xi32> to vector<8xf32>
+ %c0 = arith.constant 0 : index
+ %thread_x = gpu.thread_id x
+ %thread_x_i64 = arith.index_cast %thread_x : index to i64
+ %thread_x_i32 = llvm.trunc %thread_x_i64 : i64 to i32
+ %thread_x_f32 = arith.sitofp %thread_x_i32 : i32 to f32
+ %loaded_f32_modified = vector.insert %thread_x_f32, %loaded_f32[%c0] : f32 into vector<8xf32>
+ %loaded_modified = vector.bitcast %loaded_f32_modified : vector<8xf32> to vector<8xi32>
+ xevm.blockstore2d %dst, %base_width, %base_height, %base_pitch, %x, %y, %loaded_modified
+ <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32}>
+ : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+ gpu.return
+ }
+ }
+
+ func.func @test(%src : memref<8x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} {
+ %c1 = arith.constant 1 : index
+ %c16 = arith.constant 16 : index // Multiple of the *maximum sub-group size* (see `intel_reqd_sub_group_size`)
+ %memref_src = gpu.alloc() : memref<8x16xf32>
+ gpu.memcpy %memref_src, %src : memref<8x16xf32>, memref<8x16xf32>
+ %src_ptr_as_idx = memref.extract_aligned_pointer_as_index %memref_src : memref<8x16xf32> -> index
+ %src_ptr_as_i64 = arith.index_cast %src_ptr_as_idx : index to i64
+ %src_ptr = llvm.inttoptr %src_ptr_as_i64 : i64 to !llvm.ptr
+ %src_ptr_casted = llvm.addrspacecast %src_ptr : !llvm.ptr to !llvm.ptr<1>
+
+ %memref_dst = gpu.alloc() : memref<8x16xf32>
+ %dst_ptr_as_idx = memref.extract_aligned_pointer_as_index %memref_dst : memref<8x16xf32> -> index
+ %dst_ptr_as_i64 = arith.index_cast %dst_ptr_as_idx : index to i64
+ %dst_ptr = llvm.inttoptr %dst_ptr_as_i64 : i64 to !llvm.ptr
+ %dst_ptr_casted = llvm.addrspacecast %dst_ptr : !llvm.ptr to !llvm.ptr<1>
+
+ gpu.launch_func @kernel::@block_load_store blocks in (%c1, %c1, %c1) threads in (%c16, %c1, %c1)
+ args(%src_ptr_casted : !llvm.ptr<1>, %dst_ptr_casted : !llvm.ptr<1>)
+ gpu.dealloc %memref_src : memref<8x16xf32>
+ %dst = memref.alloc() : memref<8x16xf32>
+ gpu.memcpy %dst, %memref_dst : memref<8x16xf32>, memref<8x16xf32>
+ gpu.dealloc %memref_dst : memref<8x16xf32>
+ return %dst : memref<8x16xf32>
+ }
+
+ func.func @main() attributes {llvm.emit_c_interface} {
+ %A = memref.alloc() : memref<8x16xf32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c8 = arith.constant 8 : index
+ %c16 = arith.constant 16 : index
+ %c11_f32 = arith.constant 11.11 : f32
+ scf.for %i = %c0 to %c8 step %c1 {
+ scf.for %j = %c0 to %c16 step %c1 {
+ memref.store %c11_f32, %A[%i, %j] : memref<8x16xf32>
+ }
+ }
+ %B = call @test(%A) : (memref<8x16xf32>) -> memref<8x16xf32>
+ %B_cast = memref.cast %B : memref<8x16xf32> to memref<*xf32>
+ %A_cast = memref.cast %A : memref<8x16xf32> to memref<*xf32>
+ call @printMemrefF32(%A_cast) : (memref<*xf32>) -> ()
+ call @printMemrefF32(%B_cast) : (memref<*xf32>) -> ()
+
+ // CHECK: Unranked Memref base@ = 0x{{[0-9a-f]+}}
+ // CHECK-NEXT: [11.11{{.*}}]
+ // CHECK-COUNT-96: 11.11
+ // CHECK-NEXT: [11.11{{.*}}]
+
+ // CHECK-NEXT: Unranked Memref base@ = 0x{{[0-9a-f]+}}
+ // CHECK: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
+ // CHECK-COUNT-96: 11.11
+ // CHECK-NEXT: [11.11{{.*}}]
+
+ memref.dealloc %A : memref<8x16xf32>
+ memref.dealloc %B : memref<8x16xf32>
+ return
+ }
+ func.func private @printMemrefF32(%ptr : memref<*xf32>)
+}
diff --git a/mlir/test/Integration/Dialect/XeVM/GPU/xevm_block_load_store_pack_register.mlir b/mlir/test/Integration/Dialect/XeVM/GPU/xevm_block_load_store_pack_register.mlir
new file mode 100644
index 0000000..cb8ab1c
--- /dev/null
+++ b/mlir/test/Integration/Dialect/XeVM/GPU/xevm_block_load_store_pack_register.mlir
@@ -0,0 +1,131 @@
+// RUN: mlir-opt %s \
+// RUN: | mlir-opt -pass-pipeline='builtin.module(cse,func.func(gpu-async-region),xevm-attach-target,gpu.module(convert-gpu-to-llvm-spv{use-64bit-index=true},convert-xevm-to-llvm,cse))' \
+// RUN: | mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \
+// RUN: | mlir-opt -gpu-to-llvm -reconcile-unrealized-casts -cse -gpu-module-to-binary \
+// RUN: | mlir-runner \
+// RUN: --shared-libs=%mlir_levelzero_runtime \
+// RUN: --shared-libs=%mlir_runner_utils \
+// RUN: --shared-libs=%mlir_c_runner_utils \
+// RUN: --entry-point-result=void \
+// RUN: | FileCheck %s
+
+module @gemm attributes {gpu.container_module} {
+ gpu.module @kernel {
+ gpu.func @block_load_store(%src: !llvm.ptr<1>, %dst: !llvm.ptr<1>) kernel {
+ %base_width = arith.constant 32 : i32 // bytewidth of the block
+ %base_height_load = arith.constant 16 : i32 // number of rows
+ %base_pitch = arith.constant 32 : i32 // bytewidth of the base row
+ %x = arith.constant 0 : i32
+ %y = arith.constant 0 : i32
+
+ // Consider the following two loads:
+ // Normal load:
+ %loaded = xevm.blockload2d %src, %base_width, %base_height_load, %base_pitch, %x, %y
+ <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=16 : i32, v_blocks=1 : i32,
+ transpose=false, pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16>
+ %loaded_f16_flat = vector.bitcast %loaded : vector<16xi16> to vector<16xf16>
+ %loaded_f16 = vector.shape_cast %loaded_f16_flat : vector<16xf16> to vector<8x1x2xf16>
+
+ // Register packed load:
+ %loaded_packed = xevm.blockload2d %src, %base_width, %base_height_load, %base_pitch, %x, %y
+ <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=16 : i32, v_blocks=1 : i32,
+ transpose=false, pack_register=true}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
+ %loaded_packed_f16_flat = vector.bitcast %loaded_packed : vector<8xi32> to vector<16xf16>
+ %loaded_packed_f16 = vector.shape_cast %loaded_packed_f16_flat : vector<16xf16> to vector<8x1x2xf16>
+ // Both can be represented the same way in code as vector<16xf16>.
+ // A normal load pads a value to a dword (e.g., 32-bit) when loaded to a register.
+ // Packed load "packs" multiple sub-dword values along the column (↓), allowing a single register
+ // to hold multiple values.
+ // In SIMT, a work-item reads values along the column (↓), hence a sequence of values loaded by packing
+ // to register is logically equivalent to the sequence of values loaded using a normal load.
+ // The load results of both methods can have the same logical representation, but are expected to
+ // differ in physical layout and register efficiency.
+
+ %thread_x = gpu.thread_id x
+ %thread_x_i64 = arith.index_cast %thread_x : index to i64
+ %thread_x_i32 = llvm.trunc %thread_x_i64 : i64 to i32
+ %thread_x_f16 = arith.sitofp %thread_x_i32 : i32 to f16
+ %loaded_f16_modified = vector.insert %thread_x_f16, %loaded_packed_f16 [0,0,1] : f16 into vector<8x1x2xf16> // Both loaded_packed_f16 and loaded_f16 can be used here
+ // We can only store [1,2,4,8]x[16] shapes for f16, so we have to do 2 stores
+ %loaded_f16_modified_slice_0 = vector.extract_strided_slice %loaded_f16_modified
+ {offsets = [0, 0, 0], sizes = [4, 1, 2], strides = [1, 1, 1]} : vector<8x1x2xf16> to vector<4x1x2xf16>
+ %loaded_f16_modified_slice_0_flat = vector.shape_cast %loaded_f16_modified_slice_0 : vector<4x1x2xf16> to vector<8xf16>
+ %base_height_store = arith.constant 8 : i32 // number of rows
+ %base_width_store = arith.constant 32 : i32 // bytewidth of the block
+ %base_pitch_store = arith.constant 32 : i32 // bytewidth of the base row
+ xevm.blockstore2d %dst, %base_width_store, %base_height_store, %base_pitch_store, %x, %y, %loaded_f16_modified_slice_0_flat
+ <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xf16>)
+
+ %loaded_f16_modified_slice_1 = vector.extract_strided_slice %loaded_f16_modified
+ {offsets = [4, 0, 0], sizes = [4, 1, 2], strides = [1, 1, 1]} : vector<8x1x2xf16> to vector<4x1x2xf16>
+ %loaded_f16_modified_slice_1_flat = vector.shape_cast %loaded_f16_modified_slice_1 : vector<4x1x2xf16> to vector<8xf16>
+
+ %second_half_offset = arith.muli %base_pitch_store, %base_height_store : i32
+ %second_half_ptr = llvm.getelementptr %dst[%second_half_offset] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, i8
+ xevm.blockstore2d %second_half_ptr, %base_width_store, %base_height_store, %base_pitch_store, %x, %y, %loaded_f16_modified_slice_1_flat
+ <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xf16>)
+ gpu.return
+ }
+ }
+
+
+ func.func @test(%src : memref<16x16xf16>) -> memref<16x16xf16> attributes {llvm.emit_c_interface} {
+ %c1 = arith.constant 1 : index
+ %c16 = arith.constant 16 : index // Multiple of the *maximum sub-group size* (see `intel_reqd_sub_group_size`)
+ %memref_src = gpu.alloc() : memref<16x16xf16>
+ gpu.memcpy %memref_src, %src : memref<16x16xf16>, memref<16x16xf16>
+ %src_ptr_as_idx = memref.extract_aligned_pointer_as_index %memref_src : memref<16x16xf16> -> index
+ %src_ptr_as_i64 = arith.index_cast %src_ptr_as_idx : index to i64
+ %src_ptr = llvm.inttoptr %src_ptr_as_i64 : i64 to !llvm.ptr
+ %src_ptr_casted = llvm.addrspacecast %src_ptr : !llvm.ptr to !llvm.ptr<1>
+
+ %memref_dst = gpu.alloc() : memref<16x16xf16>
+ %dst_ptr_as_idx = memref.extract_aligned_pointer_as_index %memref_dst : memref<16x16xf16> -> index
+ %dst_ptr_as_i64 = arith.index_cast %dst_ptr_as_idx : index to i64
+ %dst_ptr = llvm.inttoptr %dst_ptr_as_i64 : i64 to !llvm.ptr
+ %dst_ptr_casted = llvm.addrspacecast %dst_ptr : !llvm.ptr to !llvm.ptr<1>
+
+ gpu.launch_func @kernel::@block_load_store blocks in (%c1, %c1, %c1) threads in (%c16, %c1, %c1)
+ args(%src_ptr_casted : !llvm.ptr<1>, %dst_ptr_casted : !llvm.ptr<1>)
+ gpu.dealloc %memref_src : memref<16x16xf16>
+ %dst = memref.alloc() : memref<16x16xf16>
+ gpu.memcpy %dst, %memref_dst : memref<16x16xf16>, memref<16x16xf16>
+ gpu.dealloc %memref_dst : memref<16x16xf16>
+ return %dst : memref<16x16xf16>
+ }
+
+ func.func @main() attributes {llvm.emit_c_interface} {
+ %A = memref.alloc() : memref<16x16xf16>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c8 = arith.constant 16 : index
+ %c16 = arith.constant 16 : index
+ %c11_f32 = arith.constant 11.1 : f16
+ scf.for %i = %c0 to %c8 step %c1 {
+ scf.for %j = %c0 to %c16 step %c1 {
+ memref.store %c11_f32, %A[%i, %j] : memref<16x16xf16>
+ }
+ }
+ %B = call @test(%A) : (memref<16x16xf16>) -> memref<16x16xf16>
+ %B_cast = memref.cast %B : memref<16x16xf16> to memref<*xf16>
+ %A_cast = memref.cast %A : memref<16x16xf16> to memref<*xf16>
+ call @printMemrefF16(%A_cast) : (memref<*xf16>) -> ()
+ call @printMemrefF16(%B_cast) : (memref<*xf16>) -> ()
+
+ // CHECK: Unranked Memref base@ = 0x{{[0-9a-f]+}}
+ // CHECK-NEXT: [11.1{{.*}}]
+ // CHECK-COUNT-224: 11.1
+ // CHECK-NEXT: [11.1{{.*}}]
+
+ // CHECK-NEXT: Unranked Memref base@ = 0x{{[0-9a-f]+}}
+ // CHECK-NEXT: [11.1{{.*}}]
+ // CHECK: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
+ // CHECK-COUNT-208: 11.1
+ // CHECK-NEXT: [11.1{{.*}}]
+
+ memref.dealloc %A : memref<16x16xf16>
+ memref.dealloc %B : memref<16x16xf16>
+ return
+ }
+ func.func private @printMemrefF16(%ptr : memref<*xf16>) attributes { llvm.emit_c_interface }
+}
diff --git a/mlir/test/Integration/Dialect/XeVM/GPU/xevm_block_load_store_transpose.mlir b/mlir/test/Integration/Dialect/XeVM/GPU/xevm_block_load_store_transpose.mlir
new file mode 100644
index 0000000..1d164be
--- /dev/null
+++ b/mlir/test/Integration/Dialect/XeVM/GPU/xevm_block_load_store_transpose.mlir
@@ -0,0 +1,133 @@
+// RUN: mlir-opt %s \
+// RUN: | mlir-opt -pass-pipeline='builtin.module(cse,func.func(gpu-async-region),xevm-attach-target,gpu.module(convert-gpu-to-llvm-spv{use-64bit-index=true},convert-xevm-to-llvm,cse))' \
+// RUN: | mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \
+// RUN: | mlir-opt -gpu-to-llvm -reconcile-unrealized-casts -cse -gpu-module-to-binary \
+// RUN: | mlir-runner \
+// RUN: --shared-libs=%mlir_levelzero_runtime \
+// RUN: --shared-libs=%mlir_runner_utils \
+// RUN: --shared-libs=%mlir_c_runner_utils \
+// RUN: --entry-point-result=void \
+// RUN: | FileCheck %s
+
+module @gemm attributes {gpu.container_module} {
+ gpu.module @kernel {
+ gpu.func @block_load_store(%src: !llvm.ptr<1>, %dst: !llvm.ptr<1>) kernel {
+ %base_width = arith.constant 32 : i32 // bytewidth of the block
+ %base_height = arith.constant 16 : i32 // number of rows
+ %base_pitch = arith.constant 32 : i32 // bytewidth of the base row
+ %x = arith.constant 0 : i32
+ %y = arith.constant 0 : i32
+ // Normally a work-item loads a vertical slice (↓), but with *transpose* a work-item
+ // loads a horizontal slice (→).
+ // The tile dimension we want to slice must be a multiple of the sub-group size:
+ // e.g., we want to slice rows (→), then we need SG_SIZE % tile_height == 0.
+ %loaded = xevm.blockload2d %src, %base_width, %base_height, %base_pitch, %x, %y
+ <{elem_size_in_bits=32 : i32, tile_width=8 : i32, tile_height=16 : i32, v_blocks=1 : i32,
+ transpose=true, pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
+ %loaded_f32 = vector.bitcast %loaded : vector<8xi32> to vector<8xf32>
+
+ %c0 = arith.constant 0 : i32
+ %thread_x = gpu.thread_id x
+ %thread_x_i64 = arith.index_cast %thread_x : index to i64
+ %thread_x_i32 = llvm.trunc %thread_x_i64 : i64 to i32
+ %thread_x_f32 = arith.sitofp %thread_x_i32 : i32 to f32
+ %loaded_f32_modified = vector.insert %thread_x_f32, %loaded_f32[7] : f32 into vector<8xf32> // Use this to see where threadIds end up stored
+ %loaded_f32_modified_1 = vector.bitcast %loaded_f32_modified : vector<8xf32> to vector<8xi32>
+
+ %base_height_store = arith.constant 8 : i32 // number of rows
+ %base_width_store = arith.constant 64 : i32 // bytewidth of the block
+ %base_pitch_store = arith.constant 64 : i32 // bytewidth of the base row
+ // "Transposed" stores are not available, meaning a work-item can store its vector as a vertical slice (↓).
+ xevm.blockstore2d %dst, %base_width_store, %base_height_store, %base_pitch_store, %x, %y, %loaded
+ <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+ gpu.return
+ }
+ }
+
+
+ func.func @test(%src : memref<16x8xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} {
+ %c1 = arith.constant 1 : index
+ %c16 = arith.constant 16 : index // Multiple of the *maximum sub-group size* (see `intel_reqd_sub_group_size`)
+ %memref_src = gpu.alloc() : memref<16x8xf32>
+ gpu.memcpy %memref_src, %src : memref<16x8xf32>, memref<16x8xf32>
+ %src_ptr_as_idx = memref.extract_aligned_pointer_as_index %memref_src : memref<16x8xf32> -> index
+ %src_ptr_as_i64 = arith.index_cast %src_ptr_as_idx : index to i64
+ %src_ptr = llvm.inttoptr %src_ptr_as_i64 : i64 to !llvm.ptr
+ %src_ptr_casted = llvm.addrspacecast %src_ptr : !llvm.ptr to !llvm.ptr<1>
+
+ %memref_dst = gpu.alloc() : memref<8x16xf32>
+ %dst_ptr_as_idx = memref.extract_aligned_pointer_as_index %memref_dst : memref<8x16xf32> -> index
+ %dst_ptr_as_i64 = arith.index_cast %dst_ptr_as_idx : index to i64
+ %dst_ptr = llvm.inttoptr %dst_ptr_as_i64 : i64 to !llvm.ptr
+ %dst_ptr_casted = llvm.addrspacecast %dst_ptr : !llvm.ptr to !llvm.ptr<1>
+
+ gpu.launch_func @kernel::@block_load_store blocks in (%c1, %c1, %c1) threads in (%c16, %c1, %c1)
+ args(%src_ptr_casted : !llvm.ptr<1>, %dst_ptr_casted : !llvm.ptr<1>)
+ gpu.dealloc %memref_src : memref<16x8xf32>
+ %dst = memref.alloc() : memref<8x16xf32>
+ gpu.memcpy %dst, %memref_dst : memref<8x16xf32>, memref<8x16xf32>
+ gpu.dealloc %memref_dst : memref<8x16xf32>
+ return %dst : memref<8x16xf32>
+ }
+
+ func.func @main() attributes {llvm.emit_c_interface} {
+ %A = memref.alloc() : memref<16x8xf32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c8 = arith.constant 8 : index
+ %c16 = arith.constant 16 : index
+ %c11_f32 = arith.constant 11.11 : f16
+ scf.for %i = %c0 to %c16 step %c1 {
+ scf.for %j = %c0 to %c8 step %c1 {
+ %c_10_f = arith.constant 10.0 : f32
+ %j_i64 = arith.index_cast %j : index to i64
+ %j_i32 = llvm.trunc %j_i64 : i64 to i32
+ %j_f32 = arith.sitofp %j_i32 : i32 to f32
+ %jj = arith.divf %j_f32, %c_10_f : f32
+
+ %i_i64 = arith.index_cast %i : index to i64
+ %i_i32 = llvm.trunc %i_i64 : i64 to i32
+ %i_f32 = arith.sitofp %i_i32 : i32 to f32
+ %ii = arith.addf %i_f32, %jj : f32
+ memref.store %ii, %A[%i, %j] : memref<16x8xf32>
+ }
+ }
+ %B = call @test(%A) : (memref<16x8xf32>) -> memref<8x16xf32>
+ %A_cast = memref.cast %A : memref<16x8xf32> to memref<*xf32>
+ %B_cast = memref.cast %B : memref<8x16xf32> to memref<*xf32>
+ call @printMemrefF32(%A_cast) : (memref<*xf32>) -> ()
+ // CHECK: Unranked Memref base@ = 0x{{[0-9a-f]+}}
+ // CHECK-NEXT: [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7],
+ // CHECK-NEXT: [1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7],
+ // CHECK-NEXT: [2, 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7],
+ // CHECK-NEXT: [3, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7],
+ // CHECK-NEXT: [4, 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7],
+ // CHECK-NEXT: [5, 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.7],
+ // CHECK-NEXT: [6, 6.1, 6.2, 6.3, 6.4, 6.5, 6.6, 6.7],
+ // CHECK-NEXT: [7, 7.1, 7.2, 7.3, 7.4, 7.5, 7.6, 7.7],
+ // CHECK-NEXT: [8, 8.1, 8.2, 8.3, 8.4, 8.5, 8.6, 8.7],
+ // CHECK-NEXT: [9, 9.1, 9.2, 9.3, 9.4, 9.5, 9.6, 9.7],
+ // CHECK-NEXT: [10, 10.1, 10.2, 10.3, 10.4, 10.5, 10.6, 10.7],
+ // CHECK-NEXT: [11, 11.1, 11.2, 11.3, 11.4, 11.5, 11.6, 11.7],
+ // CHECK-NEXT: [12, 12.1, 12.2, 12.3, 12.4, 12.5, 12.6, 12.7],
+ // CHECK-NEXT: [13, 13.1, 13.2, 13.3, 13.4, 13.5, 13.6, 13.7],
+ // CHECK-NEXT: [14, 14.1, 14.2, 14.3, 14.4, 14.5, 14.6, 14.7],
+ // CHECK-NEXT: [15, 15.1, 15.2, 15.3, 15.4, 15.5, 15.6, 15.7]
+
+ call @printMemrefF32(%B_cast) : (memref<*xf32>) -> ()
+ // CHECK: Unranked Memref base@ = 0x{{[0-9a-f]+}}
+ // CHECK-NEXT: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
+ // CHECK-NEXT: [0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1, 9.1, 10.1, 11.1, 12.1, 13.1, 14.1, 15.1],
+ // CHECK-NEXT: [0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2, 9.2, 10.2, 11.2, 12.2, 13.2, 14.2, 15.2],
+ // CHECK-NEXT: [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3, 9.3, 10.3, 11.3, 12.3, 13.3, 14.3, 15.3],
+ // CHECK-NEXT: [0.4, 1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4, 9.4, 10.4, 11.4, 12.4, 13.4, 14.4, 15.4],
+ // CHECK-NEXT: [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5, 11.5, 12.5, 13.5, 14.5, 15.5],
+ // CHECK-NEXT: [0.6, 1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6, 9.6, 10.6, 11.6, 12.6, 13.6, 14.6, 15.6],
+ // CHECK-NEXT: [0.7, 1.7, 2.7, 3.7, 4.7, 5.7, 6.7, 7.7, 8.7, 9.7, 10.7, 11.7, 12.7, 13.7, 14.7, 15.7]
+
+ memref.dealloc %A : memref<16x8xf32>
+ memref.dealloc %B : memref<8x16xf32>
+ return
+ }
+ func.func private @printMemrefF32(%ptr : memref<*xf32>) attributes { llvm.emit_c_interface }
+}
diff --git a/mlir/test/Integration/Dialect/XeVM/GPU/xevm_store_cst.mlir b/mlir/test/Integration/Dialect/XeVM/GPU/xevm_store_cst.mlir
new file mode 100644
index 0000000..c5f4cd5
--- /dev/null
+++ b/mlir/test/Integration/Dialect/XeVM/GPU/xevm_store_cst.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt %s \
+// RUN: | mlir-opt -pass-pipeline='builtin.module(cse,func.func(gpu-async-region),xevm-attach-target,gpu.module(convert-gpu-to-llvm-spv{use-64bit-index=true},convert-xevm-to-llvm,cse))' \
+// RUN: | mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \
+// RUN: | mlir-opt -gpu-to-llvm -reconcile-unrealized-casts -cse -gpu-module-to-binary \
+// RUN: | mlir-runner \
+// RUN: --shared-libs=%mlir_levelzero_runtime \
+// RUN: --shared-libs=%mlir_runner_utils \
+// RUN: --shared-libs=%mlir_c_runner_utils \
+// RUN: --entry-point-result=void \
+// RUN: | FileCheck %s
+
+module @gemm attributes {gpu.container_module} {
+
+ gpu.module @kernel {
+ gpu.func @store_constant(%ptr: !llvm.ptr<1>) kernel {
+ %const_val = arith.constant 42.0 : f32
+ %thread_x = gpu.lane_id
+ %thread_x_i64 = arith.index_cast %thread_x : index to i64
+ %ptr_next_1 = llvm.getelementptr %ptr[%thread_x_i64] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, i32
+ llvm.store %const_val, %ptr_next_1 : f32, !llvm.ptr<1>
+ gpu.return
+ }
+ }
+ func.func @test(%src : memref<8x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} {
+ %c1 = arith.constant 1 : index
+ %c16 = arith.constant 16 : index
+ %memref_0 = gpu.alloc() : memref<8x16xf32>
+ gpu.memcpy %memref_0, %src : memref<8x16xf32>, memref<8x16xf32>
+ %0 = memref.extract_aligned_pointer_as_index %memref_0 : memref<8x16xf32> -> index
+ %1 = arith.index_cast %0 : index to i64
+ %2 = llvm.inttoptr %1 : i64 to !llvm.ptr
+ %src_casted = llvm.addrspacecast %2 : !llvm.ptr to !llvm.ptr<1>
+ gpu.launch_func @kernel::@store_constant blocks in (%c1, %c1, %c1) threads in (%c16, %c1, %c1)
+ args(%src_casted : !llvm.ptr<1>)
+ %dst = memref.alloc() : memref<8x16xf32>
+ gpu.memcpy %dst, %memref_0 : memref<8x16xf32>, memref<8x16xf32>
+ gpu.dealloc %memref_0 : memref<8x16xf32>
+
+ return %dst : memref<8x16xf32>
+ }
+
+ func.func @main() attributes {llvm.emit_c_interface} {
+ %A = memref.alloc() : memref<8x16xf32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c8 = arith.constant 8 : index
+ %c16 = arith.constant 16 : index
+ %c11_f32 = arith.constant 11.11 : f32
+ scf.for %i = %c0 to %c8 step %c1 {
+ scf.for %j = %c0 to %c16 step %c1 {
+ memref.store %c11_f32, %A[%i, %j] : memref<8x16xf32>
+ }
+ }
+ %B = call @test(%A) : (memref<8x16xf32>) -> memref<8x16xf32>
+ %B_cast = memref.cast %B : memref<8x16xf32> to memref<*xf32>
+ %A_cast = memref.cast %A : memref<8x16xf32> to memref<*xf32>
+ call @printMemrefF32(%A_cast) : (memref<*xf32>) -> ()
+ call @printMemrefF32(%B_cast) : (memref<*xf32>) -> ()
+
+ // CHECK: Unranked Memref base@ = 0x{{[0-9a-f]+}}
+ // CHECK-NEXT: [11.11{{.*}}]
+ // CHECK-COUNT-96: 11.11
+ // CHECK-NEXT: [11.11{{.*}}]
+
+ // CHECK-NEXT: Unranked Memref base@ = 0x{{[0-9a-f]+}}
+ // CHECK-NEXT: [42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42]
+ // CHECK-COUNT-96: 11.11
+ // CHECK-NEXT: [11.11{{.*}}]
+
+ memref.dealloc %A : memref<8x16xf32>
+ memref.dealloc %B : memref<8x16xf32>
+ return
+ }
+ func.func private @printMemrefF32(%ptr : memref<*xf32>)
+}
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir b/mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir
index a5653f39..37564de 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir
@@ -57,7 +57,6 @@
func.func private @printMemrefF32(memref<*xf32>)
-memref.global "private" @dynamicShmem : memref<0xf16, 3> {alignment = 16 : i64}
memref.global "private" @accShmem : memref<0xf32, 3> {alignment = 16 : i64}
func.func @main() {
@@ -148,12 +147,11 @@ func.func @main() {
%c57344 = arith.constant 57344 : index
%c40960 = arith.constant 40960 : index
- %tidx = gpu.thread_id x
- %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
- %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [2, 128, 64], strides: [8192, 64, 1] : memref<0xf16, 3> to memref<2x128x64xf16, 3>
- %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [4, 64, 128], strides: [8192,128,1] : memref<0xf16, 3> to memref<4x64x128xf16,3>
- %rhsShmem = memref.subview %rhsShmem2[2, 0, 0][2, 64, 128][1, 1, 1] : memref<4x64x128xf16,3> to memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3>
+ %tidx = gpu.thread_id x
%dynsmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+ %lhsShmem = memref.view %dynsmem[%c0][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<2x128x64xf16, #gpu.address_space<workgroup>>
+ %rhsShmem = memref.view %dynsmem[%c32768][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<2x64x128xf16, #gpu.address_space<workgroup>>
+
// Step 1. [GPU] Create Async Transactional Barriers (mbarriers)
%barrier = nvgpu.mbarrier.create -> !barrierType
%cnd = arith.cmpi eq, %tidx, %c0 : index
@@ -202,11 +200,11 @@ func.func @main() {
// TMA wait
%phase_c0 = arith.constant 0 : i1
nvgpu.mbarrier.try_wait.parity %barrier[%i], %phase_c0, %ticks : !barrierType
- %lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, strided<[64, 1], offset: ?>, 3>
- %rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: ?>, 3>
+ %lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, #gpu.address_space<workgroup>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #gpu.address_space<workgroup>>
+ %rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, #gpu.address_space<workgroup>> to memref<64x128xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
// Descriptor WGMMA
- %dA = nvgpu.warpgroup.generate.descriptor %lhsSlice, %descA : memref<128x64xf16, strided<[64, 1], offset: ?>, 3>, !lhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16, 3>>
- %dB = nvgpu.warpgroup.generate.descriptor %rhsSlice, %descB : memref<64x128xf16, strided<[128, 1], offset: ?>, 3>, !rhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<64x128xf16, 3>>
+ %dA = nvgpu.warpgroup.generate.descriptor %lhsSlice, %descA : memref<128x64xf16, strided<[64, 1], offset: ?>, #gpu.address_space<workgroup>>, !lhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16, 3>>
+ %dB = nvgpu.warpgroup.generate.descriptor %rhsSlice, %descB : memref<64x128xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>, !rhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<64x128xf16, 3>>
// Perform WGMMA 128x128x64
%md = nvgpu.warpgroup.mma %dA, %dB, %mc {transposeB} : <tensor = memref<128x64xf16,3>>, <tensor = memref<64x128xf16,3>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>>
scf.yield %md : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>
@@ -271,7 +269,7 @@ func.func @main() {
vector.print str "Correct Results :"
vector.print %correctCount : i32
vector.print str "Incorrect Results :"
- vector.print %errorCount : i32
+ vector.print %errorCount : i32
return
}
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir b/mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir
index 197351f..db7754c 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir
@@ -57,7 +57,6 @@
func.func private @printMemrefF32(memref<*xf32>)
-memref.global "private" @dynamicShmem : memref<0xf16, 3> {alignment = 16 : i64}
memref.global "private" @accShmem : memref<0xf32, 3> {alignment = 16 : i64}
func.func @main() {
@@ -149,11 +148,10 @@ func.func @main() {
%c40960 = arith.constant 40960 : index
%tidx = gpu.thread_id x
- %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
- %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [2, 128, 64], strides: [8192, 64, 1] : memref<0xf16, 3> to memref<2x128x64xf16, 3>
- %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [4, 64, 128], strides: [8192,128,1] : memref<0xf16, 3> to memref<4x64x128xf16,3>
- %rhsShmem = memref.subview %rhsShmem2[2, 0, 0][2, 64, 128][1, 1, 1] : memref<4x64x128xf16,3> to memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3>
%dynsmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+ %lhsShmem = memref.view %dynsmem[%c0][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<2x128x64xf16, #gpu.address_space<workgroup>>
+ %rhsShmem = memref.view %dynsmem[%c32768][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<2x64x128xf16, #gpu.address_space<workgroup>>
+
// Step 1. [GPU] Create Async Transactional Barriers (mbarriers)
%barrier = nvgpu.mbarrier.create -> !barrierType
@@ -210,11 +208,11 @@ func.func @main() {
// TMA wait
%phase_c0 = arith.constant 0 : i1
nvgpu.mbarrier.try_wait.parity %barrier[%i], %phase_c0, %ticks : !barrierType
- %lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, strided<[64, 1], offset: ?>, 3>
- %rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: ?>, 3>
+ %lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, #gpu.address_space<workgroup>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #gpu.address_space<workgroup>>
+ %rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, #gpu.address_space<workgroup>> to memref<64x128xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
// Descriptor WGMMA
- %dA = nvgpu.warpgroup.generate.descriptor %lhsSlice, %descA : memref<128x64xf16, strided<[64, 1], offset: ?>, 3>, !lhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16, 3>>
- %dB = nvgpu.warpgroup.generate.descriptor %rhsSlice, %descB : memref<64x128xf16, strided<[128, 1], offset: ?>, 3>, !rhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<64x128xf16, 3>>
+ %dA = nvgpu.warpgroup.generate.descriptor %lhsSlice, %descA : memref<128x64xf16, strided<[64, 1], offset: ?>, #gpu.address_space<workgroup>>, !lhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16, 3>>
+ %dB = nvgpu.warpgroup.generate.descriptor %rhsSlice, %descB : memref<64x128xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>, !rhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<64x128xf16, 3>>
// Perform WGMMA 128x128x64
%md = nvgpu.warpgroup.mma %dA, %dB, %mc {transposeB} : <tensor = memref<128x64xf16,3>>, <tensor = memref<64x128xf16,3>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>>
scf.yield %md : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>
diff --git a/mlir/test/Pass/pipeline-options-parsing.mlir b/mlir/test/Pass/pipeline-options-parsing.mlir
index 9385d35..03ac38e 100644
--- a/mlir/test/Pass/pipeline-options-parsing.mlir
+++ b/mlir/test/Pass/pipeline-options-parsing.mlir
@@ -13,6 +13,7 @@
// RUN: mlir-opt %s -verify-each=false -pass-pipeline='builtin.module(builtin.module(func.func(test-options-pass{list=3}), func.func(test-options-pass{enum=one list=1,2,3,4 string=foo"bar"baz})))' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_6 %s
// RUN: mlir-opt %s -verify-each=false '-test-options-super-pass-pipeline=super-list={{enum=zero list=1 string=foo},{enum=one list=2 string="bar"},{enum=two list=3 string={baz}}}' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_7 %s
// RUN: mlir-opt %s -verify-each=false -pass-pipeline='builtin.module(func.func(test-options-super-pass{list={{enum=zero list={1} string=foo },{enum=one list={2} string=bar },{enum=two list={3} string=baz }}}))' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_7 %s
+// RUN: mlir-opt %s -verify-each=false -test-options-super-set-ab-pipeline='foo=true bar=false' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_11 %s
// This test checks that lists-of-nested-options like 'option1={...},{....}' can be parsed
@@ -106,3 +107,12 @@
// CHECK_10-NEXT: test-options-pass{enum=zero string= string-list={,}}
// CHECK_10-NEXT: )
// CHECK_10-NEXT: )
+
+// CHECK_11: builtin.module(
+// CHECK_11-NEXT: func.func(
+// CHECK_11-NEXT: test-options-pass-a
+// CHECK_11-NEXT: )
+// CHECK_11-NEXT: func.func(
+// CHECK_11-NEXT: test-options-pass-b
+// CHECK_11-NEXT: )
+// CHECK_11-NEXT: )
diff --git a/mlir/test/Target/Cpp/class.mlir b/mlir/test/Target/Cpp/class.mlir
new file mode 100644
index 0000000..32c6699
--- /dev/null
+++ b/mlir/test/Target/Cpp/class.mlir
@@ -0,0 +1,78 @@
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
+
+emitc.class @modelClass {
+ emitc.field @fieldName0 : !emitc.array<1xf32>
+ emitc.field @fieldName1 : !emitc.array<1xf32>
+ emitc.func @execute() {
+ %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+ %1 = get_field @fieldName0 : !emitc.array<1xf32>
+ %2 = get_field @fieldName1 : !emitc.array<1xf32>
+ %3 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ return
+ }
+}
+
+// CHECK-LABEL: class modelClass {
+// CHECK-NEXT: public:
+// CHECK-NEXT: float fieldName0[1];
+// CHECK-NEXT: float fieldName1[1];
+// CHECK-NEXT: void execute() {
+// CHECK-NEXT: size_t v1 = 0;
+// CHECK-NEXT: return;
+// CHECK-NEXT: }
+// CHECK-NEXT: };
+
+emitc.class final @finalClass {
+ emitc.field @fieldName0 : !emitc.array<1xf32>
+ emitc.field @fieldName1 : !emitc.array<1xf32>
+ emitc.func @execute() {
+ %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+ %1 = get_field @fieldName0 : !emitc.array<1xf32>
+ %2 = get_field @fieldName1 : !emitc.array<1xf32>
+ %3 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ return
+ }
+}
+
+// CHECK-LABEL: class finalClass final {
+// CHECK-NEXT: public:
+// CHECK-NEXT: float fieldName0[1];
+// CHECK-NEXT: float fieldName1[1];
+// CHECK-NEXT: void execute() {
+// CHECK-NEXT: size_t v1 = 0;
+// CHECK-NEXT: return;
+// CHECK-NEXT: }
+// CHECK-NEXT: };
+
+emitc.class @mainClass {
+ emitc.field @fieldName0 : !emitc.array<2xf32> = dense<0.0> {attrs = {emitc.name_hint = "another_feature"}}
+ emitc.func @get_fieldName0() {
+ %0 = emitc.get_field @fieldName0 : !emitc.array<2xf32>
+ return
+ }
+}
+
+// CHECK-LABEL: class mainClass {
+// CHECK-NEXT: public:
+// CHECK-NEXT: float fieldName0[2] = {0.0e+00f, 0.0e+00f};
+// CHECK-NEXT: void get_fieldName0() {
+// CHECK-NEXT: return;
+// CHECK-NEXT: }
+// CHECK-NEXT: };
+
+emitc.class @reflectionClass {
+ emitc.field @reflectionMap : !emitc.opaque<"const std::map<std::string, std::string>"> = #emitc.opaque<"{ { \22another_feature\22, \22fieldName0\22 } }">
+ emitc.func @get_reflectionMap() {
+ %0 = emitc.get_field @reflectionMap : !emitc.opaque<"const std::map<std::string, std::string>">
+ return
+ }
+}
+
+// CHECK-LABEL: class reflectionClass {
+// CHECK-NEXT: public:
+// CHECK-NEXT: const std::map<std::string, std::string> reflectionMap = { { "another_feature", "fieldName0" } };
+// CHECK-NEXT: void get_reflectionMap() {
+// CHECK-NEXT: return;
+// CHECK-NEXT: }
+// CHECK-NEXT: };
+
diff --git a/mlir/test/Target/Cpp/const.mlir b/mlir/test/Target/Cpp/const.mlir
index d3656f8..2a5ff1a 100644
--- a/mlir/test/Target/Cpp/const.mlir
+++ b/mlir/test/Target/Cpp/const.mlir
@@ -16,6 +16,8 @@ func.func @emitc_constant() {
%c8 = "emitc.constant"(){value = dense<0> : tensor<i32>} : () -> tensor<i32>
%c9 = "emitc.constant"(){value = dense<[0, 1]> : tensor<2xindex>} : () -> tensor<2xindex>
%c10 = "emitc.constant"(){value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
+ %c11 = "emitc.constant"(){value = dense<[0, 1]> : !emitc.array<2xindex>} : () -> !emitc.array<2xindex>
+ %c12 = "emitc.constant"(){value = dense<[0.0, 1.0]> : !emitc.array<2xf32>} : () -> !emitc.array<2xf32>
return
}
// CPP-DEFAULT: void emitc_constant() {
@@ -33,6 +35,8 @@ func.func @emitc_constant() {
// CPP-DEFAULT-NEXT: Tensor<int32_t> [[V8:[^ ]*]] = {0};
// CPP-DEFAULT-NEXT: Tensor<size_t, 2> [[V9:[^ ]*]] = {0, 1};
// CPP-DEFAULT-NEXT: Tensor<float, 2, 2> [[V10:[^ ]*]] = {0.0e+00f, 1.000000000e+00f, 2.000000000e+00f, 3.000000000e+00f};
+// CPP-DEFAULT-NEXT: size_t [[V11:[^ ]*]][2] = {0, 1};
+// CPP-DEFAULT-NEXT: float [[V12:[^ ]*]][2] = {0.0e+00f, 1.000000000e+00f};
// CPP-DECLTOP: void emitc_constant() {
// CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]];
@@ -49,6 +53,8 @@ func.func @emitc_constant() {
// CPP-DECLTOP-NEXT: Tensor<int32_t> [[V8:[^ ]*]];
// CPP-DECLTOP-NEXT: Tensor<size_t, 2> [[V9:[^ ]*]];
// CPP-DECLTOP-NEXT: Tensor<float, 2, 2> [[V10:[^ ]*]];
+// CPP-DECLTOP-NEXT: size_t [[V11:[^ ]*]][2];
+// CPP-DECLTOP-NEXT: float [[V12:[^ ]*]][2];
// CPP-DECLTOP-NEXT: [[V0]] = INT_MAX;
// CPP-DECLTOP-NEXT: [[V1]] = 42;
// CPP-DECLTOP-NEXT: [[V2]] = -1;
@@ -63,3 +69,5 @@ func.func @emitc_constant() {
// CPP-DECLTOP-NEXT: [[V8]] = {0};
// CPP-DECLTOP-NEXT: [[V9]] = {0, 1};
// CPP-DECLTOP-NEXT: [[V10]] = {0.0e+00f, 1.000000000e+00f, 2.000000000e+00f, 3.000000000e+00f};
+// CPP-DECLTOP-NEXT: [[V11]] = {0, 1};
+// CPP-DECLTOP-NEXT: [[V12]] = {0.0e+00f, 1.000000000e+00f};
diff --git a/mlir/test/Target/Cpp/control_flow.mlir b/mlir/test/Target/Cpp/control_flow.mlir
index 101b30c..ce9a0ee 100644
--- a/mlir/test/Target/Cpp/control_flow.mlir
+++ b/mlir/test/Target/Cpp/control_flow.mlir
@@ -70,7 +70,7 @@ func.func @block_labels1() {
// CPP-DECLTOP-NEXT: }
emitc.func @expression_inlining(%0 : i32, %1 : i32) {
- %2 = expression : i1 {
+ %2 = expression %0, %1 : (i32, i32) -> i1 {
%3 = cmp lt, %0, %1 : (i32, i32) -> i1
yield %3 : i1
}
diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir
index 9316d7b..433a67cc 100644
--- a/mlir/test/Target/Cpp/expressions.mlir
+++ b/mlir/test/Target/Cpp/expressions.mlir
@@ -30,7 +30,7 @@
func.func @single_use(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32 {
%p0 = emitc.literal "M_PI" : i32
- %e = emitc.expression : i1 {
+ %e = emitc.expression %arg0, %arg1, %arg2, %arg3, %p0 : (i32, i32, i32, i32, i32) -> i1 {
%a = emitc.mul %arg0, %p0 : (i32, i32) -> i32
%b = emitc.call_opaque "bar" (%a, %arg2) : (i32, i32) -> (i32)
%c = emitc.sub %b, %arg3 : (i32, i32) -> i32
@@ -61,7 +61,7 @@ func.func @single_use(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32 {
// CPP-DECLTOP-NEXT:}
func.func @do_not_inline(%arg0: i32, %arg1: i32, %arg2 : i32) -> i32 {
- %e = emitc.expression noinline : i32 {
+ %e = emitc.expression %arg0, %arg1, %arg2 noinline : (i32, i32, i32) -> i32 {
%a = emitc.add %arg0, %arg1 : (i32, i32) -> i32
%b = emitc.mul %a, %arg2 : (i32, i32) -> i32
emitc.yield %b : i32
@@ -78,7 +78,7 @@ func.func @do_not_inline(%arg0: i32, %arg1: i32, %arg2 : i32) -> i32 {
// CPP-DECLTOP-NEXT: }
func.func @parentheses_for_low_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -> f32 {
- %e = emitc.expression : f32 {
+ %e = emitc.expression %arg0, %arg1, %arg2 : (i32, i32, i32) -> f32 {
%a = emitc.add %arg0, %arg1 : (i32, i32) -> i32
%b = emitc.mul %a, %arg2 : (i32, i32) -> i32
%d = emitc.cast %b : i32 to f32
@@ -95,7 +95,7 @@ func.func @parentheses_for_low_precedence(%arg0: i32, %arg1: i32, %arg2: i32) ->
// CPP-DECLTOP-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DECLTOP-NEXT: }
func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
- %e = emitc.expression : i32 {
+ %e = emitc.expression %arg0, %arg1, %arg2 : (i32, i32, i32) -> i32 {
%0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
%1 = emitc.div %arg2, %0 : (i32, i32) -> i32
emitc.yield %1 : i32
@@ -145,32 +145,32 @@ func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -
// CPP-DECLTOP-NEXT: }
func.func @user_with_expression_trait(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
%c0 = "emitc.constant"() {value = 0 : i32} : () -> i32
- %e0 = emitc.expression : i32 {
+ %e0 = emitc.expression %arg0, %arg1, %arg2 : (i32, i32, i32) -> i32 {
%0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
%1 = emitc.div %arg2, %0 : (i32, i32) -> i32
emitc.yield %1 : i32
}
- %e1 = emitc.expression : i32 {
+ %e1 = emitc.expression %arg0, %arg1, %arg2 : (i32, i32, i32) -> i32 {
%0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
%1 = emitc.div %arg2, %0 : (i32, i32) -> i32
emitc.yield %1 : i32
}
- %e2 = emitc.expression : i32 {
+ %e2 = emitc.expression %arg0, %arg1, %arg2 : (i32, i32, i32) -> i32 {
%0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
%1 = emitc.div %arg2, %0 : (i32, i32) -> i32
emitc.yield %1 : i32
}
- %e3 = emitc.expression : i32 {
+ %e3 = emitc.expression %arg0, %arg1, %arg2 : (i32, i32, i32) -> i32 {
%0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
%1 = emitc.div %arg2, %0 : (i32, i32) -> i32
emitc.yield %1 : i32
}
- %e4 = emitc.expression : i32 {
+ %e4 = emitc.expression %arg0, %arg1, %arg2 : (i32, i32, i32) -> i32 {
%0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
%1 = emitc.div %arg2, %0 : (i32, i32) -> i32
emitc.yield %1 : i32
}
- %e5 = emitc.expression : i32 {
+ %e5 = emitc.expression %arg0, %arg1, %arg2 : (i32, i32, i32) -> i32 {
%0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
%1 = emitc.div %arg2, %0 : (i32, i32) -> i32
emitc.yield %1 : i32
@@ -217,7 +217,7 @@ func.func @user_with_expression_trait(%arg0: i32, %arg1: i32, %arg2: i32) -> i32
// CPP-DECLTOP-NEXT: }
func.func @multiple_uses(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32 {
- %e = emitc.expression : i1 {
+ %e = emitc.expression %arg0, %arg1, %arg2, %arg3 : (i32, i32, i32, i32) -> i1 {
%a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
%b = emitc.call_opaque "bar" (%a, %arg2) : (i32, i32) -> (i32)
%c = emitc.sub %b, %arg3 : (i32, i32) -> i32
@@ -269,16 +269,16 @@ func.func @multiple_uses(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32
// CPP-DECLTOP-NEXT: }
func.func @different_expressions(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32 {
- %e1 = emitc.expression : i32 {
+ %e1 = emitc.expression %arg2, %arg3 : (i32, i32) -> i32 {
%a = emitc.rem %arg2, %arg3 : (i32, i32) -> i32
emitc.yield %a : i32
}
- %e2 = emitc.expression : i32 {
+ %e2 = emitc.expression %arg0, %arg1, %e1 : (i32, i32, i32) -> i32 {
%a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
%b = emitc.call_opaque "bar" (%e1, %a) : (i32, i32) -> (i32)
emitc.yield %b : i32
}
- %e3 = emitc.expression : i1 {
+ %e3 = emitc.expression %arg1, %e2, %arg3 : (i32, i32, i32) -> i1 {
%c = emitc.sub %e2, %arg3 : (i32, i32) -> i32
%d = emitc.cmp lt, %c, %arg1 :(i32, i32) -> i1
emitc.yield %d : i1
@@ -295,6 +295,25 @@ func.func @different_expressions(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32)
return %v_load : i32
}
+// CPP-DEFAULT: int32_t expression_with_dereference(int32_t [[VAL_1:v[0-9]+]], int32_t* [[VAL_2]]) {
+// CPP-DEFAULT-NEXT: int32_t [[VAL_3:v[0-9]+]] = *([[VAL_2]] - [[VAL_1]]);
+// CPP-DEFAULT-NEXT: return [[VAL_3]];
+// CPP-DEFAULT-NEXT: }
+
+// CPP-DECLTOP: int32_t expression_with_dereference(int32_t [[VAL_1:v[0-9]+]], int32_t* [[VAL_2]]) {
+// CPP-DECLTOP-NEXT: int32_t [[VAL_3:v[0-9]+]];
+// CPP-DECLTOP-NEXT: [[VAL_3]] = *([[VAL_2]] - [[VAL_1]]);
+// CPP-DECLTOP-NEXT: return [[VAL_3]];
+// CPP-DECLTOP-NEXT: }
+func.func @expression_with_dereference(%arg1: i32, %arg2: !emitc.ptr<i32>) -> i32 {
+ %c = emitc.expression %arg1, %arg2 : (i32, !emitc.ptr<i32>) -> i32 {
+ %e = emitc.sub %arg2, %arg1 : (!emitc.ptr<i32>, i32) -> !emitc.ptr<i32>
+ %d = emitc.apply "*"(%e) : (!emitc.ptr<i32>) -> i32
+ emitc.yield %d : i32
+ }
+ return %c : i32
+}
+
// CPP-DEFAULT: bool expression_with_address_taken(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) {
// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = 42;
// CPP-DEFAULT-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]];
@@ -308,7 +327,7 @@ func.func @different_expressions(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32)
func.func @expression_with_address_taken(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr<i32>) -> i1 {
%a = "emitc.variable"(){value = 42 : i32} : () -> !emitc.lvalue<i32>
- %c = emitc.expression : i1 {
+ %c = emitc.expression %arg1, %arg2, %a : (i32, !emitc.ptr<i32>, !emitc.lvalue<i32>) -> i1 {
%d = emitc.apply "&"(%a) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
%e = emitc.sub %d, %arg1 : (!emitc.ptr<i32>, i32) -> !emitc.ptr<i32>
%f = emitc.cmp lt, %e, %arg2 : (!emitc.ptr<i32>, !emitc.ptr<i32>) -> i1
@@ -334,7 +353,7 @@ func.func @expression_with_address_taken(%arg0: i32, %arg1: i32, %arg2: !emitc.p
func.func @expression_with_subscript_user(%arg0: !emitc.ptr<!emitc.opaque<"void">>) -> i32 {
%c0 = "emitc.constant"() {value = 0 : i64} : () -> i64
- %0 = emitc.expression : !emitc.ptr<i32> {
+ %0 = emitc.expression %arg0 : (!emitc.ptr<!emitc.opaque<"void">>) -> !emitc.ptr<i32> {
%0 = emitc.cast %arg0 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
emitc.yield %0 : !emitc.ptr<i32>
}
@@ -362,7 +381,7 @@ func.func @expression_with_load(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr<i32>)
%c0 = "emitc.constant"() {value = 0 : i64} : () -> i64
%0 = "emitc.variable"() <{value = #emitc.opaque<"42">}> : () -> !emitc.lvalue<i32>
%ptr = emitc.subscript %arg2[%c0] : (!emitc.ptr<i32>, i64) -> !emitc.lvalue<i32>
- %result = emitc.expression : i1 {
+ %result = emitc.expression %arg0, %arg1, %0, %ptr : (i32, i32, !emitc.lvalue<i32>, !emitc.lvalue<i32>) -> i1 {
%a = emitc.load %0 : !emitc.lvalue<i32>
%b = emitc.add %a, %arg1 : (i32, i32) -> i32
%c = emitc.load %ptr : !emitc.lvalue<i32>
@@ -388,7 +407,7 @@ func.func @expression_with_load(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr<i32>)
func.func @expression_with_load_and_call(%arg0: !emitc.ptr<i32>) -> i1 {
%c0 = "emitc.constant"() {value = 0 : i64} : () -> i64
%ptr = emitc.subscript %arg0[%c0] : (!emitc.ptr<i32>, i64) -> !emitc.lvalue<i32>
- %result = emitc.expression : i1 {
+ %result = emitc.expression %ptr : (!emitc.lvalue<i32>) -> i1 {
%a = emitc.load %ptr : !emitc.lvalue<i32>
%b = emitc.load %ptr : !emitc.lvalue<i32>
%c = emitc.load %ptr : !emitc.lvalue<i32>
@@ -399,3 +418,24 @@ func.func @expression_with_load_and_call(%arg0: !emitc.ptr<i32>) -> i1 {
}
return %result : i1
}
+
+
+// CPP-DEFAULT: void expression_with_call_opaque_with_args_array(int32_t [[v1:v.+]], int32_t [[v2:v.+]]) {
+// CPP-DEFAULT-NEXT: bool [[v3:v.+]] = f(([[v1]] < [[v2]]));
+// CPP-DEFAULT-NEXT: return;
+// CPP-DEFAULT-NEXT: }
+
+// CPP-DECLTOP: void expression_with_call_opaque_with_args_array(int32_t [[v1:v.+]], int32_t [[v2:v.+]]) {
+// CPP-DECLTOP-NEXT: bool [[v3:v.+]];
+// CPP-DECLTOP-NEXT: [[v3]] = f(([[v1]] < [[v2]]));
+// CPP-DECLTOP-NEXT: return;
+// CPP-DECLTOP-NEXT: }
+
+emitc.func @expression_with_call_opaque_with_args_array(%0 : i32, %1 : i32) {
+ %2 = expression %0, %1 : (i32, i32) -> i1 {
+ %3 = cmp lt, %0, %1 : (i32, i32) -> i1
+ %4 = emitc.call_opaque "f"(%3) {"args" = [0: index]} : (i1) -> i1
+ yield %4 : i1
+ }
+ return
+}
diff --git a/mlir/test/Target/Cpp/for.mlir b/mlir/test/Target/Cpp/for.mlir
index 7cd3d5d..73375d5 100644
--- a/mlir/test/Target/Cpp/for.mlir
+++ b/mlir/test/Target/Cpp/for.mlir
@@ -2,15 +2,15 @@
// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck --match-full-lines %s -check-prefix=CPP-DECLTOP
func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) {
- %lb = emitc.expression : index {
+ %lb = emitc.expression %arg0, %arg1 : (index, index) -> index {
%a = emitc.add %arg0, %arg1 : (index, index) -> index
emitc.yield %a : index
}
- %ub = emitc.expression : index {
+ %ub = emitc.expression %arg1, %arg2 : (index, index) -> index {
%a = emitc.mul %arg1, %arg2 : (index, index) -> index
emitc.yield %a : index
}
- %step = emitc.expression : index {
+ %step = emitc.expression %arg0, %arg2 : (index, index) -> index {
%a = emitc.div %arg0, %arg2 : (index, index) -> index
emitc.yield %a : index
}
diff --git a/mlir/test/Target/Cpp/member.mlir b/mlir/test/Target/Cpp/member.mlir
index 20589fe..6e03952 100644
--- a/mlir/test/Target/Cpp/member.mlir
+++ b/mlir/test/Target/Cpp/member.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT
-func.func @member(%arg0: !emitc.opaque<"mystruct">, %arg1: i32) {
+func.func @member(%arg0: !emitc.opaque<"mystruct">, %arg1: i32, %arg2: index) {
%var0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<!emitc.opaque<"mystruct">>
emitc.assign %arg0 : !emitc.opaque<"mystruct"> to %var0 : !emitc.lvalue<!emitc.opaque<"mystruct">>
@@ -12,19 +12,31 @@ func.func @member(%arg0: !emitc.opaque<"mystruct">, %arg1: i32) {
%3 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<i32>
emitc.assign %2 : i32 to %3 : !emitc.lvalue<i32>
+ %4 = "emitc.member" (%var0) {member = "c"} : (!emitc.lvalue<!emitc.opaque<"mystruct">>) -> !emitc.array<2xi32>
+ %5 = emitc.subscript %4[%arg2] : (!emitc.array<2xi32>, index) -> !emitc.lvalue<i32>
+ %6 = emitc.load %5 : <i32>
+ emitc.assign %6 : i32 to %3 : !emitc.lvalue<i32>
+
+ %7 = "emitc.member" (%var0) {member = "d"} : (!emitc.lvalue<!emitc.opaque<"mystruct">>) -> !emitc.array<2xi32>
+ %8 = emitc.subscript %7[%arg2] : (!emitc.array<2xi32>, index) -> !emitc.lvalue<i32>
+ emitc.assign %arg1 : i32 to %8 : !emitc.lvalue<i32>
+
return
}
-// CPP-DEFAULT: void member(mystruct [[V0:[^ ]*]], int32_t [[V1:[^ ]*]]) {
+// CPP-DEFAULT: void member(mystruct [[V0:[^ ]*]], int32_t [[V1:[^ ]*]], size_t [[Index:[^ ]*]]) {
// CPP-DEFAULT-NEXT: mystruct [[V2:[^ ]*]];
// CPP-DEFAULT-NEXT: [[V2]] = [[V0]];
// CPP-DEFAULT-NEXT: [[V2]].a = [[V1]];
// CPP-DEFAULT-NEXT: int32_t [[V3:[^ ]*]] = [[V2]].b;
// CPP-DEFAULT-NEXT: int32_t [[V4:[^ ]*]];
// CPP-DEFAULT-NEXT: [[V4]] = [[V3]];
+// CPP-DEFAULT-NEXT: int32_t [[V5:[^ ]*]] = [[V2]].c[[[Index]]];
+// CPP-DEFAULT-NEXT: [[V4]] = [[V5]];
+// CPP-DEFAULT-NEXT: [[V2]].d[[[Index]]] = [[V1]];
-func.func @member_of_pointer(%arg0: !emitc.ptr<!emitc.opaque<"mystruct">>, %arg1: i32) {
+func.func @member_of_pointer(%arg0: !emitc.ptr<!emitc.opaque<"mystruct">>, %arg1: i32, %arg2: index) {
%var0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<!emitc.ptr<!emitc.opaque<"mystruct">>>
emitc.assign %arg0 : !emitc.ptr<!emitc.opaque<"mystruct">> to %var0 : !emitc.lvalue<!emitc.ptr<!emitc.opaque<"mystruct">>>
@@ -36,14 +48,25 @@ func.func @member_of_pointer(%arg0: !emitc.ptr<!emitc.opaque<"mystruct">>, %arg1
%3 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<i32>
emitc.assign %2 : i32 to %3 : !emitc.lvalue<i32>
+ %4 = "emitc.member_of_ptr" (%var0) {member = "c"} : (!emitc.lvalue<!emitc.ptr<!emitc.opaque<"mystruct">>>) -> !emitc.array<2xi32>
+ %5 = emitc.subscript %4[%arg2] : (!emitc.array<2xi32>, index) -> !emitc.lvalue<i32>
+ %6 = emitc.load %5 : <i32>
+ emitc.assign %6 : i32 to %3 : !emitc.lvalue<i32>
+
+ %7 = "emitc.member_of_ptr" (%var0) {member = "d"} : (!emitc.lvalue<!emitc.ptr<!emitc.opaque<"mystruct">>>) -> !emitc.array<2xi32>
+ %8 = emitc.subscript %7[%arg2] : (!emitc.array<2xi32>, index) -> !emitc.lvalue<i32>
+ emitc.assign %arg1 : i32 to %8 : !emitc.lvalue<i32>
+
return
}
-// CPP-DEFAULT: void member_of_pointer(mystruct* [[V0:[^ ]*]], int32_t [[V1:[^ ]*]]) {
+// CPP-DEFAULT: void member_of_pointer(mystruct* [[V0:[^ ]*]], int32_t [[V1:[^ ]*]], size_t [[Index:[^ ]*]]) {
// CPP-DEFAULT-NEXT: mystruct* [[V2:[^ ]*]];
// CPP-DEFAULT-NEXT: [[V2]] = [[V0]];
// CPP-DEFAULT-NEXT: [[V2]]->a = [[V1]];
// CPP-DEFAULT-NEXT: int32_t [[V3:[^ ]*]] = [[V2]]->b;
// CPP-DEFAULT-NEXT: int32_t [[V4:[^ ]*]];
// CPP-DEFAULT-NEXT: [[V4]] = [[V3]];
-
+// CPP-DEFAULT-NEXT: int32_t [[V5:[^ ]*]] = [[V2]]->c[[[Index]]];
+// CPP-DEFAULT-NEXT: [[V4]] = [[V5]];
+// CPP-DEFAULT-NEXT: [[V2]]->d[[[Index]]] = [[V1]];
diff --git a/mlir/test/Target/Cpp/switch.mlir b/mlir/test/Target/Cpp/switch.mlir
index 4e20c1f..87e4cb8 100644
--- a/mlir/test/Target/Cpp/switch.mlir
+++ b/mlir/test/Target/Cpp/switch.mlir
@@ -907,7 +907,7 @@ func.func @emitc_switch_ui64() {
func.func @emitc_switch_expression() {
%x = "emitc.constant"(){value = 42 : i64} : () -> i64
- %0 = emitc.expression : i64 {
+ %0 = emitc.expression %x : (i64) -> i64 {
%a = emitc.unary_minus %x : (i64) -> i64
emitc.yield %a : i64
}
diff --git a/mlir/test/Target/LLVMIR/Import/function-attributes.ll b/mlir/test/Target/LLVMIR/Import/function-attributes.ll
index 0b13645..cc3d799 100644
--- a/mlir/test/Target/LLVMIR/Import/function-attributes.ll
+++ b/mlir/test/Target/LLVMIR/Import/function-attributes.ll
@@ -339,18 +339,6 @@ declare void @func_attr_no_nans_fp_math_false() "no-nans-fp-math"="false"
; // -----
-; CHECK-LABEL: @func_attr_approx_func_fp_math_true
-; CHECK-SAME: attributes {approx_func_fp_math = true}
-declare void @func_attr_approx_func_fp_math_true() "approx-func-fp-math"="true"
-
-; // -----
-
-; CHECK-LABEL: @func_attr_approx_func_fp_math_false
-; CHECK-SAME: attributes {approx_func_fp_math = false}
-declare void @func_attr_approx_func_fp_math_false() "approx-func-fp-math"="false"
-
-; // -----
-
; CHECK-LABEL: @func_attr_no_signed_zeros_fp_math_true
; CHECK-SAME: attributes {no_signed_zeros_fp_math = true}
declare void @func_attr_no_signed_zeros_fp_math_true() "no-signed-zeros-fp-math"="true"
@@ -426,3 +414,8 @@ declare void @nounwind_attribute() nounwind
; CHECK-LABEL: @willreturn_attribute
; CHECK-SAME: attributes {will_return}
declare void @willreturn_attribute() willreturn
+
+// -----
+
+; expected-warning @unknown {{'preallocated' attribute is invalid on current operation, skipping it}}
+declare void @test() preallocated(i32)
diff --git a/mlir/test/Target/LLVMIR/Import/global-variables.ll b/mlir/test/Target/LLVMIR/Import/global-variables.ll
index b8bbdba..102162a 100644
--- a/mlir/test/Target/LLVMIR/Import/global-variables.ll
+++ b/mlir/test/Target/LLVMIR/Import/global-variables.ll
@@ -186,19 +186,16 @@
; CHECK-SAME: {addr_space = 0 : i32, dso_local} : !llvm.array<2 x f32>
@array_constant = internal constant [2 x float] [float 1., float 2.]
-; CHECK: llvm.mlir.global internal constant @nested_array_constant
-; CHECK-SAME-LITERAL: (dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>)
-; CHECK-SAME-LITERAL: {addr_space = 0 : i32, dso_local} : !llvm.array<2 x array<2 x i32>>
+; CHECK{LITERAL}: llvm.mlir.global internal constant @nested_array_constant(dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>)
+; CHECK-SAME: {addr_space = 0 : i32, dso_local} : !llvm.array<2 x array<2 x i32>>
@nested_array_constant = internal constant [2 x [2 x i32]] [[2 x i32] [i32 1, i32 2], [2 x i32] [i32 3, i32 4]]
-; CHECK: llvm.mlir.global internal constant @nested_array_constant3
-; CHECK-SAME-LITERAL: (dense<[[[1, 2], [3, 4]]]> : tensor<1x2x2xi32>)
-; CHECK-SAME-LITERAL: {addr_space = 0 : i32, dso_local} : !llvm.array<1 x array<2 x array<2 x i32>>>
+; CHECK{LITERAL}: llvm.mlir.global internal constant @nested_array_constant3(dense<[[[1, 2], [3, 4]]]> : tensor<1x2x2xi32>)
+; CHECK-SAME: {addr_space = 0 : i32, dso_local} : !llvm.array<1 x array<2 x array<2 x i32>>>
@nested_array_constant3 = internal constant [1 x [2 x [2 x i32]]] [[2 x [2 x i32]] [[2 x i32] [i32 1, i32 2], [2 x i32] [i32 3, i32 4]]]
-; CHECK: llvm.mlir.global internal constant @nested_array_vector
-; CHECK-SAME-LITERAL: (dense<[[[1, 2], [3, 4]]]> : vector<1x2x2xi32>)
-; CHECK-SAME-LITERAL: {addr_space = 0 : i32, dso_local} : !llvm.array<1 x array<2 x vector<2xi32>>>
+; CHECK{LITERAL}: llvm.mlir.global internal constant @nested_array_vector(dense<[[[1, 2], [3, 4]]]> : vector<1x2x2xi32>)
+; CHECK-SAME: {addr_space = 0 : i32, dso_local} : !llvm.array<1 x array<2 x vector<2xi32>>>
@nested_array_vector = internal constant [1 x [2 x <2 x i32>]] [[2 x <2 x i32>] [<2 x i32> <i32 1, i32 2>, <2 x i32> <i32 3, i32 4>]]
; CHECK: llvm.mlir.global internal constant @vector_constant_zero
@@ -221,9 +218,8 @@
; CHECK-SAME: {addr_space = 0 : i32, dso_local} : !llvm.array<1 x array<2 x vector<2xi32>>>
@nested_array_vector_zero = internal constant [1 x [2 x <2 x i32>]] zeroinitializer
-; CHECK: llvm.mlir.global internal constant @nested_bool_array_constant
-; CHECK-SAME-LITERAL: (dense<[[true, false]]> : tensor<1x2xi1>)
-; CHECK-SAME-LITERAL: {addr_space = 0 : i32, dso_local} : !llvm.array<1 x array<2 x i1>>
+; CHECK{LITERAL}: llvm.mlir.global internal constant @nested_bool_array_constant(dense<[[true, false]]> : tensor<1x2xi1>)
+; CHECK-SAME: {addr_space = 0 : i32, dso_local} : !llvm.array<1 x array<2 x i1>>
@nested_bool_array_constant = internal constant [1 x [2 x i1]] [[2 x i1] [i1 1, i1 0]]
; CHECK: llvm.mlir.global internal constant @quad_float_constant
@@ -358,3 +354,18 @@ declare void @"mlir.llvm.nameless_global_2"()
; CHECK: llvm.mlir.global private unnamed_addr constant @mlir.llvm.nameless_global_0("0\00")
@0 = private unnamed_addr constant [2 x i8] c"0\00"
+
+; // -----
+
+; CHECK-LABEL: llvm.mlir.global external @target_specific_attrs_only
+; CHECK-SAME: target_specific_attrs = {{\[\[}}"memory", "0"], ["int-attr", "4"], "no-enum-attr", ["string-attr", "string"]]}
+@target_specific_attrs_only = external global double #0
+attributes #0 = { readnone "int-attr"="4" "no-enum-attr" "string-attr"="string" }
+
+; // -----
+
+; CHECK-LABEL: llvm.mlir.global external @target_specific_attrs_combined
+; CHECK-SAME: alignment = 4 : i64, section = "mysection",
+; CHECK-SAME: target_specific_attrs = ["norecurse", ["bss-section", "my_bss.1"]]}
+@target_specific_attrs_combined = global i32 2, align 4, section "mysection" #0
+attributes #0 = { norecurse "bss-section"="my_bss.1" }
diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic-prefer-unregistered.ll b/mlir/test/Target/LLVMIR/Import/intrinsic-prefer-unregistered.ll
index 797a75c..18c9319 100644
--- a/mlir/test/Target/LLVMIR/Import/intrinsic-prefer-unregistered.ll
+++ b/mlir/test/Target/LLVMIR/Import/intrinsic-prefer-unregistered.ll
@@ -3,9 +3,9 @@
; CHECK-LABEL: llvm.func @lifetime
define void @lifetime() {
%a = alloca [16 x i8]
- ; CHECK: llvm.call_intrinsic "llvm.lifetime.start.p0"({{.*}}, %[[ptr:.*]]) : (i64, !llvm.ptr {llvm.nonnull}) -> ()
- call void @llvm.lifetime.start.p0(i64 16, ptr nonnull %a)
- ; CHECK: llvm.call_intrinsic "llvm.lifetime.end.p0"({{.*}}, %[[ptr]]) : (i64, !llvm.ptr {llvm.nonnull}) -> ()
- call void @llvm.lifetime.end.p0(i64 32, ptr nonnull %a)
+ ; CHECK: llvm.call_intrinsic "llvm.lifetime.start.p0"(%[[ptr:.*]]) : (!llvm.ptr {llvm.nonnull}) -> ()
+ call void @llvm.lifetime.start.p0(ptr nonnull %a)
+ ; CHECK: llvm.call_intrinsic "llvm.lifetime.end.p0"(%[[ptr]]) : (!llvm.ptr {llvm.nonnull}) -> ()
+ call void @llvm.lifetime.end.p0(ptr nonnull %a)
ret void
}
diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
index a419d75..07d2212 100644
--- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll
+++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
@@ -545,6 +545,15 @@ define void @masked_expand_compress_intrinsics(ptr %0, <7 x i1> %1, <7 x float>
ret void
}
+; CHECK-LABEL: llvm.func @masked_expand_compress_intrinsics_with_alignment
+define void @masked_expand_compress_intrinsics_with_alignment(ptr %0, <7 x i1> %1, <7 x float> %2) {
+ ; CHECK: %[[val1:.+]] = "llvm.intr.masked.expandload"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 8 : i64}, {}, {}]}> : (!llvm.ptr, vector<7xi1>, vector<7xf32>) -> vector<7xf32>
+ %4 = call <7 x float> @llvm.masked.expandload.v7f32(ptr align 8 %0, <7 x i1> %1, <7 x float> %2)
+ ; CHECK: "llvm.intr.masked.compressstore"(%[[val1]], %{{.*}}, %{{.*}}) <{arg_attrs = [{}, {llvm.align = 8 : i64}, {}]}> : (vector<7xf32>, !llvm.ptr, vector<7xi1>) -> ()
+ call void @llvm.masked.compressstore.v7f32(<7 x float> %4, ptr align 8 %0, <7 x i1> %1)
+ ret void
+}
+
; CHECK-LABEL: llvm.func @annotate_intrinsics
define void @annotate_intrinsics(ptr %var, ptr %ptr, i16 %int, ptr %annotation, ptr %fileName, i32 %line, ptr %args) {
; CHECK: "llvm.intr.var.annotation"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr) -> ()
@@ -876,10 +885,10 @@ define void @stack_restore(ptr %0, ptr addrspace(1) %1) {
; CHECK-LABEL: llvm.func @lifetime
define void @lifetime() {
%a = alloca [16 x i8]
- ; CHECK: llvm.intr.lifetime.start 16, %{{.*}} : !llvm.ptr
- call void @llvm.lifetime.start.p0(i64 16, ptr %a)
- ; CHECK: llvm.intr.lifetime.end 32, %{{.*}} : !llvm.ptr
- call void @llvm.lifetime.end.p0(i64 32, ptr %a)
+ ; CHECK: llvm.intr.lifetime.start %{{.*}} : !llvm.ptr
+ call void @llvm.lifetime.start.p0(ptr %a)
+ ; CHECK: llvm.intr.lifetime.end %{{.*}} : !llvm.ptr
+ call void @llvm.lifetime.end.p0(ptr %a)
ret void
}
@@ -1353,8 +1362,8 @@ declare <8 x i64> @llvm.vp.fptoui.v8i64.v8f64(<8 x double>, <8 x i1>, i32)
declare <8 x i64> @llvm.vp.fptosi.v8i64.v8f64(<8 x double>, <8 x i1>, i32)
declare <8 x i64> @llvm.vp.ptrtoint.v8i64.v8p0(<8 x ptr>, <8 x i1>, i32)
declare <8 x ptr> @llvm.vp.inttoptr.v8p0.v8i64(<8 x i64>, <8 x i1>, i32)
-declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture)
-declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture)
+declare void @llvm.lifetime.start.p0(ptr nocapture)
+declare void @llvm.lifetime.end.p0(ptr nocapture)
declare ptr @llvm.invariant.start.p0(i64 immarg, ptr nocapture)
declare void @llvm.invariant.end.p0(ptr, i64 immarg, ptr nocapture)
declare ptr @llvm.launder.invariant.group.p0(ptr nocapture)
diff --git a/mlir/test/Target/LLVMIR/fp-math-function-attributes.mlir b/mlir/test/Target/LLVMIR/fp-math-function-attributes.mlir
index 673cbd8..7b11fdc 100644
--- a/mlir/test/Target/LLVMIR/fp-math-function-attributes.mlir
+++ b/mlir/test/Target/LLVMIR/fp-math-function-attributes.mlir
@@ -54,24 +54,6 @@ llvm.func @no_nans_fp_math_func_false() attributes {no_nans_fp_math = false} {
// -----
-// CHECK-LABEL: define void @approx_func_fp_math_func_true()
-// CHECK-SAME: #[[ATTRS:[0-9]+]]
-llvm.func @approx_func_fp_math_func_true() attributes {approx_func_fp_math = true} {
- llvm.return
-}
-// CHECK: attributes #[[ATTRS]] = { "approx-func-fp-math"="true" }
-
-// -----
-//
-// CHECK-LABEL: define void @approx_func_fp_math_func_false()
-// CHECK-SAME: #[[ATTRS:[0-9]+]]
-llvm.func @approx_func_fp_math_func_false() attributes {approx_func_fp_math = false} {
- llvm.return
-}
-// CHECK: attributes #[[ATTRS]] = { "approx-func-fp-math"="false" }
-
-// -----
-
// CHECK-LABEL: define void @no_signed_zeros_fp_math_func_true()
// CHECK-SAME: #[[ATTRS:[0-9]+]]
llvm.func @no_signed_zeros_fp_math_func_true() attributes {no_signed_zeros_fp_math = true} {
diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
index eb3510c..c99dde3 100644
--- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
@@ -577,6 +577,17 @@ llvm.func @masked_expand_compress_intrinsics(%ptr: !llvm.ptr, %mask: vector<7xi1
llvm.return
}
+// CHECK-LABEL: @masked_expand_compress_intrinsics_with_alignment
+llvm.func @masked_expand_compress_intrinsics_with_alignment(%ptr: !llvm.ptr, %mask: vector<7xi1>, %passthru: vector<7xf32>) {
+ // CHECK: call <7 x float> @llvm.masked.expandload.v7f32(ptr align 8 %{{.*}}, <7 x i1> %{{.*}}, <7 x float> %{{.*}})
+ %0 = "llvm.intr.masked.expandload"(%ptr, %mask, %passthru) {arg_attrs = [{llvm.align = 8 : i32}, {}, {}]}
+ : (!llvm.ptr, vector<7xi1>, vector<7xf32>) -> (vector<7xf32>)
+ // CHECK: call void @llvm.masked.compressstore.v7f32(<7 x float> %{{.*}}, ptr align 8 %{{.*}}, <7 x i1> %{{.*}})
+ "llvm.intr.masked.compressstore"(%0, %ptr, %mask) {arg_attrs = [{}, {llvm.align = 8 : i32}, {}]}
+ : (vector<7xf32>, !llvm.ptr, vector<7xi1>) -> ()
+ llvm.return
+}
+
// CHECK-LABEL: @annotate_intrinsics
llvm.func @annotate_intrinsics(%var: !llvm.ptr, %int: i16, %ptr: !llvm.ptr, %annotation: !llvm.ptr, %fileName: !llvm.ptr, %line: i32, %attr: !llvm.ptr) {
// CHECK: call void @llvm.var.annotation.p0.p0(ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}}, i32 %{{.*}}, ptr %{{.*}})
@@ -1104,9 +1115,9 @@ llvm.func @lifetime() {
%c = llvm.mlir.constant(16 : i64) : i64
%a = llvm.alloca %c x i8 : (i64) -> !llvm.ptr
// CHECK: call void @llvm.lifetime.start
- llvm.intr.lifetime.start 16, %a : !llvm.ptr
+ llvm.intr.lifetime.start %a : !llvm.ptr
// CHECK: call void @llvm.lifetime.end
- llvm.intr.lifetime.end 16, %a : !llvm.ptr
+ llvm.intr.lifetime.end %a : !llvm.ptr
llvm.return
}
@@ -1418,8 +1429,8 @@ llvm.func @experimental_constrained_fpext(%s: f32, %v: vector<4xf32>) {
// CHECK-DAG: declare <2 x i32> @llvm.vector.extract.v2i32.v8i32(<8 x i32>, i64 immarg)
// CHECK-DAG: declare { <2 x double>, <2 x double> } @llvm.vector.deinterleave2.v4f64(<4 x double>)
// CHECK-DAG: declare { <vscale x 4 x i32>, <vscale x 4 x i32> } @llvm.vector.deinterleave2.nxv8i32(<vscale x 8 x i32>)
-// CHECK-DAG: declare void @llvm.lifetime.start.p0(i64 immarg, ptr captures(none))
-// CHECK-DAG: declare void @llvm.lifetime.end.p0(i64 immarg, ptr captures(none))
+// CHECK-DAG: declare void @llvm.lifetime.start.p0(ptr captures(none))
+// CHECK-DAG: declare void @llvm.lifetime.end.p0(ptr captures(none))
// CHECK-DAG: declare ptr @llvm.invariant.start.p0(i64 immarg, ptr captures(none))
// CHECK-DAG: declare void @llvm.invariant.end.p0(ptr, i64 immarg, ptr captures(none))
diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index b09ceee..c263afe 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -29,6 +29,26 @@ llvm.func @passthrough_wrong_type() attributes {
// -----
+// expected-error @below{{LLVM attribute 'readonly' does not expect a value}}
+llvm.mlir.global external @target_specific_attrs_unexpected_value() {target_specific_attrs = [["readonly", "42"]]} : f64
+
+// -----
+
+// expected-error @below{{LLVM attribute 'alignstack' expects a value}}
+llvm.mlir.global external @target_specific_attrs_expected_value() {target_specific_attrs = ["alignstack"]} : f64
+
+// -----
+
+// expected-error @below{{expected 'target_specific_attrs' to contain string or array attributes}}
+llvm.mlir.global external @target_specific_attrs_wrong_type() {target_specific_attrs = [42]} : f64
+
+// -----
+
+// expected-error @below{{expected arrays within 'target_specific_attrs' to contain two strings}}
+llvm.mlir.global external @target_specific_attrs_wrong_type() {target_specific_attrs = [[ 42, 42 ]]} : f64
+
+// -----
+
llvm.func @unary_float_intr_wrong_type(%arg0 : i32) -> i32 {
// expected-error @below{{op operand #0 must be floating point LLVM type or LLVM dialect-compatible vector of floating point LLVM type}}
%0 = "llvm.intr.exp"(%arg0) : (i32) -> i32
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index fc1993b..69814f2 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -3028,3 +3028,15 @@ llvm.mlir.global internal constant @test_array_attr_struct_with_struct() : !llvm
%0 = llvm.mlir.constant([#llvm.zero, [2 : i32, 1.0 : f32], #llvm.undef]) : !llvm.array<3 x struct<(i32, f32)>>
llvm.return %0 : !llvm.array<3 x struct<(i32, f32)>>
}
+
+// -----
+
+// CHECK: @target_specific_attrs_only = external global double #[[ATTRS:[0-9]+]]
+// CHECK: attributes #[[ATTRS]] = { memory(none) "int-attr"="4" "no-enum-attr" "string-attr"="string" }
+llvm.mlir.global external @target_specific_attrs_only() {target_specific_attrs = [["memory", "0"], ["int-attr", "4"], "no-enum-attr", ["string-attr", "string"]]} : f64
+
+// -----
+
+// CHECK: @target_specific_attrs_combined = global i32 2, section "mysection", align 4 #[[ATTRS:[0-9]+]]
+// CHECK: attributes #[[ATTRS]] = { norecurse "bss-section"="my_bss.1" }
+llvm.mlir.global external @target_specific_attrs_combined(2 : i32) {alignment = 4 : i64, section = "mysection", target_specific_attrs = ["norecurse", ["bss-section", "my_bss.1"]]} : i32
diff --git a/mlir/test/Target/LLVMIR/nvvm/prefetch.mlir b/mlir/test/Target/LLVMIR/nvvm/prefetch.mlir
index f38b752..5f8e8d0 100644
--- a/mlir/test/Target/LLVMIR/nvvm/prefetch.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/prefetch.mlir
@@ -32,8 +32,8 @@ llvm.func @prefetch_L2_eviction_priority(%global_ptr: !llvm.ptr<1>) {
// CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L2.evict.normal(ptr addrspace(1) %0)
// CHECK-NEXT: ret void
// CHECK-NEXT: }
- nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_last : !llvm.ptr<1>
- nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_normal : !llvm.ptr<1>
+ nvvm.prefetch level = L2, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>
+ nvvm.prefetch level = L2, evict_priority = evict_normal, %global_ptr : !llvm.ptr<1>
llvm.return
}
@@ -45,3 +45,17 @@ llvm.func @prefetch_L1_uniform(%gen_ptr: !llvm.ptr) {
nvvm.prefetch level = L1 uniform, %gen_ptr : !llvm.ptr
llvm.return
}
+
+llvm.func @prefetch_tensormap(%gen_ptr: !llvm.ptr, %const_ptr: !llvm.ptr<4>) {
+ // CHECK-LABEL: define void @prefetch_tensormap(ptr %0, ptr addrspace(4) %1) {
+ // CHECK-NEXT: call void @llvm.nvvm.prefetch.tensormap.p0(ptr %0)
+ // CHECK-NEXT: call void @llvm.nvvm.prefetch.tensormap.p4(ptr addrspace(4) %1)
+ // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(101)
+ // CHECK-NEXT: call void @llvm.nvvm.prefetch.tensormap.p101(ptr addrspace(101) %3)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.prefetch tensormap, %gen_ptr : !llvm.ptr
+ nvvm.prefetch tensormap, %const_ptr: !llvm.ptr<4>
+ nvvm.prefetch tensormap in_param_space, %gen_ptr : !llvm.ptr
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_prefetch.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_prefetch.mlir
index bfd9526..536b52b 100644
--- a/mlir/test/Target/LLVMIR/nvvm/tma_prefetch.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/tma_prefetch.mlir
@@ -1,70 +1,123 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
-// CHECK-LABEL: @tma_bulk_prefetch
llvm.func @tma_bulk_prefetch(%src : !llvm.ptr<1>, %size : i32, %ch : i64) {
- // CHECK: call void @llvm.nvvm.cp.async.bulk.prefetch.L2(ptr addrspace(1) %{{.*}}, i32 %{{.*}}, i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.prefetch.L2(ptr addrspace(1) %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+ // CHECK-LABEL: define void @tma_bulk_prefetch(ptr addrspace(1) %0, i32 %1, i64 %2) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.prefetch.L2(ptr addrspace(1) %0, i32 %1, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.prefetch.L2(ptr addrspace(1) %0, i32 %1, i64 %2, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
nvvm.cp.async.bulk.prefetch %src, %size : !llvm.ptr<1>
nvvm.cp.async.bulk.prefetch %src, %size l2_cache_hint = %ch : !llvm.ptr<1>
llvm.return
}
-// CHECK-LABEL: @tma_prefetch_1d
llvm.func @tma_prefetch_1d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+ // CHECK-LABEL: define void @tma_prefetch_1d(ptr %0, i32 %1, i64 %2) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %1, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %1, i64 %2, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] l2_cache_hint = %ch : !llvm.ptr
llvm.return
}
-// CHECK-LABEL: @tma_prefetch_2d
llvm.func @tma_prefetch_2d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %ch : i64) {
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
- nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] : !llvm.ptr
+ // CHECK-LABEL: define void @tma_prefetch_2d(ptr %0, i32 %1, i32 %2, i64 %3) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %1, i32 %2, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %1, i32 %2, i64 %3, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] {mode = #nvvm.tma_load_mode<tile>} : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] l2_cache_hint = %ch : !llvm.ptr
llvm.return
}
-// CHECK-LABEL: @tma_prefetch_3d
-llvm.func @tma_prefetch_3d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %off0 : i16, %ch : i64) {
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+llvm.func @tma_prefetch_3d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %off0 : i16, %off1 : i16, %ch : i64) {
+ // CHECK-LABEL: define void @tma_prefetch_3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 %6) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %1, i32 %2, i32 %3, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %1, i32 %2, i32 %3, i64 %6, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i64 %6, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 %6, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 %6, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] l2_cache_hint = %ch : !llvm.ptr
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
- nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] : !llvm.ptr
- nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] l2_cache_hint = %ch : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
+
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
llvm.return
}
-// CHECK-LABEL: @tma_prefetch_4d
llvm.func @tma_prefetch_4d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %off0 : i16, %off1 : i16, %ch : i64) {
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+ // CHECK-LABEL: define void @tma_prefetch_4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 %7) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i64 %7, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 %7, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 %7, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 %7, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] l2_cache_hint = %ch : !llvm.ptr
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
- nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] : !llvm.ptr
- nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] l2_cache_hint = %ch : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
+
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
llvm.return
}
-// CHECK-LABEL: @tma_prefetch_5d
llvm.func @tma_prefetch_5d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) {
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+ // CHECK-LABEL: define void @tma_prefetch_5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i16 %8, i64 %9) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i64 %9, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i16 %8, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i16 %8, i64 %9, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i64 %9, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i64 %9, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] l2_cache_hint = %ch : !llvm.ptr
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
- nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] : !llvm.ptr
- nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] l2_cache_hint = %ch : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
+
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
+ llvm.return
+}
+
+llvm.func @tma_prefetch_gather4_2d(%tma_desc : !llvm.ptr, %x0 : i32, %y1 : i32, %y2 : i32, %y3 : i32, %y4 : i32, %ch : i64) {
+ // CHECK-LABEL: define void @tma_prefetch_gather4_2d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i64 %6) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.gather4.2d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.gather4.2d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i64 %6, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%x0, %y1, %y2, %y3, %y4] {mode = #nvvm.tma_load_mode<tile_gather4>} : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%x0, %y1, %y2, %y3, %y4] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<tile_gather4>} : !llvm.ptr
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_prefetch_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_prefetch_invalid.mlir
new file mode 100644
index 0000000..23e47bd
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tma_prefetch_invalid.mlir
@@ -0,0 +1,56 @@
+// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
+
+llvm.func @tma_prefetch_0d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
+ // expected-error @below {{expects coordinates between 1 to 5 dimension}}
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[] : !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+llvm.func @tma_prefetch_2d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %off0 : i16, %ch : i64) {
+ // expected-error @below {{to use im2col mode, the tensor has to be at least 3-dimensional}}
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] im2col[%off0] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+llvm.func @tma_prefetch_5d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) {
+ // expected-error @below {{im2col offsets expected 3 (provided 2)}}
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+llvm.func @tma_prefetch_3d_im2col_w(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %off0 : i16) {
+ // expected-error @below {{im2col offsets expected 2 (provided 1)}}
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+llvm.func @tma_prefetch_4d_im2col_w_128(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %off0 : i16) {
+ // expected-error @below {{im2col offsets expected 2 (provided 1)}}
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0] {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+llvm.func @tma_prefetch_gather4_3d(%tma_desc : !llvm.ptr, %d0 : i32) {
+ // expected-error @below {{Gather4 mode expects 5 coordinates}}
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] {mode = #nvvm.tma_load_mode<tile_gather4>} : !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+llvm.func @tma_prefetch_gather4_2d(%tma_desc : !llvm.ptr, %x0 : i32, %y1 : i32, %y2 : i32, %y3 : i32, %y4 : i32, %off0 : i16, %ch : i64) {
+ // expected-error @below {{im2col offsets expected 0 (provided 1)}}
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%x0, %y1, %y2, %y3, %y4] im2col[%off0] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<tile_gather4>} : !llvm.ptr
+ llvm.return
+}
+
diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_store.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_store.mlir
new file mode 100644
index 0000000..b77927f
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tma_store.mlir
@@ -0,0 +1,94 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @tma_store_1d(%tma_desc: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %ch : i64) {
+ // CHECK-LABEL: define void @tma_store_1d(ptr %0, ptr addrspace(3) %1, i32 %2, i64 %3) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.s2g.tile.1d(ptr addrspace(3) %1, ptr %0, i32 %2, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.s2g.tile.1d(ptr addrspace(3) %1, ptr %0, i32 %2, i64 %3, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tma_desc, %src, box[%crd0] : !llvm.ptr, !llvm.ptr<3>
+
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tma_desc, %src, box[%crd0] l2_cache_hint=%ch : !llvm.ptr, !llvm.ptr<3>
+
+ llvm.return
+}
+
+llvm.func @tma_store_2d(%tma_desc: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %ch : i64) {
+ // CHECK-LABEL: define void @tma_store_2d(ptr %0, ptr addrspace(3) %1, i32 %2, i32 %3, i64 %4) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.s2g.tile.2d(ptr addrspace(3) %1, ptr %0, i32 %2, i32 %3, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.s2g.tile.2d(ptr addrspace(3) %1, ptr %0, i32 %2, i32 %3, i64 %4, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tma_desc, %src, box[%crd0,%crd1] : !llvm.ptr, !llvm.ptr<3>
+
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tma_desc, %src, box[%crd0,%crd1] l2_cache_hint=%ch : !llvm.ptr, !llvm.ptr<3>
+
+ llvm.return
+}
+
+llvm.func @tma_store_3d(%tma_desc: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %ch : i64) {
+ // CHECK-LABEL: define void @tma_store_3d(ptr %0, ptr addrspace(3) %1, i32 %2, i32 %3, i32 %4, i64 %5) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.s2g.tile.3d(ptr addrspace(3) %1, ptr %0, i32 %2, i32 %3, i32 %4, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.s2g.tile.3d(ptr addrspace(3) %1, ptr %0, i32 %2, i32 %3, i32 %4, i64 %5, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.s2g.im2col.3d(ptr addrspace(3) %1, ptr %0, i32 %2, i32 %3, i32 %4, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.s2g.im2col.3d(ptr addrspace(3) %1, ptr %0, i32 %2, i32 %3, i32 %4, i64 %5, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tma_desc, %src, box[%crd0,%crd1,%crd2] : !llvm.ptr, !llvm.ptr<3>
+
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tma_desc, %src, box[%crd0,%crd1,%crd2] l2_cache_hint=%ch : !llvm.ptr, !llvm.ptr<3>
+
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tma_desc, %src, box[%crd0,%crd1,%crd2] {mode = #nvvm.tma_store_mode<im2col>}: !llvm.ptr, !llvm.ptr<3>
+
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tma_desc, %src, box[%crd0,%crd1,%crd2] l2_cache_hint=%ch {mode = #nvvm.tma_store_mode<im2col>}: !llvm.ptr, !llvm.ptr<3>
+ llvm.return
+}
+
+llvm.func @tma_store_4d(%tma_desc: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %ch : i64) {
+ // CHECK-LABEL: define void @tma_store_4d(ptr %0, ptr addrspace(3) %1, i32 %2, i32 %3, i32 %4, i32 %5, i64 %6) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.s2g.tile.4d(ptr addrspace(3) %1, ptr %0, i32 %2, i32 %3, i32 %4, i32 %5, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.s2g.tile.4d(ptr addrspace(3) %1, ptr %0, i32 %2, i32 %3, i32 %4, i32 %5, i64 %6, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.s2g.im2col.4d(ptr addrspace(3) %1, ptr %0, i32 %2, i32 %3, i32 %4, i32 %5, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.s2g.im2col.4d(ptr addrspace(3) %1, ptr %0, i32 %2, i32 %3, i32 %4, i32 %5, i64 %6, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tma_desc, %src, box[%crd0,%crd1,%crd2,%crd3] : !llvm.ptr, !llvm.ptr<3>
+
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tma_desc, %src, box[%crd0,%crd1,%crd2,%crd3] l2_cache_hint=%ch : !llvm.ptr, !llvm.ptr<3>
+
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tma_desc, %src, box[%crd0,%crd1,%crd2,%crd3] {mode = #nvvm.tma_store_mode<im2col>}: !llvm.ptr, !llvm.ptr<3>
+
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tma_desc, %src, box[%crd0,%crd1,%crd2,%crd3] l2_cache_hint=%ch {mode = #nvvm.tma_store_mode<im2col>}: !llvm.ptr, !llvm.ptr<3>
+ llvm.return
+}
+
+llvm.func @tma_store_5d(%tma_desc: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %ch : i64) {
+ // CHECK-LABEL: define void @tma_store_5d(ptr %0, ptr addrspace(3) %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i64 %7) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.s2g.tile.5d(ptr addrspace(3) %1, ptr %0, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.s2g.tile.5d(ptr addrspace(3) %1, ptr %0, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i64 %7, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.s2g.im2col.5d(ptr addrspace(3) %1, ptr %0, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.s2g.im2col.5d(ptr addrspace(3) %1, ptr %0, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i64 %7, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tma_desc, %src, box[%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr, !llvm.ptr<3>
+
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tma_desc, %src, box[%crd0,%crd1,%crd2,%crd3,%crd4] l2_cache_hint=%ch : !llvm.ptr, !llvm.ptr<3>
+
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tma_desc, %src, box[%crd0,%crd1,%crd2,%crd3,%crd4] {mode = #nvvm.tma_store_mode<im2col>}: !llvm.ptr, !llvm.ptr<3>
+
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tma_desc, %src, box[%crd0,%crd1,%crd2,%crd3,%crd4] l2_cache_hint=%ch {mode = #nvvm.tma_store_mode<im2col>}: !llvm.ptr, !llvm.ptr<3>
+ llvm.return
+}
+
+llvm.func @tma_store_scatter(%tma_desc: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %ch : i64) {
+ // CHECK-LABEL: define void @tma_store_scatter(ptr %0, ptr addrspace(3) %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i64 %7) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.s2g.tile.scatter4.2d(ptr addrspace(3) %1, ptr %0, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.s2g.tile.scatter4.2d(ptr addrspace(3) %1, ptr %0, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i64 %7, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tma_desc, %src, box[%crd0,%crd1,%crd2,%crd3,%crd4] {mode = #nvvm.tma_store_mode<tile_scatter4>}: !llvm.ptr, !llvm.ptr<3>
+
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tma_desc, %src, box[%crd0,%crd1,%crd2,%crd3,%crd4] l2_cache_hint=%ch {mode = #nvvm.tma_store_mode<tile_scatter4>}: !llvm.ptr, !llvm.ptr<3>
+
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_store_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_store_invalid.mlir
new file mode 100644
index 0000000..9d9dc8e
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tma_store_invalid.mlir
@@ -0,0 +1,46 @@
+// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
+
+// -----
+
+llvm.func @tma_store_1d_im2col(%tma_desc: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %ch : i64) {
+ // expected-error @below {{to use im2col mode, the tensor has to be at least 3-dimensional}}
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tma_desc, %src, box[%crd0] {mode = #nvvm.tma_store_mode<im2col>} : !llvm.ptr, !llvm.ptr<3>
+
+ llvm.return
+}
+
+// -----
+
+llvm.func @tma_store_0d(%tma_desc: !llvm.ptr, %src : !llvm.ptr<3>) {
+ // expected-error @below {{expects coordinates between 1 to 5 dimension}}
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tma_desc, %src, box[] : !llvm.ptr, !llvm.ptr<3>
+
+ llvm.return
+}
+
+// -----
+
+llvm.func @tma_store_scatter(%tma_desc: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %ch : i64) {
+ // expected-error @below {{Scatter4 mode expects 5 coordinates}}
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tma_desc, %src, box[%crd0,%crd1,%crd2,%crd3] l2_cache_hint=%ch {mode = #nvvm.tma_store_mode<tile_scatter4>}: !llvm.ptr, !llvm.ptr<3>
+
+ llvm.return
+}
+
+// -----
+
+llvm.func @tma_store_asm_ch(%tma_desc: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %ch : i64, %p : i1) {
+ // expected-error @below {{Inline-ptx lowering unsupported with L2 cache-hint.}}
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tma_desc, %src, box[%crd0] l2_cache_hint=%ch, predicate=%p : !llvm.ptr, !llvm.ptr<3>
+
+ llvm.return
+}
+
+// -----
+
+llvm.func @tma_store_asm_im2col(%tma_desc: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %ch : i64, %p : i1) {
+ // expected-error @below {{Inline-ptx lowering supported only for Tile mode.}}
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tma_desc, %src, box[%crd0, %crd1, %crd2], predicate=%p {mode = #nvvm.tma_store_mode<im2col>} : !llvm.ptr, !llvm.ptr<3>
+
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_store_reduce_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_store_reduce_invalid.mlir
new file mode 100644
index 0000000..2fcf00f
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tma_store_reduce_invalid.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
+
+// -----
+
+llvm.func @tma_reduce_0d(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr, %ch : i64) {
+ // expected-error @below {{expects coordinates between 1 to 5 dimension}}
+ nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[] {redKind = #nvvm.tma_redux_kind<add>}: !llvm.ptr, !llvm.ptr<3>
+ llvm.return
+}
+
+// -----
+
+llvm.func @tma_reduce_2d_im2col(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %ch : i64) {
+ // expected-error @below {{to use im2col mode, the tensor has to be at least 3-dimensional}}
+ nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1] {redKind = #nvvm.tma_redux_kind<and>, mode = #nvvm.tma_store_mode<im2col>}: !llvm.ptr, !llvm.ptr<3>
+ llvm.return
+}
+
+// -----
+
+llvm.func @tma_store_reduce_scatter(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
+ // expected-error @below {{Scatter mode unsupported for CpAsyncBulkTensorReduceOp}}
+ nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0] {redKind = #nvvm.tma_redux_kind<add>, mode = #nvvm.tma_store_mode<tile_scatter4>} : !llvm.ptr, !llvm.ptr<3>
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 85478cc..b35a6db 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -1,5 +1,24 @@
// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
+llvm.func @pmevent_no_id() {
+ // expected-error @below {{either `id` or `mask` must be set}}
+ nvvm.pmevent
+}
+
+// -----
+
+llvm.func @pmevent_bigger15() {
+ // expected-error @below {{`id` must be between 0 and 15}}
+ nvvm.pmevent id = 141
+}
+
+// -----
+
+llvm.func @pmevent_many_ids() {
+ // expected-error @below {{`id` and `mask` cannot be set at the same time}}
+ nvvm.pmevent id = 1 mask = 1
+}
+
// -----
llvm.func @kernel_func(%numberOfThreads : i32) {
@@ -37,73 +56,49 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 3, 4
// -----
-llvm.func @nvvm_fence_proxy_acquire(%addr : !llvm.ptr, %size : i32) {
- // expected-error @below {{'nvvm.fence.proxy.acquire' op uni-directional proxies only support generic for from_proxy attribute}}
- nvvm.fence.proxy.acquire #nvvm.mem_scope<cta> %addr, %size from_proxy=#nvvm.proxy_kind<tensormap> to_proxy=#nvvm.proxy_kind<generic>
+// expected-error @below {{'"nvvm.blocksareclusters"' attribute must be used along with 'nvvm.reqntid' and 'nvvm.cluster_dim'}}
+llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.blocksareclusters,
+ nvvm.cluster_dim = array<i32: 3, 5, 7>} {
llvm.return
}
// -----
-llvm.func @nvvm_fence_proxy_release() {
- // expected-error @below {{'nvvm.fence.proxy.release' op uni-directional proxies only support generic for from_proxy attribute}}
- nvvm.fence.proxy.release #nvvm.mem_scope<cta> from_proxy=#nvvm.proxy_kind<tensormap> to_proxy=#nvvm.proxy_kind<generic>
+// expected-error @below {{'"nvvm.blocksareclusters"' attribute must be used along with 'nvvm.reqntid' and 'nvvm.cluster_dim'}}
+llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.blocksareclusters,
+ nvvm.reqntid = array<i32: 1, 23, 32>} {
llvm.return
}
// -----
llvm.func @nvvm_fence_proxy_acquire(%addr : !llvm.ptr, %size : i32) {
- // expected-error @below {{'nvvm.fence.proxy.acquire' op uni-directional proxies only support tensormap for to_proxy attribute}}
- nvvm.fence.proxy.acquire #nvvm.mem_scope<cta> %addr, %size from_proxy=#nvvm.proxy_kind<generic> to_proxy=#nvvm.proxy_kind<generic>
+ // expected-error @below {{'nvvm.fence.proxy.acquire' op uni-directional proxies only support generic for from_proxy attribute}}
+ nvvm.fence.proxy.acquire #nvvm.mem_scope<cta> %addr, %size from_proxy=#nvvm.proxy_kind<tensormap> to_proxy=#nvvm.proxy_kind<generic>
llvm.return
}
// -----
llvm.func @nvvm_fence_proxy_release() {
- // expected-error @below {{'nvvm.fence.proxy.release' op uni-directional proxies only support tensormap for to_proxy attribute}}
- nvvm.fence.proxy.release #nvvm.mem_scope<cta> from_proxy=#nvvm.proxy_kind<generic> to_proxy=#nvvm.proxy_kind<generic>
- llvm.return
-}
-
-// -----
-
-llvm.func @tma_prefetch_0d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
- // expected-error @below {{expects coordinates between 1 to 5 dimension}}
- nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[] : !llvm.ptr
- llvm.return
-}
-
-// -----
-
-llvm.func @tma_prefetch_2d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %off0 : i16, %ch : i64) {
- // expected-error @below {{to use im2col mode, the tensor has to be at least 3-dimensional}}
- nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] im2col[%off0] l2_cache_hint = %ch : !llvm.ptr
- llvm.return
-}
-
-// -----
-
-llvm.func @tma_prefetch_5d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) {
- // expected-error @below {{im2col offsets must be 2 less than number of coordinates}}
- nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] : !llvm.ptr
+ // expected-error @below {{'nvvm.fence.proxy.release' op uni-directional proxies only support generic for from_proxy attribute}}
+ nvvm.fence.proxy.release #nvvm.mem_scope<cta> from_proxy=#nvvm.proxy_kind<tensormap> to_proxy=#nvvm.proxy_kind<generic>
llvm.return
}
// -----
-llvm.func @tma_reduce_0d(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr, %ch : i64) {
- // expected-error @below {{expects coordinates between 1 to 5 dimension}}
- nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[] {redKind = #nvvm.tma_redux_kind<add>}: !llvm.ptr, !llvm.ptr<3>
+llvm.func @nvvm_fence_proxy_acquire(%addr : !llvm.ptr, %size : i32) {
+ // expected-error @below {{'nvvm.fence.proxy.acquire' op uni-directional proxies only support tensormap for to_proxy attribute}}
+ nvvm.fence.proxy.acquire #nvvm.mem_scope<cta> %addr, %size from_proxy=#nvvm.proxy_kind<generic> to_proxy=#nvvm.proxy_kind<generic>
llvm.return
}
// -----
-llvm.func @tma_reduce_2d_im2col(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %ch : i64) {
- // expected-error @below {{to use im2col mode, the tensor has to be at least 3-dimensional}}
- nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1] {redKind = #nvvm.tma_redux_kind<and>, mode = #nvvm.tma_store_mode<im2col>}: !llvm.ptr, !llvm.ptr<3>
+llvm.func @nvvm_fence_proxy_release() {
+ // expected-error @below {{'nvvm.fence.proxy.release' op uni-directional proxies only support tensormap for to_proxy attribute}}
+ nvvm.fence.proxy.release #nvvm.mem_scope<cta> from_proxy=#nvvm.proxy_kind<generic> to_proxy=#nvvm.proxy_kind<generic>
llvm.return
}
@@ -253,7 +248,7 @@ llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) {
llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
// expected-error @below {{cache eviction priority supported only for cache level L2}}
- nvvm.prefetch level = L1, %global_ptr, evict_priority = evict_last : !llvm.ptr<1>
+ nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>
llvm.return
}
@@ -261,7 +256,7 @@ llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
llvm.func @nvvm_prefetch_L2_with_evict_last_invalid_addr_space(%local_ptr: !llvm.ptr<5>) {
// expected-error @below {{cache eviction priority requires a global pointer}}
- nvvm.prefetch level = L2, %local_ptr, evict_priority = evict_last : !llvm.ptr<5>
+ nvvm.prefetch level = L2, evict_priority = evict_last, %local_ptr : !llvm.ptr<5>
llvm.return
}
@@ -269,7 +264,7 @@ llvm.func @nvvm_prefetch_L2_with_evict_last_invalid_addr_space(%local_ptr: !llvm
llvm.func @nvvm_prefetch_L2_with_evict_normal_invalid_addr_space(%local_ptr: !llvm.ptr<5>) {
// expected-error @below {{cache eviction priority requires a global pointer}}
- nvvm.prefetch level = L2, %local_ptr, evict_priority = evict_normal : !llvm.ptr<5>
+ nvvm.prefetch level = L2, evict_priority = evict_normal, %local_ptr : !llvm.ptr<5>
llvm.return
}
@@ -277,7 +272,7 @@ llvm.func @nvvm_prefetch_L2_with_evict_normal_invalid_addr_space(%local_ptr: !ll
llvm.func @nvvm_prefetch_L2_with_invalid_evict_first(%global_ptr: !llvm.ptr<1>) {
// expected-error @below {{unsupported cache eviction priority, only evict_last and evict_normal are supported}}
- nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_first : !llvm.ptr<1>
+ nvvm.prefetch level = L2, evict_priority = evict_first, %global_ptr : !llvm.ptr<1>
llvm.return
}
@@ -285,7 +280,7 @@ llvm.func @nvvm_prefetch_L2_with_invalid_evict_first(%global_ptr: !llvm.ptr<1>)
llvm.func @nvvm_prefetch_L2_with_invalid_evict_unchanged(%global_ptr: !llvm.ptr<1>) {
// expected-error @below {{unsupported cache eviction priority, only evict_last and evict_normal are supported}}
- nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_unchanged : !llvm.ptr<1>
+ nvvm.prefetch level = L2, evict_priority = evict_unchanged, %global_ptr : !llvm.ptr<1>
llvm.return
}
@@ -293,7 +288,7 @@ llvm.func @nvvm_prefetch_L2_with_invalid_evict_unchanged(%global_ptr: !llvm.ptr<
llvm.func @nvvm_prefetch_L2_with_invalid_no_allocate(%global_ptr: !llvm.ptr<1>) {
// expected-error @below {{unsupported cache eviction priority, only evict_last and evict_normal are supported}}
- nvvm.prefetch level = L2, %global_ptr, evict_priority = no_allocate : !llvm.ptr<1>
+ nvvm.prefetch level = L2, evict_priority = no_allocate, %global_ptr : !llvm.ptr<1>
llvm.return
}
@@ -315,6 +310,62 @@ llvm.func @nvvm_prefetch_uniform_with_invalid_addr_space(%global_ptr: !llvm.ptr<
// -----
+llvm.func @nvvm_prefetch_both_tensormap_and_cache_level(%gen_ptr: !llvm.ptr) {
+ // expected-error @below {{cannot specify both tensormap and cache level}}
+ nvvm.prefetch level = L1, tensormap, %gen_ptr : !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_prefetch_tensormap_invalid_addr_space(%global_ptr: !llvm.ptr<1>) {
+ // expected-error @below {{prefetch tensormap requires a generic or constant pointer}}
+ nvvm.prefetch tensormap, %global_ptr : !llvm.ptr<1>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_prefetch_tensormap_with_evict_priority(%gen_ptr: !llvm.ptr) {
+ // expected-error @below {{prefetch tensormap does not support eviction priority}}
+ nvvm.prefetch tensormap, evict_priority = evict_last, %gen_ptr : !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_prefetch_tensormap_in_param_space_non_generic(%const_ptr: !llvm.ptr<4>) {
+ // expected-error @below {{in_param_space can only be specified for a generic pointer}}
+ nvvm.prefetch tensormap in_param_space, %const_ptr : !llvm.ptr<4>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_prefetch_cache_level_invalid_addr_space(%const_ptr: !llvm.ptr<4>) {
+ // expected-error @below {{prefetch to cache level requires a generic, global, or local pointer}}
+ nvvm.prefetch level = L1, %const_ptr : !llvm.ptr<4>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_prefetch_predicate_without_tensormap(%gen_ptr: !llvm.ptr, %pred: i1) {
+ // expected-error @below {{predicate supported only on prefetch tensormap}}
+ nvvm.prefetch level = L1, %gen_ptr, predicate = %pred : !llvm.ptr, i1
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_prefetch_no_level_or_tensormap(%gen_ptr: !llvm.ptr) {
+ // expected-error @below {{requires specification of either cache level or tensormap}}
+ nvvm.prefetch %gen_ptr : !llvm.ptr
+ llvm.return
+}
+
+// -----
+
llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
// expected-error@+1 {{'nvvm.stmatrix' op expected num attribute to be 1, 2 or 4}}
nvvm.stmatrix %arg0, %r1, %r2, %r3 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32, i32
@@ -351,3 +402,136 @@ llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32
llvm.return
}
+
+// -----
+
+llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
+ // expected-error@+1 {{'nvvm.stmatrix' op expected num attribute to be 1, 2 or 4}}
+ nvvm.stmatrix %arg0, %r1, %r2, %r3 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32, i32
+ llvm.return
+}
+
+// -----
+
+llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
+ // expected-error@+1 {{'nvvm.stmatrix' op expected shape to be 8x8 or 16x8}}
+ nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32
+ llvm.return
+}
+
+// -----
+
+llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
+ // expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B16 for 8x8 matrix}}
+ nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32
+ llvm.return
+}
+// -----
+
+llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
+ // expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B8 for 16x8 matrix}}
+ nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32
+ llvm.return
+}
+
+// -----
+
+llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
+ // expected-error@+1 {{'nvvm.stmatrix' op expected layout to be col for 16x8 matrix}}
+ nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32
+ llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+ // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4 for 8x8 matrix}}
+ %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> i32
+ llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+ // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is i32}}
+ %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
+ llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+ // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}}
+ %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+ // expected-error@+1 {{'nvvm.ldmatrix' op expected element type to be b16 for 8x8 matrix}}
+ %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : (!llvm.ptr<3>) -> i32
+ llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+ // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4 for 8x16 matrix}}
+ %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32)>
+ llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+ // expected-error@+1 {{'nvvm.ldmatrix' op expected layout to be row for 8x16 matrix}}
+ %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> i32
+ llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+ // expected-error@+1 {{'nvvm.ldmatrix' op expected element type to be b8x16.b4x16_p64 or b8x16.b6x16_p32 for 8x16 matrix}}
+ %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : (!llvm.ptr<3>) -> i32
+ llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+ // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1 or 2 for 16x16 matrix}}
+ %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+ llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+ // expected-error@+1 {{'nvvm.ldmatrix' op expected layout to be col for 16x16 matrix}}
+ %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+ llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+ // expected-error@+1 {{'nvvm.ldmatrix' op expected element type to be b8, b8x16.b4x16_p64 or b8x16.b6x16_p32 for 16x16 matrix}}
+ %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> i32
+ llvm.return
+}
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+ // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}}
+ %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : (!llvm.ptr<3>) -> i32
+ llvm.return
+}
+
+// -----
+
+llvm.func @nanosleep() {
+ // expected-error@+1 {{integer constant out of range for attribute}}
+ nvvm.nanosleep 100000000000000
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 5c2cfa4..62aeb07 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -64,92 +64,94 @@ llvm.func @nvvm_special_regs() -> i32 {
%30 = nvvm.read.ptx.sreg.clock64 : i64
// CHECK: call i64 @llvm.nvvm.read.ptx.sreg.globaltimer
%31 = nvvm.read.ptx.sreg.globaltimer : i64
- // CHECK: %32 = call range(i32 0, 64) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
- %32 = nvvm.read.ptx.sreg.tid.x range <i32, 0, 64> : i32
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.globaltimer.lo()
+ %32 = nvvm.read.ptx.sreg.globaltimer.lo : i32
+ // CHECK: %33 = call range(i32 0, 64) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+ %33 = nvvm.read.ptx.sreg.tid.x range <i32, 0, 64> : i32
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.warpid
- %33 = nvvm.read.ptx.sreg.warpid : i32
+ %34 = nvvm.read.ptx.sreg.warpid : i32
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.nwarpid
- %34 = nvvm.read.ptx.sreg.nwarpid : i32
+ %35 = nvvm.read.ptx.sreg.nwarpid : i32
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.smid
- %35 = nvvm.read.ptx.sreg.smid : i32
+ %36 = nvvm.read.ptx.sreg.smid : i32
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.nsmid
- %36 = nvvm.read.ptx.sreg.nsmid : i32
+ %37 = nvvm.read.ptx.sreg.nsmid : i32
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.gridid
- %37 = nvvm.read.ptx.sreg.gridid : i32
+ %38 = nvvm.read.ptx.sreg.gridid : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg0
- %38 = nvvm.read.ptx.sreg.envreg0 : i32
+ %39 = nvvm.read.ptx.sreg.envreg0 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg1
- %39 = nvvm.read.ptx.sreg.envreg1 : i32
+ %40 = nvvm.read.ptx.sreg.envreg1 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg2
- %40 = nvvm.read.ptx.sreg.envreg2 : i32
+ %41 = nvvm.read.ptx.sreg.envreg2 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg3
- %41 = nvvm.read.ptx.sreg.envreg3 : i32
+ %42 = nvvm.read.ptx.sreg.envreg3 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg4
- %42 = nvvm.read.ptx.sreg.envreg4 : i32
+ %43 = nvvm.read.ptx.sreg.envreg4 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg5
- %43 = nvvm.read.ptx.sreg.envreg5 : i32
+ %44 = nvvm.read.ptx.sreg.envreg5 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg6
- %44 = nvvm.read.ptx.sreg.envreg6 : i32
+ %45 = nvvm.read.ptx.sreg.envreg6 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg7
- %45 = nvvm.read.ptx.sreg.envreg7 : i32
+ %46 = nvvm.read.ptx.sreg.envreg7 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg8
- %46 = nvvm.read.ptx.sreg.envreg8 : i32
+ %47 = nvvm.read.ptx.sreg.envreg8 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg9
- %47 = nvvm.read.ptx.sreg.envreg9 : i32
+ %48 = nvvm.read.ptx.sreg.envreg9 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg10
- %48 = nvvm.read.ptx.sreg.envreg10 : i32
+ %49 = nvvm.read.ptx.sreg.envreg10 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg11
- %49 = nvvm.read.ptx.sreg.envreg11 : i32
+ %50 = nvvm.read.ptx.sreg.envreg11 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg12
- %50 = nvvm.read.ptx.sreg.envreg12 : i32
+ %51 = nvvm.read.ptx.sreg.envreg12 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg13
- %51 = nvvm.read.ptx.sreg.envreg13 : i32
+ %52 = nvvm.read.ptx.sreg.envreg13 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg14
- %52 = nvvm.read.ptx.sreg.envreg14 : i32
+ %53 = nvvm.read.ptx.sreg.envreg14 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg15
- %53 = nvvm.read.ptx.sreg.envreg15 : i32
+ %54 = nvvm.read.ptx.sreg.envreg15 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg16
- %54 = nvvm.read.ptx.sreg.envreg16 : i32
+ %55 = nvvm.read.ptx.sreg.envreg16 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg17
- %55 = nvvm.read.ptx.sreg.envreg17 : i32
+ %56 = nvvm.read.ptx.sreg.envreg17 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg18
- %56 = nvvm.read.ptx.sreg.envreg18 : i32
+ %57 = nvvm.read.ptx.sreg.envreg18 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg19
- %57 = nvvm.read.ptx.sreg.envreg19 : i32
+ %58 = nvvm.read.ptx.sreg.envreg19 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg20
- %58 = nvvm.read.ptx.sreg.envreg20 : i32
+ %59 = nvvm.read.ptx.sreg.envreg20 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg21
- %59 = nvvm.read.ptx.sreg.envreg21 : i32
+ %60 = nvvm.read.ptx.sreg.envreg21 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg22
- %60 = nvvm.read.ptx.sreg.envreg22 : i32
+ %61 = nvvm.read.ptx.sreg.envreg22 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg23
- %61 = nvvm.read.ptx.sreg.envreg23 : i32
+ %62 = nvvm.read.ptx.sreg.envreg23 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg24
- %62 = nvvm.read.ptx.sreg.envreg24 : i32
+ %63 = nvvm.read.ptx.sreg.envreg24 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg25
- %63 = nvvm.read.ptx.sreg.envreg25 : i32
+ %64 = nvvm.read.ptx.sreg.envreg25 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg26
- %64 = nvvm.read.ptx.sreg.envreg26 : i32
+ %65 = nvvm.read.ptx.sreg.envreg26 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg27
- %65 = nvvm.read.ptx.sreg.envreg27 : i32
+ %66 = nvvm.read.ptx.sreg.envreg27 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg28
- %66 = nvvm.read.ptx.sreg.envreg28 : i32
+ %67 = nvvm.read.ptx.sreg.envreg28 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg29
- %67 = nvvm.read.ptx.sreg.envreg29 : i32
+ %68 = nvvm.read.ptx.sreg.envreg29 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg30
- %68 = nvvm.read.ptx.sreg.envreg30 : i32
+ %69 = nvvm.read.ptx.sreg.envreg30 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg31
- %69 = nvvm.read.ptx.sreg.envreg31 : i32
+ %70 = nvvm.read.ptx.sreg.envreg31 : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.eq
- %70 = nvvm.read.ptx.sreg.lanemask.eq : i32
+ %71 = nvvm.read.ptx.sreg.lanemask.eq : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.le
- %71 = nvvm.read.ptx.sreg.lanemask.le : i32
+ %72 = nvvm.read.ptx.sreg.lanemask.le : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.lt
- %72 = nvvm.read.ptx.sreg.lanemask.lt : i32
+ %73 = nvvm.read.ptx.sreg.lanemask.lt : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.ge
- %73 = nvvm.read.ptx.sreg.lanemask.ge : i32
+ %74 = nvvm.read.ptx.sreg.lanemask.ge : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.gt
- %74 = nvvm.read.ptx.sreg.lanemask.gt : i32
+ %75 = nvvm.read.ptx.sreg.lanemask.gt : i32
llvm.return %1 : i32
}
@@ -559,17 +561,47 @@ llvm.func @llvm_nvvm_cp_async_bulk_wait_group() {
// CHECK-LABEL: @ld_matrix
llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
// CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}})
- %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> i32
+ %l1 = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> i32
// CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}})
- %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}})
- %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
- // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}})
- %l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>} : (!llvm.ptr<3>) -> i32
+ %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+ // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}})
+ %l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> i32
// CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}})
- %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}})
- %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<col>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+ %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+ // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+ %m8n16_b6_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> i32
+ // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+ %m8n16_b6_l2 = nvvm.ldmatrix %arg0 {num = 2: i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+ %m8n16_b6_l4 = nvvm.ldmatrix %arg0{num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>,eltType =#nvvm.ld_st_matrix_elt_type<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+ // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+ %m8n16_b4_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> i32
+ // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+ %m8n16_b4_l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+ %m8n16_b4_l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+ // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8.p3(ptr addrspace(3) %{{.*}})
+ %m16n16_l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8.p3(ptr addrspace(3) %{{.*}})
+ %m16n16_l2t = nvvm.ldmatrix %arg0{num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>,eltType =#nvvm.ld_st_matrix_elt_type<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+ // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+ %m16n16_b6_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+ %m16n16_b6_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+ // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+ %m16n16_b4_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+ %m16n16_b4_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>,eltType =#nvvm.ld_st_matrix_elt_type<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
llvm.return
}
@@ -662,21 +694,24 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 1, 2
// CHECK: define ptx_kernel void @kernel_func() #[[ATTR0:[0-9]+]]
// CHECK: attributes #[[ATTR0]] = { "nvvm.maxnreg"="32" "nvvm.maxntid"="1,23,32" "nvvm.minctasm"="16" }
+// -----
+llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.blocksareclusters,
+ nvvm.reqntid = array<i32: 1, 23, 32>,
+ nvvm.cluster_dim = array<i32: 3, 5, 7>} {
+ llvm.return
+}
+
+// CHECK: define ptx_kernel void @kernel_func() #[[ATTR0:[0-9]+]]
+// CHECK: attributes #[[ATTR0]] = { "nvvm.blocksareclusters" "nvvm.cluster_dim"="3,5,7" "nvvm.reqntid"="1,23,32" }
// -----
-// CHECK: define ptx_kernel void @kernel_func
-// CHECK: !nvvm.annotations =
-// CHECK: !{{.*}} = !{ptr @kernel_func, !"grid_constant", ![[ID:[[:alnum:]]+]]}
-// CHECK: ![[ID]] = !{i32 1}
+// CHECK: define ptx_kernel void @kernel_func(ptr byval(i32) "nvvm.grid_constant" %0)
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) attributes {nvvm.kernel} {
llvm.return
}
// -----
-// CHECK: define ptx_kernel void @kernel_func
-// CHECK: !nvvm.annotations =
-// CHECK: !{{.*}} = !{ptr @kernel_func, !"grid_constant", ![[ID:[[:alnum:]]+]]}
-// CHECK: ![[ID]] = !{i32 1, i32 3}
+// CHECK: define ptx_kernel void @kernel_func(ptr byval(i32) "nvvm.grid_constant" %0, float %1, ptr byval(float) "nvvm.grid_constant" %2)
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}, %arg1: f32, %arg2: !llvm.ptr {llvm.byval = f32, nvvm.grid_constant}) attributes {nvvm.kernel} {
llvm.return
}
@@ -766,7 +801,7 @@ llvm.func @nvvm_wgmma_wait_group_aligned() {
// CHECK-LABEL: @nvvm_griddepcontrol_wait
llvm.func @nvvm_griddepcontrol_wait() {
// CHECK: call void @llvm.nvvm.griddepcontrol.wait()
- nvvm.griddepcontrol.wait
+ nvvm.griddepcontrol wait
llvm.return
}
@@ -774,7 +809,7 @@ llvm.func @nvvm_griddepcontrol_wait() {
// CHECK-LABEL: @nvvm_griddepcontrol_launch_dependents
llvm.func @nvvm_griddepcontrol_launch_dependents() {
// CHECK: call void @llvm.nvvm.griddepcontrol.launch.dependents()
- nvvm.griddepcontrol.launch.dependents
+ nvvm.griddepcontrol launch_dependents
llvm.return
}
@@ -783,8 +818,8 @@ llvm.func @nvvm_griddepcontrol_launch_dependents() {
llvm.func @nvvm_mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
// CHECK-LLVM: call ptr @llvm.nvvm.mapa(ptr %{{.*}}, i32 %{{.*}})
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr
- // CHECK-LLVM: call ptr addrspace(3) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
- %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
+ // CHECK-LLVM: call ptr addrspace(7) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
+ %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<7>
llvm.return
}
@@ -918,3 +953,23 @@ llvm.func @nvvm_dot_accumulate_2way(%a: vector<2xi16>, %b: vector<4xi8>, %c: i32
%7 = nvvm.dot.accumulate.2way %a <signed>, %b <signed>, %c {b_hi = true}: vector<2xi16>, vector<4xi8>
llvm.return
}
+
+// -----
+
+// CHECK-LABEL: @nvvm_pmevent
+llvm.func @nvvm_pmevent() {
+ // CHECK: call void @llvm.nvvm.pm.event.mask(i16 15000)
+ nvvm.pmevent mask = 15000
+ // CHECK: call void @llvm.nvvm.pm.event.mask(i16 4)
+ nvvm.pmevent mask = 4
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @nanosleep
+llvm.func @nanosleep() {
+ // CHECK: call void @llvm.nvvm.nanosleep(i32 4000)
+ nvvm.nanosleep 4000
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/ompenmp-target-allocmem-freemem.mlir b/mlir/test/Target/LLVMIR/ompenmp-target-allocmem-freemem.mlir
new file mode 100644
index 0000000..1bc9760
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/ompenmp-target-allocmem-freemem.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt %s -convert-openmp-to-llvm | mlir-translate -mlir-to-llvmir | FileCheck %s
+
+// This file contains MLIR test cases for omp.target_allocmem and omp.target_freemem
+
+// CHECK-LABEL: test_alloc_free_i64
+// CHECK: %[[ALLOC:.*]] = call ptr @omp_target_alloc(i64 8, i32 0)
+// CHECK: %[[PTRTOINT:.*]] = ptrtoint ptr %[[ALLOC]] to i64
+// CHECK: %[[INTTOPTR:.*]] = inttoptr i64 %[[PTRTOINT]] to ptr
+// CHECK: call void @omp_target_free(ptr %[[INTTOPTR]], i32 0)
+// CHECK: ret void
+llvm.func @test_alloc_free_i64() -> () {
+ %device = llvm.mlir.constant(0 : i32) : i32
+ %1 = omp.target_allocmem %device : i32, i64
+ omp.target_freemem %device, %1 : i32, i64
+ llvm.return
+}
+
+// CHECK-LABEL: test_alloc_free_vector_1d_f32
+// CHECK: %[[ALLOC:.*]] = call ptr @omp_target_alloc(i64 64, i32 0)
+// CHECK: %[[PTRTOINT:.*]] = ptrtoint ptr %[[ALLOC]] to i64
+// CHECK: %[[INTTOPTR:.*]] = inttoptr i64 %[[PTRTOINT]] to ptr
+// CHECK: call void @omp_target_free(ptr %[[INTTOPTR]], i32 0)
+// CHECK: ret void
+llvm.func @test_alloc_free_vector_1d_f32() -> () {
+ %device = llvm.mlir.constant(0 : i32) : i32
+ %1 = omp.target_allocmem %device : i32, vector<16xf32>
+ omp.target_freemem %device, %1 : i32, i64
+ llvm.return
+}
+
+// CHECK-LABEL: test_alloc_free_vector_2d_f32
+// CHECK: %[[ALLOC:.*]] = call ptr @omp_target_alloc(i64 1024, i32 0)
+// CHECK: %[[PTRTOINT:.*]] = ptrtoint ptr %[[ALLOC]] to i64
+// CHECK: %[[INTTOPTR:.*]] = inttoptr i64 %[[PTRTOINT]] to ptr
+// CHECK: call void @omp_target_free(ptr %[[INTTOPTR]], i32 0)
+// CHECK: ret void
+llvm.func @test_alloc_free_vector_2d_f32() -> () {
+ %device = llvm.mlir.constant(0 : i32) : i32
+ %1 = omp.target_allocmem %device : i32, vector<16x16xf32>
+ omp.target_freemem %device, %1 : i32, i64
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/omptarget-atomic-capture-control-options.mlir b/mlir/test/Target/LLVMIR/omptarget-atomic-capture-control-options.mlir
new file mode 100644
index 0000000..3553907
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-atomic-capture-control-options.mlir
@@ -0,0 +1,44 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK: atomicrmw add ptr %loadgep_, i32 1 monotonic, align 4, !amdgpu.no.remote.memory !{{.*}}
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<!llvm.ptr = dense<64> : vector<4xi64>, !llvm.ptr<1> = dense<64> : vector<4xi64>, !llvm.ptr<2> = dense<32> : vector<4xi64>, !llvm.ptr<3> = dense<32> : vector<4xi64>, !llvm.ptr<4> = dense<64> : vector<4xi64>, !llvm.ptr<5> = dense<32> : vector<4xi64>, !llvm.ptr<6> = dense<32> : vector<4xi64>, !llvm.ptr<7> = dense<[160, 256, 256, 32]> : vector<4xi64>, !llvm.ptr<8> = dense<[128, 128, 128, 48]> : vector<4xi64>, !llvm.ptr<9> = dense<[192, 256, 256, 32]> : vector<4xi64>, i64 = dense<64> : vector<2xi64>, i1 = dense<8> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i16 = dense<16> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, f64 = dense<64> : vector<2xi64>, f128 = dense<128> : vector<2xi64>, "dlti.endianness" = "little", "dlti.legal_int_widths" = array<i32: 32, 64>, "dlti.stack_alignment" = 32 : i64, "dlti.alloca_memory_space" = 5 : ui64, "dlti.global_memory_space" = 1 : ui64>, fir.atomic_fine_grained_memory, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", fir.target_cpu = "generic-hsa", llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9", llvm.target_triple = "amdgcn-amd-amdhsa", omp.flags = #omp.flags<openmp_device_version = 31>, omp.is_gpu = true, omp.is_target_device = true, omp.requires = #omp<clause_requires none>, omp.target_triples = [], omp.version = #omp.version<version = 31>} {
+ llvm.func @_QQmain() attributes {fir.bindc_name = "TEST", omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>, target_cpu = "generic-hsa"} {
+ %0 = llvm.mlir.constant(1 : i64) : i64
+ %1 = llvm.alloca %0 x i32 {bindc_name = "threads"} : (i64) -> !llvm.ptr<5>
+ %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr
+ %3 = llvm.mlir.constant(1 : i64) : i64
+ %4 = llvm.alloca %3 x i32 {bindc_name = "capture"} : (i64) -> !llvm.ptr<5>
+ %5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr
+ %6 = llvm.mlir.constant(1 : i64) : i64
+ %7 = llvm.alloca %6 x i32 {bindc_name = "a"} : (i64) -> !llvm.ptr<5>
+ %8 = llvm.addrspacecast %7 : !llvm.ptr<5> to !llvm.ptr
+ %9 = llvm.mlir.constant(0 : i32) : i32
+ %10 = llvm.mlir.constant(128 : i32) : i32
+ %11 = llvm.mlir.constant(1 : i64) : i64
+ %12 = llvm.mlir.constant(1 : i64) : i64
+ %13 = llvm.mlir.constant(1 : i64) : i64
+ llvm.store %10, %2 : i32, !llvm.ptr
+ llvm.store %9, %8 : i32, !llvm.ptr
+ %14 = omp.map.info var_ptr(%2 : !llvm.ptr, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr {name = "threads"}
+ %15 = omp.map.info var_ptr(%5 : !llvm.ptr, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr {name = "capture"}
+ %16 = omp.map.info var_ptr(%8 : !llvm.ptr, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr {name = "a"}
+ omp.target map_entries(%14 -> %arg0, %15 -> %arg1, %16 -> %arg2 : !llvm.ptr, !llvm.ptr, !llvm.ptr) {
+ %17 = llvm.mlir.constant(1 : i32) : i32
+ %18 = llvm.load %arg0 : !llvm.ptr -> i32
+ omp.parallel num_threads(%18 : i32) {
+ omp.atomic.capture {
+ omp.atomic.read %arg1 = %arg2 : !llvm.ptr, !llvm.ptr, i32
+ omp.atomic.update %arg2 : !llvm.ptr {
+ ^bb0(%arg3: i32):
+ %19 = llvm.add %arg3, %17 : i32
+ omp.yield(%19 : i32)
+ } {atomic_control = #omp.atomic_control<fine_grained_memory = true>}
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ llvm.return
+ }
+}
diff --git a/mlir/test/Target/LLVMIR/omptarget-atomic-update-control-options.mlir b/mlir/test/Target/LLVMIR/omptarget-atomic-update-control-options.mlir
new file mode 100644
index 0000000..3b0005b
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-atomic-update-control-options.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK: atomicrmw add ptr %loadgep_, i32 1 monotonic, align 4, !amdgpu.ignore.denormal.mode !{{.*}}, !amdgpu.no.fine.grained.memory !{{.*}}, !amdgpu.no.remote.memory !{{.*}}
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<!llvm.ptr = dense<64> : vector<4xi64>, !llvm.ptr<1> = dense<64> : vector<4xi64>, !llvm.ptr<2> = dense<32> : vector<4xi64>, !llvm.ptr<3> = dense<32> : vector<4xi64>, !llvm.ptr<4> = dense<64> : vector<4xi64>, !llvm.ptr<5> = dense<32> : vector<4xi64>, !llvm.ptr<6> = dense<32> : vector<4xi64>, !llvm.ptr<7> = dense<[160, 256, 256, 32]> : vector<4xi64>, !llvm.ptr<8> = dense<[128, 128, 128, 48]> : vector<4xi64>, !llvm.ptr<9> = dense<[192, 256, 256, 32]> : vector<4xi64>, i64 = dense<64> : vector<2xi64>, i1 = dense<8> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i16 = dense<16> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, f64 = dense<64> : vector<2xi64>, f128 = dense<128> : vector<2xi64>, "dlti.endianness" = "little", "dlti.legal_int_widths" = array<i32: 32, 64>, "dlti.stack_alignment" = 32 : i64, "dlti.alloca_memory_space" = 5 : ui64, "dlti.global_memory_space" = 1 : ui64>, fir.atomic_ignore_denormal_mode, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", fir.target_cpu = "generic-hsa", llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9", llvm.target_triple = "amdgcn-amd-amdhsa", omp.flags = #omp.flags<openmp_device_version = 31>, omp.is_gpu = true, omp.is_target_device = true, omp.requires = #omp<clause_requires none>, omp.target_triples = [], omp.version = #omp.version<version = 31>} {
+ llvm.func @_QQmain() attributes {fir.bindc_name = "TEST", omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>, target_cpu = "generic-hsa"} {
+ %0 = llvm.mlir.constant(1 : i64) : i64
+ %1 = llvm.alloca %0 x i32 {bindc_name = "threads"} : (i64) -> !llvm.ptr<5>
+ %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr
+ %3 = llvm.mlir.constant(1 : i64) : i64
+ %4 = llvm.alloca %3 x i32 {bindc_name = "a"} : (i64) -> !llvm.ptr<5>
+ %5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr
+ %6 = llvm.mlir.constant(0 : i32) : i32
+ %7 = llvm.mlir.constant(128 : i32) : i32
+ %8 = llvm.mlir.constant(1 : i64) : i64
+ %9 = llvm.mlir.constant(1 : i64) : i64
+ llvm.store %7, %2 : i32, !llvm.ptr
+ llvm.store %6, %5 : i32, !llvm.ptr
+ %10 = omp.map.info var_ptr(%2 : !llvm.ptr, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr {name = "threads"}
+ %11 = omp.map.info var_ptr(%5 : !llvm.ptr, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr {name = "a"}
+ omp.target map_entries(%10 -> %arg0, %11 -> %arg1 : !llvm.ptr, !llvm.ptr) {
+ %12 = llvm.mlir.constant(1 : i32) : i32
+ %13 = llvm.load %arg0 : !llvm.ptr -> i32
+ omp.parallel num_threads(%13 : i32) {
+ omp.atomic.update %arg1 : !llvm.ptr {
+ ^bb0(%arg2: i32):
+ %14 = llvm.add %arg2, %12 : i32
+ omp.yield(%14 : i32)
+ } {atomic_control = #omp.atomic_control<ignore_denormal_mode = true>}
+ omp.terminator
+ }
+ omp.terminator
+ }
+ llvm.return
+ }
+}
diff --git a/mlir/test/Target/LLVMIR/omptarget-debug-147063.mlir b/mlir/test/Target/LLVMIR/omptarget-debug-147063.mlir
new file mode 100644
index 0000000..12d389a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-debug-147063.mlir
@@ -0,0 +1,45 @@
+// RUN: mlir-translate -mlir-to-llvmir %s
+
+module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu", omp.is_gpu = false, omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
+ omp.private {type = private} @_QFFfnEv_private_i32 : i32 loc(#loc1)
+ llvm.func internal @_QFPfn() {
+ %0 = llvm.mlir.constant(1 : i64) : i64 loc(#loc1)
+ %1 = llvm.alloca %0 x i32 {bindc_name = "v"} : (i64) -> !llvm.ptr loc(#loc1)
+ %2 = llvm.mlir.constant(1 : i32) : i32
+ omp.parallel private(@_QFFfnEv_private_i32 %1 -> %arg0 : !llvm.ptr) {
+ llvm.store %2, %arg0 : i32, !llvm.ptr loc(#loc2)
+ %4 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr {name = "v"} loc(#loc2)
+ omp.target map_entries(%4 -> %arg1 : !llvm.ptr) {
+ %5 = llvm.mlir.constant(1 : i32) : i32
+ %6 = llvm.load %arg1 : !llvm.ptr -> i32 loc(#loc3)
+ %7 = llvm.add %6, %5 : i32 loc(#loc3)
+ llvm.store %7, %arg1 : i32, !llvm.ptr loc(#loc3)
+ omp.terminator loc(#loc3)
+ } loc(#loc7)
+ omp.terminator
+ } loc(#loc4)
+ llvm.return
+ } loc(#loc6)
+}
+
+#di_file = #llvm.di_file<"target.f90" in "">
+#di_null_type = #llvm.di_null_type
+#di_compile_unit = #llvm.di_compile_unit<id = distinct[0]<>,
+ sourceLanguage = DW_LANG_Fortran95, file = #di_file, producer = "flang",
+ isOptimized = false, emissionKind = LineTablesOnly>
+#di_subroutine_type = #llvm.di_subroutine_type<
+ callingConvention = DW_CC_program, types = #di_null_type>
+#di_subprogram = #llvm.di_subprogram<id = distinct[1]<>,
+ compileUnit = #di_compile_unit, scope = #di_file, name = "main",
+ file = #di_file, subprogramFlags = "Definition|MainSubprogram",
+ type = #di_subroutine_type>
+#di_subprogram1 = #llvm.di_subprogram<compileUnit = #di_compile_unit,
+ name = "target", file = #di_file, subprogramFlags = "Definition",
+ type = #di_subroutine_type>
+
+#loc1 = loc("test.f90":7:15)
+#loc2 = loc("test.f90":1:7)
+#loc3 = loc("test.f90":3:7)
+#loc4 = loc("test.f90":16:7)
+#loc6 = loc(fused<#di_subprogram>[#loc1])
+#loc7 = loc(fused<#di_subprogram1>[#loc3])
diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir b/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir
index 830610f..5d2861a 100644
--- a/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir
@@ -37,7 +37,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
// CHECK-SAME: #[[ATTRS1:[0-9]+]]
// CHECK: call void @__kmpc_for_static_loop_4u(ptr addrspacecast (ptr addrspace(1) @[[GLOB]] to ptr),
// CHECK-SAME: ptr @[[LOOP_BODY_FUNC:.*]], ptr %[[LOO_BODY_FUNC_ARG:.*]], i32 10,
-// CHECK-SAME: i32 %[[THREAD_NUM:.*]], i32 0)
+// CHECK-SAME: i32 %[[THREAD_NUM:.*]], i8 0)
// CHECK: define internal void @[[LOOP_BODY_FUNC]](i32 %[[CNT:.*]], ptr %[[LOOP_BODY_ARG_PTR:.*]]) #[[ATTRS2:[0-9]+]] {
diff --git a/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir b/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir
index 0ebcec0..b42e387 100644
--- a/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir
@@ -25,7 +25,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
// CHECK: define void @[[FUNC_COLLAPSED_WSLOOP:.*]](ptr %[[ARG0:.*]])
// CHECK: call void @__kmpc_for_static_loop_4u(ptr addrspacecast (ptr addrspace(1) @[[GLOB2:[0-9]+]] to ptr),
// CHECK-SAME: ptr @[[COLLAPSED_WSLOOP_BODY_FN:.*]], ptr %[[STRUCT_ARG:.*]], i32 10000,
-// CHECK-SAME: i32 %[[NUM_THREADS:.*]], i32 0)
+// CHECK-SAME: i32 %[[NUM_THREADS:.*]], i8 0)
// CHECK: define internal void @[[COLLAPSED_WSLOOP_BODY_FN]](i32 %[[LOOP_CNT:.*]], ptr %[[LOOP_BODY_ARG:.*]])
// CHECK: %[[TMP0:.*]] = urem i32 %[[LOOP_CNT]], 100
diff --git a/mlir/test/Target/LLVMIR/omptarget-wsloop.mlir b/mlir/test/Target/LLVMIR/omptarget-wsloop.mlir
index a9f913b..7be635f 100644
--- a/mlir/test/Target/LLVMIR/omptarget-wsloop.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-wsloop.mlir
@@ -37,7 +37,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
// CHECK: %[[GEP:.*]] = getelementptr { ptr }, ptr addrspace(5) %[[STRUCTARG]], i32 0, i32 0
// CHECK: store ptr %[[ARG0]], ptr addrspace(5) %[[GEP]], align 8
// CHECK: %[[NUM_THREADS:.*]] = call i32 @omp_get_num_threads()
-// CHECK: call void @__kmpc_for_static_loop_4u(ptr addrspacecast (ptr addrspace(1) @[[GLOB1:[0-9]+]] to ptr), ptr @[[LOOP_BODY_FN:.*]], ptr %[[STRUCTARG_ASCAST]], i32 10, i32 %[[NUM_THREADS]], i32 0)
+// CHECK: call void @__kmpc_for_static_loop_4u(ptr addrspacecast (ptr addrspace(1) @[[GLOB1:[0-9]+]] to ptr), ptr @[[LOOP_BODY_FN:.*]], ptr %[[STRUCTARG_ASCAST]], i32 10, i32 %[[NUM_THREADS]], i32 0, i8 0)
// CHECK: define internal void @[[LOOP_BODY_FN]](i32 %[[LOOP_CNT:.*]], ptr %[[LOOP_BODY_ARG:.*]])
// CHECK: %[[GEP2:.*]] = getelementptr { ptr }, ptr %[[LOOP_BODY_ARG]], i32 0, i32 0
@@ -46,6 +46,6 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
// CHECK: store i32 %[[VAL0:.*]], ptr %[[GEP3]], align 4
// CHECK: define void @[[FUNC_EMPTY_WSLOOP:.*]]()
-// CHECK: call void @__kmpc_for_static_loop_4u(ptr addrspacecast (ptr addrspace(1) @[[GLOB2:[0-9]+]] to ptr), ptr @[[LOOP_EMPTY_BODY_FN:.*]], ptr null, i32 10, i32 %[[NUM_THREADS:.*]], i32 0)
+// CHECK: call void @__kmpc_for_static_loop_4u(ptr addrspacecast (ptr addrspace(1) @[[GLOB2:[0-9]+]] to ptr), ptr @[[LOOP_EMPTY_BODY_FN:.*]], ptr null, i32 10, i32 %[[NUM_THREADS:.*]], i32 0, i8 0)
// CHECK: define internal void @[[LOOP_EMPTY_BODY_FN]](i32 %[[LOOP_CNT:.*]])
diff --git a/mlir/test/Target/LLVMIR/openmp-simd-aligned.mlir b/mlir/test/Target/LLVMIR/openmp-simd-aligned.mlir
index 234604e..902548c 100644
--- a/mlir/test/Target/LLVMIR/openmp-simd-aligned.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-simd-aligned.mlir
@@ -58,3 +58,25 @@ llvm.func @_QPsimd_aligned_allocatable() {
}
llvm.return
}
+
+//CHECK-LABEL: define void @_QPsimd_aligned_non_power_of_two() {
+//CHECK: %[[A_ADDR:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, i64 1, align 8
+//CHECK: %[[B_ADDR:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, i64 1, align 8
+//CHECK: %[[LOAD_B:.*]] = load ptr, ptr %[[B_ADDR]], align 8
+//CHECK: call void @llvm.assume(i1 true) [ "align"(ptr %[[LOAD_B]], i64 64) ]
+//CHECK-NOT: call void @llvm.assume(i1 true) [ "align"(ptr %{{.*}}, i64 257) ]
+llvm.func @_QPsimd_aligned_non_power_of_two() {
+ %0 = llvm.mlir.constant(1 : i64) : i64
+ %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {bindc_name = "a"} : (i64) -> !llvm.ptr
+ %2 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {bindc_name = "b"} : (i64) -> !llvm.ptr
+ %3 = llvm.mlir.constant(1 : i32) : i32
+ %4 = llvm.mlir.constant(10 : i32) : i32
+ %5 = llvm.mlir.constant(1 : i32) : i32
+ omp.simd aligned(%1 : !llvm.ptr -> 257 : i64, %2 : !llvm.ptr -> 64 : i64) {
+ omp.loop_nest (%arg0) : i32 = (%3) to (%4) inclusive step (%5) {
+ omp.yield
+ }
+ }
+ llvm.return
+}
+
diff --git a/mlir/test/Target/LLVMIR/ptr.mlir b/mlir/test/Target/LLVMIR/ptr.mlir
new file mode 100644
index 0000000..c1620cb
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/ptr.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: declare ptr @llvm_ptr_address_space(ptr addrspace(1), ptr addrspace(3))
+llvm.func @llvm_ptr_address_space(!ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<3>>) -> !ptr.ptr<#llvm.address_space<0>>
+
+// CHECK-LABEL: define void @llvm_ops_with_ptr_values
+// CHECK-SAME: (ptr %[[ARG:.*]]) {
+// CHECK-NEXT: %[[V0:.*]] = load ptr addrspace(1), ptr %[[ARG]], align 8
+// CHECK-NEXT: store ptr addrspace(1) %[[V0]], ptr %[[ARG]], align 8
+// CHECK-NEXT: ret void
+// CHECK-NEXT: }
+llvm.func @llvm_ops_with_ptr_values(%arg0: !llvm.ptr) {
+ %1 = llvm.load %arg0 : !llvm.ptr -> !ptr.ptr<#llvm.address_space<1>>
+ llvm.store %1, %arg0 : !ptr.ptr<#llvm.address_space<1>>, !llvm.ptr
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 740990a..a464358 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -86,12 +86,12 @@ llvm.func @kernel_func_unsafe_fp_atomics()
}
llvm.func @rocdl.lane_id() -> i32 {
- // CHECK: [[mbcntlo:%.+]] = call i32 @llvm.amdgcn.mbcnt.lo(i32 -1, i32 0)
- // CHECK-NEXT: call i32 @llvm.amdgcn.mbcnt.hi(i32 -1, i32 [[mbcntlo]])
+ // CHECK: [[mbcntlo:%.+]] = call noundef range(i32 0, 32) i32 @llvm.amdgcn.mbcnt.lo(i32 -1, i32 0)
+ // CHECK-NEXT: call noundef range(i32 0, 64) i32 @llvm.amdgcn.mbcnt.hi(i32 -1, i32 [[mbcntlo]])
%0 = llvm.mlir.constant(-1 : i32) : i32
%1 = llvm.mlir.constant(0 : i32) : i32
- %2 = rocdl.mbcnt.lo %0, %1 : (i32, i32) -> i32
- %3 = rocdl.mbcnt.hi %0, %2 : (i32, i32) -> i32
+ %2 = rocdl.mbcnt.lo %0, %1 {res_attrs = [{llvm.noundef, llvm.range = #llvm.constant_range<i32, 0, 32>}]} : (i32, i32) -> i32
+ %3 = rocdl.mbcnt.hi %0, %2 {res_attrs = [{llvm.noundef, llvm.range = #llvm.constant_range<i32, 0, 64>}]} : (i32, i32) -> i32
llvm.return %3 : i32
}
@@ -125,6 +125,23 @@ llvm.func @rocdl.ballot64(%pred : i1) -> i64 {
llvm.return %0 : i64
}
+llvm.func @rocdl.readfirstlane(%src0 : f32, %src1: f64, %src2: i32, %src3: vector<2 x f32>) -> f32 {
+ // CHECK-LABEL: rocdl.readfirstlane
+ // CHECK: call float @llvm.amdgcn.readfirstlane.f32(float %{{.*}})
+ %0 = rocdl.readfirstlane %src0 : f32
+
+ // CHECK: call double @llvm.amdgcn.readfirstlane.f64(double %{{.*}})
+ %1 = rocdl.readfirstlane %src1 : f64
+
+ // CHECK: call i32 @llvm.amdgcn.readfirstlane.i32(i32 %{{.*}})
+ %2 = rocdl.readfirstlane %src2 : i32
+
+ // CHECK: call <2 x float> @llvm.amdgcn.readfirstlane.v2f32(<2 x float> %{{.*}})
+ %3 = rocdl.readfirstlane %src3 : vector<2 x f32>
+
+ llvm.return %0 : f32
+}
+
llvm.func @rocdl.readlane(%src0 : f32, %src1: f64, %src2: i32, %src3: vector<2 x f32>) -> f32 {
%idx = llvm.mlir.constant(0 : i32) : i32
@@ -924,6 +941,20 @@ llvm.func @rocdl.permlanex16(%src0 : f32, %src1 : i32, %src2 : vector<2 x f32>,
llvm.return %ret0 : f32
}
+llvm.func @rocdl.permlane16.swap(%src : i32) -> !llvm.struct<(i32, i32)> {
+ // CHECK-LABEL: rocdl.permlane16.swap
+ // CHECK: call { i32, i32 } @llvm.amdgcn.permlane16.swap(i32 %{{.*}}, i32 %{{.*}}, i1 false, i1 true)
+ %ret = rocdl.permlane16.swap %src, %src, 0, -1 : (i32, i32) -> !llvm.struct<(i32, i32)>
+ llvm.return %ret : !llvm.struct<(i32, i32)>
+}
+
+llvm.func @rocdl.permlane32.swap(%src : i32) -> !llvm.struct<(i32, i32)> {
+ // CHECK-LABEL: rocdl.permlane32.swap
+ // CHECK: call { i32, i32 } @llvm.amdgcn.permlane32.swap(i32 %{{.*}}, i32 %{{.*}}, i1 false, i1 true)
+ %ret = rocdl.permlane32.swap %src, %src, 0, -1 : (i32, i32) -> !llvm.struct<(i32, i32)>
+ llvm.return %ret : !llvm.struct<(i32, i32)>
+}
+
llvm.func @rocdl.wmma.fp8(%arg0 : vector<2 x i32>, %arg1 : vector<8xf32>) -> vector<8xf32> {
// CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.fp8.fp8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}})
%r0 = rocdl.wmma.f32.16x16x16.fp8_fp8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
diff --git a/mlir/test/Target/LLVMIR/target-to-data-layout-and-target-features.mlir b/mlir/test/Target/LLVMIR/target-to-data-layout-and-target-features.mlir
new file mode 100644
index 0000000..b6b2976
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/target-to-data-layout-and-target-features.mlir
@@ -0,0 +1,137 @@
+// REQUIRES: target=x86{{.*}}
+
+// RUN: mlir-opt -llvm-target-to-data-layout -split-input-file %s | FileCheck --check-prefix=DATA-LAYOUT %s
+// RUN: mlir-opt -llvm-target-to-target-features -split-input-file %s | FileCheck --check-prefix=TARGET-FEATURES %s
+
+// DATA-LAYOUT: module attributes
+// DATA-LAYOUT-SAME: dlti.dl_spec = #dlti.dl_spec
+// DATA-LAYOUT-SAME: "dlti.endianness" = "little"
+// DATA-LAYOUT-SAME: llvm.target = #llvm.target<
+// DATA-LAYOUT-SAME: triple = "x86_64-unknown-linux"
+// DATA-LAYOUT-SAME: chip = ""
+// DATA-LAYOUT-NOT: features =
+
+// TARGET-FEATURES: module attributes
+// TARGET-FEATURES-NOT: dlti.dl_spec
+// TARGET-FEATURES-SAME: llvm.target = #llvm.target<
+// TARGET-FEATURES-SAME: triple = "x86_64-unknown-linux"
+// TARGET-FEATURES-SAME: chip = ""
+// TARGET-FEATURES-SAME: features = <[
+// TARGET-FEATURES-SAME: +64bit
+// TARGET-FEATURES-NOT: +avx
+// TARGET-FEATURES-SAME: +sse
+// TARGET-FEATURES-NOT: +mmx
+
+module attributes { llvm.target = #llvm.target<triple = "x86_64-unknown-linux",
+ chip = ""> } {
+}
+
+// -----
+
+// DATA-LAYOUT: module attributes
+// DATA-LAYOUT-SAME: dlti.dl_spec = #dlti.dl_spec
+// DATA-LAYOUT-SAME: "dlti.endianness" = "little"
+// DATA-LAYOUT-SAME: llvm.target = #llvm.target<
+// DATA-LAYOUT-SAME: triple = "x86_64-unknown-linux"
+// DATA-LAYOUT-SAME: chip = ""
+// DATA-LAYOUT-SAME: features = <["+mmx", "+sse"]>
+
+// TARGET-FEATURES: module attributes
+// TARGET-FEATURES-NOT: dlti.dl_spec
+// TARGET-FEATURES-SAME: llvm.target = #llvm.target<
+// TARGET-FEATURES-SAME: triple = "x86_64-unknown-linux"
+// TARGET-FEATURES-SAME: chip = ""
+// TARGET-FEATURES-SAME: features = <[
+// TARGET-FEATURES-SAME: +64bit
+// TARGET-FEATURES-NOT: +avx
+// TARGET-FEATURES-SAME: +mmx
+// TARGET-FEATURES-SAME: +sse
+
+module attributes { llvm.target = #llvm.target<triple = "x86_64-unknown-linux",
+ chip = "",
+ features = <["+mmx", "+sse"]>> } {
+}
+
+// -----
+
+// DATA-LAYOUT: module attributes
+// DATA-LAYOUT-SAME: dlti.dl_spec = #dlti.dl_spec
+// DATA-LAYOUT-SAME: "dlti.endianness" = "little"
+// DATA-LAYOUT-SAME: llvm.target = #llvm.target<
+// DATA-LAYOUT-SAME: triple = "x86_64-unknown-linux"
+// DATA-LAYOUT-SAME: chip = "skylake"
+// DATA-LAYOUT-NOT: features =
+
+// TARGET-FEATURES: module attributes
+// TARGET-FEATURES-NOT: dlti.dl_spec
+// TARGET-FEATURES-SAME: llvm.target = #llvm.target<
+// TARGET-FEATURES-SAME: triple = "x86_64-unknown-linux"
+// TARGET-FEATURES-SAME: chip = "skylake"
+// TARGET-FEATURES-SAME: features = <[
+// TARGET-FEATURES-SAME: +64bit
+// TARGET-FEATURES-SAME: +avx
+// TARGET-FEATURES-SAME: +avx2
+// TARGET-FEATURES-NOT: +avx512f
+// TARGET-FEATURES-SAME: +mmx
+// TARGET-FEATURES-SAME: +sse
+
+module attributes { llvm.target = #llvm.target<triple = "x86_64-unknown-linux",
+ chip = "skylake"> } {
+}
+
+// -----
+
+// DATA-LAYOUT: module attributes
+// DATA-LAYOUT-SAME: dlti.dl_spec = #dlti.dl_spec
+// DATA-LAYOUT-SAME: "dlti.endianness" = "little"
+// DATA-LAYOUT-SAME: llvm.target = #llvm.target<
+// DATA-LAYOUT-SAME: triple = "x86_64-unknown-linux"
+// DATA-LAYOUT-SAME: chip = "skylake"
+// DATA-LAYOUT-SAME: features = <["-sse", "-avx"]>
+
+// TARGET-FEATURES: module attributes
+// TARGET-FEATURES-NOT: dlti.dl_spec
+// TARGET-FEATURES-SAME: llvm.target = #llvm.target<
+// TARGET-FEATURES-SAME: triple = "x86_64-unknown-linux"
+// TARGET-FEATURES-SAME: chip = "skylake"
+// TARGET-FEATURES-SAME: features = <[
+// TARGET-FEATURES-SAME: +64bit
+// TARGET-FEATURES-NOT: +avx
+// TARGET-FEATURES-NOT: +avx2
+// TARGET-FEATURES-SAME: +mmx
+// TARGET-FEATURES-NOT: +sse
+
+module attributes { llvm.target = #llvm.target<triple = "x86_64-unknown-linux",
+ chip = "skylake",
+ features = <["-sse", "-avx"]>> } {
+}
+
+// -----
+
+// DATA-LAYOUT: module attributes
+// DATA-LAYOUT-SAME: dlti.dl_spec = #dlti.dl_spec
+// DATA-LAYOUT-SAME: "dlti.endianness" = "little"
+// DATA-LAYOUT-SAME: index = 32
+// DATA-LAYOUT-SAME: llvm.target = #llvm.target<
+// DATA-LAYOUT-SAME: triple = "x86_64-unknown-linux"
+// DATA-LAYOUT-SAME: chip = "skylake"
+// DATA-LAYOUT-SAME: features = <["-mmx", "+avx512f"]>
+
+// TARGET-FEATURES: module attributes
+// TARGET-FEATURES-SAME: #dlti.dl_spec<index = 32 : i64>
+// TARGET-FEATURES-SAME: llvm.target = #llvm.target<
+// TARGET-FEATURES-SAME: triple = "x86_64-unknown-linux"
+// TARGET-FEATURES-SAME: chip = "skylake"
+// TARGET-FEATURES-SAME: features = <[
+// TARGET-FEATURES-SAME: +64bit
+// TARGET-FEATURES-SAME: +avx
+// TARGET-FEATURES-SAME: +avx2
+// TARGET-FEATURES-SAME: +avx512f
+// TARGET-FEATURES-NOT: +mmx
+// TARGET-FEATURES-SAME: +sse
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<index = 32>,
+ llvm.target = #llvm.target<triple = "x86_64-unknown-linux",
+ chip = "skylake",
+ features = <["-mmx", "+avx512f"]>> } {
+}
diff --git a/mlir/test/Target/LLVMIR/target-to-data-layout-invalid.mlir b/mlir/test/Target/LLVMIR/target-to-data-layout-invalid.mlir
new file mode 100644
index 0000000..c0ff534
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/target-to-data-layout-invalid.mlir
@@ -0,0 +1,9 @@
+// REQUIRES: target=x86{{.*}}
+// RUN: mlir-opt %s -llvm-target-to-data-layout --split-input-file --verify-diagnostics
+
+// expected-error @+1 {{failed to obtain llvm::DataLayout for #llvm.target}}
+module attributes { dlti.dl_spec = #dlti.dl_spec<index = 32>,
+llvm.target =
+ #llvm.target<triple="x64_86-unknown-linux",
+ chip="NON-EXISTING CHIP"> } {
+}
diff --git a/mlir/test/Target/LLVMIR/target-to-data-layout-no-init.mlir b/mlir/test/Target/LLVMIR/target-to-data-layout-no-init.mlir
new file mode 100644
index 0000000..2a9e978
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/target-to-data-layout-no-init.mlir
@@ -0,0 +1,12 @@
+// REQUIRES: target=x86{{.*}}
+// RUN: mlir-opt %s -llvm-target-to-data-layout="initialize-llvm-targets=false" --split-input-file --verify-diagnostics
+
+// Without initializing the (right) LLVM targets/backends ("initialize-llvm-targets=false"),
+// it is not possible to obtain LLVM's DataLayout for the target.
+
+// expected-error @+1 {{failed to obtain llvm::DataLayout for #llvm.target}}
+module attributes { dlti.dl_spec = #dlti.dl_spec<index = 32>,
+llvm.target =
+ #llvm.target<triple="x64_86-unknown-linux",
+ chip="skylake"> } {
+}
diff --git a/mlir/test/Target/LLVMIR/target-to-target-features-dlti-query.mlir b/mlir/test/Target/LLVMIR/target-to-target-features-dlti-query.mlir
new file mode 100644
index 0000000..9a1f49b
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/target-to-target-features-dlti-query.mlir
@@ -0,0 +1,75 @@
+// REQUIRES: target=x86{{.*}}
+
+// RUN: mlir-opt -transform-interpreter -split-input-file %s --verify-diagnostics
+
+// Check that processor features, like AVX, are appropriated derived and queryable.
+
+// expected-remark @+2 {{attr associated to ["features", "+avx"] = unit}}
+// expected-remark @below {{attr associated to ["features", "avx"] = true}}
+module attributes { llvm.target = #llvm.target<triple = "x86_64-unknown-linux",
+ chip = "skylake">,
+ test.dl_spec = #dlti.dl_spec<index = 32> } {
+ func.func private @f()
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+ %funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
+ %module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
+ %mod = transform.apply_registered_pass "llvm-target-to-target-features" to %module : (!transform.any_op) -> !transform.any_op
+ %plus_avx = transform.dlti.query ["features", "+avx"] at %mod : (!transform.any_op) -> !transform.any_param
+ transform.debug.emit_param_as_remark %plus_avx, "attr associated to [\"features\", \"+avx\"] =" at %mod : !transform.any_param, !transform.any_op
+ %avx = transform.dlti.query ["features", "avx"] at %mod : (!transform.any_op) -> !transform.any_param
+ transform.debug.emit_param_as_remark %avx, "attr associated to [\"features\", \"avx\"] =" at %mod : !transform.any_param, !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Check that newer processor features, like AMX, are appropriated derived and queryable.
+
+// expected-remark @+2 {{attr associated to ["features", "+amx-bf16"] = unit}}
+// expected-remark @below {{attr associated to ["features", "amx-bf16"] = true}}
+module attributes { llvm.target = #llvm.target<triple = "x86_64-unknown-linux",
+ chip = "sapphirerapids">,
+ test.dl_spec = #dlti.dl_spec<index = 32> } {
+ func.func private @f()
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+ %funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
+ %module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
+ %mod = transform.apply_registered_pass "llvm-target-to-target-features" to %module : (!transform.any_op) -> !transform.any_op
+ %plus_avx = transform.dlti.query ["features", "+amx-bf16"] at %mod : (!transform.any_op) -> !transform.any_param
+ transform.debug.emit_param_as_remark %plus_avx, "attr associated to [\"features\", \"+amx-bf16\"] =" at %mod : !transform.any_param, !transform.any_op
+ %avx = transform.dlti.query ["features", "amx-bf16"] at %mod : (!transform.any_op) -> !transform.any_param
+ transform.debug.emit_param_as_remark %avx, "attr associated to [\"features\", \"amx-bf16\"] =" at %mod : !transform.any_param, !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Check that features that a processor does not have, AMX in this case,
+// aren't derived and hence that querying for them will fail.
+
+// expected-error @+2 {{target op of failed DLTI query}}
+// expected-note @below {{key "+amx-bf16" has no DLTI-mapping per attr: #llvm.target_features}}
+module attributes { llvm.target = #llvm.target<triple = "x86_64-unknown-linux",
+ chip = "skylake">,
+ test.dl_spec = #dlti.dl_spec<index = 32> } {
+ func.func private @f()
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+ %funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
+ %module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
+ %mod = transform.apply_registered_pass "llvm-target-to-target-features" to %module : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{'transform.dlti.query' op failed to apply}}
+ %param = transform.dlti.query ["features", "+amx-bf16"] at %mod : (!transform.any_op) -> !transform.any_param
+ transform.yield
+ }
+}
diff --git a/mlir/test/Target/SPIRV/arm-tensor-constant.mlir b/mlir/test/Target/SPIRV/arm-tensor-constant.mlir
index 275e586..7fb8af1 100644
--- a/mlir/test/Target/SPIRV/arm-tensor-constant.mlir
+++ b/mlir/test/Target/SPIRV/arm-tensor-constant.mlir
@@ -1,17 +1,36 @@
// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip %s | FileCheck %s
-// DISABLED: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv %s | spirv-val %}
-
-// FIXME(#152012): Fix arm tensor constant validation errors and reenable spirv-val tests.
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv %s | spirv-val %}
spirv.module Logical Vulkan requires #spirv.vce<v1.3,
[VulkanMemoryModel, Shader, TensorsARM, Linkage], [SPV_KHR_vulkan_memory_model, SPV_ARM_tensors]> {
- // CHECK-LABEL: @arm_tensor_of_i32
- spirv.func @arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+ // CHECK-LABEL: @rank_1_arm_tensor_of_i32
+ spirv.func @rank_1_arm_tensor_of_i32() -> (!spirv.arm.tensor<3xi32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<[1, 2, 3]> : !spirv.arm.tensor<3xi32>
+ %0 = spirv.Constant dense<[1, 2, 3]> : !spirv.arm.tensor<3xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<3xi32>
+ }
+
+ // CHECK-LABEL: @rank_2_arm_tensor_of_i32
+ spirv.func @rank_2_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
// CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32>
%0 = spirv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32>
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
}
+ // CHECK-LABEL: @rank_3_arm_tensor_of_i32
+ spirv.func @rank_3_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x2x3xi32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}{{\[}}[1, 2, 3], [4, 5, 6]], {{\[}}[7, 8, 9], [10, 11, 12]]]> : !spirv.arm.tensor<2x2x3xi32>
+ %0 = spirv.Constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : !spirv.arm.tensor<2x2x3xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x2x3xi32>
+ }
+
+ // CHECK-LABEL: @rank_4_arm_tensor_of_i32
+ spirv.func @rank_4_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3x4x5xi32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<5> : !spirv.arm.tensor<2x3x4x5xi32>
+ %0 = spirv.Constant dense<5> : !spirv.arm.tensor<2x3x4x5xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3x4x5xi32>
+ }
+
// CHECK-LABEL: @splat_arm_tensor_of_i32
spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
// CHECK: {{%.*}} = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
@@ -19,13 +38,34 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3,
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
}
- // CHECK-LABEL: @arm_tensor_of_f32
- spirv.func @arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+ // CHECK-LABEL: @rank_1_arm_tensor_of_f32
+ spirv.func @rank_1_arm_tensor_of_f32() -> (!spirv.arm.tensor<3xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : !spirv.arm.tensor<3xf32>
+ %0 = spirv.Constant dense<[1.0, 2.0, 3.0]> : !spirv.arm.tensor<3xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<3xf32>
+ }
+
+ // CHECK-LABEL: @rank_2_arm_tensor_of_f32
+ spirv.func @rank_2_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
// CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : !spirv.arm.tensor<2x3xf32>
- %0 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>: !spirv.arm.tensor<2x3xf32>
+ %0 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : !spirv.arm.tensor<2x3xf32>
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
}
+ // CHECK-LABEL: @rank_3_arm_tensor_of_f32
+ spirv.func @rank_3_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x2x3xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]], {{\[}}[7.000000e+00, 8.000000e+00, 9.000000e+00], [1.000000e+01, 1.100000e+01, 1.200000e+01]]]> : !spirv.arm.tensor<2x2x3xf32>
+ %0 = spirv.Constant dense<[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]> : !spirv.arm.tensor<2x2x3xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x2x3xf32>
+ }
+
+ // CHECK-LABEL: @rank_4_arm_tensor_of_f32
+ spirv.func @rank_4_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3x4x5xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<5.000000e+00> : !spirv.arm.tensor<2x3x4x5xf32>
+ %0 = spirv.Constant dense<5.0> : !spirv.arm.tensor<2x3x4x5xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3x4x5xf32>
+ }
+
// CHECK-LABEL: @splat_arm_tensor_of_f32
spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
// CHECK: {{%.*}} = spirv.Constant dense<2.000000e+00> : !spirv.arm.tensor<2x3xf32>
diff --git a/mlir/test/Target/SPIRV/debug-negative.mlir b/mlir/test/Target/SPIRV/debug-negative.mlir
new file mode 100644
index 0000000..2c82687
--- /dev/null
+++ b/mlir/test/Target/SPIRV/debug-negative.mlir
@@ -0,0 +1,5 @@
+// RUN: mlir-translate %s --test-spirv-roundtrip-debug --no-implicit-module --verify-diagnostics
+
+// expected-error@below {{SPV_KHR_non_semantic_info extension not available}}
+spirv.module Logical GLSL450 requires #spirv.vce<v1.3, [Shader], []> attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader], []>, #spirv.resource_limits<>>} {
+}
diff --git a/mlir/test/Target/SPIRV/debug.mlir b/mlir/test/Target/SPIRV/debug.mlir
index 58bf364..5a7ed19 100644
--- a/mlir/test/Target/SPIRV/debug.mlir
+++ b/mlir/test/Target/SPIRV/debug.mlir
@@ -1,69 +1,70 @@
// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip-debug -mlir-print-debuginfo -mlir-print-local-scope %s | FileCheck %s
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv %s | spirv-val %}
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
- // CHECK: loc({{".*debug.mlir"}}:5:3)
+spirv.module Logical GLSL450 requires #spirv.vce<v1.3, [Shader, GroupNonUniformArithmetic], [SPV_KHR_non_semantic_info, SPV_KHR_storage_buffer_storage_class]> attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, GroupNonUniformArithmetic], [SPV_KHR_non_semantic_info, SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>} {
+ // CHECK: loc({{".*debug.mlir"}}:6:3)
spirv.GlobalVariable @var0 bind(0, 1) : !spirv.ptr<f32, Input>
spirv.func @arithmetic(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) "None" {
- // CHECK: loc({{".*debug.mlir"}}:8:10)
+ // CHECK: loc({{".*debug.mlir"}}:9:10)
%0 = spirv.FAdd %arg0, %arg1 : vector<4xf32>
- // CHECK: loc({{".*debug.mlir"}}:10:10)
+ // CHECK: loc({{".*debug.mlir"}}:11:10)
%1 = spirv.FNegate %arg0 : vector<4xf32>
spirv.Return
}
spirv.func @atomic(%ptr: !spirv.ptr<i32, Workgroup>, %value: i32, %comparator: i32) "None" {
- // CHECK: loc({{".*debug.mlir"}}:16:10)
+ // CHECK: loc({{".*debug.mlir"}}:17:10)
%1 = spirv.AtomicAnd <Device> <None> %ptr, %value : !spirv.ptr<i32, Workgroup>
spirv.Return
}
spirv.func @bitwiser(%arg0 : i32, %arg1 : i32) "None" {
- // CHECK: loc({{".*debug.mlir"}}:22:10)
+ // CHECK: loc({{".*debug.mlir"}}:23:10)
%0 = spirv.BitwiseAnd %arg0, %arg1 : i32
spirv.Return
}
spirv.func @convert(%arg0 : f32) "None" {
- // CHECK: loc({{".*debug.mlir"}}:28:10)
+ // CHECK: loc({{".*debug.mlir"}}:29:10)
%0 = spirv.ConvertFToU %arg0 : f32 to i32
spirv.Return
}
spirv.func @composite(%arg0 : !spirv.struct<(f32, !spirv.struct<(!spirv.array<4xf32>, f32)>)>, %arg1: !spirv.array<4xf32>, %arg2 : f32, %arg3 : f32) "None" {
- // CHECK: loc({{".*debug.mlir"}}:34:10)
+ // CHECK: loc({{".*debug.mlir"}}:35:10)
%0 = spirv.CompositeInsert %arg1, %arg0[1 : i32, 0 : i32] : !spirv.array<4xf32> into !spirv.struct<(f32, !spirv.struct<(!spirv.array<4xf32>, f32)>)>
- // CHECK: loc({{".*debug.mlir"}}:36:10)
+ // CHECK: loc({{".*debug.mlir"}}:37:10)
%1 = spirv.CompositeConstruct %arg2, %arg3 : (f32, f32) -> vector<2xf32>
spirv.Return
}
spirv.func @group_non_uniform(%val: f32) "None" {
- // CHECK: loc({{".*debug.mlir"}}:42:10)
+ // CHECK: loc({{".*debug.mlir"}}:43:10)
%0 = spirv.GroupNonUniformFAdd <Workgroup> <Reduce> %val : f32 -> f32
spirv.Return
}
spirv.func @local_var() "None" {
%zero = spirv.Constant 0: i32
- // CHECK: loc({{".*debug.mlir"}}:49:12)
+ // CHECK: loc({{".*debug.mlir"}}:50:12)
%var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
spirv.Return
}
spirv.func @logical(%arg0: i32, %arg1: i32) "None" {
- // CHECK: loc({{".*debug.mlir"}}:55:10)
+ // CHECK: loc({{".*debug.mlir"}}:56:10)
%0 = spirv.IEqual %arg0, %arg1 : i32
spirv.Return
}
spirv.func @memory_accesses(%arg0 : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, StorageBuffer>, %arg1 : i32, %arg2 : i32) "None" {
- // CHECK: loc({{".*debug.mlir"}}:61:10)
+ // CHECK: loc({{".*debug.mlir"}}:62:10)
%2 = spirv.AccessChain %arg0[%arg1, %arg2] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
- // CHECK: loc({{".*debug.mlir"}}:63:10)
+ // CHECK: loc({{".*debug.mlir"}}:64:10)
%3 = spirv.Load "StorageBuffer" %2 : f32
- // CHECK: loc({{.*debug.mlir"}}:65:5)
+ // CHECK: loc({{.*debug.mlir"}}:66:5)
spirv.Store "StorageBuffer" %2, %3 : f32
- // CHECK: loc({{".*debug.mlir"}}:67:5)
+ // CHECK: loc({{".*debug.mlir"}}:68:5)
spirv.Return
}
@@ -73,49 +74,49 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%ivar = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
%jvar = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
spirv.mlir.loop {
- // CHECK: loc({{".*debug.mlir"}}:75:5)
+ // CHECK: loc({{".*debug.mlir"}}:76:5)
spirv.Branch ^header
^header:
%ival0 = spirv.Load "Function" %ivar : i32
%icmp = spirv.SLessThan %ival0, %count : i32
- // CHECK: loc({{".*debug.mlir"}}:75:5)
+ // CHECK: loc({{".*debug.mlir"}}:76:5)
spirv.BranchConditional %icmp, ^body, ^merge
^body:
spirv.Store "Function" %jvar, %zero : i32
spirv.mlir.loop {
- // CHECK: loc({{".*debug.mlir"}}:85:7)
+ // CHECK: loc({{".*debug.mlir"}}:86:7)
spirv.Branch ^header
^header:
%jval0 = spirv.Load "Function" %jvar : i32
%jcmp = spirv.SLessThan %jval0, %count : i32
- // CHECK: loc({{".*debug.mlir"}}:85:7)
+ // CHECK: loc({{".*debug.mlir"}}:86:7)
spirv.BranchConditional %jcmp, ^body, ^merge
^body:
- // CHECK: loc({{".*debug.mlir"}}:95:9)
+ // CHECK: loc({{".*debug.mlir"}}:96:9)
spirv.Branch ^continue
^continue:
%jval1 = spirv.Load "Function" %jvar : i32
%add = spirv.IAdd %jval1, %one : i32
spirv.Store "Function" %jvar, %add : i32
- // CHECK: loc({{".*debug.mlir"}}:101:9)
+ // CHECK: loc({{".*debug.mlir"}}:102:9)
spirv.Branch ^header
^merge:
- // CHECK: loc({{".*debug.mlir"}}:85:7)
+ // CHECK: loc({{".*debug.mlir"}}:86:7)
spirv.mlir.merge
- // CHECK: loc({{".*debug.mlir"}}:85:7)
+ // CHECK: loc({{".*debug.mlir"}}:86:7)
}
- // CHECK: loc({{".*debug.mlir"}}:108:7)
+ // CHECK: loc({{".*debug.mlir"}}:109:7)
spirv.Branch ^continue
^continue:
%ival1 = spirv.Load "Function" %ivar : i32
%add = spirv.IAdd %ival1, %one : i32
spirv.Store "Function" %ivar, %add : i32
- // CHECK: loc({{".*debug.mlir"}}:114:7)
+ // CHECK: loc({{".*debug.mlir"}}:115:7)
spirv.Branch ^header
^merge:
- // CHECK: loc({{".*debug.mlir"}}:75:5)
+ // CHECK: loc({{".*debug.mlir"}}:76:5)
spirv.mlir.merge
- // CHECK: loc({{".*debug.mlir"}}:75:5)
+ // CHECK: loc({{".*debug.mlir"}}:76:5)
}
spirv.Return
}
@@ -126,21 +127,23 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%two = spirv.Constant 2: i32
%var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
spirv.mlir.selection {
- // CHECK: loc({{".*debug.mlir"}}:128:5)
+ // CHECK: loc({{".*debug.mlir"}}:129:5)
spirv.BranchConditional %cond [5, 10], ^then, ^else
^then:
spirv.Store "Function" %var, %one : i32
- // CHECK: loc({{".*debug.mlir"}}:134:7)
+ // CHECK: loc({{".*debug.mlir"}}:135:7)
spirv.Branch ^merge
^else:
spirv.Store "Function" %var, %two : i32
- // CHECK: loc({{".*debug.mlir"}}:138:7)
+ // CHECK: loc({{".*debug.mlir"}}:139:7)
spirv.Branch ^merge
^merge:
- // CHECK: loc({{".*debug.mlir"}}:128:5)
+ // CHECK: loc({{".*debug.mlir"}}:129:5)
spirv.mlir.merge
- // CHECK: loc({{".*debug.mlir"}}:128:5)
+ // CHECK: loc({{".*debug.mlir"}}:129:5)
}
spirv.Return
}
+
+ spirv.EntryPoint "GLCompute" @local_var
}
diff --git a/mlir/test/Target/SPIRV/mlir-translate.mlir b/mlir/test/Target/SPIRV/mlir-translate.mlir
new file mode 100644
index 0000000..cbce351
--- /dev/null
+++ b/mlir/test/Target/SPIRV/mlir-translate.mlir
@@ -0,0 +1,29 @@
+// Check that `--spirv-save-validation-files-with-prefix` generates
+// a correct number of files.
+
+// REQUIRES: shell
+// RUN: rm -rf %t
+// RUN: mkdir %t && mlir-translate --serialize-spirv --no-implicit-module \
+// RUN: --split-input-file --spirv-save-validation-files-with-prefix=%t/foo %s \
+// RUN: && ls %t/foo*.spv | wc -l | FileCheck %s
+// RUN: rm -rf %t
+
+// CHECK: 4
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
+}
diff --git a/mlir/test/Target/SPIRV/module.mlir b/mlir/test/Target/SPIRV/module.mlir
index dcdcab8..d4000df 100644
--- a/mlir/test/Target/SPIRV/module.mlir
+++ b/mlir/test/Target/SPIRV/module.mlir
@@ -1,21 +1,29 @@
-// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip -split-input-file %s | FileCheck %s
+// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip --split-input-file %s | FileCheck %s
+
+// REQUIRES: shell
+// RUN: %if spirv-tools %{ rm -rf %t %}
+// RUN: %if spirv-tools %{ mkdir %t %}
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %}
+// RUN: %if spirv-tools %{ ls %t/module*.spv | xargs -I{} spirv-val {} %}
// CHECK: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK-NEXT: spirv.func @foo() "Inline" {
// CHECK-NEXT: spirv.Return
// CHECK-NEXT: }
+// CHECK-NEXT: spirv.EntryPoint "Vertex" @foo
// CHECK-NEXT: }
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.func @foo() -> () "Inline" {
spirv.Return
}
+ spirv.EntryPoint "Vertex" @foo
}
// -----
// CHECK: v1.5
-spirv.module Logical GLSL450 requires #spirv.vce<v1.5, [Shader], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.5, [Shader, Linkage], []> {
}
// -----
@@ -26,13 +34,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.6, [Shader, Linkage], []> {
// -----
-// CHECK: [Shader, Float16]
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Float16], []> {
+// CHECK: [Shader, Float16, Linkage]
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Float16, Linkage], []> {
}
// -----
// CHECK: [SPV_KHR_float_controls, SPV_KHR_subgroup_vote]
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], [SPV_KHR_float_controls, SPV_KHR_subgroup_vote]> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], [SPV_KHR_float_controls, SPV_KHR_subgroup_vote]> {
}
diff --git a/mlir/test/Target/Wasm/abs.mlir b/mlir/test/Target/Wasm/abs.mlir
new file mode 100644
index 0000000..9c45ba7
--- /dev/null
+++ b/mlir/test/Target/Wasm/abs.mlir
@@ -0,0 +1,23 @@
+// RUN: yaml2obj %S/inputs/abs.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to generate this test:
+(module
+ (func (export "abs_f32") (result f32)
+ f32.const 10
+ f32.abs)
+
+ (func (export "abs_f64") (result f64)
+ f64.const 10
+ f64.abs)
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @abs_f32() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.abs %[[VAL_0]] : f32
+// CHECK: wasmssa.return %[[VAL_1]] : f32
+
+// CHECK-LABEL: wasmssa.func @abs_f64() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.abs %[[VAL_0]] : f64
+// CHECK: wasmssa.return %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/and.mlir b/mlir/test/Target/Wasm/and.mlir
new file mode 100644
index 0000000..4c0fea0
--- /dev/null
+++ b/mlir/test/Target/Wasm/and.mlir
@@ -0,0 +1,27 @@
+// RUN: yaml2obj %S/inputs/and.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to generate this test:
+(module
+ (func (export "and_i32") (result i32)
+ i32.const 10
+ i32.const 3
+ i32.and)
+
+ (func (export "and_i64") (result i64)
+ i64.const 10
+ i64.const 3
+ i64.and)
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @and_i32() -> i32 {
+// CHECK: %0 = wasmssa.const 10 : i32
+// CHECK: %1 = wasmssa.const 3 : i32
+// CHECK: %2 = wasmssa.and %0 %1 : i32
+// CHECK: wasmssa.return %2 : i32
+
+// CHECK-LABEL: wasmssa.func @and_i64() -> i64 {
+// CHECK: %0 = wasmssa.const 10 : i64
+// CHECK: %1 = wasmssa.const 3 : i64
+// CHECK: %2 = wasmssa.and %0 %1 : i64
+// CHECK: wasmssa.return %2 : i64
diff --git a/mlir/test/Target/Wasm/bad_wasm_version.yaml b/mlir/test/Target/Wasm/bad_wasm_version.yaml
new file mode 100644
index 0000000..f834afb
--- /dev/null
+++ b/mlir/test/Target/Wasm/bad_wasm_version.yaml
@@ -0,0 +1,8 @@
+# RUN: yaml2obj %s -o - | not mlir-translate --import-wasm 2>&1 | FileCheck %s
+
+# CHECK: unsupported Wasm version
+
+--- !WASM
+FileHeader:
+ Version: 0xDEADBEEF
+...
diff --git a/mlir/test/Target/Wasm/clz.mlir b/mlir/test/Target/Wasm/clz.mlir
new file mode 100644
index 0000000..3e6641d
--- /dev/null
+++ b/mlir/test/Target/Wasm/clz.mlir
@@ -0,0 +1,25 @@
+// RUN: yaml2obj %S/inputs/clz.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to generate this test:
+(module
+ (func (export "clz_i32") (result i32)
+ i32.const 10
+ i32.clz
+ )
+
+ (func (export "clz_i64") (result i64)
+ i64.const 10
+ i64.clz
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @clz_i32() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.clz %[[VAL_0]] : i32
+// CHECK: wasmssa.return %[[VAL_1]] : i32
+
+// CHECK-LABEL: wasmssa.func @clz_i64() -> i64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.clz %[[VAL_0]] : i64
+// CHECK: wasmssa.return %[[VAL_1]] : i64
diff --git a/mlir/test/Target/Wasm/const.mlir b/mlir/test/Target/Wasm/const.mlir
new file mode 100644
index 0000000..aa9e76f
--- /dev/null
+++ b/mlir/test/Target/Wasm/const.mlir
@@ -0,0 +1,37 @@
+// RUN: yaml2obj %S/inputs/const.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+/* Source code used to create this test:
+(module
+ (func(result i32)
+ i32.const 1
+ )
+ (func (result i64)
+ i64.const 3
+ )
+ (func (result f32)
+ f32.const 4.0
+ )
+ (func (result f64)
+ f64.const 9.0
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func nested @func_0() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1 : i32
+// CHECK: wasmssa.return %[[VAL_0]] : i32
+// CHECK: }
+
+// CHECK-LABEL: wasmssa.func nested @func_1() -> i64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 3 : i64
+// CHECK: wasmssa.return %[[VAL_0]] : i64
+// CHECK: }
+
+// CHECK-LABEL: wasmssa.func nested @func_2() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 4.000000e+00 : f32
+// CHECK: wasmssa.return %[[VAL_0]] : f32
+// CHECK: }
+
+// CHECK-LABEL: wasmssa.func nested @func_3() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 9.000000e+00 : f64
+// CHECK: wasmssa.return %[[VAL_0]] : f64
+// CHECK: }
diff --git a/mlir/test/Target/Wasm/copysign.mlir b/mlir/test/Target/Wasm/copysign.mlir
new file mode 100644
index 0000000..33d7a56
--- /dev/null
+++ b/mlir/test/Target/Wasm/copysign.mlir
@@ -0,0 +1,31 @@
+// RUN: yaml2obj %S/inputs/copysign.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to generate this test:
+(module
+ (func (export "copysign_f32") (result f32)
+ f32.const 10
+ f32.const 1
+ f32.copysign
+ )
+
+ (func (export "copysign_f64") (result f64)
+ f64.const 10
+ f64.const 1
+ f64.copysign
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @copysign_f32() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.000000e+00 : f32
+// CHECK: %[[VAL_2:.*]] = wasmssa.copysign %[[VAL_0]] %[[VAL_1]] : f32
+// CHECK: wasmssa.return %[[VAL_2]] : f32
+// CHECK: }
+
+// CHECK-LABEL: wasmssa.func @copysign_f64() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.000000e+00 : f64
+// CHECK: %[[VAL_2:.*]] = wasmssa.copysign %[[VAL_0]] %[[VAL_1]] : f64
+// CHECK: wasmssa.return %[[VAL_2]] : f64
+// CHECK: }
diff --git a/mlir/test/Target/Wasm/ctz.mlir b/mlir/test/Target/Wasm/ctz.mlir
new file mode 100644
index 0000000..6c0806f
--- /dev/null
+++ b/mlir/test/Target/Wasm/ctz.mlir
@@ -0,0 +1,25 @@
+// RUN: yaml2obj %S/inputs/ctz.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to generate this test:
+(module
+ (func (export "ctz_i32") (result i32)
+ i32.const 10
+ i32.ctz
+ )
+
+ (func (export "ctz_i64") (result i64)
+ i64.const 10
+ i64.ctz
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @ctz_i32() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.ctz %[[VAL_0]] : i32
+// CHECK: wasmssa.return %[[VAL_1]] : i32
+
+// CHECK-LABEL: wasmssa.func @ctz_i64() -> i64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.ctz %[[VAL_0]] : i64
+// CHECK: wasmssa.return %[[VAL_1]] : i64
diff --git a/mlir/test/Target/Wasm/div.mlir b/mlir/test/Target/Wasm/div.mlir
new file mode 100644
index 0000000..c91f780
--- /dev/null
+++ b/mlir/test/Target/Wasm/div.mlir
@@ -0,0 +1,127 @@
+// RUN: yaml2obj %S/inputs/div.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module
+ (func (export "div_u_i32") (result i32)
+ i32.const 10
+ i32.const 2
+ i32.div_u
+ )
+
+ (func (export "div_u_i32_zero") (result i32)
+ i32.const 10
+ i32.const 0
+ i32.div_u
+ )
+
+ (func (export "div_s_i32") (result i32)
+ i32.const 10
+ i32.const 2
+ i32.div_s
+ )
+
+ (func (export "div_s_i32_zero") (result i32)
+ i32.const 10
+ i32.const 0
+ i32.div_s
+ )
+
+ (func (export "div_u_i64") (result i64)
+ i64.const 10
+ i64.const 2
+ i64.div_u
+ )
+
+ ;; explode
+ (func (export "div_u_i64_zero") (result i64)
+ i64.const 10
+ i64.const 0
+ i64.div_u
+ )
+
+ (func (export "div_s_i64") (result i64)
+ i64.const 10
+ i64.const 2
+ i64.div_s
+ )
+
+ ;; explode
+ (func (export "div_s_i64_zero") (result i64)
+ i64.const 10
+ i64.const 0
+ i64.div_s
+ )
+
+ (func (export "div_f32") (result f32)
+ f32.const 10
+ f32.const 2
+ f32.div
+ )
+
+ (func (export "div_f64") (result f64)
+ f64.const 10
+ f64.const 2
+ f64.div
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @div_u_i32() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 2 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.div_ui %[[VAL_0]] %[[VAL_1]] : i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @div_u_i32_zero() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 0 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.div_ui %[[VAL_0]] %[[VAL_1]] : i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @div_s_i32() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 2 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.div_si %[[VAL_0]] %[[VAL_1]] : i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @div_s_i32_zero() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 0 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.div_si %[[VAL_0]] %[[VAL_1]] : i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @div_u_i64() -> i64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 2 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.div_ui %[[VAL_0]] %[[VAL_1]] : i64
+// CHECK: wasmssa.return %[[VAL_2]] : i64
+
+// CHECK-LABEL: wasmssa.func @div_u_i64_zero() -> i64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 0 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.div_ui %[[VAL_0]] %[[VAL_1]] : i64
+// CHECK: wasmssa.return %[[VAL_2]] : i64
+
+// CHECK-LABEL: wasmssa.func @div_s_i64() -> i64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 2 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.div_si %[[VAL_0]] %[[VAL_1]] : i64
+// CHECK: wasmssa.return %[[VAL_2]] : i64
+
+// CHECK-LABEL: wasmssa.func @div_s_i64_zero() -> i64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 0 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.div_si %[[VAL_0]] %[[VAL_1]] : i64
+// CHECK: wasmssa.return %[[VAL_2]] : i64
+
+// CHECK-LABEL: wasmssa.func @div_f32() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 2.000000e+00 : f32
+// CHECK: %[[VAL_2:.*]] = wasmssa.div %[[VAL_0]] %[[VAL_1]] : f32
+// CHECK: wasmssa.return %[[VAL_2]] : f32
+
+// CHECK-LABEL: wasmssa.func @div_f64() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 2.000000e+00 : f64
+// CHECK: %[[VAL_2:.*]] = wasmssa.div %[[VAL_0]] %[[VAL_1]] : f64
+// CHECK: wasmssa.return %[[VAL_2]] : f64
diff --git a/mlir/test/Target/Wasm/function_export_out_of_scope.yaml b/mlir/test/Target/Wasm/function_export_out_of_scope.yaml
new file mode 100644
index 0000000..b08c2c8
--- /dev/null
+++ b/mlir/test/Target/Wasm/function_export_out_of_scope.yaml
@@ -0,0 +1,13 @@
+# RUN: yaml2obj %s | not mlir-translate --import-wasm -o - 2>&1 | FileCheck %s
+
+# CHECK: trying to export function 42 which is undefined in this scope
+
+--- !WASM
+FileHeader:
+ Version: 0x00000001
+Sections:
+ - Type: EXPORT
+ Exports:
+ - Name: function_export
+ Kind: FUNCTION
+ Index: 42
diff --git a/mlir/test/Target/Wasm/global.mlir b/mlir/test/Target/Wasm/global.mlir
new file mode 100644
index 0000000..e72fe69
--- /dev/null
+++ b/mlir/test/Target/Wasm/global.mlir
@@ -0,0 +1,66 @@
+// RUN: yaml2obj %S/inputs/global.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+/* Source code used to create this test:
+(module
+
+;; import a global variable from js
+(global $imported_glob (import "env" "from_js") i32)
+
+;; create a global variable
+(global $normal_glob i32(i32.const 10))
+(global $glob_mut (mut i32) (i32.const 10))
+(global $glob_mut_ext (mut i32) (i32.const 10))
+
+(global $normal_glob_i64 i64(i64.const 11))
+(global $normal_glob_f32 f32(f32.const 12))
+(global $normal_glob_f64 f64(f64.const 13))
+
+(func $main (result i32)
+;; load both global variables onto the stack
+global.get $imported_glob
+global.get $normal_glob
+
+i32.add ;; add up both globals
+
+global.get $glob_mut
+global.get $glob_mut_ext
+i32.add
+i32.add
+)
+)
+*/
+
+// CHECK-LABEL: wasmssa.import_global "from_js" from "env" as @global_0 nested : i32
+
+// CHECK-LABEL: wasmssa.func nested @func_0() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.global_get @global_0 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.global_get @global_1 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.add %[[VAL_0]] %[[VAL_1]] : i32
+// CHECK: %[[VAL_3:.*]] = wasmssa.global_get @global_2 : i32
+// CHECK: %[[VAL_4:.*]] = wasmssa.global_get @global_3 : i32
+// CHECK: %[[VAL_5:.*]] = wasmssa.add %[[VAL_3]] %[[VAL_4]] : i32
+// CHECK: %[[VAL_6:.*]] = wasmssa.add %[[VAL_2]] %[[VAL_5]] : i32
+// CHECK: wasmssa.return %[[VAL_6]] : i32
+
+// CHECK-LABEL: wasmssa.global @global_1 i32 nested : {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
+// CHECK: wasmssa.return %[[VAL_0]] : i32
+
+// CHECK-LABEL: wasmssa.global @global_2 i32 mutable nested : {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
+// CHECK: wasmssa.return %[[VAL_0]] : i32
+
+// CHECK-LABEL: wasmssa.global @global_3 i32 mutable nested : {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
+// CHECK: wasmssa.return %[[VAL_0]] : i32
+
+// CHECK-LABEL: wasmssa.global @global_4 i64 nested : {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 11 : i64
+// CHECK: wasmssa.return %[[VAL_0]] : i64
+
+// CHECK-LABEL: wasmssa.global @global_5 f32 nested : {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.200000e+01 : f32
+// CHECK: wasmssa.return %[[VAL_0]] : f32
+
+// CHECK-LABEL: wasmssa.global @global_6 f64 nested : {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.300000e+01 : f64
+// CHECK: wasmssa.return %[[VAL_0]] : f64
diff --git a/mlir/test/Target/Wasm/import.mlir b/mlir/test/Target/Wasm/import.mlir
new file mode 100644
index 0000000..541dcf3
--- /dev/null
+++ b/mlir/test/Target/Wasm/import.mlir
@@ -0,0 +1,19 @@
+// RUN: yaml2obj %S/inputs/import.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module
+(import "my_module" "foo" (func $foo (param i32)))
+(import "my_module" "bar" (func $bar (param i32)))
+(import "my_module" "table" (table $round 2 funcref))
+(import "my_module" "mem" (memory $mymem 2))
+(import "my_module" "glob" (global $globglob i32))
+(import "my_other_module" "glob_mut" (global $glob_mut (mut i32)))
+)
+*/
+
+// CHECK-LABEL: wasmssa.import_func "foo" from "my_module" as @func_0 {sym_visibility = "nested", type = (i32) -> ()}
+// CHECK: wasmssa.import_func "bar" from "my_module" as @func_1 {sym_visibility = "nested", type = (i32) -> ()}
+// CHECK: wasmssa.import_table "table" from "my_module" as @table_0 {sym_visibility = "nested", type = !wasmssa<tabletype !wasmssa.funcref [2:]>}
+// CHECK: wasmssa.import_mem "mem" from "my_module" as @mem_0 {limits = !wasmssa<limit[2:]>, sym_visibility = "nested"}
+// CHECK: wasmssa.import_global "glob" from "my_module" as @global_0 nested : i32
+// CHECK: wasmssa.import_global "glob_mut" from "my_other_module" as @global_1 mutable nested : i32
diff --git a/mlir/test/Target/Wasm/inputs/abs.yaml.wasm b/mlir/test/Target/Wasm/inputs/abs.yaml.wasm
new file mode 100644
index 0000000..1cb6d21
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/abs.yaml.wasm
@@ -0,0 +1,33 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - F32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - F64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1 ]
+ - Type: EXPORT
+ Exports:
+ - Name: abs_f32
+ Kind: FUNCTION
+ Index: 0
+ - Name: abs_f64
+ Kind: FUNCTION
+ Index: 1
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 43000020418B0B
+ - Index: 1
+ Locals: []
+ Body: 440000000000002440990B
+...
diff --git a/mlir/test/Target/Wasm/inputs/and.yaml.wasm b/mlir/test/Target/Wasm/inputs/and.yaml.wasm
new file mode 100644
index 0000000..926445b
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/and.yaml.wasm
@@ -0,0 +1,33 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - I64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1 ]
+ - Type: EXPORT
+ Exports:
+ - Name: and_i32
+ Kind: FUNCTION
+ Index: 0
+ - Name: and_i64
+ Kind: FUNCTION
+ Index: 1
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410A4103710B
+ - Index: 1
+ Locals: []
+ Body: 420A4203830B
+...
diff --git a/mlir/test/Target/Wasm/inputs/clz.yaml.wasm b/mlir/test/Target/Wasm/inputs/clz.yaml.wasm
new file mode 100644
index 0000000..f537bdb
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/clz.yaml.wasm
@@ -0,0 +1,33 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - I64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1 ]
+ - Type: EXPORT
+ Exports:
+ - Name: clz_i32
+ Kind: FUNCTION
+ Index: 0
+ - Name: clz_i64
+ Kind: FUNCTION
+ Index: 1
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410A670B
+ - Index: 1
+ Locals: []
+ Body: 420A790B
+...
diff --git a/mlir/test/Target/Wasm/inputs/const.yaml.wasm b/mlir/test/Target/Wasm/inputs/const.yaml.wasm
new file mode 100644
index 0000000..be8f88e
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/const.yaml.wasm
@@ -0,0 +1,39 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - I64
+ - Index: 2
+ ParamTypes: []
+ ReturnTypes:
+ - F32
+ - Index: 3
+ ParamTypes: []
+ ReturnTypes:
+ - F64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1, 2, 3 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 41010B
+ - Index: 1
+ Locals: []
+ Body: 42030B
+ - Index: 2
+ Locals: []
+ Body: 43000080400B
+ - Index: 3
+ Locals: []
+ Body: 4400000000000022400B
+...
diff --git a/mlir/test/Target/Wasm/inputs/copysign.yaml.wasm b/mlir/test/Target/Wasm/inputs/copysign.yaml.wasm
new file mode 100644
index 0000000..46c0412
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/copysign.yaml.wasm
@@ -0,0 +1,33 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - F32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - F64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1 ]
+ - Type: EXPORT
+ Exports:
+ - Name: copysign_f32
+ Kind: FUNCTION
+ Index: 0
+ - Name: copysign_f64
+ Kind: FUNCTION
+ Index: 1
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 4300002041430000803F980B
+ - Index: 1
+ Locals: []
+ Body: 44000000000000244044000000000000F03FA60B
+...
diff --git a/mlir/test/Target/Wasm/inputs/ctz.yaml.wasm b/mlir/test/Target/Wasm/inputs/ctz.yaml.wasm
new file mode 100644
index 0000000..5140085
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/ctz.yaml.wasm
@@ -0,0 +1,33 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - I64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1 ]
+ - Type: EXPORT
+ Exports:
+ - Name: ctz_i32
+ Kind: FUNCTION
+ Index: 0
+ - Name: ctz_i64
+ Kind: FUNCTION
+ Index: 1
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410A680B
+ - Index: 1
+ Locals: []
+ Body: 420A7A0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/div.yaml.wasm b/mlir/test/Target/Wasm/inputs/div.yaml.wasm
new file mode 100644
index 0000000..648e10c
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/div.yaml.wasm
@@ -0,0 +1,89 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - I64
+ - Index: 2
+ ParamTypes: []
+ ReturnTypes:
+ - F32
+ - Index: 3
+ ParamTypes: []
+ ReturnTypes:
+ - F64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 0, 0, 0, 1, 1, 1, 1, 2, 3 ]
+ - Type: EXPORT
+ Exports:
+ - Name: div_u_i32
+ Kind: FUNCTION
+ Index: 0
+ - Name: div_u_i32_zero
+ Kind: FUNCTION
+ Index: 1
+ - Name: div_s_i32
+ Kind: FUNCTION
+ Index: 2
+ - Name: div_s_i32_zero
+ Kind: FUNCTION
+ Index: 3
+ - Name: div_u_i64
+ Kind: FUNCTION
+ Index: 4
+ - Name: div_u_i64_zero
+ Kind: FUNCTION
+ Index: 5
+ - Name: div_s_i64
+ Kind: FUNCTION
+ Index: 6
+ - Name: div_s_i64_zero
+ Kind: FUNCTION
+ Index: 7
+ - Name: div_f32
+ Kind: FUNCTION
+ Index: 8
+ - Name: div_f64
+ Kind: FUNCTION
+ Index: 9
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410A41026E0B
+ - Index: 1
+ Locals: []
+ Body: 410A41006E0B
+ - Index: 2
+ Locals: []
+ Body: 410A41026D0B
+ - Index: 3
+ Locals: []
+ Body: 410A41006D0B
+ - Index: 4
+ Locals: []
+ Body: 420A4202800B
+ - Index: 5
+ Locals: []
+ Body: 420A4200800B
+ - Index: 6
+ Locals: []
+ Body: 420A42027F0B
+ - Index: 7
+ Locals: []
+ Body: 420A42007F0B
+ - Index: 8
+ Locals: []
+ Body: 43000020414300000040950B
+ - Index: 9
+ Locals: []
+ Body: 440000000000002440440000000000000040A30B
+...
diff --git a/mlir/test/Target/Wasm/inputs/global.yaml.wasm b/mlir/test/Target/Wasm/inputs/global.yaml.wasm
new file mode 100644
index 0000000..4bbb434
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/global.yaml.wasm
@@ -0,0 +1,63 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Type: IMPORT
+ Imports:
+ - Module: env
+ Field: from_js
+ Kind: GLOBAL
+ GlobalType: I32
+ GlobalMutable: false
+ - Type: FUNCTION
+ FunctionTypes: [ 0 ]
+ - Type: GLOBAL
+ Globals:
+ - Index: 1
+ Type: I32
+ Mutable: false
+ InitExpr:
+ Opcode: I32_CONST
+ Value: 10
+ - Index: 2
+ Type: I32
+ Mutable: true
+ InitExpr:
+ Opcode: I32_CONST
+ Value: 10
+ - Index: 3
+ Type: I32
+ Mutable: true
+ InitExpr:
+ Opcode: I32_CONST
+ Value: 10
+ - Index: 4
+ Type: I64
+ Mutable: false
+ InitExpr:
+ Opcode: I64_CONST
+ Value: 11
+ - Index: 5
+ Type: F32
+ Mutable: false
+ InitExpr:
+ Opcode: F32_CONST
+ Value: 1094713344
+ - Index: 6
+ Type: F64
+ Mutable: false
+ InitExpr:
+ Opcode: F64_CONST
+ Value: 4623507967449235456
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 230023016A230223036A6A0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/import.yaml.wasm b/mlir/test/Target/Wasm/inputs/import.yaml.wasm
new file mode 100644
index 0000000..7c467ff
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/import.yaml.wasm
@@ -0,0 +1,44 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes:
+ - I32
+ ReturnTypes: []
+ - Type: IMPORT
+ Imports:
+ - Module: my_module
+ Field: foo
+ Kind: FUNCTION
+ SigIndex: 0
+ - Module: my_module
+ Field: bar
+ Kind: FUNCTION
+ SigIndex: 0
+ - Module: my_module
+ Field: table
+ Kind: TABLE
+ Table:
+ Index: 0
+ ElemType: FUNCREF
+ Limits:
+ Minimum: 0x2
+ - Module: my_module
+ Field: mem
+ Kind: MEMORY
+ Memory:
+ Minimum: 0x2
+ - Module: my_module
+ Field: glob
+ Kind: GLOBAL
+ GlobalType: I32
+ GlobalMutable: false
+ - Module: my_other_module
+ Field: glob_mut
+ Kind: GLOBAL
+ GlobalType: I32
+ GlobalMutable: true
+...
diff --git a/mlir/test/Target/Wasm/inputs/local.yaml.wasm b/mlir/test/Target/Wasm/inputs/local.yaml.wasm
new file mode 100644
index 0000000..3e937f3
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/local.yaml.wasm
@@ -0,0 +1,37 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - F32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Index: 2
+ ParamTypes:
+ - I32
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1, 2 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals:
+ - Type: F32
+ Count: 2
+ Body: 43000000412100200043000040412201920B
+ - Index: 1
+ Locals:
+ - Type: I32
+ Count: 2
+ Body: 410821002000410C22016A0B
+ - Index: 2
+ Locals: []
+ Body: 4103210020000B
+...
diff --git a/mlir/test/Target/Wasm/inputs/max.yaml.wasm b/mlir/test/Target/Wasm/inputs/max.yaml.wasm
new file mode 100644
index 0000000..fc04b01
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/max.yaml.wasm
@@ -0,0 +1,33 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - F32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - F64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1 ]
+ - Type: EXPORT
+ Exports:
+ - Name: min_f32
+ Kind: FUNCTION
+ Index: 0
+ - Name: min_f64
+ Kind: FUNCTION
+ Index: 1
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 4300002041430000803F970B
+ - Index: 1
+ Locals: []
+ Body: 44000000000000244044000000000000F03FA50B
+...
diff --git a/mlir/test/Target/Wasm/inputs/memory_min_eq_max.yaml.wasm b/mlir/test/Target/Wasm/inputs/memory_min_eq_max.yaml.wasm
new file mode 100644
index 0000000..f3edf5f
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/memory_min_eq_max.yaml.wasm
@@ -0,0 +1,10 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: MEMORY
+ Memories:
+ - Flags: [ HAS_MAX ]
+ Minimum: 0x0
+ Maximum: 0x0
+...
diff --git a/mlir/test/Target/Wasm/inputs/memory_min_max.yaml.wasm b/mlir/test/Target/Wasm/inputs/memory_min_max.yaml.wasm
new file mode 100644
index 0000000..fe70fb6
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/memory_min_max.yaml.wasm
@@ -0,0 +1,10 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: MEMORY
+ Memories:
+ - Flags: [ HAS_MAX ]
+ Minimum: 0x0
+ Maximum: 0x10000
+...
diff --git a/mlir/test/Target/Wasm/inputs/memory_min_no_max.yaml.wasm b/mlir/test/Target/Wasm/inputs/memory_min_no_max.yaml.wasm
new file mode 100644
index 0000000..8508ce3
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/memory_min_no_max.yaml.wasm
@@ -0,0 +1,8 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: MEMORY
+ Memories:
+ - Minimum: 0x1
+...
diff --git a/mlir/test/Target/Wasm/inputs/min.yaml.wasm b/mlir/test/Target/Wasm/inputs/min.yaml.wasm
new file mode 100644
index 0000000..925a5e9
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/min.yaml.wasm
@@ -0,0 +1,33 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - F32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - F64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1 ]
+ - Type: EXPORT
+ Exports:
+ - Name: min_f32
+ Kind: FUNCTION
+ Index: 0
+ - Name: min_f64
+ Kind: FUNCTION
+ Index: 1
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 4300002041430000803F960B
+ - Index: 1
+ Locals: []
+ Body: 44000000000000244044000000000000F03FA40B
+...
diff --git a/mlir/test/Target/Wasm/inputs/neg.yaml.wasm b/mlir/test/Target/Wasm/inputs/neg.yaml.wasm
new file mode 100644
index 0000000..8392a2b
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/neg.yaml.wasm
@@ -0,0 +1,33 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - F32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - F64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1 ]
+ - Type: EXPORT
+ Exports:
+ - Name: neg_f32
+ Kind: FUNCTION
+ Index: 0
+ - Name: neg_f64
+ Kind: FUNCTION
+ Index: 1
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 43000020418C0B
+ - Index: 1
+ Locals: []
+ Body: 4400000000000024409A0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/or.yaml.wasm b/mlir/test/Target/Wasm/inputs/or.yaml.wasm
new file mode 100644
index 0000000..9528ce8
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/or.yaml.wasm
@@ -0,0 +1,33 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - I64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1 ]
+ - Type: EXPORT
+ Exports:
+ - Name: or_i32
+ Kind: FUNCTION
+ Index: 0
+ - Name: or_i64
+ Kind: FUNCTION
+ Index: 1
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410A4103720B
+ - Index: 1
+ Locals: []
+ Body: 420A4203840B
+...
diff --git a/mlir/test/Target/Wasm/inputs/popcnt.yaml.wasm b/mlir/test/Target/Wasm/inputs/popcnt.yaml.wasm
new file mode 100644
index 0000000..03c57ad
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/popcnt.yaml.wasm
@@ -0,0 +1,33 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - I64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1 ]
+ - Type: EXPORT
+ Exports:
+ - Name: popcnt_i32
+ Kind: FUNCTION
+ Index: 0
+ - Name: popcnt_i64
+ Kind: FUNCTION
+ Index: 1
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410A690B
+ - Index: 1
+ Locals: []
+ Body: 420A7B0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/rem.yaml.wasm b/mlir/test/Target/Wasm/inputs/rem.yaml.wasm
new file mode 100644
index 0000000..468a9db
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/rem.yaml.wasm
@@ -0,0 +1,45 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - I64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1, 0, 1 ]
+ - Type: EXPORT
+ Exports:
+ - Name: rem_u_i32
+ Kind: FUNCTION
+ Index: 0
+ - Name: rem_u_i64
+ Kind: FUNCTION
+ Index: 1
+ - Name: rem_s_i32
+ Kind: FUNCTION
+ Index: 2
+ - Name: rem_s_i64
+ Kind: FUNCTION
+ Index: 3
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410A4103700B
+ - Index: 1
+ Locals: []
+ Body: 420A4203820B
+ - Index: 2
+ Locals: []
+ Body: 410A41036F0B
+ - Index: 3
+ Locals: []
+ Body: 420A4203810B
+...
diff --git a/mlir/test/Target/Wasm/inputs/rotl.yaml.wasm b/mlir/test/Target/Wasm/inputs/rotl.yaml.wasm
new file mode 100644
index 0000000..87466cb1
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/rotl.yaml.wasm
@@ -0,0 +1,33 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - I64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1 ]
+ - Type: EXPORT
+ Exports:
+ - Name: rotl_i32
+ Kind: FUNCTION
+ Index: 0
+ - Name: rotl_i64
+ Kind: FUNCTION
+ Index: 1
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410A4103770B
+ - Index: 1
+ Locals: []
+ Body: 420A4203890B
+...
diff --git a/mlir/test/Target/Wasm/inputs/rotr.yaml.wasm b/mlir/test/Target/Wasm/inputs/rotr.yaml.wasm
new file mode 100644
index 0000000..805a93f
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/rotr.yaml.wasm
@@ -0,0 +1,33 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - I64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1 ]
+ - Type: EXPORT
+ Exports:
+ - Name: rotr_i32
+ Kind: FUNCTION
+ Index: 0
+ - Name: rotr_i64
+ Kind: FUNCTION
+ Index: 1
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410A4103780B
+ - Index: 1
+ Locals: []
+ Body: 420A42038A0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/shl.yaml.wasm b/mlir/test/Target/Wasm/inputs/shl.yaml.wasm
new file mode 100644
index 0000000..d07605e
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/shl.yaml.wasm
@@ -0,0 +1,33 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - I64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1 ]
+ - Type: EXPORT
+ Exports:
+ - Name: shl_i32
+ Kind: FUNCTION
+ Index: 0
+ - Name: shl_i64
+ Kind: FUNCTION
+ Index: 1
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410A4103740B
+ - Index: 1
+ Locals: []
+ Body: 420A4203860B
+...
diff --git a/mlir/test/Target/Wasm/inputs/shr_s.yaml.wasm b/mlir/test/Target/Wasm/inputs/shr_s.yaml.wasm
new file mode 100644
index 0000000..d5d8013
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/shr_s.yaml.wasm
@@ -0,0 +1,33 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - I64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1 ]
+ - Type: EXPORT
+ Exports:
+ - Name: shr_s_i32
+ Kind: FUNCTION
+ Index: 0
+ - Name: shr_s_i64
+ Kind: FUNCTION
+ Index: 1
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410A4103750B
+ - Index: 1
+ Locals: []
+ Body: 420A4203870B
+...
diff --git a/mlir/test/Target/Wasm/inputs/shr_u.yaml.wasm b/mlir/test/Target/Wasm/inputs/shr_u.yaml.wasm
new file mode 100644
index 0000000..cd81514
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/shr_u.yaml.wasm
@@ -0,0 +1,33 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - I64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1 ]
+ - Type: EXPORT
+ Exports:
+ - Name: shr_u_i32
+ Kind: FUNCTION
+ Index: 0
+ - Name: shr_u_i64
+ Kind: FUNCTION
+ Index: 1
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410A4103760B
+ - Index: 1
+ Locals: []
+ Body: 420A4203880B
+...
diff --git a/mlir/test/Target/Wasm/inputs/sqrt.yaml.wasm b/mlir/test/Target/Wasm/inputs/sqrt.yaml.wasm
new file mode 100644
index 0000000..f8ab84b
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/sqrt.yaml.wasm
@@ -0,0 +1,33 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - F32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - F64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1 ]
+ - Type: EXPORT
+ Exports:
+ - Name: sqrt_f32
+ Kind: FUNCTION
+ Index: 0
+ - Name: sqrt_f64
+ Kind: FUNCTION
+ Index: 1
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 4300002041910B
+ - Index: 1
+ Locals: []
+ Body: 4400000000000024409F0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/stats.yaml.wasm b/mlir/test/Target/Wasm/inputs/stats.yaml.wasm
new file mode 100644
index 0000000..bf57768
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/stats.yaml.wasm
@@ -0,0 +1,38 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes:
+ - I32
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0 ]
+ - Type: TABLE
+ Tables:
+ - Index: 0
+ ElemType: FUNCREF
+ Limits:
+ Minimum: 0x2
+ - Type: MEMORY
+ Memories:
+ - Flags: [ HAS_MAX ]
+ Minimum: 0x0
+ Maximum: 0x10000
+ - Type: GLOBAL
+ Globals:
+ - Index: 0
+ Type: I32
+ Mutable: false
+ InitExpr:
+ Opcode: I32_CONST
+ Value: 10
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 20000B
+...
diff --git a/mlir/test/Target/Wasm/inputs/sub.yaml.wasm b/mlir/test/Target/Wasm/inputs/sub.yaml.wasm
new file mode 100644
index 0000000..95b6bcc
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/sub.yaml.wasm
@@ -0,0 +1,39 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - I64
+ - Index: 2
+ ParamTypes: []
+ ReturnTypes:
+ - F32
+ - Index: 3
+ ParamTypes: []
+ ReturnTypes:
+ - F64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1, 2, 3 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410C41326B0B
+ - Index: 1
+ Locals: []
+ Body: 421442057D0B
+ - Index: 2
+ Locals: []
+ Body: 430000A0404300006041930B
+ - Index: 3
+ Locals: []
+ Body: 440000000000003140440000000000000000A10B
+...
diff --git a/mlir/test/Target/Wasm/inputs/table.yaml.wasm b/mlir/test/Target/Wasm/inputs/table.yaml.wasm
new file mode 100644
index 0000000..387f418
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/table.yaml.wasm
@@ -0,0 +1,23 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TABLE
+ Tables:
+ - Index: 0
+ ElemType: FUNCREF
+ Limits:
+ Minimum: 0x2
+ - Index: 1
+ ElemType: FUNCREF
+ Limits:
+ Flags: [ HAS_MAX ]
+ Minimum: 0x2
+ Maximum: 0x4
+ - Index: 2
+ ElemType: EXTERNREF
+ Limits:
+ Flags: [ HAS_MAX ]
+ Minimum: 0x2
+ Maximum: 0x4
+...
diff --git a/mlir/test/Target/Wasm/inputs/xor.yaml.wasm b/mlir/test/Target/Wasm/inputs/xor.yaml.wasm
new file mode 100644
index 0000000..45079c3
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/xor.yaml.wasm
@@ -0,0 +1,33 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - I64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1 ]
+ - Type: EXPORT
+ Exports:
+ - Name: xor_i32
+ Kind: FUNCTION
+ Index: 0
+ - Name: xor_i64
+ Kind: FUNCTION
+ Index: 1
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410A4103730B
+ - Index: 1
+ Locals: []
+ Body: 420A4203850B
+...
diff --git a/mlir/test/Target/Wasm/invalid_function_type_index.yaml b/mlir/test/Target/Wasm/invalid_function_type_index.yaml
new file mode 100644
index 0000000..2d2954a
--- /dev/null
+++ b/mlir/test/Target/Wasm/invalid_function_type_index.yaml
@@ -0,0 +1,16 @@
+# RUN: yaml2obj %s | not mlir-translate --import-wasm -o - 2>&1 | FileCheck %s
+# CHECK: error: invalid type index: 2
+
+--- !WASM
+FileHeader:
+ Version: 0x00000001
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes:
+ - I32
+ ReturnTypes: []
+ - Type: FUNCTION
+ FunctionTypes:
+ - 2
diff --git a/mlir/test/Target/Wasm/local.mlir b/mlir/test/Target/Wasm/local.mlir
new file mode 100644
index 0000000..32f5900
--- /dev/null
+++ b/mlir/test/Target/Wasm/local.mlir
@@ -0,0 +1,59 @@
+// RUN: yaml2obj %S/inputs/local.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+/* Source code used to create this test:
+(module
+ (func $local_f32 (result f32)
+ (local $var1 f32)
+ (local $var2 f32)
+ f32.const 8.0
+ local.set $var1
+ local.get $var1
+ f32.const 12.0
+ local.tee $var2
+ f32.add
+ )
+ (func $local_i32 (result i32)
+ (local $var1 i32)
+ (local $var2 i32)
+ i32.const 8
+ local.set $var1
+ local.get $var1
+ i32.const 12
+ local.tee $var2
+ i32.add
+ )
+ (func $local_arg (param $var i32) (result i32)
+ i32.const 3
+ local.set $var
+ local.get $var
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func nested @func_0() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.local of type f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.local of type f32
+// CHECK: %[[VAL_2:.*]] = wasmssa.const 8.000000e+00 : f32
+// CHECK: wasmssa.local_set %[[VAL_0]] : ref to f32 to %[[VAL_2]] : f32
+// CHECK: %[[VAL_3:.*]] = wasmssa.local_get %[[VAL_0]] : ref to f32
+// CHECK: %[[VAL_4:.*]] = wasmssa.const 1.200000e+01 : f32
+// CHECK: %[[VAL_5:.*]] = wasmssa.local_tee %[[VAL_1]] : ref to f32 to %[[VAL_4]] : f32
+// CHECK: %[[VAL_6:.*]] = wasmssa.add %[[VAL_3]] %[[VAL_5]] : f32
+// CHECK: wasmssa.return %[[VAL_6]] : f32
+
+// CHECK-LABEL: wasmssa.func nested @func_1() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.local of type i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.local of type i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.const 8 : i32
+// CHECK: wasmssa.local_set %[[VAL_0]] : ref to i32 to %[[VAL_2]] : i32
+// CHECK: %[[VAL_3:.*]] = wasmssa.local_get %[[VAL_0]] : ref to i32
+// CHECK: %[[VAL_4:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_5:.*]] = wasmssa.local_tee %[[VAL_1]] : ref to i32 to %[[VAL_4]] : i32
+// CHECK: %[[VAL_6:.*]] = wasmssa.add %[[VAL_3]] %[[VAL_5]] : i32
+// CHECK: wasmssa.return %[[VAL_6]] : i32
+
+// CHECK-LABEL: wasmssa.func nested @func_2(
+// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 3 : i32
+// CHECK: wasmssa.local_set %[[ARG0]] : ref to i32 to %[[VAL_0]] : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32
+// CHECK: wasmssa.return %[[VAL_1]] : i32
diff --git a/mlir/test/Target/Wasm/max.mlir b/mlir/test/Target/Wasm/max.mlir
new file mode 100644
index 0000000..4ef2042
--- /dev/null
+++ b/mlir/test/Target/Wasm/max.mlir
@@ -0,0 +1,30 @@
+// RUN: yaml2obj %S/inputs/max.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to generate this test:
+(module
+ (func (export "max_f32") (result f32)
+ f32.const 10
+ f32.const 1
+ f32.max
+ )
+
+ (func (export "max_f64") (result f64)
+ f64.const 10
+ f64.const 1
+ f64.max
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @min_f32() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.000000e+00 : f32
+// CHECK: %[[VAL_2:.*]] = wasmssa.max %[[VAL_0]] %[[VAL_1]] : f32
+// CHECK: wasmssa.return %[[VAL_2]] : f32
+
+
+// CHECK-LABEL: wasmssa.func @min_f64() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.000000e+00 : f64
+// CHECK: %[[VAL_2:.*]] = wasmssa.max %[[VAL_0]] %[[VAL_1]] : f64
+// CHECK: wasmssa.return %[[VAL_2]] : f64
diff --git a/mlir/test/Target/Wasm/memory_min_eq_max.mlir b/mlir/test/Target/Wasm/memory_min_eq_max.mlir
new file mode 100644
index 0000000..2ba5ab5
--- /dev/null
+++ b/mlir/test/Target/Wasm/memory_min_eq_max.mlir
@@ -0,0 +1,7 @@
+// RUN: yaml2obj %S/inputs/memory_min_eq_max.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module (memory 0 0))
+*/
+
+// CHECK-LABEL: wasmssa.memory @mem_0 nested !wasmssa<limit[0: 0]>
diff --git a/mlir/test/Target/Wasm/memory_min_max.mlir b/mlir/test/Target/Wasm/memory_min_max.mlir
new file mode 100644
index 0000000..ebf6418
--- /dev/null
+++ b/mlir/test/Target/Wasm/memory_min_max.mlir
@@ -0,0 +1,7 @@
+// RUN: yaml2obj %S/inputs/memory_min_max.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module (memory 0 65536))
+*/
+
+// CHECK-LABEL: wasmssa.memory @mem_0 nested !wasmssa<limit[0: 65536]>
diff --git a/mlir/test/Target/Wasm/memory_min_no_max.mlir b/mlir/test/Target/Wasm/memory_min_no_max.mlir
new file mode 100644
index 0000000..8d88786
--- /dev/null
+++ b/mlir/test/Target/Wasm/memory_min_no_max.mlir
@@ -0,0 +1,7 @@
+// RUN: yaml2obj %S/inputs/memory_min_no_max.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module (memory 1))
+*/
+
+// CHECK-LABEL: wasmssa.memory @mem_0 nested !wasmssa<limit[1:]>
diff --git a/mlir/test/Target/Wasm/min.mlir b/mlir/test/Target/Wasm/min.mlir
new file mode 100644
index 0000000..1058c7d
--- /dev/null
+++ b/mlir/test/Target/Wasm/min.mlir
@@ -0,0 +1,29 @@
+// RUN: yaml2obj %S/inputs/min.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to generate this test:
+(module
+ (func (export "min_f32") (result f32)
+ f32.const 10
+ f32.const 1
+ f32.min
+ )
+
+ (func (export "min_f64") (result f64)
+ f64.const 10
+ f64.const 1
+ f64.min
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @min_f32() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.000000e+00 : f32
+// CHECK: %[[VAL_2:.*]] = wasmssa.min %[[VAL_0]] %[[VAL_1]] : f32
+// CHECK: wasmssa.return %[[VAL_2]] : f32
+
+// CHECK-LABEL: wasmssa.func @min_f64() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.000000e+00 : f64
+// CHECK: %[[VAL_2:.*]] = wasmssa.min %[[VAL_0]] %[[VAL_1]] : f64
+// CHECK: wasmssa.return %[[VAL_2]] : f64
diff --git a/mlir/test/Target/Wasm/missing_header.yaml b/mlir/test/Target/Wasm/missing_header.yaml
new file mode 100644
index 0000000..a9f812e
--- /dev/null
+++ b/mlir/test/Target/Wasm/missing_header.yaml
@@ -0,0 +1,12 @@
+# RUN: not yaml2obj %s -o - | not mlir-translate --import-wasm 2>&1 | FileCheck %s
+
+# CHECK: source file does not contain valid Wasm header
+
+--- !WASM
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes: []
+...
diff --git a/mlir/test/Target/Wasm/neg.mlir b/mlir/test/Target/Wasm/neg.mlir
new file mode 100644
index 0000000..5811ab50
--- /dev/null
+++ b/mlir/test/Target/Wasm/neg.mlir
@@ -0,0 +1,23 @@
+// RUN: yaml2obj %S/inputs/neg.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to generate this test:
+(module
+ (func (export "neg_f32") (result f32)
+ f32.const 10
+ f32.neg)
+
+ (func (export "neg_f64") (result f64)
+ f64.const 10
+ f64.neg)
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @neg_f32() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.neg %[[VAL_0]] : f32
+// CHECK: wasmssa.return %[[VAL_1]] : f32
+
+// CHECK-LABEL: wasmssa.func @neg_f64() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.neg %[[VAL_0]] : f64
+// CHECK: wasmssa.return %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/or.mlir b/mlir/test/Target/Wasm/or.mlir
new file mode 100644
index 0000000..521f2ba
--- /dev/null
+++ b/mlir/test/Target/Wasm/or.mlir
@@ -0,0 +1,27 @@
+// RUN: yaml2obj %S/inputs/or.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to generate this test:
+(module
+ (func (export "or_i32") (result i32)
+ i32.const 10
+ i32.const 3
+ i32.or)
+
+ (func (export "or_i64") (result i64)
+ i64.const 10
+ i64.const 3
+ i64.or)
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @or_i32() -> i32 {
+// CHECK: %0 = wasmssa.const 10 : i32
+// CHECK: %1 = wasmssa.const 3 : i32
+// CHECK: %2 = wasmssa.or %0 %1 : i32
+// CHECK: wasmssa.return %2 : i32
+
+// CHECK-LABEL: wasmssa.func @or_i64() -> i64 {
+// CHECK: %0 = wasmssa.const 10 : i64
+// CHECK: %1 = wasmssa.const 3 : i64
+// CHECK: %2 = wasmssa.or %0 %1 : i64
+// CHECK: wasmssa.return %2 : i64
diff --git a/mlir/test/Target/Wasm/popcnt.mlir b/mlir/test/Target/Wasm/popcnt.mlir
new file mode 100644
index 0000000..235333a
--- /dev/null
+++ b/mlir/test/Target/Wasm/popcnt.mlir
@@ -0,0 +1,25 @@
+// RUN: yaml2obj %S/inputs/popcnt.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to generate this test:
+(module
+ (func (export "popcnt_i32") (result i32)
+ i32.const 10
+ i32.popcnt
+ )
+
+ (func (export "popcnt_i64") (result i64)
+ i64.const 10
+ i64.popcnt
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @popcnt_i32() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.popcnt %[[VAL_0]] : i32
+// CHECK: wasmssa.return %[[VAL_1]] : i32
+
+// CHECK-LABEL: wasmssa.func @popcnt_i64() -> i64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.popcnt %[[VAL_0]] : i64
+// CHECK: wasmssa.return %[[VAL_1]] : i64
diff --git a/mlir/test/Target/Wasm/rem.mlir b/mlir/test/Target/Wasm/rem.mlir
new file mode 100644
index 0000000..b19b8d9
--- /dev/null
+++ b/mlir/test/Target/Wasm/rem.mlir
@@ -0,0 +1,53 @@
+// RUN: yaml2obj %S/inputs/rem.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to generate this test:
+(module
+ (func (export "rem_u_i32") (result i32)
+ i32.const 10
+ i32.const 3
+ i32.rem_u)
+
+ (func (export "rem_u_i64") (result i64)
+ i64.const 10
+ i64.const 3
+ i64.rem_u)
+
+ (func (export "rem_s_i32") (result i32)
+ i32.const 10
+ i32.const 3
+ i32.rem_s)
+
+ (func (export "rem_s_i64") (result i64)
+ i64.const 10
+ i64.const 3
+ i64.rem_s)
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @rem_u_i32() -> i32 {
+// CHECK: %0 = wasmssa.const 10 : i32
+// CHECK: %1 = wasmssa.const 3 : i32
+// CHECK: %2 = wasmssa.rem_ui %0 %1 : i32
+// CHECK: wasmssa.return %2 : i32
+// CHECK: }
+
+// CHECK-LABEL: wasmssa.func @rem_u_i64() -> i64 {
+// CHECK: %0 = wasmssa.const 10 : i64
+// CHECK: %1 = wasmssa.const 3 : i64
+// CHECK: %2 = wasmssa.rem_ui %0 %1 : i64
+// CHECK: wasmssa.return %2 : i64
+// CHECK: }
+
+// CHECK-LABEL: wasmssa.func @rem_s_i32() -> i32 {
+// CHECK: %0 = wasmssa.const 10 : i32
+// CHECK: %1 = wasmssa.const 3 : i32
+// CHECK: %2 = wasmssa.rem_si %0 %1 : i32
+// CHECK: wasmssa.return %2 : i32
+// CHECK: }
+
+// CHECK-LABEL: wasmssa.func @rem_s_i64() -> i64 {
+// CHECK: %0 = wasmssa.const 10 : i64
+// CHECK: %1 = wasmssa.const 3 : i64
+// CHECK: %2 = wasmssa.rem_si %0 %1 : i64
+// CHECK: wasmssa.return %2 : i64
+// CHECK: }
diff --git a/mlir/test/Target/Wasm/rotl.mlir b/mlir/test/Target/Wasm/rotl.mlir
new file mode 100644
index 0000000..ec573554
--- /dev/null
+++ b/mlir/test/Target/Wasm/rotl.mlir
@@ -0,0 +1,27 @@
+// RUN: yaml2obj %S/inputs/rotl.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to generate this test:
+(module
+ (func (export "rotl_i32") (result i32)
+ i32.const 10
+ i32.const 3
+ i32.rotl)
+
+ (func (export "rotl_i64") (result i64)
+ i64.const 10
+ i64.const 3
+ i64.rotl)
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @rotl_i32() -> i32 {
+// CHECK: %0 = wasmssa.const 10 : i32
+// CHECK: %1 = wasmssa.const 3 : i32
+// CHECK: %2 = wasmssa.rotl %0 by %1 bits : i32
+// CHECK: wasmssa.return %2 : i32
+
+// CHECK-LABEL: wasmssa.func @rotl_i64() -> i64 {
+// CHECK: %0 = wasmssa.const 10 : i64
+// CHECK: %1 = wasmssa.const 3 : i64
+// CHECK: %2 = wasmssa.rotl %0 by %1 bits : i64
+// CHECK: wasmssa.return %2 : i64
diff --git a/mlir/test/Target/Wasm/rotr.mlir b/mlir/test/Target/Wasm/rotr.mlir
new file mode 100644
index 0000000..5618b43
--- /dev/null
+++ b/mlir/test/Target/Wasm/rotr.mlir
@@ -0,0 +1,27 @@
+// RUN: yaml2obj %S/inputs/rotr.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to generate this test:
+(module
+ (func (export "rotr_i32") (result i32)
+ i32.const 10
+ i32.const 3
+ i32.rotr)
+
+ (func (export "rotr_i64") (result i64)
+ i64.const 10
+ i64.const 3
+ i64.rotr)
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @rotr_i32() -> i32 {
+// CHECK: %0 = wasmssa.const 10 : i32
+// CHECK: %1 = wasmssa.const 3 : i32
+// CHECK: %2 = wasmssa.rotr %0 by %1 bits : i32
+// CHECK: wasmssa.return %2 : i32
+
+// CHECK-LABEL: wasmssa.func @rotr_i64() -> i64 {
+// CHECK: %0 = wasmssa.const 10 : i64
+// CHECK: %1 = wasmssa.const 3 : i64
+// CHECK: %2 = wasmssa.rotr %0 by %1 bits : i64
+// CHECK: wasmssa.return %2 : i64
diff --git a/mlir/test/Target/Wasm/shl.mlir b/mlir/test/Target/Wasm/shl.mlir
new file mode 100644
index 0000000..f2bdd57
--- /dev/null
+++ b/mlir/test/Target/Wasm/shl.mlir
@@ -0,0 +1,27 @@
+// RUN: yaml2obj %S/inputs/shl.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to generate this test:
+(module
+ (func (export "shl_i32") (result i32)
+ i32.const 10
+ i32.const 3
+ i32.shl)
+
+ (func (export "shl_i64") (result i64)
+ i64.const 10
+ i64.const 3
+ i64.shl)
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @shl_i32() -> i32 {
+// CHECK: %0 = wasmssa.const 10 : i32
+// CHECK: %1 = wasmssa.const 3 : i32
+// CHECK: %2 = wasmssa.shl %0 by %1 bits : i32
+// CHECK: wasmssa.return %2 : i32
+
+// CHECK-LABEL: wasmssa.func @shl_i64() -> i64 {
+// CHECK: %0 = wasmssa.const 10 : i64
+// CHECK: %1 = wasmssa.const 3 : i64
+// CHECK: %2 = wasmssa.shl %0 by %1 bits : i64
+// CHECK: wasmssa.return %2 : i64
diff --git a/mlir/test/Target/Wasm/shr_s.mlir b/mlir/test/Target/Wasm/shr_s.mlir
new file mode 100644
index 0000000..247d9be
--- /dev/null
+++ b/mlir/test/Target/Wasm/shr_s.mlir
@@ -0,0 +1,27 @@
+// RUN: yaml2obj %S/inputs/shr_s.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to generate this test:
+(module
+ (func (export "shr_s_i32") (result i32)
+ i32.const 10
+ i32.const 3
+ i32.shr_s)
+
+ (func (export "shr_s_i64") (result i64)
+ i64.const 10
+ i64.const 3
+ i64.shr_s)
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @shr_s_i32() -> i32 {
+// CHECK: %0 = wasmssa.const 10 : i32
+// CHECK: %1 = wasmssa.const 3 : i32
+// CHECK: %2 = wasmssa.shr_s %0 by %1 bits : i32
+// CHECK: wasmssa.return %2 : i32
+
+// CHECK-LABEL: wasmssa.func @shr_s_i64() -> i64 {
+// CHECK: %0 = wasmssa.const 10 : i64
+// CHECK: %1 = wasmssa.const 3 : i64
+// CHECK: %2 = wasmssa.shr_s %0 by %1 bits : i64
+// CHECK: wasmssa.return %2 : i64
diff --git a/mlir/test/Target/Wasm/shr_u.mlir b/mlir/test/Target/Wasm/shr_u.mlir
new file mode 100644
index 0000000..9a79eed
--- /dev/null
+++ b/mlir/test/Target/Wasm/shr_u.mlir
@@ -0,0 +1,27 @@
+// RUN: yaml2obj %S/inputs/shr_u.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to generate this test:
+(module
+ (func (export "shr_u_i32") (result i32)
+ i32.const 10
+ i32.const 3
+ i32.shr_u)
+
+ (func (export "shr_u_i64") (result i64)
+ i64.const 10
+ i64.const 3
+ i64.shr_u)
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @shr_u_i32() -> i32 {
+// CHECK: %0 = wasmssa.const 10 : i32
+// CHECK: %1 = wasmssa.const 3 : i32
+// CHECK: %2 = wasmssa.shr_u %0 by %1 bits : i32
+// CHECK: wasmssa.return %2 : i32
+
+// CHECK-LABEL: wasmssa.func @shr_u_i64() -> i64 {
+// CHECK: %0 = wasmssa.const 10 : i64
+// CHECK: %1 = wasmssa.const 3 : i64
+// CHECK: %2 = wasmssa.shr_u %0 by %1 bits : i64
+// CHECK: wasmssa.return %2 : i64
diff --git a/mlir/test/Target/Wasm/sqrt.mlir b/mlir/test/Target/Wasm/sqrt.mlir
new file mode 100644
index 0000000..77444ad
--- /dev/null
+++ b/mlir/test/Target/Wasm/sqrt.mlir
@@ -0,0 +1,23 @@
+// RUN: yaml2obj %S/inputs/sqrt.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to generate this test:
+(module
+ (func (export "sqrt_f32") (result f32)
+ f32.const 10
+ f32.sqrt)
+
+ (func (export "sqrt_f64") (result f64)
+ f64.const 10
+ f64.sqrt)
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @sqrt_f32() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.sqrt %[[VAL_0]] : f32
+// CHECK: wasmssa.return %[[VAL_1]] : f32
+
+// CHECK-LABEL: wasmssa.func @sqrt_f64() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.sqrt %[[VAL_0]] : f64
+// CHECK: wasmssa.return %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/sub.mlir b/mlir/test/Target/Wasm/sub.mlir
new file mode 100644
index 0000000..b9c6caf
--- /dev/null
+++ b/mlir/test/Target/Wasm/sub.mlir
@@ -0,0 +1,52 @@
+// RUN: yaml2obj %S/inputs/sub.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+/* Source code used to create this test:
+(module
+ (func $sub_i32 (result i32)
+ i32.const 12
+ i32.const 50
+ i32.sub
+ )
+
+ (func $sub_i64 (result i64)
+ i64.const 20
+ i64.const 5
+ i64.sub
+ )
+
+ (func $sub_f32 (result f32)
+ f32.const 5
+ f32.const 14
+ f32.sub
+ )
+
+ (func $sub_f64 (result f64)
+ f64.const 17
+ f64.const 0
+ f64.sub
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func nested @func_0() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.sub %[[VAL_0]] %[[VAL_1]] : i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func nested @func_1() -> i64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 20 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 5 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.sub %[[VAL_0]] %[[VAL_1]] : i64
+// CHECK: wasmssa.return %[[VAL_2]] : i64
+
+// CHECK-LABEL: wasmssa.func nested @func_2() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f32
+// CHECK: %[[VAL_2:.*]] = wasmssa.sub %[[VAL_0]] %[[VAL_1]] : f32
+// CHECK: wasmssa.return %[[VAL_2]] : f32
+
+// CHECK-LABEL: wasmssa.func nested @func_3() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.700000e+01 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 0.000000e+00 : f64
+// CHECK: %[[VAL_2:.*]] = wasmssa.sub %[[VAL_0]] %[[VAL_1]] : f64
+// CHECK: wasmssa.return %[[VAL_2]] : f64
diff --git a/mlir/test/Target/Wasm/xor.mlir b/mlir/test/Target/Wasm/xor.mlir
new file mode 100644
index 0000000..94691de
--- /dev/null
+++ b/mlir/test/Target/Wasm/xor.mlir
@@ -0,0 +1,27 @@
+// RUN: yaml2obj %S/inputs/xor.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to generate this test:
+(module
+ (func (export "xor_i32") (result i32)
+ i32.const 10
+ i32.const 3
+ i32.xor)
+
+ (func (export "xor_i64") (result i64)
+ i64.const 10
+ i64.const 3
+ i64.xor)
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @xor_i32() -> i32 {
+// CHECK: %0 = wasmssa.const 10 : i32
+// CHECK: %1 = wasmssa.const 3 : i32
+// CHECK: %2 = wasmssa.xor %0 %1 : i32
+// CHECK: wasmssa.return %2 : i32
+
+// CHECK-LABEL: wasmssa.func @xor_i64() -> i64 {
+// CHECK: %0 = wasmssa.const 10 : i64
+// CHECK: %1 = wasmssa.const 3 : i64
+// CHECK: %2 = wasmssa.xor %0 %1 : i64
+// CHECK: wasmssa.return %2 : i64
diff --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir
index 1ed0887..d8e10aa 100644
--- a/mlir/test/Transforms/inlining.mlir
+++ b/mlir/test/Transforms/inlining.mlir
@@ -5,14 +5,18 @@
// RUN: mlir-opt %s -inline='op-pipelines=func.func(canonicalize,cse)' | FileCheck %s --check-prefix INLINE_SIMPLIFY
// Inline a function that takes an argument.
-func.func @func_with_arg(%c : i32) -> i32 {
- %b = arith.addi %c, %c : i32
- return %b : i32
+func.func @func_with_arg(%arg0 : i32) -> i32 {
+ %b = arith.addi %arg0, %arg0 : i32
+ %c = builtin.unrealized_conversion_cast %b : i32 to i64
+ %d = builtin.unrealized_conversion_cast %c : i64 to i32
+ return %d : i32
}
// CHECK-LABEL: func @inline_with_arg
func.func @inline_with_arg(%arg0 : i32) -> i32 {
// CHECK-NEXT: arith.addi
+ // CHECK-NEXT: unrealized_conversion_cast
+ // CHECK-NEXT: unrealized_conversion_cast
// CHECK-NEXT: return
%0 = call @func_with_arg(%arg0) : (i32) -> i32
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 9ded6a3..fa2c145 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -571,3 +571,47 @@ module @return_void_with_unused_argument {
}
}
+// -----
+
+// CHECK-LABEL: module @dynamically_unreachable
+module @dynamically_unreachable {
+ func.func @dynamically_unreachable() {
+ // This value is used by an operation in a dynamically unreachable block.
+ %zero = arith.constant 0 : i64
+
+ // Dataflow analysis knows from the constant condition that
+ // ^bb1 is unreachable
+ %false = arith.constant false
+ cf.cond_br %false, ^bb1, ^bb4
+ ^bb1:
+ // This unreachable operation should be removed.
+ // CHECK-NOT: arith.cmpi
+ %3 = arith.cmpi eq, %zero, %zero : i64
+ cf.br ^bb1
+ ^bb4:
+ return
+ }
+}
+
+// CHECK-LABEL: module @last_block_not_exit
+module @last_block_not_exit {
+ // return value can be removed because it's private.
+ func.func private @terminated_with_condbr(%arg0: i1, %arg1: i1) -> i1 {
+ %true = arith.constant true
+ %false = arith.constant false
+ cf.cond_br %arg0, ^bb1(%false : i1), ^bb2
+ ^bb1(%1: i1): // 2 preds: ^bb0, ^bb2
+ return %1 : i1
+ ^bb2: // pred: ^bb3
+ cf.cond_br %arg1, ^bb1(%false : i1), ^bb1(%true : i1)
+ }
+
+ func.func public @call_private_but_not_use() {
+ %i0 = arith.constant 0: i1
+ %i1 = arith.constant 1: i1
+ call @terminated_with_condbr(%i0, %i1) : (i1, i1) -> i1
+ func.return
+ }
+ // CHECK-LABEL: @call_private_but_not_use
+ // CHECK: call @terminated_with_condbr(%false, %true) : (i1, i1)
+}
diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir
index 0fc822b..8cad6b9 100644
--- a/mlir/test/Transforms/test-canonicalize.mlir
+++ b/mlir/test/Transforms/test-canonicalize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s --check-prefixes=CHECK,RS
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=disabled}))' | FileCheck %s --check-prefixes=CHECK,NO-RS
// CHECK-LABEL: func @remove_op_with_inner_ops_pattern
@@ -80,12 +80,10 @@ func.func @test_dialect_canonicalizer() -> (i32) {
// Check that the option to control region simplification actually works
// CHECK-LABEL: test_region_simplify
-func.func @test_region_simplify() {
- // CHECK-NEXT: return
- // NO-RS-NEXT: ^bb1
- // NO-RS-NEXT: return
- // CHECK-NEXT: }
- return
-^bb1:
- return
+func.func @test_region_simplify(%input1 : i32, %cond : i1) -> i32 {
+ // RS-NEXT: "test.br"(%arg0)[^bb1] : (i32) -> ()
+ // NO-RS-NEXT: "test.br"(%arg0, %arg0)[^bb1] : (i32, i32) -> ()
+ "test.br"(%input1, %input1)[^bb1] : (i32, i32) -> ()
+^bb1(%used_arg : i32, %unused_arg : i32):
+ return %used_arg : i32
}
diff --git a/mlir/test/Transforms/test-context-aware-type-converter.mlir b/mlir/test/Transforms/test-context-aware-type-converter.mlir
new file mode 100644
index 0000000..ae178b6
--- /dev/null
+++ b/mlir/test/Transforms/test-context-aware-type-converter.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt %s -test-legalize-type-conversion="allow-pattern-rollback=0" -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: func @simple_context_aware_conversion_1()
+func.func @simple_context_aware_conversion_1() attributes {increment = 1 : i64} {
+ // Case 1: Convert i37 --> i38.
+ // CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i37 to i38
+ // CHECK: "test.legal_op_d"(%[[cast]]) : (i38) -> ()
+ %0 = "test.context_op"() : () -> (i37)
+ "test.replace_with_legal_op"(%0) : (i37) -> ()
+ return
+}
+
+// CHECK-LABEL: func @simple_context_aware_conversion_2()
+func.func @simple_context_aware_conversion_2() attributes {increment = 2 : i64} {
+ // Case 2: Convert i37 --> i39.
+ // CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i37 to i39
+ // CHECK: "test.legal_op_d"(%[[cast]]) : (i39) -> ()
+ %0 = "test.context_op"() : () -> (i37)
+ "test.replace_with_legal_op"(%0) : (i37) -> ()
+ return
+}
+
+// -----
+
+// Note: This test case does not work with allow-pattern-rollback=1. When
+// rollback is enabled, the type converter cannot find the enclosing function
+// because the operand of the scf.yield is pointing to a detached block.
+
+// CHECK-LABEL: func @convert_block_arguments
+// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i37 to i38
+// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[iter:.*]] = %[[cast]]) -> (i38) {
+// CHECK: scf.yield %[[iter]] : i38
+// CHECK: }
+func.func @convert_block_arguments(%lb: index, %ub: index, %step: index) attributes {increment = 1 : i64} {
+ %0 = "test.context_op"() : () -> (i37)
+ scf.for %iv = %lb to %ub step %step iter_args(%arg0 = %0) -> i37 {
+ scf.yield %arg0 : i37
+ }
+ return
+}
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index 9bffe92..c003f8b 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -142,3 +142,4 @@ func.func @test_signature_conversion_no_converter() {
}) : () -> ()
return
}
+
diff --git a/mlir/test/Transforms/test-legalizer-fold-after.mlir b/mlir/test/Transforms/test-legalizer-fold-after.mlir
new file mode 100644
index 0000000..7f80252
--- /dev/null
+++ b/mlir/test/Transforms/test-legalizer-fold-after.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt %s -test-legalize-patterns="test-legalize-folding-mode=after-patterns" | FileCheck %s
+
+// CHECK-LABEL: @fold_legalization
+func.func @fold_legalization() -> i32 {
+ // CHECK-NOT: op_in_place_self_fold
+ // CHECK: 97
+ %1 = "test.op_in_place_self_fold"() : () -> (i32)
+ "test.return"(%1) : (i32) -> ()
+}
diff --git a/mlir/test/Transforms/test-legalizer-fold-before.mlir b/mlir/test/Transforms/test-legalizer-fold-before.mlir
new file mode 100644
index 0000000..fe6e293
--- /dev/null
+++ b/mlir/test/Transforms/test-legalizer-fold-before.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt %s -test-legalize-patterns="test-legalize-folding-mode=before-patterns" | FileCheck %s
+
+// CHECK-LABEL: @fold_legalization
+func.func @fold_legalization() -> i32 {
+ // CHECK: op_in_place_self_fold
+ // CHECK-SAME: folded
+ %1 = "test.op_in_place_self_fold"() : () -> (i32)
+ "test.return"(%1) : (i32) -> ()
+}
diff --git a/mlir/test/Transforms/test-legalizer-no-fold.mlir b/mlir/test/Transforms/test-legalizer-no-fold.mlir
new file mode 100644
index 0000000..720d17f
--- /dev/null
+++ b/mlir/test/Transforms/test-legalizer-no-fold.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -test-legalize-patterns="test-legalize-folding-mode=never" | FileCheck %s
+
+// CHECK-LABEL: @remove_foldable_op(
+func.func @remove_foldable_op(%arg0 : i32) -> (i32) {
+ // Check that op was not folded.
+ // CHECK: "test.op_with_region_fold"
+ %0 = "test.op_with_region_fold"(%arg0) ({
+ "foo.op_with_region_terminator"() : () -> ()
+ }) : (i32) -> (i32)
+ "test.return"(%0) : (i32) -> ()
+}
+
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 5630d15..3fa42ff 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -1,9 +1,15 @@
-// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns -verify-diagnostics -profile-actions-to=- %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics -profile-actions-to=- %s | FileCheck %s --check-prefix=CHECK-PROFILER
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0" -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0 build-materializations=0 attach-debug-materialization-kind=1" -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-KIND
+
+// CHECK-PROFILER: "name": "pass-execution", "cat": "PERF", "ph": "B"
+// CHECK-PROFILER: "name": "apply-conversion", "cat": "PERF", "ph": "B"
+// CHECK-PROFILER: "name": "apply-pattern", "cat": "PERF", "ph": "B"
+// CHECK-PROFILER: "name": "apply-pattern", "cat": "PERF", "ph": "E"
+// CHECK-PROFILER: "name": "apply-conversion", "cat": "PERF", "ph": "E"
+// CHECK-PROFILER: "name": "pass-execution", "cat": "PERF", "ph": "E"
-// CHECK: "name": "pass-execution", "cat": "PERF", "ph": "B"
-// CHECK: "name": "apply-conversion", "cat": "PERF", "ph": "B"
-// CHECK: "name": "apply-pattern", "cat": "PERF", "ph": "B"
-// CHECK: "name": "apply-pattern", "cat": "PERF", "ph": "E"
// Note: Listener notifications appear after the pattern application because
// the conversion driver sends all notifications at the end of the conversion
// in bulk.
@@ -11,8 +17,6 @@
// CHECK-NEXT: notifyOperationReplaced: test.illegal_op_a
// CHECK-NEXT: notifyOperationModified: func.return
// CHECK-NEXT: notifyOperationErased: test.illegal_op_a
-// CHECK: "name": "apply-conversion", "cat": "PERF", "ph": "E"
-// CHECK: "name": "pass-execution", "cat": "PERF", "ph": "E"
// CHECK-LABEL: verifyDirectPattern
func.func @verifyDirectPattern() -> i32 {
// CHECK-NEXT: "test.legal_op_a"() <{status = "Success"}
@@ -29,7 +33,9 @@ func.func @verifyDirectPattern() -> i32 {
// CHECK-NEXT: notifyOperationErased: test.illegal_op_c
// CHECK-NEXT: notifyOperationInserted: test.legal_op_a, was unlinked
// CHECK-NEXT: notifyOperationReplaced: test.illegal_op_e
-// CHECK-NEXT: notifyOperationErased: test.illegal_op_e
+// Note: func.return is modified a second time when running in no-rollback
+// mode.
+// CHECK: notifyOperationErased: test.illegal_op_e
// CHECK-LABEL: verifyLargerBenefit
func.func @verifyLargerBenefit() -> i32 {
@@ -70,7 +76,7 @@ func.func @remap_call_1_to_1(%arg0: i64) {
// CHECK: notifyBlockInserted into func.func: was unlinked
// Contents of the old block are moved to the new block.
-// CHECK-NEXT: notifyOperationInserted: test.return, was linked, exact position unknown
+// CHECK-NEXT: notifyOperationInserted: test.return
// The old block is erased.
// CHECK-NEXT: notifyBlockErased
@@ -185,9 +191,12 @@ func.func @remap_drop_region() {
// -----
// CHECK-LABEL: func @dropped_input_in_use
+// CHECK-KIND-LABEL: func @dropped_input_in_use
func.func @dropped_input_in_use(%arg: i16, %arg2: i64) {
- // CHECK-NEXT: "test.cast"{{.*}} : () -> i16
- // CHECK-NEXT: "work"{{.*}} : (i16)
+ // CHECK-NEXT: %[[cast:.*]] = "test.cast"() : () -> i16
+ // CHECK-NEXT: "work"(%[[cast]]) : (i16)
+ // CHECK-KIND-NEXT: %[[cast:.*]] = builtin.unrealized_conversion_cast to i16 {__kind__ = "source"}
+ // CHECK-KIND-NEXT: "work"(%[[cast]]) : (i16)
// expected-remark@+1 {{op 'work' is not legalizable}}
"work"(%arg) : (i16) -> ()
}
@@ -409,8 +418,10 @@ func.func @test_remap_block_arg() {
// CHECK-LABEL: func @test_multiple_1_to_n_replacement()
// CHECK: %[[legal_op:.*]]:4 = "test.legal_op"() : () -> (f16, f16, f16, f16)
-// CHECK: %[[cast:.*]] = "test.cast"(%[[legal_op]]#0, %[[legal_op]]#1, %[[legal_op]]#2, %[[legal_op]]#3) : (f16, f16, f16, f16) -> f16
-// CHECK: "test.valid"(%[[cast]]) : (f16) -> ()
+// Note: There is a bug in the rollback-based conversion driver: it emits a
+// "test.cast" : (f16, f16, f16, f16) -> f16, when it should be emitting
+// three consecutive casts of (f16, f16) -> f16.
+// CHECK: "test.valid"(%{{.*}}) : (f16) -> ()
func.func @test_multiple_1_to_n_replacement() {
%0 = "test.multiple_1_to_n_replacement"() : () -> (f16)
"test.invalid"(%0) : (f16) -> ()
@@ -423,6 +434,11 @@ func.func @test_multiple_1_to_n_replacement() {
// CHECK: %[[cast:.*]] = "test.cast"(%[[producer]]) : (i16) -> f64
// CHECK: "test.valid_consumer"(%[[cast]]) : (f64) -> ()
// CHECK: "test.valid_consumer"(%[[producer]]) : (i16) -> ()
+// CHECK-KIND-LABEL: func @test_lookup_without_converter
+// CHECK-KIND: %[[producer:.*]] = "test.valid_producer"() : () -> i16
+// CHECK-KIND: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[producer]] : i16 to f64 {__kind__ = "target"}
+// CHECK-KIND: "test.valid_consumer"(%[[cast]]) : (f64) -> ()
+// CHECK-KIND: "test.valid_consumer"(%[[producer]]) : (i16) -> ()
func.func @test_lookup_without_converter() {
%0 = "test.replace_with_valid_producer"() {type = i16} : () -> (i64)
"test.replace_with_valid_consumer"(%0) {with_converter} : (i64) -> ()
@@ -432,3 +448,24 @@ func.func @test_lookup_without_converter() {
// expected-remark@+1 {{op 'func.return' is not legalizable}}
return
}
+
+// -----
+// expected-remark@-1 {{applyPartialConversion failed}}
+
+func.func @test_skip_1to1_pattern(%arg0: f32) {
+ // expected-error@+1 {{failed to legalize operation 'test.type_consumer'}}
+ "test.type_consumer"(%arg0) : (f32) -> ()
+ return
+}
+
+// -----
+
+// Demonstrate that the pattern generally works, but only for 1:1 type
+// conversions.
+
+// CHECK-LABEL: @test_working_1to1_pattern(
+func.func @test_working_1to1_pattern(%arg0: f16) {
+ // CHECK-NEXT: "test.return"() : () -> ()
+ "test.type_consumer"(%arg0) : (f16) -> ()
+ "test.return"() : () -> ()
+}
diff --git a/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp
index 43005e2..8e2f03b 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp
@@ -33,7 +33,6 @@ struct TestLivenessAnalysisPass
void runOnOperation() override {
auto &livenessAnalysis = getAnalysis<RunLivenessAnalysis>();
-
Operation *op = getOperation();
raw_ostream &os = llvm::outs();
diff --git a/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp b/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp
index e708b7d..8bab9a0 100644
--- a/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp
@@ -42,6 +42,12 @@ private:
ListOption<unsigned> permList{*this, "permutation-map",
llvm::cl::desc("Specify the loop permutation"),
llvm::cl::OneOrMore};
+
+ /// Specify whether to check validity of loop permutation.
+ Option<bool> checkValidity{
+ *this, "check-validity",
+ llvm::cl::desc("Check validity of the loop permutation"),
+ llvm::cl::init(false)};
};
} // namespace
@@ -60,6 +66,9 @@ void TestLoopPermutation::runOnOperation() {
// Permute if the nest's size is consistent with the specified
// permutation.
if (nest.size() >= 2 && nest.size() == permMap.size()) {
+ if (checkValidity.getValue() &&
+ !isValidLoopInterchangePermutation(nest, permMap))
+ continue;
permuteLoops(nest, permMap);
}
}
diff --git a/mlir/test/lib/Dialect/GPU/CMakeLists.txt b/mlir/test/lib/Dialect/GPU/CMakeLists.txt
index 418c884..882d5ab 100644
--- a/mlir/test/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/GPU/CMakeLists.txt
@@ -30,6 +30,7 @@ set(LIBS
MLIRVectorDialect
MLIRVectorToLLVMPass
MLIRXeVMDialect
+ MLIRXeVMToLLVMIRTranslation
)
add_mlir_library(MLIRGPUTestPasses
diff --git a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
index ab02866..69a3d98 100644
--- a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
@@ -6,7 +6,11 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
+#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
+#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Pass/Pass.h"
@@ -34,6 +38,10 @@ struct TestLLVMLegalizePatternsPass
: public PassWrapper<TestLLVMLegalizePatternsPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLLVMLegalizePatternsPass)
+ TestLLVMLegalizePatternsPass() = default;
+ TestLLVMLegalizePatternsPass(const TestLLVMLegalizePatternsPass &other)
+ : PassWrapper(other) {}
+
StringRef getArgument() const final { return "test-llvm-legalize-patterns"; }
StringRef getDescription() const final {
return "Run LLVM dialect legalization patterns";
@@ -45,22 +53,48 @@ struct TestLLVMLegalizePatternsPass
void runOnOperation() override {
MLIRContext *ctx = &getContext();
+
+ // Set up type converter.
LLVMTypeConverter converter(ctx);
+ converter.addConversion(
+ [&](IntegerType type, SmallVectorImpl<Type> &result) {
+ if (type.isInteger(17)) {
+ // Convert i17 -> (i18, i18).
+ result.append(2, Builder(ctx).getIntegerType(18));
+ return success();
+ }
+
+ result.push_back(type);
+ return success();
+ });
+
+ // Populate patterns.
mlir::RewritePatternSet patterns(ctx);
patterns.add<TestDirectReplacementOp>(ctx, converter);
+ arith::populateArithToLLVMConversionPatterns(converter, patterns);
+ populateFuncToLLVMConversionPatterns(converter, patterns);
+ cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
// Define the conversion target used for the test.
ConversionTarget target(*ctx);
target.addLegalOp(OperationName("test.legal_op", ctx));
+ target.addLegalDialect<LLVM::LLVMDialect>();
+ target.addDynamicallyLegalOp<func::FuncOp>(
+ [&](func::FuncOp funcOp) { return funcOp->hasAttr("is_legal"); });
// Handle a partial conversion.
DenseSet<Operation *> unlegalizedOps;
ConversionConfig config;
config.unlegalizedOps = &unlegalizedOps;
+ config.allowPatternRollback = allowPatternRollback;
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns), config)))
getOperation()->emitError() << "applyPartialConversion failed";
}
+
+ Option<bool> allowPatternRollback{*this, "allow-pattern-rollback",
+ llvm::cl::desc("Allow pattern rollback"),
+ llvm::cl::init(true)};
};
} // namespace
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index d0700f9..2cf25d8 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -34,6 +34,8 @@ struct TestDataLayoutPropagationPass
RewritePatternSet patterns(context);
linalg::populateDataLayoutPropagationPatterns(
patterns, [](OpOperand *opOperand) { return true; });
+ linalg::populateExtractSliceSinkingPatterns(
+ patterns, [](OpOperand *opOperand) { return true; });
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
diff --git a/mlir/test/lib/Dialect/Math/CMakeLists.txt b/mlir/test/lib/Dialect/Math/CMakeLists.txt
index 91e70d1..900dff3 100644
--- a/mlir/test/lib/Dialect/Math/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Math/CMakeLists.txt
@@ -1,7 +1,6 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRMathTestPasses
TestAlgebraicSimplification.cpp
- TestExpandMath.cpp
TestPolynomialApproximation.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
deleted file mode 100644
index efc1acf..0000000
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ /dev/null
@@ -1,62 +0,0 @@
-//===- TestExpandMath.cpp - Test expand math op into exp form -------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file contains test passes for expanding math operations.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Math/Transforms/Passes.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-using namespace mlir;
-
-namespace {
-struct TestExpandMathPass
- : public PassWrapper<TestExpandMathPass, OperationPass<>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestExpandMathPass)
-
- void runOnOperation() override;
- StringRef getArgument() const final { return "test-expand-math"; }
- void getDependentDialects(DialectRegistry &registry) const override {
- registry
- .insert<arith::ArithDialect, scf::SCFDialect, vector::VectorDialect>();
- }
- StringRef getDescription() const final { return "Test expanding math"; }
-};
-} // namespace
-
-void TestExpandMathPass::runOnOperation() {
- RewritePatternSet patterns(&getContext());
- populateExpandCtlzPattern(patterns);
- populateExpandExp2FPattern(patterns);
- populateExpandTanPattern(patterns);
- populateExpandSinhPattern(patterns);
- populateExpandCoshPattern(patterns);
- populateExpandTanhPattern(patterns);
- populateExpandAsinhPattern(patterns);
- populateExpandAcoshPattern(patterns);
- populateExpandAtanhPattern(patterns);
- populateExpandFmaFPattern(patterns);
- populateExpandCeilFPattern(patterns);
- populateExpandPowFPattern(patterns);
- populateExpandFPowIPattern(patterns);
- populateExpandRoundFPattern(patterns);
- populateExpandRoundEvenPattern(patterns);
- populateExpandRsqrtPattern(patterns);
- (void)applyPatternsGreedily(getOperation(), std::move(patterns));
-}
-
-namespace mlir {
-namespace test {
-void registerTestExpandMathPass() { PassRegistration<TestExpandMathPass>(); }
-} // namespace test
-} // namespace mlir
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 5890913..fe1e916 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -385,13 +385,15 @@ TestOpAsmAttrInterfaceAttr::getAlias(::llvm::raw_ostream &os) const {
//===----------------------------------------------------------------------===//
bool TestConstMemorySpaceAttr::isValidLoad(
- Type type, mlir::ptr::AtomicOrdering ordering, IntegerAttr alignment,
+ Type type, mlir::ptr::AtomicOrdering ordering,
+ std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError) const {
return true;
}
bool TestConstMemorySpaceAttr::isValidStore(
- Type type, mlir::ptr::AtomicOrdering ordering, IntegerAttr alignment,
+ Type type, mlir::ptr::AtomicOrdering ordering,
+ std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError) const {
if (emitError)
emitError() << "memory space is read-only";
@@ -400,7 +402,8 @@ bool TestConstMemorySpaceAttr::isValidStore(
bool TestConstMemorySpaceAttr::isValidAtomicOp(
mlir::ptr::AtomicBinOp binOp, Type type, mlir::ptr::AtomicOrdering ordering,
- IntegerAttr alignment, function_ref<InFlightDiagnostic()> emitError) const {
+ std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
+ function_ref<InFlightDiagnostic()> emitError) const {
if (emitError)
emitError() << "memory space is read-only";
return false;
@@ -408,7 +411,8 @@ bool TestConstMemorySpaceAttr::isValidAtomicOp(
bool TestConstMemorySpaceAttr::isValidAtomicXchg(
Type type, mlir::ptr::AtomicOrdering successOrdering,
- mlir::ptr::AtomicOrdering failureOrdering, IntegerAttr alignment,
+ mlir::ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
+ const ::mlir::DataLayout *dataLayout,
function_ref<InFlightDiagnostic()> emitError) const {
if (emitError)
emitError() << "memory space is read-only";
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index a4c615b..987e8f36 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -236,13 +236,14 @@ void test::writeToMlirBytecode(DialectBytecodeWriter &writer,
// Dynamic operations
//===----------------------------------------------------------------------===//
-std::unique_ptr<DynamicOpDefinition> getDynamicGenericOp(TestDialect *dialect) {
+static std::unique_ptr<DynamicOpDefinition>
+getDynamicGenericOp(TestDialect *dialect) {
return DynamicOpDefinition::get(
"dynamic_generic", dialect, [](Operation *op) { return success(); },
[](Operation *op) { return success(); });
}
-std::unique_ptr<DynamicOpDefinition>
+static std::unique_ptr<DynamicOpDefinition>
getDynamicOneOperandTwoResultsOp(TestDialect *dialect) {
return DynamicOpDefinition::get(
"dynamic_one_operand_two_results", dialect,
@@ -262,7 +263,7 @@ getDynamicOneOperandTwoResultsOp(TestDialect *dialect) {
[](Operation *op) { return success(); });
}
-std::unique_ptr<DynamicOpDefinition>
+static std::unique_ptr<DynamicOpDefinition>
getDynamicCustomParserPrinterOp(TestDialect *dialect) {
auto verifier = [](Operation *op) {
if (op->getNumOperands() == 0 && op->getNumResults() == 0)
diff --git a/mlir/test/lib/Dialect/Test/TestEnumDefs.td b/mlir/test/lib/Dialect/Test/TestEnumDefs.td
index 10e424a..51938d4 100644
--- a/mlir/test/lib/Dialect/Test/TestEnumDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestEnumDefs.td
@@ -27,9 +27,10 @@ def SomeI32Enum : I32EnumAttr<"SomeI32Enum", "",
def I64Case5: I64EnumAttrCase<"case5", 5>;
def I64Case10: I64EnumAttrCase<"case10", 10>;
+def I64Case1p32 : I64EnumAttrCase<"caseLarse", 4294967296>;
def SomeI64Enum: I64EnumAttr<
- "SomeI64Enum", "", [I64Case5, I64Case10]>;
+ "SomeI64Enum", "", [I64Case5, I64Case10, I64Case1p32]>;
//===----------------------------------------------------------------------===//
// Test Enum
@@ -53,6 +54,13 @@ def TestSimpleEnum : I32Enum<"SimpleEnum", "", [
let cppNamespace = "::test";
}
+def TestSimpleEnum64 : I64Enum<"SimpleEnum64", "", [
+ I64EnumCase<"a", 4294967296>,
+ I64EnumCase<"b", 4294967297>
+ ]> {
+ let cppNamespace = "::test";
+}
+
//===----------------------------------------------------------------------===//
// Test Bit Enum
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 843bd30..231400e 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1169,6 +1169,26 @@ def OpP : TEST_Op<"op_p"> {
let results = (outs I32);
}
+// Test constant-folding a pattern that maps `(F32) -> SI32`.
+def SignOp : TEST_Op<"sign", [SameOperandsAndResultShape]> {
+ let arguments = (ins RankedTensorOf<[F32]>:$operand);
+ let results = (outs RankedTensorOf<[SI32]>:$result);
+
+ let assemblyFormat = [{
+ $operand attr-dict `:` functional-type(operands, results)
+ }];
+}
+
+// Test constant-folding a pattern that maps `(F32, F32) -> I1`.
+def LessThanOp : TEST_Op<"less_than", [SameOperandsAndResultShape]> {
+ let arguments = (ins RankedTensorOf<[F32]>:$lhs, RankedTensorOf<[F32]>:$rhs);
+ let results = (outs RankedTensorOf<[I1]>:$result);
+
+ let assemblyFormat = [{
+ $lhs `,` $rhs attr-dict `:` functional-type(operands, results)
+ }];
+}
+
// Test same operand name enforces equality condition check.
def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>;
@@ -1478,6 +1498,8 @@ def TestOpInPlaceSelfFold : TEST_Op<"op_in_place_self_fold"> {
let results = (outs I32);
let hasFolder = 1;
}
+def : Pat<(TestOpInPlaceSelfFold:$op $_),
+ (TestOpConstant ConstantAttr<I32Attr, "97">)>;
// Test op that simply returns success.
def TestOpInPlaceFoldSuccess : TEST_Op<"op_in_place_fold_success"> {
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 7150401..95f381e 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -10,8 +10,10 @@
#include "TestOps.h"
#include "TestTypes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
@@ -202,6 +204,66 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
}
};
+struct FoldSignOpF32ToSI32 : public OpRewritePattern<test::SignOp> {
+ using OpRewritePattern<test::SignOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(test::SignOp op,
+ PatternRewriter &rewriter) const override {
+ if (op->getNumOperands() != 1 || op->getNumResults() != 1)
+ return failure();
+
+ TypedAttr operandAttr;
+ matchPattern(op->getOperand(0), m_Constant(&operandAttr));
+ if (!operandAttr)
+ return failure();
+
+ TypedAttr res = cast_or_null<TypedAttr>(
+ constFoldUnaryOp<FloatAttr, FloatAttr::ValueType, void, IntegerAttr>(
+ operandAttr, op.getType(), [](APFloat operand) -> APSInt {
+ static const APFloat zero(0.0f);
+ int operandSign = 0;
+ if (operand != zero)
+ operandSign = (operand < zero) ? -1 : +1;
+ return APSInt(APInt(32, operandSign), false);
+ }));
+ if (!res)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, res);
+ return success();
+ }
+};
+
+struct FoldLessThanOpF32ToI1 : public OpRewritePattern<test::LessThanOp> {
+ using OpRewritePattern<test::LessThanOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(test::LessThanOp op,
+ PatternRewriter &rewriter) const override {
+ if (op->getNumOperands() != 2 || op->getNumResults() != 1)
+ return failure();
+
+ TypedAttr lhsAttr;
+ TypedAttr rhsAttr;
+ matchPattern(op->getOperand(0), m_Constant(&lhsAttr));
+ matchPattern(op->getOperand(1), m_Constant(&rhsAttr));
+
+ if (!lhsAttr || !rhsAttr)
+ return failure();
+
+ Attribute operandAttrs[2] = {lhsAttr, rhsAttr};
+ TypedAttr res = cast_or_null<TypedAttr>(
+ constFoldBinaryOp<FloatAttr, FloatAttr::ValueType, void, IntegerAttr>(
+ operandAttrs, op.getType(), [](APFloat lhs, APFloat rhs) -> APInt {
+ return APInt(1, lhs < rhs);
+ }));
+ if (!res)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, res);
+ return success();
+ }
+};
+
/// This pattern moves "test.move_before_parent_op" before the parent op.
struct MoveBeforeParentOp : public RewritePattern {
MoveBeforeParentOp(MLIRContext *context)
@@ -1116,8 +1178,8 @@ struct TestNonRootReplacement : public RewritePattern {
auto illegalOp = ILLegalOpF::create(rewriter, op->getLoc(), resultType);
auto legalOp = LegalOpB::create(rewriter, op->getLoc(), resultType);
- rewriter.replaceOp(illegalOp, legalOp);
rewriter.replaceOp(op, illegalOp);
+ rewriter.replaceOp(illegalOp, legalOp);
return success();
}
};
@@ -1301,6 +1363,7 @@ public:
// Helper function that replaces the given op with a new op of the given
// name and doubles each result (1 -> 2 replacement of each result).
auto replaceWithDoubleResults = [&](Operation *op, StringRef name) {
+ rewriter.setInsertionPointAfter(op);
SmallVector<Type> types;
for (Type t : op->getResultTypes()) {
types.push_back(t);
@@ -1324,6 +1387,23 @@ public:
}
};
+/// Pattern that erases 'test.type_consumers' iff the input operand is the
+/// result of a 1:1 type conversion.
+/// Used to test correct skipping of 1:1 patterns in the 1:N case.
+class TestTypeConsumerOpPattern
+ : public OpConversionPattern<TestTypeConsumerOp> {
+public:
+ TestTypeConsumerOpPattern(MLIRContext *ctx, const TypeConverter &converter)
+ : OpConversionPattern<TestTypeConsumerOp>(converter, ctx) {}
+
+ LogicalResult
+ matchAndRewrite(TestTypeConsumerOp op, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const final {
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
/// Test unambiguous overload resolution of replaceOpWithMultiple. This
/// function is just to trigger compiler errors. It is never executed.
[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
@@ -1435,8 +1515,8 @@ struct TestLegalizePatternDriver
TestRepetitive1ToNConsumer>(&getContext());
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
- TestBlockArgReplace, TestReplaceWithValidConsumer>(
- &getContext(), converter);
+ TestBlockArgReplace, TestReplaceWithValidConsumer,
+ TestTypeConsumerOpPattern>(&getContext(), converter);
patterns.add<TestConvertBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
@@ -1446,8 +1526,8 @@ struct TestLegalizePatternDriver
ConversionTarget target(getContext());
target.addLegalOp<ModuleOp>();
target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
- TerminatorOp, OneRegionOp, TestValidProducerOp,
- TestValidConsumerOp>();
+ TerminatorOp, TestOpConstant, OneRegionOp,
+ TestValidProducerOp, TestValidConsumerOp>();
target.addLegalOp(OperationName("test.legal_op", &getContext()));
target
.addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
@@ -1495,12 +1575,18 @@ struct TestLegalizePatternDriver
target.addDynamicallyLegalOp<ConvertBlockArgsOp>(
[](ConvertBlockArgsOp op) { return op.getIsLegal(); });
+ // Set up configuration.
+ ConversionConfig config;
+ config.allowPatternRollback = allowPatternRollback;
+ config.foldingMode = foldingMode;
+ config.buildMaterializations = buildMaterializations;
+ config.attachDebugMaterializationKind = attachDebugMaterializationKind;
+ DumpNotifications dumpNotifications;
+ config.listener = &dumpNotifications;
+
// Handle a partial conversion.
if (mode == ConversionMode::Partial) {
DenseSet<Operation *> unlegalizedOps;
- ConversionConfig config;
- DumpNotifications dumpNotifications;
- config.listener = &dumpNotifications;
config.unlegalizedOps = &unlegalizedOps;
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns), config))) {
@@ -1519,9 +1605,6 @@ struct TestLegalizePatternDriver
return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
});
- ConversionConfig config;
- DumpNotifications dumpNotifications;
- config.listener = &dumpNotifications;
if (failed(applyFullConversion(getOperation(), target,
std::move(patterns), config))) {
getOperation()->emitRemark() << "applyFullConversion failed";
@@ -1534,7 +1617,6 @@ struct TestLegalizePatternDriver
// Analyze the convertible operations.
DenseSet<Operation *> legalizedOps;
- ConversionConfig config;
config.legalizableOps = &legalizedOps;
if (failed(applyAnalysisConversion(getOperation(), target,
std::move(patterns), config)))
@@ -1555,6 +1637,34 @@ struct TestLegalizePatternDriver
clEnumValN(ConversionMode::Full, "full", "Perform a full conversion"),
clEnumValN(ConversionMode::Partial, "partial",
"Perform a partial conversion"))};
+
+ Option<DialectConversionFoldingMode> foldingMode{
+ *this, "test-legalize-folding-mode",
+ llvm::cl::desc("The folding mode to use with the test driver"),
+ llvm::cl::init(DialectConversionFoldingMode::BeforePatterns),
+ llvm::cl::values(clEnumValN(DialectConversionFoldingMode::Never, "never",
+ "Never attempt to fold"),
+ clEnumValN(DialectConversionFoldingMode::BeforePatterns,
+ "before-patterns",
+ "Only attempt to fold not legal operations "
+ "before applying patterns"),
+ clEnumValN(DialectConversionFoldingMode::AfterPatterns,
+ "after-patterns",
+ "Only attempt to fold not legal operations "
+ "after applying patterns"))};
+ Option<bool> allowPatternRollback{*this, "allow-pattern-rollback",
+ llvm::cl::desc("Allow pattern rollback"),
+ llvm::cl::init(true)};
+ Option<bool> attachDebugMaterializationKind{
+ *this, "attach-debug-materialization-kind",
+ llvm::cl::desc(
+ "Attach materialization kind to unrealized_conversion_cast ops"),
+ llvm::cl::init(false)};
+ Option<bool> buildMaterializations{
+ *this, "build-materializations",
+ llvm::cl::desc(
+ "If set to 'false', leave unrealized_conversion_cast ops in place"),
+ llvm::cl::init(true)};
};
} // namespace
@@ -1874,9 +1984,9 @@ struct TestReplaceWithLegalOp : public ConversionPattern {
: ConversionPattern(converter, "test.replace_with_legal_op",
/*benefit=*/1, ctx) {}
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
- rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]);
+ rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0].front());
return success();
}
};
@@ -1885,6 +1995,10 @@ struct TestTypeConversionDriver
: public PassWrapper<TestTypeConversionDriver, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver)
+ TestTypeConversionDriver() = default;
+ TestTypeConversionDriver(const TestTypeConversionDriver &other)
+ : PassWrapper(other) {}
+
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<TestDialect>();
}
@@ -1911,8 +2025,13 @@ struct TestTypeConversionDriver
// Otherwise, the type is illegal.
return nullptr;
});
- converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
- // Drop all integer types.
+ converter.addConversion([](IndexType type) { return type; });
+ converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &types) {
+ if (type.isInteger(38)) {
+ // i38 is legal.
+ types.push_back(type);
+ }
+ // Drop all other integer types.
return success();
});
converter.addConversion(
@@ -1949,6 +2068,33 @@ struct TestTypeConversionDriver
results.push_back(result);
return success();
});
+ converter.addConversion([](Value v) -> std::optional<Type> {
+ // Context-aware type conversion rule that converts i37 to
+ // i(37 + increment). The increment is taken from the enclosing
+ // function.
+ auto intType = dyn_cast<IntegerType>(v.getType());
+ if (!intType || intType.getWidth() != 37)
+ return std::nullopt;
+ Region *r = v.getParentRegion();
+ if (!r) {
+ // No enclosing region found. This can happen when running with
+ // allow-pattern-rollback = true. Context-aware type conversions are
+ // not fully supported when running in rollback mode.
+ return Type();
+ }
+ Operation *op = r->getParentOp();
+ if (!op)
+ return Type();
+ if (!isa<FunctionOpInterface>(op))
+ op = op->getParentOfType<FunctionOpInterface>();
+ if (!op)
+ return Type();
+ auto incrementAttr = op->getAttrOfType<IntegerAttr>("increment");
+ if (!incrementAttr)
+ return Type();
+ return IntegerType::get(v.getContext(),
+ intType.getWidth() + incrementAttr.getInt());
+ });
/// Add the legal set of type materializations.
converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
@@ -1969,9 +2115,19 @@ struct TestTypeConversionDriver
// Otherwise, fail.
return nullptr;
});
+ // Materialize i37 to any desired type with unrealized_conversion_cast.
+ converter.addTargetMaterialization([](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1 || !inputs[0].getType().isInteger(37))
+ return Value();
+ return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
+ .getResult(0);
+ });
// Initialize the conversion target.
mlir::ConversionTarget target(getContext());
+ target.addLegalOp(OperationName("test.context_op", &getContext()));
target.addLegalOp<LegalOpD>();
target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType());
@@ -2002,11 +2158,19 @@ struct TestTypeConversionDriver
patterns.add<TestTypeConversionAnotherProducer>(&getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
+ mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
+ converter, patterns, target);
+ ConversionConfig config;
+ config.allowPatternRollback = allowPatternRollback;
if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
+ std::move(patterns), config)))
signalPassFailure();
}
+
+ Option<bool> allowPatternRollback{*this, "allow-pattern-rollback",
+ llvm::cl::desc("Allow pattern rollback"),
+ llvm::cl::init(true)};
};
} // namespace
@@ -2226,6 +2390,24 @@ struct TestSelectiveReplacementPatternDriver
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
+
+struct TestFoldTypeConvertingOp
+ : public PassWrapper<TestFoldTypeConvertingOp, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFoldTypeConvertingOp)
+
+ StringRef getArgument() const final { return "test-fold-type-converting-op"; }
+ StringRef getDescription() const final {
+ return "Test helper functions for folding ops whose input and output types "
+ "differ, e.g. float comparisons of the form `(f32, f32) -> i1`.";
+ }
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ mlir::RewritePatternSet patterns(context);
+ patterns.add<FoldSignOpF32ToSI32, FoldLessThanOpF32ToI1>(context);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ signalPassFailure();
+ }
+};
} // namespace
//===----------------------------------------------------------------------===//
@@ -2256,6 +2438,8 @@ void registerPatternsTestPass() {
PassRegistration<TestMergeBlocksPatternDriver>();
PassRegistration<TestSelectiveReplacementPatternDriver>();
+
+ PassRegistration<TestFoldTypeConvertingOp>();
}
} // namespace test
} // namespace mlir
diff --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
index 6457487..5f93035 100644
--- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
+++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
@@ -178,6 +178,8 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
newTosaConv2DOp.getResult().getType().isUnsignedInteger();
bool outputUnsigned = outputType.isUnsignedInteger();
+ RoundingModeAttr doubleRoundAttr =
+ RoundingModeAttr::get(rewriter.getContext(), RoundingMode::DOUBLE_ROUND);
auto newTosaRescaleOp = tosa::RescaleOp::create(
rewriter, op->getLoc(), outputType, newTosaConv2DOp.getResult(),
getConstTensorInt<int32_t>(rewriter, op->getLoc(), {multiplier}),
@@ -185,7 +187,7 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
{static_cast<int8_t>(shift)}),
inputZp.value(), outputZp.value(),
/* scale32 = */ rewriter.getBoolAttr(true),
- /* double_round = */ rewriter.getStringAttr("DOUBLE_ROUND"),
+ /* double_round = */ doubleRoundAttr,
/* per_channel = */ rewriter.getBoolAttr(false),
rewriter.getBoolAttr(inputUnsigned),
rewriter.getBoolAttr(outputUnsigned));
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f89c944..bb1598e 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -786,6 +786,28 @@ struct TestVectorGatherLowering
}
};
+struct TestUnrollVectorFromElements
+ : public PassWrapper<TestUnrollVectorFromElements,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorFromElements)
+
+ StringRef getArgument() const final {
+ return "test-unroll-vector-from-elements";
+ }
+ StringRef getDescription() const final {
+ return "Test unrolling patterns for from_elements ops";
+ }
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<func::FuncDialect, vector::VectorDialect, ub::UBDialect>();
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorFromElementsLoweringPatterns(patterns);
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestFoldArithExtensionIntoVectorContractPatterns
: public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
OperationPass<func::FuncOp>> {
@@ -1059,6 +1081,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorGatherLowering>();
+ PassRegistration<TestUnrollVectorFromElements>();
+
PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
PassRegistration<TestVectorEmulateMaskedLoadStore>();
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index c6245b6..200323c 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -7,11 +7,14 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
@@ -79,7 +82,7 @@ struct TestXeGPUUnrollingPatterns
if (auto layout = tdescTy.getLayoutAttr()) {
auto inst_data = layout.getInstData();
- if (inst_data && layout.isSgLayout())
+ if (inst_data && layout.isForSubgroup())
return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
inst_data.asArrayRef().end());
}
@@ -147,12 +150,118 @@ struct TestXeGPUUnrollingPatterns
}
};
+#undef DEBUG_TYPE
+#define DEBUG_TYPE "test-xegpu-layout-interface"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+// Test pattern for distributing vector::StepOp from workgroup to subgroup.
+// Validates DistributeLayoutAttr interfaces for offset computation
+// abstraction between LayoutAttr and SliceAttr.
+class TestStepOpPattern : public OpConversionPattern<vector::StepOp> {
+ using OpConversionPattern<vector::StepOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto layoutName = xegpu::getLayoutName(op->getResult(0));
+ auto sliceAttr = op->getAttrOfType<xegpu::SliceAttr>(layoutName);
+ if (!sliceAttr || sliceAttr.getRank() != 1)
+ return failure();
+
+ std::optional<SmallVector<int64_t>> sgShape = sliceAttr.getSgDataAsInt();
+ if (!sgShape)
+ return failure();
+
+ Location loc = op.getLoc();
+ VectorType type = op.getResult().getType();
+ auto wgShape = type.getShape();
+
+ Value sgId =
+ gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+ auto maybeOffsets = sliceAttr.getOffsets(rewriter, loc, sgId, wgShape);
+ if (failed(maybeOffsets))
+ return failure();
+
+ VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
+ Value base = vector::StepOp::create(rewriter, loc, newTy);
+ SmallVector<Value> newOps;
+ for (auto offsets : *maybeOffsets) {
+ Value bcast =
+ vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
+ Value add = arith::AddIOp::create(rewriter, loc, base, bcast);
+ newOps.push_back(add);
+ }
+ rewriter.replaceOpWithMultiple(op, {newOps});
+ return success();
+ }
+};
+
+struct TestXeGPULayoutInterface
+ : public PassWrapper<TestXeGPULayoutInterface,
+ OperationPass<gpu::GPUModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPULayoutInterface)
+
+ StringRef getArgument() const final { return "test-xegpu-layout-interface"; }
+
+ StringRef getDescription() const final {
+ return "Test the implementation of XeGPU Layout interfaces";
+ }
+
+ void getDependentDialects(::mlir::DialectRegistry &registry) const override {
+ registry.insert<arith::ArithDialect>();
+ registry.insert<memref::MemRefDialect>();
+ registry.insert<xegpu::XeGPUDialect>();
+ registry.insert<vector::VectorDialect>();
+ registry.insert<index::IndexDialect>();
+ }
+
+ TestXeGPULayoutInterface() = default;
+ TestXeGPULayoutInterface(const TestXeGPULayoutInterface &pass)
+ : PassWrapper(pass) {}
+
+ void runOnOperation() override {
+ MLIRContext *ctx = &getContext();
+
+ TypeConverter typeConverter;
+ auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type,
+ mlir::ValueRange inputs,
+ mlir::Location loc) -> mlir::Value {
+ return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
+ .getResult(0);
+ };
+ typeConverter.addSourceMaterialization(materializeCast);
+ typeConverter.addTargetMaterialization(materializeCast);
+
+ RewritePatternSet patterns(ctx);
+ patterns.add<TestStepOpPattern>(typeConverter, ctx);
+
+ ConversionTarget target(*ctx);
+ auto isLegal = [&](xegpu::SliceAttr layout) -> bool {
+ return !layout || !layout.isForWorkgroup();
+ };
+
+ target.addDynamicallyLegalOp<vector::StepOp>(
+ [&](vector::StepOp op) -> bool {
+ auto layoutName = xegpu::getLayoutName(op->getResult(0));
+ auto sliceAttr = op->getAttrOfType<xegpu::SliceAttr>(layoutName);
+ return isLegal(sliceAttr);
+ });
+
+ target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
+
+ (void)applyPartialConversion(getOperation(), target, std::move(patterns));
+ }
+};
+
} // namespace
namespace mlir {
namespace test {
void registerTestXeGPULowerings() {
PassRegistration<TestXeGPUUnrollingPatterns>();
+ PassRegistration<TestXeGPULayoutInterface>();
}
} // namespace test
} // namespace mlir
diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp
index 25c8e53..df2736b 100644
--- a/mlir/test/lib/Pass/TestPassManager.cpp
+++ b/mlir/test/lib/Pass/TestPassManager.cpp
@@ -133,6 +133,51 @@ struct TestOptionsSuperPass
llvm::cl::desc("Example list of PassPipelineOptions option")};
};
+struct TestOptionsPassA
+ : public PassWrapper<TestOptionsPassA, OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOptionsPassA)
+
+ struct Options : public PassPipelineOptions<Options> {
+ Option<bool> foo{*this, "foo", llvm::cl::desc("Example boolean option")};
+ };
+
+ TestOptionsPassA() = default;
+ TestOptionsPassA(const TestOptionsPassA &) : PassWrapper() {}
+ TestOptionsPassA(const Options &options) { this->options.foo = options.foo; }
+
+ void runOnOperation() final {}
+ StringRef getArgument() const final { return "test-options-pass-a"; }
+ StringRef getDescription() const final {
+ return "Test superset options parsing capabilities - subset A";
+ }
+
+ Options options;
+};
+
+struct TestOptionsPassB
+ : public PassWrapper<TestOptionsPassB, OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOptionsPassB)
+
+ struct Options : public PassPipelineOptions<Options> {
+ Option<bool> bar{*this, "bar", llvm::cl::desc("Example boolean option")};
+ };
+
+ TestOptionsPassB() = default;
+ TestOptionsPassB(const TestOptionsPassB &) : PassWrapper() {}
+ TestOptionsPassB(const Options &options) { this->options.bar = options.bar; }
+
+ void runOnOperation() final {}
+ StringRef getArgument() const final { return "test-options-pass-b"; }
+ StringRef getDescription() const final {
+ return "Test superset options parsing capabilities - subset B";
+ }
+
+ Options options;
+};
+
+struct TestPipelineOptionsSuperSetAB : TestOptionsPassA::Options,
+ TestOptionsPassB::Options {};
+
/// A test pass that always aborts to enable testing the crash recovery
/// mechanism of the pass manager.
struct TestCrashRecoveryPass
@@ -270,6 +315,9 @@ void registerPassManagerTestPass() {
PassRegistration<TestOptionsPass>();
PassRegistration<TestOptionsSuperPass>();
+ PassRegistration<TestOptionsPassA>();
+ PassRegistration<TestOptionsPassB>();
+
PassRegistration<TestModulePass>();
PassRegistration<TestFunctionPass>();
@@ -306,5 +354,16 @@ void registerPassManagerTestPass() {
[](OpPassManager &pm, const TestOptionsSuperPass::Options &options) {
pm.addPass(std::make_unique<TestOptionsSuperPass>(options));
});
+
+ PassPipelineRegistration<TestPipelineOptionsSuperSetAB>
+ registerPipelineOptionsSuperSetABPipeline(
+ "test-options-super-set-ab-pipeline",
+ "Parses options of PassPipelineOptions using pass pipeline "
+ "registration",
+ [](OpPassManager &pm, const TestPipelineOptionsSuperSetAB &options) {
+ // Pass superset AB options to subset options A and B
+ pm.addPass(std::make_unique<TestOptionsPassA>(options));
+ pm.addPass(std::make_unique<TestOptionsPassB>(options));
+ });
}
} // namespace mlir
diff --git a/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp b/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp
index c3c871a..f5a6fc5 100644
--- a/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp
+++ b/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp
@@ -29,7 +29,7 @@ using namespace mlir;
namespace mlir::test {
std::unique_ptr<Pass> createTestConvertToSPIRVPass(bool convertGPUModules,
bool nestInGPUModule);
-}
+} // namespace mlir::test
namespace {
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index f392bda..f99c24d 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -29,6 +29,7 @@ lit_shell_env = os.environ.get("LIT_USE_INTERNAL_SHELL")
if lit_shell_env:
use_lit_shell = lit.util.pythonize_bool(lit_shell_env)
+# Set the test format based on shell configuration
config.test_format = lit.formats.ShTest(execute_external=not use_lit_shell)
# suffixes: A list of file extensions to treat as test files.
@@ -181,15 +182,16 @@ config.excludes = [
]
# Tweak the PATH to include the tools dir.
-llvm_config.with_environment("PATH", config.mlir_tools_dir, append_path=True)
-llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True)
+tool_dirs = [config.mlir_tools_dir, config.llvm_tools_dir, config.lit_tools_dir]
+for dirs in tool_dirs:
+ llvm_config.with_environment("PATH", dirs, append_path=True)
-tool_dirs = [config.mlir_tools_dir, config.llvm_tools_dir]
tools = [
"mlir-tblgen",
"mlir-translate",
"mlir-lsp-server",
"mlir-capi-execution-engine-test",
+ "mlir-capi-global-constructors-test",
"mlir-capi-ir-test",
"mlir-capi-irdl-test",
"mlir-capi-llvm-test",
@@ -378,3 +380,6 @@ if config.run_rocm_tests:
if config.arm_emulator_executable:
config.available_features.add("arm-emulator")
+
+if sys.version_info >= (3, 11):
+ config.available_features.add("python-ge-311")
diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in
index d904780..8a742a2 100644
--- a/mlir/test/lit.site.cfg.py.in
+++ b/mlir/test/lit.site.cfg.py.in
@@ -5,6 +5,7 @@ import sys
config.target_triple = "@LLVM_TARGET_TRIPLE@"
config.llvm_src_root = "@LLVM_SOURCE_DIR@"
config.llvm_tools_dir = lit_config.substitute("@LLVM_TOOLS_DIR@")
+config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@"
config.spirv_tools_tests = @LLVM_INCLUDE_SPIRV_TOOLS_TESTS@
config.llvm_shlib_ext = "@SHLIBEXT@"
config.llvm_shlib_dir = lit_config.substitute(path(r"@SHLIBDIR@"))
@@ -33,6 +34,7 @@ config.run_rocm_tests = @MLIR_ENABLE_ROCM_CONVERSIONS@
config.enable_rocm_runner = @MLIR_ENABLE_ROCM_RUNNER@
config.gpu_compilation_format = "@MLIR_GPU_COMPILATION_TEST_FORMAT@"
config.rocm_test_chipset = "@ROCM_TEST_CHIPSET@"
+config.run_xevm_tests = @MLIR_ENABLE_XEVM_CONVERSIONS@
config.enable_sycl_runner = @MLIR_ENABLE_SYCL_RUNNER@
config.enable_levelzero_runner = @MLIR_ENABLE_LEVELZERO_RUNNER@
config.enable_spirv_cpu_runner = @MLIR_ENABLE_SPIRV_CPU_RUNNER@
diff --git a/mlir/test/mlir-runner/test-expand-math-approx.mlir b/mlir/test/mlir-runner/test-expand-math-approx.mlir
index b599c9d..3f9d3f2 100644
--- a/mlir/test/mlir-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-runner/test-expand-math-approx.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(test-expand-math),convert-vector-to-scf,convert-scf-to-cf,convert-vector-to-llvm,convert-to-llvm,reconcile-unrealized-casts)" \
+// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(math-expand-ops),convert-vector-to-scf,convert-scf-to-cf,convert-vector-to-llvm,convert-to-llvm,reconcile-unrealized-casts)" \
// RUN: | mlir-runner \
// RUN: -e main -entry-point-result=void -O0 \
// RUN: -shared-libs=%mlir_c_runner_utils \
diff --git a/mlir/test/mlir-tblgen/enums-python-bindings.td b/mlir/test/mlir-tblgen/enums-python-bindings.td
index 1c5567f..cd23b6a 100644
--- a/mlir/test/mlir-tblgen/enums-python-bindings.td
+++ b/mlir/test/mlir-tblgen/enums-python-bindings.td
@@ -62,12 +62,15 @@ def MyEnum64 : I64EnumAttr<"MyEnum64", "An example 64-bit enum", [One64, Two64]>
// CHECK: def _myenum64(x, context):
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x))
+def User : I32BitEnumAttrCaseBit<"User", 0, "user">;
+def Group : I32BitEnumAttrCaseBit<"Group", 1, "group">;
+def Other : I32BitEnumAttrCaseBit<"Other", 2, "other">;
+
def TestBitEnum
- : I32BitEnumAttr<"TestBitEnum", "", [
- I32BitEnumAttrCaseBit<"User", 0, "user">,
- I32BitEnumAttrCaseBit<"Group", 1, "group">,
- I32BitEnumAttrCaseBit<"Other", 2, "other">,
- ]> {
+ : I32BitEnumAttr<
+ "TestBitEnum", "",
+ [User, Group, Other,
+ I32BitEnumAttrCaseGroup<"Any", [User, Group, Other], "any">]> {
let genSpecializedAttr = 0;
let separator = " | ";
}
@@ -79,9 +82,10 @@ def TestBitEnum_Attr : EnumAttr<Test_Dialect, TestBitEnum, "testbitenum">;
// CHECK: User = 1
// CHECK: Group = 2
// CHECK: Other = 4
+// CHECK: Any = 7
// CHECK: def __iter__(self):
-// CHECK: return iter([case for case in type(self) if (self & case) is case])
+// CHECK: return iter([case for case in type(self) if (self & case) is case and self is not case])
// CHECK: def __len__(self):
// CHECK: return bin(self).count("1")
@@ -94,6 +98,8 @@ def TestBitEnum_Attr : EnumAttr<Test_Dialect, TestBitEnum, "testbitenum">;
// CHECK: return "group"
// CHECK: if self is TestBitEnum.Other:
// CHECK: return "other"
+// CHECK: if self is TestBitEnum.Any:
+// CHECK: return "any"
// CHECK: raise ValueError("Unknown TestBitEnum enum entry.")
// CHECK: @register_attribute_builder("TestBitEnum")
diff --git a/mlir/test/mlir-translate/emitc_classops.mlir b/mlir/test/mlir-translate/emitc_classops.mlir
deleted file mode 100644
index d880f9b..0000000
--- a/mlir/test/mlir-translate/emitc_classops.mlir
+++ /dev/null
@@ -1,78 +0,0 @@
-// RUN: mlir-translate --mlir-to-cpp %s | FileCheck %s
-
-emitc.class @modelClass {
- emitc.field @fieldName0 : !emitc.array<1xf32>
- emitc.field @fieldName1 : !emitc.array<1xf32>
- emitc.func @execute() {
- %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
- %1 = get_field @fieldName0 : !emitc.array<1xf32>
- %2 = get_field @fieldName1 : !emitc.array<1xf32>
- %3 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
- return
- }
-}
-
-// CHECK-LABEL: class modelClass {
-// CHECK-NEXT: public:
-// CHECK-NEXT: float fieldName0[1];
-// CHECK-NEXT: float fieldName1[1];
-// CHECK-NEXT: void execute() {
-// CHECK-NEXT: size_t v1 = 0;
-// CHECK-NEXT: return;
-// CHECK-NEXT: }
-// CHECK-NEXT: };
-
-emitc.class final @finalClass {
- emitc.field @fieldName0 : !emitc.array<1xf32>
- emitc.field @fieldName1 : !emitc.array<1xf32>
- emitc.func @execute() {
- %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
- %1 = get_field @fieldName0 : !emitc.array<1xf32>
- %2 = get_field @fieldName1 : !emitc.array<1xf32>
- %3 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
- return
- }
-}
-
-// CHECK-LABEL: class finalClass final {
-// CHECK-NEXT: public:
-// CHECK-NEXT: float fieldName0[1];
-// CHECK-NEXT: float fieldName1[1];
-// CHECK-NEXT: void execute() {
-// CHECK-NEXT: size_t v1 = 0;
-// CHECK-NEXT: return;
-// CHECK-NEXT: }
-// CHECK-NEXT: };
-
-emitc.class @mainClass {
- emitc.field @fieldName0 : !emitc.array<2xf32> = dense<0.0> {attrs = {emitc.name_hint = "another_feature"}}
- emitc.func @get_fieldName0() {
- %0 = emitc.get_field @fieldName0 : !emitc.array<2xf32>
- return
- }
-}
-
-// CHECK-LABEL: class mainClass {
-// CHECK-NEXT: public:
-// CHECK-NEXT: float fieldName0[2] = {0.0e+00f, 0.0e+00f};
-// CHECK-NEXT: void get_fieldName0() {
-// CHECK-NEXT: return;
-// CHECK-NEXT: }
-// CHECK-NEXT: };
-
-emitc.class @reflectionClass {
- emitc.field @reflectionMap : !emitc.opaque<"const std::map<std::string, std::string>"> = #emitc.opaque<"{ { \22another_feature\22, \22fieldName0\22 } }">
- emitc.func @get_reflectionMap() {
- %0 = emitc.get_field @reflectionMap : !emitc.opaque<"const std::map<std::string, std::string>">
- return
- }
-}
-
-// CHECK-LABEL: class reflectionClass {
-// CHECK-NEXT: public:
-// CHECK-NEXT: const std::map<std::string, std::string> reflectionMap = { { "another_feature", "fieldName0" } };
-// CHECK-NEXT: void get_reflectionMap() {
-// CHECK-NEXT: return;
-// CHECK-NEXT: }
-// CHECK-NEXT: };
-
diff --git a/mlir/test/python/dialects/linalg/opdsl/test_core_named_ops.py b/mlir/test/python/dialects/linalg/opdsl/test_core_named_ops.py
index ee76b6d..bc273bf 100644
--- a/mlir/test/python/dialects/linalg/opdsl/test_core_named_ops.py
+++ b/mlir/test/python/dialects/linalg/opdsl/test_core_named_ops.py
@@ -1,7 +1,7 @@
# RUN: %PYTHON -m mlir.dialects.linalg.opdsl.dump_oplib .ops.core_named_ops | FileCheck %s
# Just verify that at least one known op is generated.
-# CHECK: name: matmul
+# CHECK: name: copy
# verify some special cases: negf->NegFOp, powf->PowFOp
# CHECK cpp_class_name: NegFOp
diff --git a/mlir/test/python/dialects/nvvm.py b/mlir/test/python/dialects/nvvm.py
index 0eef97d..3eb62be 100644
--- a/mlir/test/python/dialects/nvvm.py
+++ b/mlir/test/python/dialects/nvvm.py
@@ -5,6 +5,8 @@ from mlir.ir import *
from mlir.dialects import nvvm
from mlir.dialects import llvm
from mlir.dialects import func
+import mlir.extras.types as T
+from mlir.dialects import arith
def constructAndPrintInModule(f):
@@ -25,6 +27,7 @@ def testSmoke():
"!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>"
)
shape_attr = Attribute.parse("#nvvm.shape<m = 64, n = 32, k = 16>")
+
# CHECK-LABEL: func @wgmma_f32_f16_f16(%arg0: i64, %arg1: i64)
@func.FuncOp.from_py_func(i64, i64)
def wgmma_f32_f16_f16(desc_a, desc_b):
@@ -48,3 +51,41 @@ def testSmoke():
layoutA=nvvm.MMALayout.col,
layoutB=nvvm.MMALayout.col,
)
+
+
+# CHECK-LABEL: TEST: test_inline_ptx
+# CHECK-LABEL: func.func @my_inline_ptx(
+# CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: f32, %[[arg1:[a-zA-Z0-9_]+]]: f32, %[[arg2:[a-zA-Z0-9_]+]]: i32, %[[arg3:[a-zA-Z0-9_]+]]: i32)
+# CHECK: %[[S0:.+]]:2 = nvvm.inline_ptx
+# CHECK-SAME: ro(%[[arg0]], %[[arg1]] : f32, f32) rw(%[[arg2]], %[[arg3]] : i32, i32) -> f32, f32
+# CHECK: %[[S1:.+]] = arith.addf %[[arg0]], %[[arg1]] : f32
+# CHECK: %[[S2:.+]] = arith.addi %[[arg2]], %[[arg3]] : i32
+# CHECK: %[[S3:.+]] = arith.addf %[[S0]]#0, %[[S0]]#1 : f32
+
+
+@constructAndPrintInModule
+def test_inline_ptx():
+ i32 = T.i32()
+ f32 = T.f32()
+
+ @func.FuncOp.from_py_func(f32, f32, i32, i32)
+ def my_inline_ptx(a, b, c, d):
+ ptx = r"""
+ {
+ .reg .pred p;
+ setp.ge.s32 p, {$r0}, {$r1};
+ selp.s32 {$r0}, {$r0}, {$r1}, p;
+ selp.s32 {$r1}, {$r0}, {$r1}, p;
+ selp.s32 {$rw0}, {$r0}, {$r1}, p;
+ selp.s32 {$rw1}, {$r0}, {$r1}, p;
+ }
+ """
+ wo0, wo1 = nvvm.inline_ptx(
+ read_only_args=[a, b],
+ read_write_args=[c, d],
+ write_only_args=[f32, f32],
+ ptx_code=ptx,
+ )
+ arith.addf(a, b)
+ arith.addi(c, d)
+ arith.addf(wo0, wo1)
diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py
index a51f215..5a648fe 100644
--- a/mlir/test/python/dialects/transform_vector_ext.py
+++ b/mlir/test/python/dialects/transform_vector_ext.py
@@ -46,6 +46,8 @@ def non_configurable_patterns():
vector.ApplyLowerOuterProductPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_gather
vector.ApplyLowerGatherPatternsOp()
+ # CHECK: transform.apply_patterns.vector.unroll_from_elements
+ vector.ApplyUnrollFromElementsPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_scan
vector.ApplyLowerScanPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_shape_cast
diff --git a/mlir/test/python/global_constructors.py b/mlir/test/python/global_constructors.py
new file mode 100644
index 0000000..5020c00
--- /dev/null
+++ b/mlir/test/python/global_constructors.py
@@ -0,0 +1,72 @@
+# UNSUPPORTED: target=aarch64{{.*}}, target=arm64{{.*}}
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+# REQUIRES: host-supports-jit
+import gc, sys, os, tempfile
+from mlir.ir import *
+from mlir.passmanager import *
+from mlir.execution_engine import *
+from mlir.runtime import *
+
+
+# Log everything to stderr and flush so that we have a unified stream to match
+# errors/info emitted by MLIR to stderr.
+def log(*args):
+ print(*args, file=sys.stderr)
+ sys.stderr.flush()
+
+
+def run(f):
+ log("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+
+
+def lowerToLLVM(module):
+ pm = PassManager.parse(
+ "builtin.module(convert-func-to-llvm,reconcile-unrealized-casts)"
+ )
+ pm.run(module.operation)
+ return module
+
+
+# Test JIT callback in global constructor
+# CHECK-LABEL: TEST: testJITCallbackInGlobalCtor
+def testJITCallbackInGlobalCtor():
+ init_cnt = 0
+
+ @ctypes.CFUNCTYPE(None)
+ def initCallback():
+ nonlocal init_cnt
+ init_cnt += 1
+
+ with Context():
+ module = Module.parse(
+ r"""
+llvm.mlir.global_ctors ctors = [@ctor], priorities = [0 : i32], data = [#llvm.zero]
+llvm.func @ctor() {
+ func.call @init_callback() : () -> ()
+ llvm.return
+}
+func.func private @init_callback() attributes { llvm.emit_c_interface }
+ """
+ )
+
+ # Setup execution engine
+ execution_engine = ExecutionEngine(lowerToLLVM(module))
+
+ # Validate initialization hasn't run yet
+ assert init_cnt == 0
+
+ # # Register callback
+ execution_engine.register_runtime("init_callback", initCallback)
+
+ # # Initialize and verify
+ execution_engine.initialize()
+ assert init_cnt == 1
+ # # Second initialization should be no-op
+ execution_engine.initialize()
+ assert init_cnt == 1
+
+
+run(testJITCallbackInGlobalCtor)
diff --git a/mlir/test/python/ir/auto_location.py b/mlir/test/python/ir/auto_location.py
new file mode 100644
index 0000000..c2d5108
--- /dev/null
+++ b/mlir/test/python/ir/auto_location.py
@@ -0,0 +1,101 @@
+# RUN: %PYTHON %s | FileCheck %s
+# REQUIRES: python-ge-311
+import gc
+from contextlib import contextmanager
+
+from mlir.ir import *
+from mlir.dialects._ods_common import _cext
+from mlir.dialects import arith, _arith_ops_gen
+
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+
+
+@contextmanager
+def with_infer_location():
+ _cext.globals.set_loc_tracebacks_enabled(True)
+ yield
+ _cext.globals.set_loc_tracebacks_enabled(False)
+
+
+# CHECK-LABEL: TEST: testInferLocations
+@run
+def testInferLocations():
+ with Context() as ctx, with_infer_location():
+ ctx.allow_unregistered_dialects = True
+
+ op = Operation.create("custom.op1")
+ one = arith.constant(IndexType.get(), 1)
+ _cext.globals.register_traceback_file_exclusion(arith.__file__)
+ two = arith.constant(IndexType.get(), 2)
+
+ # fmt: off
+ # CHECK: loc(callsite("testInferLocations"("{{.*}}[[SEP:[/\\]]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":31:13 to :43) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))
+ # fmt: on
+ print(op.location)
+
+ # fmt: off
+ # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":65:12 to :76) at callsite("constant"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":110:40 to :81) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":32:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))))
+ # fmt: on
+ print(one.location)
+
+ # fmt: off
+ # CHECK: loc(callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":34:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))
+ # fmt: on
+ print(two.location)
+
+ _cext.globals.register_traceback_file_inclusion(_arith_ops_gen.__file__)
+ three = arith.constant(IndexType.get(), 3)
+ # fmt: off
+ # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":405:4 to :218) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":52:16 to :50) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))
+ # fmt: on
+ print(three.location)
+
+ def foo():
+ four = arith.constant(IndexType.get(), 4)
+ print(four.location)
+
+ # fmt: off
+ # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":405:4 to :218) at callsite("testInferLocations.<locals>.foo"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":59:19 to :53) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":65:8 to :13) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))))
+ # fmt: on
+ foo()
+
+ _cext.globals.register_traceback_file_exclusion(__file__)
+
+ # fmt: off
+ # CHECK: loc("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":405:4 to :218))
+ # fmt: on
+ foo()
+
+ def bar1():
+ def bar2():
+ def bar3():
+ five = arith.constant(IndexType.get(), 5)
+ print(five.location)
+
+ bar3()
+
+ bar2()
+
+ _cext.globals.register_traceback_file_inclusion(__file__)
+ _cext.globals.register_traceback_file_exclusion(_arith_ops_gen.__file__)
+
+ _cext.globals.set_loc_tracebacks_frame_limit(2)
+ # fmt: off
+ # CHECK: loc(callsite("testInferLocations.<locals>.bar1.<locals>.bar2.<locals>.bar3"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":77:27 to :61) at "testInferLocations.<locals>.bar1.<locals>.bar2"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":80:16 to :22)))
+ # fmt: on
+ bar1()
+
+ _cext.globals.set_loc_tracebacks_frame_limit(1)
+ # fmt: off
+ # CHECK: loc("testInferLocations.<locals>.bar1.<locals>.bar2.<locals>.bar3"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":77:27 to :61))
+ # fmt: on
+ bar1()
+
+ _cext.globals.set_loc_tracebacks_frame_limit(0)
+ # CHECK: loc(unknown)
+ bar1()
diff --git a/mlir/test/python/ir/context_managers.py b/mlir/test/python/ir/context_managers.py
index 8091687..5d9f9ce 100644
--- a/mlir/test/python/ir/context_managers.py
+++ b/mlir/test/python/ir/context_managers.py
@@ -35,25 +35,14 @@ def testLocationEnterExit():
# Asserting a different context should clear it.
with Context() as ctx2:
assert Context.current is ctx2
- try:
- _ = Location.current
- except ValueError:
- pass
- else:
- assert False, "Expected exception"
+ assert Location.current is None
# And should restore.
assert Context.current is ctx1
assert Location.current is loc1
# All should clear.
- try:
- _ = Location.current
- except ValueError as e:
- # CHECK: No current Location
- print(e)
- else:
- assert False, "Expected exception"
+ assert Location.current is None
run(testLocationEnterExit)
@@ -72,12 +61,7 @@ def testInsertionPointEnterExit():
assert InsertionPoint.current is ip
assert Location.current is loc1
# Location should clear.
- try:
- _ = Location.current
- except ValueError:
- pass
- else:
- assert False, "Expected exception"
+ assert Location.current is None
# Asserting the same Context should preserve.
with ctx1:
diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py
index 6065e59..ad4c934 100644
--- a/mlir/test/python/ir/module.py
+++ b/mlir/test/python/ir/module.py
@@ -121,27 +121,17 @@ def testRoundtripBinary():
def testModuleOperation():
ctx = Context()
module = Module.parse(r"""module @successfulParse {}""", ctx)
- assert ctx._get_live_module_count() == 1
op1 = module.operation
- assert ctx._get_live_operation_count() == 1
- live_ops = ctx._get_live_operation_objects()
- assert len(live_ops) == 1
- assert live_ops[0] is op1
- live_ops = None
# CHECK: module @successfulParse
print(op1)
# Ensure that operations are the same on multiple calls.
op2 = module.operation
- assert ctx._get_live_operation_count() == 1
- assert op1 is op2
+ assert op1 is not op2
+ assert op1 == op2
# Test live operation clearing.
op1 = module.operation
- assert ctx._get_live_operation_count() == 1
- num_invalidated = ctx._clear_live_operations()
- assert num_invalidated == 1
- assert ctx._get_live_operation_count() == 0
op1 = None
gc.collect()
op1 = module.operation
@@ -155,9 +145,6 @@ def testModuleOperation():
op1 = None
op2 = None
gc.collect()
- print("LIVE OPERATIONS:", ctx._get_live_operation_count())
- assert ctx._get_live_operation_count() == 0
- assert ctx._get_live_module_count() == 0
# CHECK-LABEL: TEST: testModuleCapsule
@@ -165,16 +152,17 @@ def testModuleOperation():
def testModuleCapsule():
ctx = Context()
module = Module.parse(r"""module @successfulParse {}""", ctx)
- assert ctx._get_live_module_count() == 1
# CHECK: "mlir.ir.Module._CAPIPtr"
module_capsule = module._CAPIPtr
print(module_capsule)
module_dup = Module._CAPICreate(module_capsule)
- assert module is module_dup
+ assert module is not module_dup
+ assert module == module_dup
+ module._clear_mlir_module()
+ assert module != module_dup
assert module_dup.context is ctx
# Gc and verify destructed.
module = None
module_capsule = None
module_dup = None
gc.collect()
- assert ctx._get_live_module_count() == 0
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index c6b5daf..7759b17 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -907,7 +907,13 @@ def testCapsuleConversions():
m_capsule = m._CAPIPtr
assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
m2 = Operation._CAPICreate(m_capsule)
- assert m2 is m
+ assert m2 is not m
+ assert m2 == m
+ # Gc and verify destructed.
+ m = None
+ m_capsule = None
+ m2 = None
+ gc.collect()
# CHECK-LABEL: TEST: testOperationErase
@@ -1021,6 +1027,8 @@ def testDetachFromParent():
with Context():
m1 = Module.parse("func.func private @foo()")
func = m1.body.operations[0].detach_from_parent()
+ # CHECK: func.attached=False
+ print(f"{func.attached=}")
try:
func.detach_from_parent()
diff --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py
index 8b6d7ea..99d5fadf 100644
--- a/mlir/test/python/ir/symbol_table.py
+++ b/mlir/test/python/ir/symbol_table.py
@@ -56,6 +56,7 @@ def testSymbolTableInsert():
print(m1)
assert "bar" not in symbol_table
+ bar._set_invalid()
try:
print(bar)
except RuntimeError as e:
diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index e26d42b..5f92f5b 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -176,14 +176,6 @@ def testRunPipelineError():
@run
def testPostPassOpInvalidation():
with Context() as ctx:
- log_op_count = lambda: log("live ops:", ctx._get_live_operation_count())
-
- # CHECK: invalidate_ops=False
- log("invalidate_ops=False")
-
- # CHECK: live ops: 0
- log_op_count()
-
module = ModuleOp.parse(
"""
module {
@@ -196,9 +188,6 @@ def testPostPassOpInvalidation():
"""
)
- # CHECK: live ops: 1
- log_op_count()
-
outer_const_op = module.body.operations[0]
# CHECK: %[[VAL0:.*]] = arith.constant 10 : i64
log(outer_const_op)
@@ -214,12 +203,7 @@ def testPostPassOpInvalidation():
# CHECK: %[[VAL1]] = arith.constant 10 : i64
log(inner_const_op)
- # CHECK: live ops: 4
- log_op_count()
-
- PassManager.parse("builtin.module(canonicalize)").run(
- module, invalidate_ops=False
- )
+ PassManager.parse("builtin.module(canonicalize)").run(module)
# CHECK: func.func @foo() {
# CHECK: return
# CHECK: }
@@ -233,9 +217,6 @@ def testPostPassOpInvalidation():
# CHECK: invalidate_ops=True
log("invalidate_ops=True")
- # CHECK: live ops: 4
- log_op_count()
-
module = ModuleOp.parse(
"""
module {
@@ -247,30 +228,24 @@ def testPostPassOpInvalidation():
}
"""
)
- outer_const_op = module.body.operations[0]
- func_op = module.body.operations[1]
- inner_const_op = func_op.body.blocks[0].operations[0]
-
- # CHECK: live ops: 4
- log_op_count()
PassManager.parse("builtin.module(canonicalize)").run(module)
- # CHECK: live ops: 1
- log_op_count()
-
+ func_op._set_invalid()
try:
log(func_op)
except RuntimeError as e:
# CHECK: the operation has been invalidated
log(e)
+ outer_const_op._set_invalid()
try:
log(outer_const_op)
except RuntimeError as e:
# CHECK: the operation has been invalidated
log(e)
+ inner_const_op._set_invalid()
try:
log(inner_const_op)
except RuntimeError as e:
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 0a1693c..b8e28c6 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -449,7 +449,7 @@ static bool isAttribute(LinalgOperandDefKind kind) {
}
// Get the enum name for the given operand kind.
-std::string convertOperandKindToEnumName(LinalgOperandDefKind kind) {
+static std::string convertOperandKindToEnumName(LinalgOperandDefKind kind) {
switch (kind) {
case LinalgOperandDefKind::UnaryFnAttr:
return std::string("UnaryFn");
@@ -466,7 +466,7 @@ std::string convertOperandKindToEnumName(LinalgOperandDefKind kind) {
}
// Get the enum name for the given function kind.
-std::string convertFunctionKindToEnumName(ScalarFnKind kind) {
+static std::string convertFunctionKindToEnumName(ScalarFnKind kind) {
switch (kind) {
case ScalarFnKind::Unary:
return std::string("UnaryFn");
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 14714c45..7b992b4 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -98,7 +98,6 @@ void registerTestDiagnosticsMetadataPass();
void registerTestDominancePass();
void registerTestDynamicPipelinePass();
void registerTestEmulateNarrowTypePass();
-void registerTestExpandMathPass();
void registerTestFooAnalysisPass();
void registerTestComposeSubView();
void registerTestMultiBuffering();
@@ -245,7 +244,6 @@ void registerTestPasses() {
mlir::test::registerTestDominancePass();
mlir::test::registerTestDynamicPipelinePass();
mlir::test::registerTestEmulateNarrowTypePass();
- mlir::test::registerTestExpandMathPass();
mlir::test::registerTestFooAnalysisPass();
mlir::test::registerTestComposeSubView();
mlir::test::registerTestMultiBuffering();
diff --git a/mlir/tools/mlir-runner/mlir-runner.cpp b/mlir/tools/mlir-runner/mlir-runner.cpp
index 932c9f6..44ad660 100644
--- a/mlir/tools/mlir-runner/mlir-runner.cpp
+++ b/mlir/tools/mlir-runner/mlir-runner.cpp
@@ -32,7 +32,7 @@ using namespace mlir;
// TODO: Consider removing this linking functionality from the SPIR-V CPU Runner
// flow in favour of a more proper host/device split like other runners.
// https://github.com/llvm/llvm-project/issues/115348
-static llvm::cl::opt<bool> LinkNestedModules(
+static llvm::cl::opt<bool> linkNestedModules(
"link-nested-modules",
llvm::cl::desc("Link two nested MLIR modules into a single LLVM IR module. "
"Useful if both the host and device code can be run on the "
@@ -56,7 +56,7 @@ convertMLIRModule(Operation *op, llvm::LLVMContext &context) {
return op->emitError("op must be a 'builtin.module"), nullptr;
std::unique_ptr<llvm::Module> kernelModule;
- if (LinkNestedModules) {
+ if (linkNestedModules) {
// Verify that there is only one nested module.
auto modules = module.getOps<ModuleOp>();
if (!llvm::hasSingleElement(modules)) {
@@ -73,7 +73,7 @@ convertMLIRModule(Operation *op, llvm::LLVMContext &context) {
std::unique_ptr<llvm::Module> mainModule =
translateModuleToLLVMIR(module, context);
- if (LinkNestedModules)
+ if (linkNestedModules)
llvm::Linker::linkModules(*mainModule, std::move(kernelModule));
return mainModule;
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index 10a162f81..a1899a8 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -585,7 +585,7 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
os.getStream().printReindented(strfmt(checkParamKey, param->getName()));
if (isa<ParameterElement>(arg))
genVariableParser(param, ctx, os.indent());
- else if (auto custom = dyn_cast<CustomDirective>(arg))
+ else if (auto *custom = dyn_cast<CustomDirective>(arg))
genCustomParser(custom, ctx, os.indent());
os.unindent() << "} else ";
// Print the check for duplicate or unknown parameter.
@@ -877,9 +877,9 @@ void DefFormat::genCommaSeparatedPrinter(
extra(arg);
shouldEmitSpace = false;
lastWasPunctuation = true;
- if (auto realParam = dyn_cast<ParameterElement>(arg))
+ if (auto *realParam = dyn_cast<ParameterElement>(arg))
genVariablePrinter(realParam, ctx, os);
- else if (auto custom = dyn_cast<CustomDirective>(arg))
+ else if (auto *custom = dyn_cast<CustomDirective>(arg))
genCustomPrinter(custom, ctx, os);
if (param->isOptional())
os.unindent() << "}\n";
@@ -1124,7 +1124,7 @@ DefFormatParser::verifyStructArguments(SMLoc loc,
return emitError(loc, "expected a parameter, custom directive or params "
"directive in `struct` arguments list");
}
- if (auto custom = dyn_cast<CustomDirective>(el)) {
+ if (auto *custom = dyn_cast<CustomDirective>(el)) {
if (custom->getNumElements() != 1) {
return emitError(loc, "`struct` can only contain `custom` directives "
"with a single argument");
diff --git a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
index da28ca3..533a9cf 100644
--- a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
@@ -151,9 +151,9 @@ void Generator::emitParse(StringRef kind, const Record &x) {
os << "\n\n";
}
-void printParseConditional(mlir::raw_indented_ostream &ios,
- ArrayRef<const Init *> args,
- ArrayRef<std::string> argNames) {
+static void printParseConditional(mlir::raw_indented_ostream &ios,
+ ArrayRef<const Init *> args,
+ ArrayRef<std::string> argNames) {
ios << "if ";
auto parenScope = ios.scope("(", ") {");
ios.indent();
diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index 8e2d611..acc9b61d 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -64,7 +64,7 @@ static void emitEnumClass(EnumInfo enumInfo, raw_ostream &os) {
if (enumInfo.isBitEnum()) {
os << formatv(" def __iter__(self):\n"
" return iter([case for case in type(self) if "
- "(self & case) is case])\n");
+ "(self & case) is case and self is not case])\n");
os << formatv(" def __len__(self):\n"
" return bin(self).count(\"1\")\n");
os << "\n";
diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index 06dc588..d152763 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -222,7 +222,7 @@ inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
llvm::StringSwitch<StringRef>(separator.trim())
.Case("|", "parseOptionalVerticalBar")
.Case(",", "parseOptionalComma")
- .Default("error, enum seperator must be '|' or ','");
+ .Default("error, enum separator must be '|' or ','");
os << formatv(parsedAndPrinterStartUnquotedBitEnum, qualName, cppNamespace,
enumInfo.getSummary(), casesList, separator,
parseSeparatorFn);
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index d2e38e9..038f56d 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -41,6 +41,7 @@ from ._ods_common import (
segmented_accessor as _ods_segmented_accessor,
)
_ods_ir = _ods_cext.ir
+_ods_cext.globals.register_traceback_file_exclusion(__file__)
import builtins
from typing import Sequence as _Sequence, Union as _Union
diff --git a/mlir/tools/mlir-tblgen/mlir-tblgen.cpp b/mlir/tools/mlir-tblgen/mlir-tblgen.cpp
index 6c4b619..9c5cc6a 100644
--- a/mlir/tools/mlir-tblgen/mlir-tblgen.cpp
+++ b/mlir/tools/mlir-tblgen/mlir-tblgen.cpp
@@ -18,10 +18,11 @@ using namespace llvm;
using namespace mlir;
// Generator that prints records.
-GenRegistration printRecords("print-records", "Print all records to stdout",
- [](const RecordKeeper &records, raw_ostream &os) {
- os << records;
- return false;
- });
+static GenRegistration
+ printRecords("print-records", "Print all records to stdout",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ os << records;
+ return false;
+ });
int main(int argc, char **argv) { return MlirTblgenMain(argc, argv); }
diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
index 4343f2d..48634a1 100644
--- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
+++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
@@ -40,7 +40,7 @@ static llvm::cl::opt<std::string>
selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
llvm::cl::cat(dialectGenCat), llvm::cl::Required);
-Value createPredicate(OpBuilder &builder, tblgen::Pred pred) {
+static Value createPredicate(OpBuilder &builder, tblgen::Pred pred) {
MLIRContext *ctx = builder.getContext();
if (pred.isCombined()) {
@@ -68,21 +68,22 @@ Value createPredicate(OpBuilder &builder, tblgen::Pred pred) {
return op;
}
-Value typeToConstraint(OpBuilder &builder, Type type) {
+static Value typeToConstraint(OpBuilder &builder, Type type) {
MLIRContext *ctx = builder.getContext();
auto op =
irdl::IsOp::create(builder, UnknownLoc::get(ctx), TypeAttr::get(type));
return op.getOutput();
}
-Value baseToConstraint(OpBuilder &builder, StringRef baseClass) {
+static Value baseToConstraint(OpBuilder &builder, StringRef baseClass) {
MLIRContext *ctx = builder.getContext();
auto op = irdl::BaseOp::create(builder, UnknownLoc::get(ctx),
StringAttr::get(ctx, baseClass));
return op.getOutput();
}
-std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
+static std::optional<Type> recordToType(MLIRContext *ctx,
+ const Record &predRec) {
if (predRec.isSubClassOf("I")) {
auto width = predRec.getValueAsInt("bitwidth");
return IntegerType::get(ctx, width, IntegerType::Signless);
@@ -171,7 +172,8 @@ std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
return std::nullopt;
}
-Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
+static Value createTypeConstraint(OpBuilder &builder,
+ tblgen::Constraint constraint) {
MLIRContext *ctx = builder.getContext();
const Record &predRec = constraint.getDef();
@@ -260,7 +262,8 @@ Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
return createPredicate(builder, constraint.getPredicate());
}
-Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
+static Value createAttrConstraint(OpBuilder &builder,
+ tblgen::Constraint constraint) {
MLIRContext *ctx = builder.getContext();
const Record &predRec = constraint.getDef();
@@ -341,7 +344,8 @@ Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
return createPredicate(builder, constraint.getPredicate());
}
-Value createRegionConstraint(OpBuilder &builder, tblgen::Region constraint) {
+static Value createRegionConstraint(OpBuilder &builder,
+ tblgen::Region constraint) {
MLIRContext *ctx = builder.getContext();
const Record &predRec = constraint.getDef();
@@ -383,8 +387,8 @@ static StringRef getAttrName(tblgen::AttrDef &tblgenType) {
}
/// Extract an operation to IRDL.
-irdl::OperationOp createIRDLOperation(OpBuilder &builder,
- tblgen::Operator &tblgenOp) {
+static irdl::OperationOp createIRDLOperation(OpBuilder &builder,
+ tblgen::Operator &tblgenOp) {
MLIRContext *ctx = builder.getContext();
StringRef opName = getOperatorName(tblgenOp);
@@ -488,7 +492,8 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
return op;
}
-irdl::TypeOp createIRDLType(OpBuilder &builder, tblgen::TypeDef &tblgenType) {
+static irdl::TypeOp createIRDLType(OpBuilder &builder,
+ tblgen::TypeDef &tblgenType) {
MLIRContext *ctx = builder.getContext();
StringRef typeName = getTypeName(tblgenType);
std::string combined = ("!" + typeName).str();
@@ -501,8 +506,8 @@ irdl::TypeOp createIRDLType(OpBuilder &builder, tblgen::TypeDef &tblgenType) {
return op;
}
-irdl::AttributeOp createIRDLAttr(OpBuilder &builder,
- tblgen::AttrDef &tblgenAttr) {
+static irdl::AttributeOp createIRDLAttr(OpBuilder &builder,
+ tblgen::AttrDef &tblgenAttr) {
MLIRContext *ctx = builder.getContext();
StringRef attrName = getAttrName(tblgenAttr);
std::string combined = ("#" + attrName).str();
diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
index f64bb24..a4a48f0 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
@@ -1087,14 +1087,15 @@ TEST(IntegerPolyhedronTest, negativeDividends) {
checkDivisionRepresentation(poly1, divisions, denoms);
}
-void expectRationalLexMin(const IntegerPolyhedron &poly,
- ArrayRef<Fraction> min) {
+static void expectRationalLexMin(const IntegerPolyhedron &poly,
+ ArrayRef<Fraction> min) {
auto lexMin = poly.findRationalLexMin();
ASSERT_TRUE(lexMin.isBounded());
EXPECT_EQ(ArrayRef<Fraction>(*lexMin), min);
}
-void expectNoRationalLexMin(OptimumKind kind, const IntegerPolyhedron &poly) {
+static void expectNoRationalLexMin(OptimumKind kind,
+ const IntegerPolyhedron &poly) {
ASSERT_NE(kind, OptimumKind::Bounded)
<< "Use expectRationalLexMin for bounded min";
EXPECT_EQ(poly.findRationalLexMin().getKind(), kind);
@@ -1167,13 +1168,15 @@ TEST(IntegerPolyhedronTest, findRationalLexMin) {
parseIntegerPolyhedron("(x) : (2*x >= 0, -x - 1 >= 0)"));
}
-void expectIntegerLexMin(const IntegerPolyhedron &poly, ArrayRef<int64_t> min) {
+static void expectIntegerLexMin(const IntegerPolyhedron &poly,
+ ArrayRef<int64_t> min) {
MaybeOptimum<SmallVector<DynamicAPInt, 8>> lexMin = poly.findIntegerLexMin();
ASSERT_TRUE(lexMin.isBounded());
EXPECT_EQ(*lexMin, getDynamicAPIntVec(min));
}
-void expectNoIntegerLexMin(OptimumKind kind, const IntegerPolyhedron &poly) {
+static void expectNoIntegerLexMin(OptimumKind kind,
+ const IntegerPolyhedron &poly) {
ASSERT_NE(kind, OptimumKind::Bounded)
<< "Use expectRationalLexMin for bounded min";
EXPECT_EQ(poly.findRationalLexMin().getKind(), kind);
@@ -1191,7 +1194,7 @@ TEST(IntegerPolyhedronTest, findIntegerLexMin) {
">= 0, -11*z + 5*y - 3*x + 7 >= 0)"));
}
-void expectSymbolicIntegerLexMin(
+static void expectSymbolicIntegerLexMin(
StringRef polyStr,
ArrayRef<std::pair<StringRef, StringRef>> expectedLexminRepr,
ArrayRef<StringRef> expectedUnboundedDomainRepr) {
@@ -1218,8 +1221,9 @@ void expectSymbolicIntegerLexMin(
}
}
-void expectSymbolicIntegerLexMin(
- StringRef polyStr, ArrayRef<std::pair<StringRef, StringRef>> result) {
+static void
+expectSymbolicIntegerLexMin(StringRef polyStr,
+ ArrayRef<std::pair<StringRef, StringRef>> result) {
expectSymbolicIntegerLexMin(polyStr, result, {});
}
@@ -1463,8 +1467,8 @@ TEST(IntegerPolyhedronTest, computeVolume) {
/*trueVolume=*/{}, /*resultBound=*/{});
}
-bool containsPointNoLocal(const IntegerPolyhedron &poly,
- ArrayRef<int64_t> point) {
+static bool containsPointNoLocal(const IntegerPolyhedron &poly,
+ ArrayRef<int64_t> point) {
return poly.containsPointNoLocal(getDynamicAPIntVec(point)).has_value();
}
diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index a6ed5c5..9ae90a4 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -714,3 +714,14 @@ TEST(IntegerRelationTest, getVarKindRange) {
}
EXPECT_THAT(actual, ElementsAre(2, 3, 4));
}
+
+TEST(IntegerRelationTest, addLocalModulo) {
+ IntegerRelation rel = parseRelationFromSet("(x) : (x >= 0, 100 - x >= 0)", 1);
+ unsigned result = rel.addLocalModulo({1, 0}, 32); // x % 32
+ rel.convertVarKind(VarKind::Local,
+ result - rel.getVarKindOffset(VarKind::Local),
+ rel.getNumVarKind(VarKind::Local), VarKind::Range);
+ for (unsigned x = 0; x <= 100; ++x) {
+ EXPECT_TRUE(rel.containsPointNoLocal({x, x % 32}));
+ }
+}
diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index abc6c707..fe26fc1 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -124,7 +124,7 @@ FOREVERY_BINOP(IMPL_BINOP_PATTERN)
class MergerTestBase : public ::testing::TestWithParam<LevelFormat> {
protected:
MergerTestBase(unsigned numTensors, unsigned numLoops)
- : merger(numTensors, numLoops, /*maxRank=*/numLoops) {
+ : merger(numTensors, numLoops, /*maxLvlRank=*/numLoops) {
tensors.reserve(numTensors);
for (unsigned t = 0; t < numTensors; t++)
tensors.push_back(merger.addTensorExp(tid(t)));
diff --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp
index 887db22..cdeeca2 100644
--- a/mlir/unittests/ExecutionEngine/Invoke.cpp
+++ b/mlir/unittests/ExecutionEngine/Invoke.cpp
@@ -205,7 +205,13 @@ TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(BasicMemref)) {
};
int64_t shape[] = {k, m};
int64_t shapeAlloc[] = {k + 1, m + 1};
- OwningMemRef<float, 2> a(shape, shapeAlloc, init);
+ // Use a large alignment to stress the case where the memref data/basePtr are
+ // disjoint.
+ int alignment = 8192;
+ OwningMemRef<float, 2> a(shape, shapeAlloc, init, alignment);
+ ASSERT_EQ(
+ (void *)(((uintptr_t)a->basePtr + alignment - 1) & ~(alignment - 1)),
+ a->data);
ASSERT_EQ(a->sizes[0], k);
ASSERT_EQ(a->sizes[1], m);
ASSERT_EQ(a->strides[0], m + 1);
@@ -316,4 +322,55 @@ TEST(NativeMemRefJit, MAYBE_JITCallback) {
ASSERT_EQ(elt, coefficient * count++);
}
+static int initCnt = 0;
+// A helper function that will be called during the JIT's initialization.
+static void initCallback() { initCnt += 1; }
+
+TEST(MLIRExecutionEngine, SKIP_WITHOUT_JIT(CallbackInGlobalCtor)) {
+ auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
+ ASSERT_TRUE(!!tmBuilderOrError);
+ if (tmBuilderOrError->getTargetTriple().isAArch64()) {
+ GTEST_SKIP() << "Skipping global ctor initialization test on Aarch64 "
+ "because of bug #71963";
+ return;
+ }
+ std::string moduleStr = R"mlir(
+ llvm.mlir.global_ctors ctors = [@ctor], priorities = [0 : i32], data = [#llvm.zero]
+ llvm.func @ctor() {
+ func.call @init_callback() : () -> ()
+ llvm.return
+ }
+ func.func private @init_callback() attributes { llvm.emit_c_interface }
+ )mlir";
+
+ DialectRegistry registry;
+ registerAllDialects(registry);
+ registerBuiltinDialectTranslation(registry);
+ registerLLVMDialectTranslation(registry);
+ MLIRContext context(registry);
+ auto module = parseSourceString<ModuleOp>(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+ ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
+ ExecutionEngineOptions jitOptions;
+ auto jitOrError = ExecutionEngine::create(*module, jitOptions);
+ ASSERT_TRUE(!!jitOrError);
+ // validate initialization is not run on construction
+ ASSERT_EQ(initCnt, 0);
+ auto jit = std::move(jitOrError.get());
+ // Define any extra symbols so they're available at initialization.
+ jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
+ llvm::orc::SymbolMap symbolMap;
+ symbolMap[interner("_mlir_ciface_init_callback")] = {
+ llvm::orc::ExecutorAddr::fromPtr(initCallback),
+ llvm::JITSymbolFlags::Exported};
+ return symbolMap;
+ });
+ jit->initialize();
+ // validate the side effect of initialization
+ ASSERT_EQ(initCnt, 1);
+ // next initialization should be noop
+ jit->initialize();
+ ASSERT_EQ(initCnt, 1);
+}
+
#endif // _WIN32
diff --git a/mlir/unittests/IR/AttrTypeReplacerTest.cpp b/mlir/unittests/IR/AttrTypeReplacerTest.cpp
index c7b42eb..f17c712 100644
--- a/mlir/unittests/IR/AttrTypeReplacerTest.cpp
+++ b/mlir/unittests/IR/AttrTypeReplacerTest.cpp
@@ -80,9 +80,9 @@ public:
});
}
- Type getFunctionTypeChain(unsigned N) {
+ Type getFunctionTypeChain(unsigned n) {
Type type = b.getIndexType();
- for (unsigned i = 0; i < N; i++)
+ for (unsigned i = 0; i < n; i++)
type = b.getFunctionType({}, type);
return type;
};
@@ -168,9 +168,9 @@ public:
});
}
- Type getFunctionTypeTree(unsigned N) {
+ Type getFunctionTypeTree(unsigned n) {
Type type = b.getIndexType();
- for (unsigned i = 0; i < N; i++)
+ for (unsigned i = 0; i < n; i++)
type = b.getFunctionType(type, type);
return type;
};
diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt
index a46e647..75cd2d6 100644
--- a/mlir/unittests/IR/CMakeLists.txt
+++ b/mlir/unittests/IR/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_unittest(MLIRIRTests
MemrefLayoutTest.cpp
OperationSupportTest.cpp
PatternMatchTest.cpp
+ RemarkTest.cpp
ShapedTypeTest.cpp
SymbolTableTest.cpp
TypeTest.cpp
@@ -28,3 +29,4 @@ add_mlir_unittest(MLIRIRTests
target_include_directories(MLIRIRTests PRIVATE "${MLIR_BINARY_DIR}/test/lib/Dialect/Test")
mlir_target_link_libraries(MLIRIRTests PRIVATE MLIRIR)
target_link_libraries(MLIRIRTests PRIVATE MLIRTestDialect)
+target_link_libraries(MLIRIRTests PRIVATE MLIRRemarkStreamer)
diff --git a/mlir/unittests/IR/RemarkTest.cpp b/mlir/unittests/IR/RemarkTest.cpp
new file mode 100644
index 0000000..65e1e08
--- /dev/null
+++ b/mlir/unittests/IR/RemarkTest.cpp
@@ -0,0 +1,315 @@
+//===- RemarkTest.cpp - Remark unit tests -------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Remarks.h"
+#include "mlir/Remark/RemarkStreamer.h"
+#include "mlir/Support/TypeID.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/IR/LLVMRemarkStreamer.h"
+#include "llvm/Remarks/RemarkFormat.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/LogicalResult.h"
+#include "llvm/Support/YAMLParser.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include <optional>
+
+using namespace llvm;
+using namespace mlir;
+using namespace testing;
+namespace {
+
+TEST(Remark, TestOutputOptimizationRemark) {
+ std::string categoryVectorizer("Vectorizer");
+ std::string categoryRegister("Register");
+ std::string categoryUnroll("Unroll");
+ std::string categoryInliner("Inliner");
+ std::string categoryReroller("Reroller");
+ std::string myPassname1("myPass1");
+ SmallString<64> tmpPathStorage;
+ sys::fs::createUniquePath("remarks-%%%%%%.yaml", tmpPathStorage,
+ /*MakeAbsolute=*/true);
+ std::string yamlFile =
+ std::string(tmpPathStorage.data(), tmpPathStorage.size());
+ ASSERT_FALSE(yamlFile.empty());
+
+ {
+ MLIRContext context;
+ Location loc = UnknownLoc::get(&context);
+
+ context.printOpOnDiagnostic(true);
+ context.printStackTraceOnDiagnostic(true);
+
+ // Setup the remark engine
+ mlir::remark::RemarkCategories cats{/*passed=*/categoryVectorizer,
+ /*missed=*/categoryUnroll,
+ /*analysis=*/categoryRegister,
+ /*failed=*/categoryInliner};
+
+ LogicalResult isEnabled =
+ mlir::remark::enableOptimizationRemarksWithLLVMStreamer(
+ context, yamlFile, llvm::remarks::Format::YAML, cats);
+ ASSERT_TRUE(succeeded(isEnabled)) << "Failed to enable remark engine";
+
+ // PASS: something succeeded
+ remark::passed(loc, remark::RemarkOpts::name("Pass1")
+ .category(categoryVectorizer)
+ .subCategory(myPassname1)
+ .function("bar"))
+ << "vectorized loop" << remark::metric("tripCount", 128);
+
+ // ANALYSIS: neutral insight
+ remark::analysis(
+ loc, remark::RemarkOpts::name("Analysis1").category(categoryRegister))
+ << "Kernel uses 168 registers";
+
+ // MISSED: explain why + suggest a fix
+ remark::missed(loc, remark::RemarkOpts::name("Miss1")
+ .category(categoryUnroll)
+ .subCategory(myPassname1))
+ << remark::reason("not profitable at this size")
+ << remark::suggest("increase unroll factor to >=4");
+
+ // FAILURE: action attempted but failed
+ remark::failed(loc, remark::RemarkOpts::name("Failed1")
+ .category(categoryInliner)
+ .subCategory(myPassname1))
+ << remark::reason("failed due to unsupported pattern");
+
+ // FAILURE: Won't show up
+ remark::failed(loc, remark::RemarkOpts::name("Failed2")
+ .category(categoryReroller)
+ .subCategory(myPassname1))
+ << remark::reason("failed due to rerolling pattern");
+ }
+
+ // Read the file
+ auto bufferOrErr = MemoryBuffer::getFile(yamlFile);
+ ASSERT_TRUE(static_cast<bool>(bufferOrErr)) << "Failed to open remarks file";
+ std::string content = bufferOrErr.get()->getBuffer().str();
+
+ EXPECT_THAT(content, HasSubstr("--- !Passed"));
+ EXPECT_THAT(content, HasSubstr("Name: Pass1"));
+ EXPECT_THAT(content, HasSubstr("Pass: 'Vectorizer:myPass1'"));
+ EXPECT_THAT(content, HasSubstr("Function: bar"));
+ EXPECT_THAT(content, HasSubstr("Remark: vectorized loop"));
+ EXPECT_THAT(content, HasSubstr("tripCount: '128'"));
+
+ EXPECT_THAT(content, HasSubstr("--- !Analysis"));
+ EXPECT_THAT(content, HasSubstr("Pass: Register"));
+ EXPECT_THAT(content, HasSubstr("Name: Analysis1"));
+ EXPECT_THAT(content, HasSubstr("Function: '<unknown function>'"));
+ EXPECT_THAT(content, HasSubstr("Remark: Kernel uses 168 registers"));
+
+ EXPECT_THAT(content, HasSubstr("--- !Missed"));
+ EXPECT_THAT(content, HasSubstr("Pass: 'Unroll:myPass1'"));
+ EXPECT_THAT(content, HasSubstr("Name: Miss1"));
+ EXPECT_THAT(content, HasSubstr("Function: '<unknown function>'"));
+ EXPECT_THAT(content,
+ HasSubstr("Reason: not profitable at this size"));
+ EXPECT_THAT(content,
+ HasSubstr("Suggestion: 'increase unroll factor to >=4'"));
+
+ EXPECT_THAT(content, HasSubstr("--- !Failure"));
+ EXPECT_THAT(content, HasSubstr("Pass: 'Inliner:myPass1'"));
+ EXPECT_THAT(content, HasSubstr("Name: Failed1"));
+ EXPECT_THAT(content, HasSubstr("Function: '<unknown function>'"));
+ EXPECT_THAT(content,
+ HasSubstr("Reason: failed due to unsupported pattern"));
+
+ EXPECT_THAT(content, Not(HasSubstr("Failed2")));
+ EXPECT_THAT(content, Not(HasSubstr("Reroller")));
+
+ // Also verify document order to avoid false positives.
+ size_t iPassed = content.find("--- !Passed");
+ size_t iAnalysis = content.find("--- !Analysis");
+ size_t iMissed = content.find("--- !Missed");
+ size_t iFailure = content.find("--- !Failure");
+
+ ASSERT_NE(iPassed, std::string::npos);
+ ASSERT_NE(iAnalysis, std::string::npos);
+ ASSERT_NE(iMissed, std::string::npos);
+ ASSERT_NE(iFailure, std::string::npos);
+
+ EXPECT_LT(iPassed, iAnalysis);
+ EXPECT_LT(iAnalysis, iMissed);
+ EXPECT_LT(iMissed, iFailure);
+}
+
+TEST(Remark, TestNoOutputOptimizationRemark) {
+ const auto *pass1Msg = "My message";
+
+ std::string categoryFailName("myImportantCategory");
+ std::string myPassname1("myPass1");
+ std::string funcName("myFunc");
+ SmallString<64> tmpPathStorage;
+ sys::fs::createUniquePath("remarks-%%%%%%.yaml", tmpPathStorage,
+ /*MakeAbsolute=*/true);
+ std::string yamlFile =
+ std::string(tmpPathStorage.data(), tmpPathStorage.size());
+ ASSERT_FALSE(yamlFile.empty());
+ std::error_code ec =
+ llvm::sys::fs::remove(yamlFile, /*IgnoreNonExisting=*/true);
+ if (ec) {
+ FAIL() << "Failed to remove file " << yamlFile << ": " << ec.message();
+ }
+ {
+ MLIRContext context;
+ Location loc = UnknownLoc::get(&context);
+ remark::failed(loc, remark::RemarkOpts::name("myfail")
+ .category(categoryFailName)
+ .subCategory(myPassname1))
+ << remark::reason(pass1Msg);
+ }
+ // No setup, so no output file should be created
+ // check!
+ bool fileExists = llvm::sys::fs::exists(yamlFile);
+ EXPECT_FALSE(fileExists)
+ << "Expected no YAML file to be created without setupOptimizationRemarks";
+}
+
+TEST(Remark, TestOutputOptimizationRemarkDiagnostic) {
+ std::string categoryVectorizer("Vectorizer");
+ std::string categoryRegister("Register");
+ std::string categoryUnroll("Unroll");
+ std::string myPassname1("myPass1");
+ std::string fName("foo");
+
+ llvm::SmallVector<std::string> seenMsg;
+ {
+ MLIRContext context;
+ Location loc = UnknownLoc::get(&context);
+
+ context.printOpOnDiagnostic(true);
+ context.printStackTraceOnDiagnostic(true);
+
+ // Register a handler that captures the diagnostic.
+ ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) {
+ seenMsg.push_back(diag.str());
+ return success();
+ });
+
+ // Setup the remark engine
+ mlir::remark::RemarkCategories cats{/*passed=*/categoryVectorizer,
+ /*missed=*/categoryUnroll,
+ /*analysis=*/categoryRegister,
+ /*failed=*/categoryUnroll};
+
+ LogicalResult isEnabled =
+ remark::enableOptimizationRemarks(context, nullptr, cats, true);
+
+ ASSERT_TRUE(succeeded(isEnabled)) << "Failed to enable remark engine";
+
+ // PASS: something succeeded
+ remark::passed(loc, remark::RemarkOpts::name("pass1")
+ .category(categoryVectorizer)
+ .function(fName)
+ .subCategory(myPassname1))
+ << "vectorized loop" << remark::metric("tripCount", 128);
+
+ // ANALYSIS: neutral insight
+ remark::analysis(loc, remark::RemarkOpts::name("Analysis1")
+ .category(categoryRegister)
+ .function(fName))
+ << "Kernel uses 168 registers";
+
+ // MISSED: explain why + suggest a fix
+ int target = 128;
+ int tripBad = 4;
+ int threshold = 256;
+
+ remark::missed(loc, {"", categoryUnroll, "unroller2", ""})
+ << remark::reason("tripCount={0} < threshold={1}", tripBad, threshold);
+
+ remark::missed(loc, {"", categoryUnroll, "", ""})
+ << remark::reason("tripCount={0} < threshold={1}", tripBad, threshold)
+ << remark::suggest("increase unroll to {0}", target);
+
+ // FAILURE: action attempted but failed
+ remark::failed(loc, {"", categoryUnroll, "", ""})
+ << remark::reason("failed due to unsupported pattern");
+ }
+ // clang-format off
+ unsigned long expectedSize = 5;
+ ASSERT_EQ(seenMsg.size(), expectedSize);
+ EXPECT_EQ(seenMsg[0], "[Passed] pass1 | Category:Vectorizer:myPass1 | Function=foo | Remark=\"vectorized loop\", tripCount=128");
+ EXPECT_EQ(seenMsg[1], "[Analysis] Analysis1 | Category:Register | Function=foo | Remark=\"Kernel uses 168 registers\"");
+ EXPECT_EQ(seenMsg[2], "[Missed] | Category:Unroll:unroller2 | Reason=\"tripCount=4 < threshold=256\"");
+ EXPECT_EQ(seenMsg[3], "[Missed] | Category:Unroll | Reason=\"tripCount=4 < threshold=256\", Suggestion=\"increase unroll to 128\"");
+ EXPECT_EQ(seenMsg[4], "[Failure] | Category:Unroll | Reason=\"failed due to unsupported pattern\"");
+ // clang-format on
+}
+
+/// Custom remark streamer that prints remarks to stderr.
+class MyCustomStreamer : public remark::detail::MLIRRemarkStreamerBase {
+public:
+ MyCustomStreamer() = default;
+
+ void streamOptimizationRemark(const remark::detail::Remark &remark) override {
+ llvm::errs() << "Custom remark: ";
+ remark.print(llvm::errs(), true);
+ llvm::errs() << "\n";
+ }
+};
+
+TEST(Remark, TestCustomOptimizationRemarkDiagnostic) {
+ testing::internal::CaptureStderr();
+ const auto *pass1Msg = "My message";
+ const auto *pass2Msg = "My another message";
+ const auto *pass3Msg = "Do not show this message";
+
+ std::string categoryLoopunroll("LoopUnroll");
+ std::string categoryInline("Inliner");
+ std::string myPassname1("myPass1");
+ std::string myPassname2("myPass2");
+ std::string funcName("myFunc");
+
+ std::string seenMsg = "";
+
+ {
+ MLIRContext context;
+ Location loc = UnknownLoc::get(&context);
+
+ // Setup the remark engine
+ mlir::remark::RemarkCategories cats{/*passed=*/categoryLoopunroll,
+ /*missed=*/std::nullopt,
+ /*analysis=*/std::nullopt,
+ /*failed=*/categoryLoopunroll};
+
+ LogicalResult isEnabled = remark::enableOptimizationRemarks(
+ context, std::make_unique<MyCustomStreamer>(), cats, true);
+ ASSERT_TRUE(succeeded(isEnabled)) << "Failed to enable remark engine";
+
+ // Remark 1: pass, category LoopUnroll
+ remark::passed(loc, {"", categoryLoopunroll, myPassname1, ""}) << pass1Msg;
+ // Remark 2: failure, category LoopUnroll
+ remark::failed(loc, {"", categoryLoopunroll, myPassname2, ""})
+ << remark::reason(pass2Msg);
+ // Remark 3: pass, category Inline (should not be printed)
+ remark::passed(loc, {"", categoryInline, myPassname1, ""}) << pass3Msg;
+ }
+
+ llvm::errs().flush();
+ std::string errOut = ::testing::internal::GetCapturedStderr();
+
+ // Expect exactly two "Custom remark:" lines.
+ auto first = errOut.find("Custom remark:");
+ EXPECT_NE(first, std::string::npos);
+ auto second = errOut.find("Custom remark:", first + 1);
+ EXPECT_NE(second, std::string::npos);
+ auto third = errOut.find("Custom remark:", second + 1);
+ EXPECT_EQ(third, std::string::npos);
+
+ // Containment checks for messages.
+ EXPECT_NE(errOut.find(pass1Msg), std::string::npos); // printed
+ EXPECT_NE(errOut.find(pass2Msg), std::string::npos); // printed
+ EXPECT_EQ(errOut.find(pass3Msg), std::string::npos); // filtered out
+}
+} // namespace
diff --git a/mlir/unittests/Rewrite/PatternBenefit.cpp b/mlir/unittests/Rewrite/PatternBenefit.cpp
index 65ea4ee..e4363f9 100644
--- a/mlir/unittests/Rewrite/PatternBenefit.cpp
+++ b/mlir/unittests/Rewrite/PatternBenefit.cpp
@@ -66,12 +66,7 @@ TEST(PatternBenefitTest, BenefitOrder) {
PatternApplicator pa(frozenPatterns);
pa.applyDefaultCostModel();
- class MyPatternRewriter : public PatternRewriter {
- public:
- MyPatternRewriter(MLIRContext *ctx) : PatternRewriter(ctx) {}
- };
-
- MyPatternRewriter rewriter(&context);
+ PatternRewriter rewriter(&context);
(void)pa.matchAndRewrite(*module, rewriter);
EXPECT_TRUE(called1);
diff --git a/mlir/utils/clang-tidy/apply-clang-tidy.sh b/mlir/utils/clang-tidy/apply-clang-tidy.sh
index a359218..f0973dd 100755
--- a/mlir/utils/clang-tidy/apply-clang-tidy.sh
+++ b/mlir/utils/clang-tidy/apply-clang-tidy.sh
@@ -89,7 +89,8 @@ find $SRCS | grep ".cpp$" | sort | while read file ; do
echo "-----------------------------------"
echo "-- Apply check $check on file $file"
- echo "$TIMING_TIDY $CLANG_TIDY -p $BUILD_DIR $file --checks="-*,$check" -fix -fix-errors"
+ COMMAND=$(echo "$TIMING_TIDY $CLANG_TIDY -p $BUILD_DIR $file --checks="-*,$check" -fix -fix-errors")
+ echo $COMMAND
{ $TIMING_TIDY $CLANG_TIDY -p $BUILD_DIR $file --checks="-*,$check" -fix -fix-errors ; } 2>&1
git clang-format -f
if [[ $(git diff --stat) == '' ]]; then
@@ -101,16 +102,19 @@ find $SRCS | grep ".cpp$" | sort | while read file ; do
# Clang-tidy sometimes update files in the build directory, erase the .inc file generate by tablegen
# to force them to be regenerated now.
find $BUILD_DIR/tools/mlir/ | grep '\.inc' | while read file ; do rm $file ; done
- ninja -C $BUILD_DIR check-mlir > ${REJECT_DIR}/ninja.${check}.$(basename $file).log 2>&1
+ echo $COMMAND > ${REJECT_DIR}/ninja.${check}.$(basename $file).log
+ ninja -C $BUILD_DIR --quiet check-mlir >> ${REJECT_DIR}/${check}.$(basename $file).ninja.log 2>&1
if [[ $? != 0 ]] ; then
echo "check-mlir failed! (see ninja.${check}.${file}.log)"
[[ ! -z "$REJECT_DIR" ]] && git diff > "${REJECT_DIR}/${check}_$(basename ${file}).reject.diff"
continue
fi
+ rm -f ${REJECT_DIR}/ninja.${check}.$(basename $file).log
+
echo "-----------------------------------"
echo "-- Success, commit changes for check $check on file $file"
git clang-format -f
- git commit -a -m "Apply clang-tidy fixes for $check in $(basename $file) (NFC)"
+ git commit -a -m "[MLIR] Apply clang-tidy fixes for $check in $(basename $file) (NFC)"
done
done
diff --git a/mlir/utils/tree-sitter-mlir/dialect/linalg.js b/mlir/utils/tree-sitter-mlir/dialect/linalg.js
index ddde92b..f465808 100644
--- a/mlir/utils/tree-sitter-mlir/dialect/linalg.js
+++ b/mlir/utils/tree-sitter-mlir/dialect/linalg.js
@@ -4,7 +4,6 @@ module.exports = {
linalg_dialect : $ => prec.right(choice(
seq(choice(
'linalg.batch_matmul',
- 'linalg.batch_matmul_transpose_b',
'linalg.batch_matvec',
'linalg.batch_reduce_matmul', 'linalg.broadcast',
'linalg.conv_1d_ncw_fcw', 'linalg.conv_1d_nwc_wcf',
@@ -27,7 +26,6 @@ module.exports = {
'linalg.dot', 'linalg.elemwise_binary',
'linalg.elemwise_unary', 'linalg.fill',
'linalg.fill_rng_2d', 'linalg.matmul',
- 'linalg.matmul_transpose_b',
'linalg.matmul_unsigned', 'linalg.matvec',
'linalg.mmt4d', 'linalg.pooling_nchw_max',
'linalg.pooling_nchw_sum',
diff --git a/mlir/utils/tree-sitter-mlir/queries/highlights.scm b/mlir/utils/tree-sitter-mlir/queries/highlights.scm
index 4cbea7b..59e280b 100644
--- a/mlir/utils/tree-sitter-mlir/queries/highlights.scm
+++ b/mlir/utils/tree-sitter-mlir/queries/highlights.scm
@@ -213,7 +213,6 @@
"bufferization.to_tensor"
"linalg.batch_matmul"
- "linalg.batch_matmul_transpose_b"
"linalg.batch_matvec"
"linalg.batch_reduce_matmul"
"linalg.broadcast"
@@ -244,7 +243,6 @@
"linalg.fill"
"linalg.fill_rng_2d"
"linalg.matmul"
- "linalg.matmul_transpose_b"
"linalg.matmul_unsigned"
"linalg.matvec"
"linalg.mmt4d"