aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/docs/Canonicalization.md2
-rw-r--r--mlir/docs/Dialects/Shard.md6
-rw-r--r--mlir/include/mlir-c/Rewrite.h2
-rw-r--r--mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h54
-rw-r--r--mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h8
-rw-r--r--mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h27
-rw-r--r--mlir/include/mlir/Conversion/Passes.h1
-rw-r--r--mlir/include/mlir/Conversion/Passes.td32
-rw-r--r--mlir/include/mlir/Dialect/Affine/LoopUtils.h2
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td5
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td40
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td26
-rw-r--r--mlir/include/mlir/Dialect/MemRef/IR/MemRef.h1
-rw-r--r--mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td2
-rw-r--r--mlir/include/mlir/Dialect/OpenACC/OpenACC.h4
-rw-r--r--mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td43
-rw-r--r--mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td94
-rw-r--r--mlir/include/mlir/Dialect/Shard/IR/ShardOps.td119
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h48
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc940
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td33
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h13
-rw-r--r--mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td7
-rw-r--r--mlir/include/mlir/Dialect/Vector/IR/VectorOps.td1
-rw-r--r--mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td175
-rw-r--r--mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td6
-rw-r--r--mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td65
-rw-r--r--mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td62
-rw-r--r--mlir/include/mlir/IR/Remarks.h140
-rw-r--r--mlir/include/mlir/Interfaces/CMakeLists.txt1
-rw-r--r--mlir/include/mlir/Interfaces/InferIntRangeInterface.h12
-rw-r--r--mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h145
-rw-r--r--mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td45
-rw-r--r--mlir/include/mlir/Remark/RemarkStreamer.h1
-rw-r--r--mlir/include/mlir/TableGen/CodeGenHelpers.h2
-rw-r--r--mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h71
-rw-r--r--mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h9
-rw-r--r--mlir/lib/Analysis/CMakeLists.txt2
-rw-r--r--mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp127
-rw-r--r--mlir/lib/Analysis/FlatLinearValueConstraints.cpp9
-rw-r--r--mlir/lib/Analysis/Presburger/Simplex.cpp2
-rw-r--r--mlir/lib/Bindings/Python/IRCore.cpp56
-rw-r--r--mlir/lib/Bindings/Python/Rewrite.cpp2
-rw-r--r--mlir/lib/CAPI/Transforms/Rewrite.cpp2
-rw-r--r--mlir/lib/Conversion/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp2
-rw-r--r--mlir/lib/Conversion/MathToROCDL/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp76
-rw-r--r--mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp4
-rw-r--r--mlir/lib/Conversion/MathToXeVM/CMakeLists.txt22
-rw-r--r--mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp167
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp7
-rw-r--r--mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp200
-rw-r--r--mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp8
-rw-r--r--mlir/lib/Dialect/Affine/IR/AffineOps.cpp150
-rw-r--r--mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp5
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp75
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp14
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp6
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp74
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp10
-rw-r--r--mlir/lib/Dialect/MemRef/IR/CMakeLists.txt3
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp95
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp2
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp281
-rw-r--r--mlir/lib/Dialect/Shard/IR/ShardOps.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp4
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp4
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp7
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp4
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp74
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformTypes.cpp6
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp105
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp62
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp52
-rw-r--r--mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp2
-rw-r--r--mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp141
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp146
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp122
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp7
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp3
-rw-r--r--mlir/lib/IR/Diagnostics.cpp6
-rw-r--r--mlir/lib/IR/MLIRContext.cpp17
-rw-r--r--mlir/lib/IR/Remarks.cpp57
-rw-r--r--mlir/lib/Interfaces/CMakeLists.txt16
-rw-r--r--mlir/lib/Interfaces/InferIntRangeInterface.cpp19
-rw-r--r--mlir/lib/Interfaces/InferStridedMetadataInterface.cpp36
-rw-r--r--mlir/lib/Remark/RemarkStreamer.cpp4
-rw-r--r--mlir/lib/TableGen/CodeGenHelpers.cpp4
-rw-r--r--mlir/lib/Target/LLVMIR/DebugImporter.cpp4
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleImport.cpp3
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleTranslation.cpp2
-rw-r--r--mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp5
-rw-r--r--mlir/lib/Target/Wasm/TranslateFromWasm.cpp461
-rw-r--r--mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp2
-rw-r--r--mlir/lib/Tools/PDLL/Parser/Parser.cpp2
-rw-r--r--mlir/lib/Tools/mlir-opt/MlirOptMain.cpp37
-rw-r--r--mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp2
-rw-r--r--mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp2
-rw-r--r--mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp11
-rw-r--r--mlir/python/CMakeLists.txt10
-rw-r--r--mlir/python/mlir/dialects/OpenACCOps.td14
-rw-r--r--mlir/python/mlir/dialects/gpu/__init__.py146
-rw-r--r--mlir/python/mlir/dialects/openacc.py5
-rw-r--r--mlir/test/Analysis/DataFlow/test-strided-metadata-range-analysis.mlir67
-rw-r--r--mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir76
-rw-r--r--mlir/test/Conversion/MathToXeVM/lit.local.cfg7
-rw-r--r--mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir155
-rw-r--r--mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir119
-rw-r--r--mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir12
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/dpas.mlir8
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir201
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir29
-rw-r--r--mlir/test/Dialect/Affine/canonicalize.mlir130
-rw-r--r--mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir6
-rw-r--r--mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir32
-rw-r--r--mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir38
-rw-r--r--mlir/test/Dialect/LLVMIR/canonicalize.mlir11
-rw-r--r--mlir/test/Dialect/LLVMIR/rocdl.mlir14
-rw-r--r--mlir/test/Dialect/Linalg/match-ops-interpreter.mlir14
-rw-r--r--mlir/test/Dialect/Linalg/one-shot-bufferize.mlir16
-rw-r--r--mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir181
-rw-r--r--mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir133
-rw-r--r--mlir/test/Dialect/MemRef/canonicalize.mlir30
-rw-r--r--mlir/test/Dialect/OpenACC/pointer-like-interface-alloc.mlir4
-rw-r--r--mlir/test/Dialect/OpenACC/recipe-populate-firstprivate.mlir102
-rw-r--r--mlir/test/Dialect/OpenACC/recipe-populate-private.mlir82
-rw-r--r--mlir/test/Dialect/SCF/one-shot-bufferize.mlir12
-rw-r--r--mlir/test/Dialect/Tensor/one-shot-bufferize.mlir41
-rw-r--r--mlir/test/Dialect/Tosa/tosa-attach-target.mlir8
-rw-r--r--mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir21
-rw-r--r--mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir20
-rw-r--r--mlir/test/Dialect/Vector/canonicalize/vector-step.mlir311
-rw-r--r--mlir/test/Dialect/Vector/vector-unroll-options.mlir68
-rw-r--r--mlir/test/Dialect/Vector/vector-warp-distribute.mlir19
-rw-r--r--mlir/test/Dialect/WasmSSA/custom_parser/global.mlir4
-rw-r--r--mlir/test/Dialect/WasmSSA/custom_parser/if.mlir8
-rw-r--r--mlir/test/Dialect/WasmSSA/custom_parser/import.mlir8
-rw-r--r--mlir/test/Dialect/WasmSSA/custom_parser/local.mlir12
-rw-r--r--mlir/test/Dialect/WasmSSA/custom_parser/memory.mlir8
-rw-r--r--mlir/test/Dialect/WasmSSA/custom_parser/table.mlir6
-rw-r--r--mlir/test/Dialect/WasmSSA/extend-invalid.mlir4
-rw-r--r--mlir/test/Dialect/WasmSSA/global-invalid.mlir12
-rw-r--r--mlir/test/Dialect/WasmSSA/locals-invalid.mlir4
-rw-r--r--mlir/test/Dialect/XeGPU/invalid.mlir36
-rw-r--r--mlir/test/Dialect/XeGPU/ops.mlir60
-rw-r--r--mlir/test/Integration/GPU/SPIRV/simple_add.mlir11
-rw-r--r--mlir/test/Pass/remark-final.mlir17
-rw-r--r--mlir/test/Target/LLVMIR/Import/debug-info.ll3
-rw-r--r--mlir/test/Target/LLVMIR/Import/function-attributes.ll6
-rw-r--r--mlir/test/Target/LLVMIR/llvmir.mlir11
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir12
-rw-r--r--mlir/test/Target/LLVMIR/nvvmir-invalid.mlir30
-rw-r--r--mlir/test/Target/LLVMIR/nvvmir.mlir4
-rw-r--r--mlir/test/Target/LLVMIR/rocdl.mlir14
-rw-r--r--mlir/test/Target/Wasm/abs.mlir4
-rw-r--r--mlir/test/Target/Wasm/add_div.mlir40
-rw-r--r--mlir/test/Target/Wasm/and.mlir4
-rw-r--r--mlir/test/Target/Wasm/block.mlir16
-rw-r--r--mlir/test/Target/Wasm/block_complete_type.mlir24
-rw-r--r--mlir/test/Target/Wasm/block_value_type.mlir19
-rw-r--r--mlir/test/Target/Wasm/branch_if.mlir29
-rw-r--r--mlir/test/Target/Wasm/call.mlir17
-rw-r--r--mlir/test/Target/Wasm/clz.mlir4
-rw-r--r--mlir/test/Target/Wasm/comparison_ops.mlir269
-rw-r--r--mlir/test/Target/Wasm/const.mlir8
-rw-r--r--mlir/test/Target/Wasm/convert.mlir85
-rw-r--r--mlir/test/Target/Wasm/copysign.mlir4
-rw-r--r--mlir/test/Target/Wasm/ctz.mlir4
-rw-r--r--mlir/test/Target/Wasm/demote.mlir15
-rw-r--r--mlir/test/Target/Wasm/div.mlir20
-rw-r--r--mlir/test/Target/Wasm/double_nested_loop.mlir63
-rw-r--r--mlir/test/Target/Wasm/empty_blocks_list_and_stack.mlir53
-rw-r--r--mlir/test/Target/Wasm/eq.mlir56
-rw-r--r--mlir/test/Target/Wasm/eqz.mlir21
-rw-r--r--mlir/test/Target/Wasm/extend.mlir69
-rw-r--r--mlir/test/Target/Wasm/global.mlir16
-rw-r--r--mlir/test/Target/Wasm/if.mlir112
-rw-r--r--mlir/test/Target/Wasm/import.mlir12
-rw-r--r--mlir/test/Target/Wasm/inputs/add_div.yaml.wasm50
-rw-r--r--mlir/test/Target/Wasm/inputs/block.yaml.wasm22
-rw-r--r--mlir/test/Target/Wasm/inputs/block_complete_type.yaml.wasm23
-rw-r--r--mlir/test/Target/Wasm/inputs/block_value_type.yaml.wasm18
-rw-r--r--mlir/test/Target/Wasm/inputs/branch_if.yaml.wasm18
-rw-r--r--mlir/test/Target/Wasm/inputs/call.yaml.wasm26
-rw-r--r--mlir/test/Target/Wasm/inputs/comparison_ops.yaml.wasm88
-rw-r--r--mlir/test/Target/Wasm/inputs/convert.yaml.wasm69
-rw-r--r--mlir/test/Target/Wasm/inputs/demote.yaml.wasm18
-rw-r--r--mlir/test/Target/Wasm/inputs/double_nested_loop.yaml.wasm19
-rw-r--r--mlir/test/Target/Wasm/inputs/empty_blocks_list_and_stack.yaml.wasm21
-rw-r--r--mlir/test/Target/Wasm/inputs/eq.yaml.wasm27
-rw-r--r--mlir/test/Target/Wasm/inputs/eqz.yaml.wasm29
-rw-r--r--mlir/test/Target/Wasm/inputs/extend.yaml.wasm40
-rw-r--r--mlir/test/Target/Wasm/inputs/if.yaml.wasm25
-rw-r--r--mlir/test/Target/Wasm/inputs/loop.yaml.wasm17
-rw-r--r--mlir/test/Target/Wasm/inputs/loop_with_inst.yaml.wasm20
-rw-r--r--mlir/test/Target/Wasm/inputs/ne.yaml.wasm27
-rw-r--r--mlir/test/Target/Wasm/inputs/promote.yaml.wasm18
-rw-r--r--mlir/test/Target/Wasm/inputs/reinterpret.yaml.wasm53
-rw-r--r--mlir/test/Target/Wasm/inputs/rounding.yaml.wasm37
-rw-r--r--mlir/test/Target/Wasm/inputs/wrap.yaml.wasm24
-rw-r--r--mlir/test/Target/Wasm/invalid_block_type_index.yaml28
-rw-r--r--mlir/test/Target/Wasm/local.mlir6
-rw-r--r--mlir/test/Target/Wasm/loop.mlir17
-rw-r--r--mlir/test/Target/Wasm/loop_with_inst.mlir33
-rw-r--r--mlir/test/Target/Wasm/max.mlir4
-rw-r--r--mlir/test/Target/Wasm/memory_min_eq_max.mlir2
-rw-r--r--mlir/test/Target/Wasm/memory_min_max.mlir2
-rw-r--r--mlir/test/Target/Wasm/memory_min_no_max.mlir2
-rw-r--r--mlir/test/Target/Wasm/min.mlir4
-rw-r--r--mlir/test/Target/Wasm/ne.mlir52
-rw-r--r--mlir/test/Target/Wasm/neg.mlir4
-rw-r--r--mlir/test/Target/Wasm/or.mlir4
-rw-r--r--mlir/test/Target/Wasm/popcnt.mlir4
-rw-r--r--mlir/test/Target/Wasm/promote.mlir14
-rw-r--r--mlir/test/Target/Wasm/reinterpret.mlir46
-rw-r--r--mlir/test/Target/Wasm/rem.mlir8
-rw-r--r--mlir/test/Target/Wasm/rotl.mlir4
-rw-r--r--mlir/test/Target/Wasm/rotr.mlir4
-rw-r--r--mlir/test/Target/Wasm/rounding.mlir50
-rw-r--r--mlir/test/Target/Wasm/shl.mlir4
-rw-r--r--mlir/test/Target/Wasm/shr_s.mlir4
-rw-r--r--mlir/test/Target/Wasm/shr_u.mlir4
-rw-r--r--mlir/test/Target/Wasm/sqrt.mlir4
-rw-r--r--mlir/test/Target/Wasm/sub.mlir8
-rw-r--r--mlir/test/Target/Wasm/wrap.mlir15
-rw-r--r--mlir/test/Target/Wasm/xor.mlir4
-rw-r--r--mlir/test/lib/Analysis/CMakeLists.txt1
-rw-r--r--mlir/test/lib/Analysis/DataFlow/TestStridedMetadataRangeAnalysis.cpp86
-rw-r--r--mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp26
-rw-r--r--mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp6
-rw-r--r--mlir/test/lib/Dialect/OpenACC/CMakeLists.txt1
-rw-r--r--mlir/test/lib/Dialect/OpenACC/TestOpenACC.cpp6
-rw-r--r--mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp8
-rw-r--r--mlir/test/lib/Dialect/OpenACC/TestRecipePopulate.cpp110
-rw-r--r--mlir/test/lib/Dialect/Test/TestAttrDefs.td17
-rw-r--r--mlir/test/lib/Dialect/Test/TestAttributes.cpp18
-rw-r--r--mlir/test/lib/Dialect/Test/TestAttributes.h1
-rw-r--r--mlir/test/lib/Dialect/Test/TestDialect.h1
-rw-r--r--mlir/test/lib/Dialect/Test/TestDialect.td5
-rw-r--r--mlir/test/lib/Dialect/Test/TestOpDefs.cpp44
-rw-r--r--mlir/test/lib/Dialect/Test/TestOps.td15
-rw-r--r--mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp6
-rw-r--r--mlir/test/lib/Pass/TestRemarksPass.cpp7
-rw-r--r--mlir/test/mlir-pdll/CodeGen/CPP/general.pdll2
-rw-r--r--mlir/test/mlir-tblgen/cpp-class-comments.td10
-rw-r--r--mlir/test/python/dialects/gpu/dialect.py93
-rw-r--r--mlir/test/python/dialects/openacc.py171
-rw-r--r--mlir/test/python/ir/operation.py28
-rw-r--r--mlir/tools/mlir-opt/mlir-opt.cpp2
-rw-r--r--mlir/tools/mlir-pdll/mlir-pdll.cpp2
-rw-r--r--mlir/tools/mlir-tblgen/EnumsGen.cpp21
-rw-r--r--mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp13
-rw-r--r--mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp2
-rw-r--r--mlir/tools/mlir-tblgen/OpInterfacesGen.cpp38
-rw-r--r--mlir/tools/mlir-tblgen/RewriterGen.cpp2
-rw-r--r--mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp35
-rw-r--r--mlir/unittests/Dialect/SparseTensor/MergerTest.cpp3
-rw-r--r--mlir/unittests/IR/RemarkTest.cpp80
-rwxr-xr-xmlir/utils/generate-test-checks.py48
262 files changed, 8974 insertions, 1412 deletions
diff --git a/mlir/docs/Canonicalization.md b/mlir/docs/Canonicalization.md
index 686e500..2622c08 100644
--- a/mlir/docs/Canonicalization.md
+++ b/mlir/docs/Canonicalization.md
@@ -55,7 +55,7 @@ Some important things to think about w.r.t. canonicalization patterns:
* It is always good to eliminate operations entirely when possible, e.g. by
folding known identities (like "x + 0 = x").
-* Pattens with expensive running time (i.e. have O(n) complexity) or
+* Patterns with expensive running time (i.e. have O(n) complexity) or
complicated cost models don't belong to canonicalization: since the
algorithm is executed iteratively until fixed-point we want patterns that
execute quickly (in particular their matching phase).
diff --git a/mlir/docs/Dialects/Shard.md b/mlir/docs/Dialects/Shard.md
index eb6ff61..573e888 100644
--- a/mlir/docs/Dialects/Shard.md
+++ b/mlir/docs/Dialects/Shard.md
@@ -27,9 +27,9 @@ the tensor is sharded - not specified manually.
### Device Groups
-Each collective operation runs within a group of devices. You define groups
-using the `grid` and `grid_axes` attributes, which describe how to slice the
-full device grid into smaller groups.
+Collective operations run within groups of devices, which are defined
+using the `grid` and `grid_axes` attributes. These describe
+how the full device grid is sliced into smaller groups.
Devices that have the same coordinates *outside* the listed `grid_axes` belong
to the same group.
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 2db1d84..fe42a20 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -352,7 +352,7 @@ typedef struct {
/// Create a rewrite pattern that matches the operation
/// with the given rootName, corresponding to mlir::OpRewritePattern.
-MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate(
+MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePatternCreate(
MlirStringRef rootName, unsigned benefit, MlirContext context,
MlirRewritePatternCallbacks callbacks, void *userData,
size_t nGeneratedNames, MlirStringRef *generatedNames);
diff --git a/mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h
new file mode 100644
index 0000000..72ac247
--- /dev/null
+++ b/mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h
@@ -0,0 +1,54 @@
+//===- StridedMetadataRange.h - Strided metadata range analysis -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H
+#define MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H
+
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Interfaces/InferStridedMetadataInterface.h"
+
+namespace mlir {
+namespace dataflow {
+
+/// This lattice element represents the strided metadata of an SSA value.
+class StridedMetadataRangeLattice : public Lattice<StridedMetadataRange> {
+public:
+ using Lattice::Lattice;
+};
+
+/// Strided metadata range analysis determines the strided metadata ranges of
+/// SSA values using operations that define `InferStridedMetadataInterface`.
+///
+/// This analysis depends on DeadCodeAnalysis, SparseConstantPropagation, and
+/// IntegerRangeAnalysis, and will be a silent no-op if the analyses are not
+/// loaded in the same solver context.
+class StridedMetadataRangeAnalysis
+ : public SparseForwardDataFlowAnalysis<StridedMetadataRangeLattice> {
+public:
+ StridedMetadataRangeAnalysis(DataFlowSolver &solver,
+ int32_t indexBitwidth = 64);
+
+ /// At an entry point, we cannot reason about strided metadata ranges unless
+ /// the type also encodes the data. For example, a memref with static layout.
+ void setToEntryState(StridedMetadataRangeLattice *lattice) override;
+
+ /// Visit an operation. Invoke the transfer function on each operation that
+ /// implements `InferStridedMetadataInterface`.
+ LogicalResult
+ visitOperation(Operation *op,
+ ArrayRef<const StridedMetadataRangeLattice *> operands,
+ ArrayRef<StridedMetadataRangeLattice *> results) override;
+
+private:
+ /// Index bitwidth to use when operating with the int-ranges.
+ int32_t indexBitwidth = 64;
+};
+} // namespace dataflow
+} // end namespace mlir
+
+#endif // MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H
diff --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
index 46573e79..60f1888 100644
--- a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
+++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
@@ -9,6 +9,7 @@
#define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/IR/PatternMatch.h"
#include <memory>
@@ -19,8 +20,11 @@ class Pass;
#include "mlir/Conversion/Passes.h.inc"
/// Populate the given list with patterns that convert from Math to ROCDL calls.
-void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter,
- RewritePatternSet &patterns);
+// `chipset` specifies the AMDGPU chipset to target. If `std::nullopt`,
+// none of the chipset dependent patterns are added.
+void populateMathToROCDLConversionPatterns(
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ std::optional<amdgpu::Chipset> chipset);
} // namespace mlir
#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
diff --git a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
new file mode 100644
index 0000000..91d3c92
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
@@ -0,0 +1,27 @@
+//===- MathToXeVM.h - Utils for converting Math to XeVM -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
+#define MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMATHTOXEVM
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Math to XeVM calls.
+void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns,
+ bool convertArith);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index da061b2..40d866e 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -49,6 +49,7 @@
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 3c18ecc..9f76f5d 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -778,6 +778,10 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
let summary = "Convert Math dialect to ROCDL library calls";
let description = [{
This pass converts supported Math ops to ROCDL library calls.
+
+ The chipset option specifies the target AMDGPU architecture. If the chipset
+ is empty, none of the chipset-dependent patterns are added, and the pass
+ will not attempt to parse the chipset.
}];
let dependentDialects = [
"arith::ArithDialect",
@@ -785,6 +789,9 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
"ROCDL::ROCDLDialect",
"vector::VectorDialect",
];
+ let options = [Option<"chipset", "chipset", "std::string",
+ /*default=*/"\"\"",
+ "Chipset that these operations will run on">];
}
//===----------------------------------------------------------------------===//
@@ -797,6 +804,31 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> {
}
//===----------------------------------------------------------------------===//
+// MathToXeVM
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> {
+ let summary =
+ "Convert (fast) math operations to native XeVM/SPIRV equivalents";
+ let description = [{
+ This pass converts supported math ops marked with the `afn` fastmath flag
+ to function calls for OpenCL `native_` math intrinsics: These intrinsics
+ are typically mapped directly to native device instructions, often resulting
+ in better performance. However, the precision/error of these intrinsics
+ are implementation-defined, and thus math ops are only converted when they
+ have the `afn` fastmath flag enabled.
+ }];
+ let options = [Option<
+ "convertArith", "convert-arith", "bool", /*default=*/"true",
+ "Convert supported Arith ops (e.g. arith.divf) as well.">];
+ let dependentDialects = [
+ "arith::ArithDialect",
+ "xevm::XeVMDialect",
+ "LLVM::LLVMDialect",
+ ];
+}
+
+//===----------------------------------------------------------------------===//
// MathToEmitC
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
index 9b59af7..830c394 100644
--- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h
+++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
@@ -61,7 +61,7 @@ LogicalResult loopUnrollUpToFactor(AffineForOp forOp, uint64_t unrollFactor);
/// Returns true if `loops` is a perfectly nested loop nest, where loops appear
/// in it from outermost to innermost.
-bool LLVM_ATTRIBUTE_UNUSED isPerfectlyNested(ArrayRef<AffineForOp> loops);
+[[maybe_unused]] bool isPerfectlyNested(ArrayRef<AffineForOp> loops);
/// Get perfectly nested sequence of loops starting at root of loop nest
/// (the first op being another AffineFor, and the second op - a terminator).
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 9753dca..d0811a2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -973,6 +973,7 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector",
custom<ShuffleType>(ref(type($v1)), type($res), ref($mask))
}];
+ let hasFolder = 1;
let hasVerifier = 1;
string llvmInstName = "ShuffleVector";
@@ -1985,6 +1986,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
OptionalAttr<StrAttr>:$instrument_function_exit,
OptionalAttr<UnitAttr>:$no_inline,
OptionalAttr<UnitAttr>:$always_inline,
+ OptionalAttr<UnitAttr>:$inline_hint,
OptionalAttr<UnitAttr>:$no_unwind,
OptionalAttr<UnitAttr>:$will_return,
OptionalAttr<UnitAttr>:$optimize_none,
@@ -2037,6 +2039,9 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
/// Returns true if the `always_inline` attribute is set, false otherwise.
bool isAlwaysInline() { return bool(getAlwaysInlineAttr()); }
+ /// Returns true if the `inline_hint` attribute is set, false otherwise.
+ bool isInlineHint() { return bool(getInlineHintAttr()); }
+
/// Returns true if the `optimize_none` attribute is set, false otherwise.
bool isOptimizeNone() { return bool(getOptimizeNoneAttr()); }
}];
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 89fbeb7..d959464 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -263,6 +263,7 @@ class NVVM_PureSpecialRangeableRegisterOp<string mnemonic, list<Trait> traits =
let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
let mlirBuilder = baseMlirBuilder # importRangeRetAttrCode # baseMlirBuilderCoda;
+ let hasVerifier = 1;
// Backwards-compatibility builder for an unspecified range.
let builders = [
@@ -279,6 +280,11 @@ class NVVM_PureSpecialRangeableRegisterOp<string mnemonic, list<Trait> traits =
SetIntRangeFn setResultRanges) {
nvvmInferResultRanges(getOperation(), getResult(), argRanges, setResultRanges);
}
+
+ // Verify the range attribute satisfies LLVM ConstantRange constructor requirements.
+ ::llvm::LogicalResult $cppClass::verify() {
+ return verifyConstantRangeAttr(getOperation(), getRange());
+ }
}];
}
@@ -1655,6 +1661,40 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> {
}];
}
+def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
+ let summary = "Convert a pair of float inputs to f4x2";
+ let description = [{
+ This Op converts each of the given float inputs to the specified fp4 type.
+ The result `dst` is returned as an i8 type where the converted values are
+ packed such that the value converted from `a` is stored in the upper 4 bits
+ of `dst` and the value converted from `b` is stored in the lower 4 bits of
+ `dst`.
+ The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+ }];
+
+ let results = (outs I8:$dst);
+ let arguments = (ins F32:$a, F32:$b,
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$dstTy);
+ let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`";
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ static mlir::NVVM::IDArgPair
+ getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
+ LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
+ }];
+
+ string llvmBuilder = [{
+ auto [intId, args] = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder);
+ llvm::Value *packedI16 = createIntrinsicCall(builder, intId, args);
+ $dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext()));
+ }];
+}
+
def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
let summary = "Convert a pair of float inputs to f6x2";
let description = [{
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 6925cec..68f31e6 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -412,6 +412,32 @@ def ROCDL_WaitExpcntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.expcnt", [], 0, [0],
let assemblyFormat = "$count attr-dict";
}
+def ROCDL_WaitAsynccntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.asynccnt", [], 0, [0], ["count"]>,
+ Arguments<(ins I16Attr:$count)> {
+ let summary = "Wait until ASYNCCNT is less than or equal to `count`";
+ let description = [{
+ Wait for the counter specified to be less-than or equal-to the `count`
+ before continuing.
+
+ Available on gfx1250+.
+ }];
+ let results = (outs);
+ let assemblyFormat = "$count attr-dict";
+}
+
+def ROCDL_WaitTensorcntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.tensorcnt", [], 0, [0], ["count"]>,
+ Arguments<(ins I16Attr:$count)> {
+ let summary = "Wait until TENSORCNT is less than or equal to `count`";
+ let description = [{
+ Wait for the counter specified to be less-than or equal-to the `count`
+ before continuing.
+
+ Available on gfx1250+.
+ }];
+ let results = (outs);
+ let assemblyFormat = "$count attr-dict";
+}
+
def ROCDL_SetPrioOp : ROCDL_ConcreteNonMemIntrOp<"s.setprio", [], 0, [0], ["priority"]>,
Arguments<(ins I16Attr:$priority)> {
let assemblyFormat = "$priority attr-dict";
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index 30f33ed..69447f7 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -17,6 +17,7 @@
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Interfaces/InferStridedMetadataInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/MemOpInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 89bd0f1..b39207f 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -14,6 +14,7 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
+include "mlir/Interfaces/InferStridedMetadataInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/MemOpInterfaces.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
@@ -2085,6 +2086,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<InferStridedMetadataOpInterface>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
AttrSizedOperandSegments,
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index 8f87235..b8aa497 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -183,6 +183,10 @@ static constexpr StringLiteral getRoutineInfoAttrName() {
return StringLiteral("acc.routine_info");
}
+static constexpr StringLiteral getVarNameAttrName() {
+ return VarNameAttr::name;
+}
+
static constexpr StringLiteral getCombinedConstructsAttrName() {
return CombinedConstructsTypeAttr::name;
}
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 77e833f..1eaa21b4 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -415,6 +415,13 @@ def OpenACC_ConstructResource : Resource<"::mlir::acc::ConstructResource">;
// Define a resource for the OpenACC current device setting.
def OpenACC_CurrentDeviceIdResource : Resource<"::mlir::acc::CurrentDeviceIdResource">;
+// Attribute for saving variable names - this can be attached to non-acc-dialect
+// operations in order to ensure the name is preserved.
+def OpenACC_VarNameAttr : OpenACC_Attr<"VarName", "var_name"> {
+ let parameters = (ins StringRefParameter<"">:$name);
+ let assemblyFormat = "`<` $name `>`";
+}
+
// Used for data specification in data clauses (2.7.1).
// Either (or both) extent and upperbound must be specified.
def OpenACC_DataBoundsOp : OpenACC_Op<"bounds",
@@ -1316,6 +1323,24 @@ def OpenACC_PrivateRecipeOp
}];
let hasRegionVerifier = 1;
+
+ let extraClassDeclaration = [{
+ /// Creates a PrivateRecipeOp and populates its regions based on the
+ /// variable type as long as the type implements MappableType or
+ /// PointerLikeType interface. If a type implements both, the MappableType
+ /// API will be preferred. Returns std::nullopt if the recipe cannot be
+ /// created or populated. The builder's current insertion point will be used
+ /// and it must be a valid place for this operation to be inserted. The
+ /// `recipeName` must be a unique name to prevent "redefinition of symbol"
+ /// IR errors.
+ static std::optional<PrivateRecipeOp> createAndPopulate(
+ ::mlir::OpBuilder &builder,
+ ::mlir::Location loc,
+ ::llvm::StringRef recipeName,
+ ::mlir::Type varType,
+ ::llvm::StringRef varName = "",
+ ::mlir::ValueRange bounds = {});
+ }];
}
//===----------------------------------------------------------------------===//
@@ -1410,6 +1435,24 @@ def OpenACC_FirstprivateRecipeOp
}];
let hasRegionVerifier = 1;
+
+ let extraClassDeclaration = [{
+ /// Creates a FirstprivateRecipeOp and populates its regions based on the
+ /// variable type as long as the type implements MappableType or
+ /// PointerLikeType interface. If a type implements both, the MappableType
+ /// API will be preferred. Returns std::nullopt if the recipe cannot be
+ /// created or populated. The builder's current insertion point will be used
+ /// and it must be a valid place for this operation to be inserted. The
+ /// `recipeName` must be a unique name to prevent "redefinition of symbol"
+ /// IR errors.
+ static std::optional<FirstprivateRecipeOp> createAndPopulate(
+ ::mlir::OpBuilder &builder,
+ ::mlir::Location loc,
+ ::llvm::StringRef recipeName,
+ ::mlir::Type varType,
+ ::llvm::StringRef varName = "",
+ ::mlir::ValueRange bounds = {});
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
index 0d16255..93e9e3d 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
@@ -73,17 +73,31 @@ def OpenACC_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> {
InterfaceMethod<
/*description=*/[{
Generates allocation operations for the pointer-like type. It will create
- an allocate that produces memory space for an instance of the current type.
+ an allocate operation that produces memory space for an instance of the
+ current type.
The `varName` parameter is optional and can be used to provide a name
- for the allocated variable. If the current type is represented
- in a way that it does not capture the pointee type, `varType` must be
- passed in to provide the necessary type information.
+ for the allocated variable. When provided, it must be used by the
+ implementation; and if the implementing dialect does not have its own
+ way to save it, the discardable `acc.var_name` attribute from the acc
+ dialect will be used.
+
+ If the current type is represented in a way that it does not capture
+ the pointee type, `varType` must be passed in to provide the necessary
+ type information.
The `originalVar` parameter is optional but enables support for dynamic
types (e.g., dynamic memrefs). When provided, implementations can extract
runtime dimension information from the original variable to create
- allocations with matching dynamic sizes.
+ allocations with matching dynamic sizes. When generating recipe bodies,
+ `originalVar` should be the block argument representing the original
+ variable in the recipe region.
+
+ The `needsFree` output parameter indicates whether the allocated memory
+ requires explicit deallocation. Implementations should set this to true
+ for heap allocations that need a matching deallocation operation (e.g.,
+ alloc) and false for stack-based allocations (e.g., alloca). During
+ recipe generation, this determines whether a destroy region is created.
Returns a Value representing the result of the allocation. If no value
is returned, it means the allocation was not successfully generated.
@@ -94,7 +108,8 @@ def OpenACC_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> {
"::mlir::Location":$loc,
"::llvm::StringRef":$varName,
"::mlir::Type":$varType,
- "::mlir::Value":$originalVar),
+ "::mlir::Value":$originalVar,
+ "bool &":$needsFree),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return {};
@@ -102,23 +117,34 @@ def OpenACC_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> {
>,
InterfaceMethod<
/*description=*/[{
- Generates deallocation operations for the pointer-like type. It deallocates
- the instance provided.
+ Generates deallocation operations for the pointer-like type.
- The `varPtr` parameter is required and must represent an instance that was
- previously allocated. If the current type is represented in a way that it
- does not capture the pointee type, `varType` must be passed in to provide
- the necessary type information. Nothing is generated in case the allocate
- is `alloca`-like.
+ The `varToFree` parameter is required and must represent an instance
+ that was previously allocated. When generating recipe bodies, this
+ should be the block argument representing the private variable in the
+ destroy region.
+
+ The `allocRes` parameter is optional and provides the result of the
+ corresponding allocation from the init region. This allows implementations
+ to inspect the allocation operation to determine the appropriate
+ deallocation strategy. This is necessary because in recipe generation,
+ the allocation and deallocation occur in separate regions. Dialects that
+ use only one allocation type or can determine deallocation from type
+ information alone may ignore this parameter.
+
+ The `varType` parameter must be provided if the current type does not
+ capture the pointee type information. No deallocation is generated for
+ stack-based allocations (e.g., alloca).
- Returns true if deallocation was successfully generated or successfully
- deemed as not needed to be generated, false otherwise.
+ Returns true if deallocation was successfully generated or determined to
+ be unnecessary, false otherwise.
}],
/*retTy=*/"bool",
/*methodName=*/"genFree",
/*args=*/(ins "::mlir::OpBuilder &":$builder,
"::mlir::Location":$loc,
- "::mlir::TypedValue<::mlir::acc::PointerLikeType>":$varPtr,
+ "::mlir::TypedValue<::mlir::acc::PointerLikeType>":$varToFree,
+ "::mlir::Value":$allocRes,
"::mlir::Type":$varType),
/*methodBody=*/"",
/*defaultImplementation=*/[{
@@ -274,6 +300,14 @@ def OpenACC_MappableTypeInterface : TypeInterface<"MappableType"> {
The `initVal` can be empty - it is primarily needed for reductions
to ensure the variable is also initialized with appropriate value.
+ The `needsDestroy` out-parameter is set by implementations to indicate
+ that destruction code must be generated after the returned private
+ variable usages, typically in the destroy region of recipe operations
+ (for example, when heap allocations or temporaries requiring cleanup
+ are created during initialization). When `needsDestroy` is set, callers
+ should invoke `generatePrivateDestroy` in the recipe's destroy region
+ with the privatized value returned by this method.
+
If the return value is empty, it means that recipe body was not
successfully generated.
}],
@@ -284,12 +318,38 @@ def OpenACC_MappableTypeInterface : TypeInterface<"MappableType"> {
"::mlir::TypedValue<::mlir::acc::MappableType>":$var,
"::llvm::StringRef":$varName,
"::mlir::ValueRange":$extents,
- "::mlir::Value":$initVal),
+ "::mlir::Value":$initVal,
+ "bool &":$needsDestroy),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return {};
}]
>,
+ InterfaceMethod<
+ /*description=*/[{
+ Generates destruction operations for a privatized value previously
+ produced by `generatePrivateInit`. This is typically inserted in a
+ recipe's destroy region, after all uses of the privatized value.
+
+ The `privatized` value is the SSA value yielded by the init region
+ (and passed as the privatized argument to the destroy region).
+ Implementations should free heap-allocated storage or perform any
+ cleanup required for the given type. If no destruction is required,
+ this function should be a no-op and return `true`.
+
+ Returns true if destruction was successfully generated or deemed not
+ necessary, false otherwise.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"generatePrivateDestroy",
+ /*args=*/(ins "::mlir::OpBuilder &":$builder,
+ "::mlir::Location":$loc,
+ "::mlir::Value":$privatized),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return true;
+ }]
+ >,
];
}
diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
index 29b384f..5e68f75e 100644
--- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
@@ -174,7 +174,7 @@ def Shard_NeighborsLinearIndicesOp : Shard_Op<"neighbors_linear_indices", [
```
The above returns two indices, `633` and `693`, which correspond to the
index of the previous process `(1, 1, 3)`, and the next process
- `(1, 3, 3) along the split axis `1`.
+ `(1, 3, 3)` along the split axis `1`.
A negative value is returned if there is no neighbor in the respective
direction along the given `split_axes`.
@@ -494,7 +494,9 @@ def Shard_AllGatherOp : Shard_CollectiveCommunicationOpBase<"all_gather", [
]> {
let summary = "All-gather over a device grid.";
let description = [{
- Gathers along the `gather_axis` tensor axis.
+ Concatenates all tensor slices from a device group defined by `grid_axes` along
+ the tensor dimension `gather_axis` and replicates the result across all devices
+ in the group.
Example:
```mlir
@@ -546,10 +548,13 @@ def Shard_AllReduceOp : Shard_CollectiveCommunicationOpBase<"all_reduce", [
SameOperandsAndResultShape]> {
let summary = "All-reduce over a device grid.";
let description = [{
- The accumulation element type is specified by the result type and
- it does not need to match the input element type.
- The input element is converted to the result element type before
- performing the reduction.
+ Reduces the input tensor across all devices within the groups defined by
+ `grid_axes`, using the specified reduction method. The operation performs an
+ element-wise reduction over the tensor slices from all devices in each group.
+ Each device in a group receives a replicated copy of the reduction result.
+ The accumulation element type is determined by the result type and does not
+ need to match the input element type. Before performing the reduction, each
+ input element is converted to the result element type.
Attributes:
`reduction`: Indicates the reduction method.
@@ -583,13 +588,15 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [
SameOperandsAndResultElementType,
SameOperandsAndResultRank
]> {
- let summary = "All-slice over a device grid. This is the inverse of all-gather.";
+ let summary = "All-slice over a device grid.";
let description = [{
- Slice along the `slice_axis` tensor axis.
- This operation can be thought of as the inverse of all-gather.
- Technically, it is not required that all processes have the same input tensor.
- Each process will slice a piece of its local tensor based on its in-group device index.
- The operation does not communicate data between devices.
+ Within each device group defined by `grid_axes`, slices the input tensor along
+ the `slice_axis` dimension. It can be viewed as the inverse of an all-gather if
+ the input data is replicated along the `slice_axis`.
+ Each process simply crops its local data to the slice corresponding to its
+ in-group device index.
+ Notice: `AllSliceOp` does not involve any communication between devices and
+ devices within a group may not have replicated input data.
Example:
```mlir
@@ -610,7 +617,7 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [
```
Result:
```
- gather tensor
+ slice tensor
axis 1
------------>
+-------+-------+
@@ -646,8 +653,10 @@ def Shard_AllToAllOp : Shard_CollectiveCommunicationOpBase<"all_to_all", [
SameOperandsAndResultRank]> {
let summary = "All-to-all over a device grid.";
let description = [{
- Performs an all-to-all on tensor pieces split along `split_axis`.
- The resulting pieces are concatenated along `concat_axis` on ech device.
+ Each participant logically splits its input along split_axis,
+ then scatters the resulting pieces across the group defined by `grid_axes`.
+ After receiving data pieces from other participants' scatters,
+ it concatenates them along concat_axis to produce the final result.
Example:
```
@@ -702,10 +711,9 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
]> {
let summary = "Broadcast over a device grid.";
let description = [{
- Broadcast the tensor on `root` to all devices in each respective group.
- The operation broadcasts along grid axes `grid_axes`.
- The `root` device specifies the in-group multi-index that is broadcast to
- all other devices in the group.
+ Copies the input tensor on `root` to all devices in each group defined by
+ `grid_axes`. The `root` device is defined by its in-group multi-index.
+ The contents of input tensors on non-root devices are ignored.
Example:
```
@@ -722,7 +730,7 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
+-------+-------+ | broadcast
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0
+-------+-------+ ↓
- device (1, 0) -> | | | <- device (1, 1)
+ device (1, 0) -> | * * | * * | <- device (1, 1)
+-------+-------+
```
@@ -758,11 +766,10 @@ def Shard_GatherOp : Shard_CollectiveCommunicationOpBase<"gather", [
]> {
let summary = "Gather over a device grid.";
let description = [{
- Gathers on device `root` along the `gather_axis` tensor axis.
- `root` specifies the coordinates of a device along `grid_axes`.
- It uniquely identifies the root device for each device group.
- The result tensor on non-root devices is undefined.
- Using it will result in undefined behavior.
+ Concatenates all tensor slices from a device group defined by `grid_axes` along
+ the tensor dimension `gather_axis` and returns the resulting tensor on each
+ `root` device. The result on all other (non-root) devices is undefined.
+ The `root` device is defined by its in-group multi-index.
Example:
```mlir
@@ -821,7 +828,9 @@ def Shard_RecvOp : Shard_CollectiveCommunicationOpBase<"recv", [
]> {
let summary = "Send over a device grid.";
let description = [{
- Receive from a device within a device group.
+ Receive tensor from device `source`, which is defined by its in-group
+ multi-index. The groups are defined by `grid_axes`.
+ The content of input tensor is ignored.
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
@@ -845,13 +854,15 @@ def Shard_ReduceOp : Shard_CollectiveCommunicationOpBase<"reduce", [
]> {
let summary = "Reduce over a device grid.";
let description = [{
- Reduces on device `root` within each device group.
- `root` specifies the coordinates of a device along `grid_axes`.
- It uniquely identifies the root device within its device group.
- The accumulation element type is specified by the result type and
- it does not need to match the input element type.
- The input element is converted to the result element type before
- performing the reduction.
+ Reduces the input tensor across all devices within the groups defined by
+ `grid_axes`, using the specified reduction method. The operation performs an
+ element-wise reduction over the tensor slices from all devices in each group.
+ The reduction result will be returned on the `root` device of each group.
+ It is undefined on all other (non-root) devices.
+ The `root` device is defined by its in-group multi-index.
+ The accumulation element type is determined by the result type and does not
+ need to match the input element type. Before performing the reduction, each
+ input element is converted to the result element type.
Attributes:
`reduction`: Indicates the reduction method.
@@ -886,16 +897,18 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
SameOperandsAndResultRank]> {
let summary = "Reduce-scatter over a device grid.";
let description = [{
- After the reduction, the result is scattered within each device group.
- The tensor is split along `scatter_axis` and the pieces distributed
- across the device group.
+ Reduces the input tensor across all devices within the groups defined by
+ `grid_axes` using the specified reduction method. The reduction is performed
+ element-wise across the tensor pieces from all devices in the group.
+ After reduction, the reduction result is scattered (split and distributed)
+ across the device group along `scatter_axis`.
Example:
```
shard.grid @grid0(shape = 2x2)
...
%1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1]
reduction = <max> scatter_axis = 0
- : tensor<3x4xf32> -> tensor<1x4xf64>
+ : tensor<2x2xf32> -> tensor<1x2xf64>
```
Input:
```
@@ -916,13 +929,13 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
Result:
```
+-------+
- | 6 8 | <- devices (0, 0)
+ | 5 6 | <- devices (0, 0)
+-------+
- | 10 12 | <- devices (0, 1)
+ | 7 8 | <- devices (0, 1)
+-------+
- | 22 24 | <- devices (1, 0)
+ | 13 14 | <- devices (1, 0)
+-------+
- | 26 28 | <- devices (1, 1)
+ | 15 16 | <- devices (1, 1)
+-------+
```
}];
@@ -950,8 +963,10 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
]> {
let summary = "Scatter over a device grid.";
let description = [{
- For each device group split the input tensor on the `root` device along
- axis `scatter_axis` and scatter the parts across the group devices.
+ For each device group defined by `grid_axes`, the input tensor on the `root`
+ device is split along axis `scatter_axis` and distributed across the group.
+ The content of the input on all other (non-root) devices is ignored.
+ The `root` device is defined by its in-group multi-index.
Example:
```
@@ -968,8 +983,8 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
(0, 1)
↓
+-------+-------+ | scatter tensor
- device (0, 0) -> | | | | axis 0
- | | | ↓
+ device (0, 0) -> | * * | * * | | axis 0
+ | * * | * * | ↓
+-------+-------+
device (1, 0) -> | 1 2 | 5 6 |
| 3 4 | 7 8 |
@@ -1018,7 +1033,8 @@ def Shard_SendOp : Shard_CollectiveCommunicationOpBase<"send", [
]> {
let summary = "Send over a device grid.";
let description = [{
- Send from one device to another within a device group.
+ Send input tensor to device `destination`, which is defined by its in-group
+ multi-index. The groups are defined by `grid_axes`.
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
@@ -1043,12 +1059,11 @@ def Shard_ShiftOp : Shard_CollectiveCommunicationOpBase<"shift", [
]> {
let summary = "Shift over a device grid.";
let description = [{
- Within each device group shift along grid axis `shift_axis` by an offset
- `offset`.
- The result on devices that do not have a corresponding source is undefined.
- `shift_axis` must be one of `grid_axes`.
- If the `rotate` attribute is present,
- instead of a shift a rotation is done.
+ Within each device group defined by `grid_axes`, shifts input tensors along the
+ device grid's axis `shift_axis` by the specified offset. The `shift_axis` must
+ be one of the `grid_axes`. If the `rotate` attribute is set, the shift is circular.
+ That is, the offset wraps around according to the group size along `shift_axis`.
+ Otherwise, the results on devices without a corresponding source are undefined.
Example:
```
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
index 10491f6..4ecf03c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
@@ -50,28 +50,63 @@ TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
/// returned by getDefaultTargetEnv() if not provided.
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);
+/// A thin wrapper around the SpecificationVersion enum to represent
+/// and provide utilities around the TOSA specification version.
+class TosaSpecificationVersion {
+public:
+ TosaSpecificationVersion(uint32_t major, uint32_t minor)
+ : majorVersion(major), minorVersion(minor) {}
+ TosaSpecificationVersion(SpecificationVersion version)
+ : TosaSpecificationVersion(fromVersionEnum(version)) {}
+
+ bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const {
+ return this->majorVersion == baseVersion.majorVersion &&
+ this->minorVersion >= baseVersion.minorVersion;
+ }
+
+ uint32_t getMajor() const { return majorVersion; }
+ uint32_t getMinor() const { return minorVersion; }
+
+private:
+ uint32_t majorVersion = 0;
+ uint32_t minorVersion = 0;
+
+ static TosaSpecificationVersion
+ fromVersionEnum(SpecificationVersion version) {
+ switch (version) {
+ case SpecificationVersion::V_1_0:
+ return TosaSpecificationVersion(1, 0);
+ case SpecificationVersion::V_1_1_DRAFT:
+ return TosaSpecificationVersion(1, 1);
+ }
+ llvm_unreachable("Unknown TOSA version");
+ }
+};
+
+llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version);
+
/// This class represents the capability enabled in the target implementation
/// such as profile, extension, and level. It's a wrapper class around
/// tosa::TargetEnvAttr.
class TargetEnv {
public:
TargetEnv() {}
- explicit TargetEnv(Level level, const ArrayRef<Profile> &profiles,
+ explicit TargetEnv(SpecificationVersion specificationVersion, Level level,
+ const ArrayRef<Profile> &profiles,
const ArrayRef<Extension> &extensions)
- : level(level) {
+ : specificationVersion(specificationVersion), level(level) {
enabledProfiles.insert_range(profiles);
enabledExtensions.insert_range(extensions);
}
explicit TargetEnv(TargetEnvAttr targetAttr)
- : TargetEnv(targetAttr.getLevel(), targetAttr.getProfiles(),
- targetAttr.getExtensions()) {}
+ : TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(),
+ targetAttr.getProfiles(), targetAttr.getExtensions()) {}
void addProfile(Profile p) { enabledProfiles.insert(p); }
void addExtension(Extension e) { enabledExtensions.insert(e); }
- // TODO implement the following utilities.
- // Version getSpecVersion() const;
+ SpecificationVersion getSpecVersion() const { return specificationVersion; }
TosaLevel getLevel() const {
if (level == Level::eightK)
@@ -105,6 +140,7 @@ public:
}
private:
+ SpecificationVersion specificationVersion;
Level level;
llvm::SmallSet<Profile, 3> enabledProfiles;
llvm::SmallSet<Extension, 13> enabledExtensions;
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index 1f718ac..c1b5e78 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -2,441 +2,779 @@
// `tools/genspec.py` in https://git.mlplatform.org/tosa/specification.git
profileComplianceMap = {
{"tosa.argmax",
- {{{Profile::pro_int}, {{i8T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, i32T}, {fp32T, i32T}}}}},
+ {{{Profile::pro_int}, {{{i8T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, i32T}, SpecificationVersion::V_1_0},
+ {{fp32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.avg_pool2d",
- {{{Profile::pro_int}, {{i8T, i8T, i8T, i32T, i8T}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i8T, i32T, i8T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp32T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.conv2d",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.conv3d",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.depthwise_conv2d",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.matmul",
- {{{Profile::pro_int}, {{i8T, i8T, i8T, i8T, i32T}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i8T, i8T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp32T},
- {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.max_pool2d",
- {{{Profile::pro_int}, {{i8T, i8T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int}, {{{i8T, i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.transpose_conv2d",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.clamp",
- {{{Profile::pro_int}, {{i8T, i8T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.erf", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.sigmoid", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.tanh", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int}, {{{i8T, i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.erf",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sigmoid",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.tanh",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.add",
- {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.arithmetic_right_shift",
{{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.bitwise_and",
{{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.bitwise_or",
{{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.bitwise_xor",
{{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.intdiv",
- {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.logical_and",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.logical_left_shift",
{{{Profile::pro_int, Profile::pro_fp},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
anyOf}}},
{"tosa.logical_right_shift",
{{{Profile::pro_int, Profile::pro_fp},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
anyOf}}},
{"tosa.logical_or",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.logical_xor",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.maximum",
- {{{Profile::pro_int}, {{i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{Profile::pro_int}, {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.minimum",
- {{{Profile::pro_int}, {{i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{Profile::pro_int}, {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.mul",
- {{{Profile::pro_int}, {{i8T, i8T, i32T}, {i16T, i16T, i32T}}},
- {{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_int, Profile::pro_fp},
+ {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.pow",
- {{{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.sub",
- {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
- {"tosa.table", {{{Profile::pro_int}, {{i8T, i8T, i8T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.table",
+ {{{Profile::pro_int}, {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}}}}},
{"tosa.abs",
- {{{Profile::pro_int}, {{i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.bitwise_not",
- {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}}},
- {"tosa.ceil", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.clz", {{{Profile::pro_int}, {{i32T, i32T}}}}},
- {"tosa.cos", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.exp", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.floor", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.log", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.ceil",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.clz",
+ {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.cos",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.exp",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.floor",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.log",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.logical_not",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.negate",
{{{Profile::pro_int},
- {{i8T, i8T, i8T, i8T},
- {i16T, i16T, i16T, i16T},
- {i32T, i32T, i32T, i32T}}},
+ {{{i8T, i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reciprocal",
- {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.rsqrt", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.sin", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.rsqrt",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sin",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.select",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.equal",
- {{{Profile::pro_int}, {{i32T, i32T, boolT}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
+ {{{Profile::pro_int},
+ {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}},
{"tosa.greater",
- {{{Profile::pro_int}, {{i32T, i32T, boolT}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
+ {{{Profile::pro_int},
+ {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}},
{"tosa.greater_equal",
- {{{Profile::pro_int}, {{i32T, i32T, boolT}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
+ {{{Profile::pro_int},
+ {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}},
{"tosa.reduce_all",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.reduce_any",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.reduce_max",
- {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reduce_min",
- {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reduce_product",
- {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reduce_sum",
- {{{Profile::pro_int}, {{i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.concat",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.pad",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reshape",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reverse",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.slice",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.tile",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.transpose",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.gather",
{{{Profile::pro_int},
- {{i8T, i32T, i8T}, {i16T, i32T, i16T}, {i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, i32T, fp16T}, {fp32T, i32T, fp32T}}}}},
+ {{{i8T, i32T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i32T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, i32T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, i32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.scatter",
{{{Profile::pro_int},
- {{i8T, i32T, i8T, i8T},
- {i16T, i32T, i16T, i16T},
- {i32T, i32T, i32T, i32T}}},
+ {{{i8T, i32T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i32T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, i32T, fp16T, fp16T}, {fp32T, i32T, fp32T, fp32T}}}}},
+ {{{fp16T, i32T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, i32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.resize",
- {{{Profile::pro_int}, {{i8T, i32T}, {i8T, i8T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int},
+ {{{i8T, i32T}, SpecificationVersion::V_1_0},
+ {{i8T, i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.cast",
{{{Profile::pro_int},
- {{boolT, i8T},
- {boolT, i16T},
- {boolT, i32T},
- {i8T, boolT},
- {i8T, i16T},
- {i8T, i32T},
- {i16T, boolT},
- {i16T, i8T},
- {i16T, i32T},
- {i32T, boolT},
- {i32T, i8T},
- {i32T, i16T}}},
- {{Profile::pro_fp},
- {{i8T, fp16T},
- {i8T, fp32T},
- {i16T, fp16T},
- {i16T, fp32T},
- {i32T, fp16T},
- {i32T, fp32T},
- {fp16T, i8T},
- {fp16T, i16T},
- {fp16T, i32T},
- {fp16T, fp32T},
- {fp32T, i8T},
- {fp32T, i16T},
- {fp32T, i32T},
- {fp32T, fp16T}}}}},
+ {{{boolT, i8T}, SpecificationVersion::V_1_0},
+ {{boolT, i16T}, SpecificationVersion::V_1_0},
+ {{boolT, i32T}, SpecificationVersion::V_1_0},
+ {{i8T, boolT}, SpecificationVersion::V_1_0},
+ {{i8T, i16T}, SpecificationVersion::V_1_0},
+ {{i8T, i32T}, SpecificationVersion::V_1_0},
+ {{i16T, boolT}, SpecificationVersion::V_1_0},
+ {{i16T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i32T}, SpecificationVersion::V_1_0},
+ {{i32T, boolT}, SpecificationVersion::V_1_0},
+ {{i32T, i8T}, SpecificationVersion::V_1_0},
+ {{i32T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{i8T, fp16T}, SpecificationVersion::V_1_0},
+ {{i8T, fp32T}, SpecificationVersion::V_1_0},
+ {{i16T, fp16T}, SpecificationVersion::V_1_0},
+ {{i16T, fp32T}, SpecificationVersion::V_1_0},
+ {{i32T, fp16T}, SpecificationVersion::V_1_0},
+ {{i32T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp16T, i8T}, SpecificationVersion::V_1_0},
+ {{fp16T, i16T}, SpecificationVersion::V_1_0},
+ {{fp16T, i32T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp32T, i8T}, SpecificationVersion::V_1_0},
+ {{fp32T, i16T}, SpecificationVersion::V_1_0},
+ {{fp32T, i32T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.rescale",
{{{Profile::pro_int},
- {{i8T, i8T, i8T, i8T},
- {i8T, i8T, i16T, i16T},
- {i8T, i8T, i32T, i32T},
- {i16T, i16T, i8T, i8T},
- {i16T, i16T, i16T, i16T},
- {i16T, i16T, i32T, i32T},
- {i32T, i32T, i8T, i8T},
- {i32T, i32T, i16T, i16T},
- {i32T, i32T, i32T, i32T}}}}},
+ {{{i8T, i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i8T, i8T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i32T, i32T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.const",
{{{Profile::pro_int, Profile::pro_fp},
- {{boolT}, {i8T}, {i16T}, {i32T}},
+ {{{boolT}, SpecificationVersion::V_1_0},
+ {{i8T}, SpecificationVersion::V_1_0},
+ {{i16T}, SpecificationVersion::V_1_0},
+ {{i32T}, SpecificationVersion::V_1_0}},
anyOf},
- {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+ {{Profile::pro_fp},
+ {{{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.identity",
{{{Profile::pro_int, Profile::pro_fp},
- {{boolT, boolT}, {i8T, i8T}, {i16T, i16T}, {i32T, i32T}},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0},
+ {{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}},
anyOf},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.variable",
- {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+ {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.variable_write",
- {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+ {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.variable_read",
- {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+ {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
};
extensionComplianceMap = {
{"tosa.argmax",
- {{{Extension::int16}, {{i16T, i32T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, i32T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, i32T}}},
- {{Extension::bf16}, {{bf16T, i32T}}}}},
+ {{{Extension::int16}, {{{i16T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3}, {{{fp8e4m3T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2}, {{{fp8e5m2T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.avg_pool2d",
- {{{Extension::int16}, {{i16T, i16T, i16T, i32T, i16T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+ {{{Extension::int16},
+ {{{i16T, i16T, i16T, i32T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T},
+ SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T},
+ SpecificationVersion::V_1_0}}},
+ {{Extension::bf16},
+ {{{bf16T, bf16T, bf16T, fp32T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.conv2d",
- {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}},
+ {{{Extension::int4},
+ {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16},
+ {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}},
+ {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}},
+ {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::bf16},
- {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+ {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.conv3d",
- {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}},
+ {{{Extension::int4},
+ {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16},
+ {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}},
+ {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}},
+ {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::bf16},
- {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+ {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.depthwise_conv2d",
- {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}},
+ {{{Extension::int4},
+ {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16},
+ {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}},
+ {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}},
+ {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::bf16},
- {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
- {"tosa.fft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
+ SpecificationVersion::V_1_0}}}}},
+ {"tosa.fft2d",
+ {{{Extension::fft},
+ {{{fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.matmul",
- {{{Extension::int16}, {{i16T, i16T, i16T, i16T, i48T}}},
+ {{{Extension::int16},
+ {{{i16T, i16T, i16T, i16T, i48T}, SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T},
- {fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp32T}}},
+ {{{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp32T},
+ SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T},
- {fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp32T}}},
+ {{{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp32T},
+ SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e4m3, Extension::fp8e5m2},
- {{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp16T},
- {fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp32T},
- {fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp16T},
- {fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp32T}},
+ {{{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp16T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp32T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp16T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp32T},
+ SpecificationVersion::V_1_1_DRAFT}},
allOf},
- {{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T, fp32T}}}}},
+ {{Extension::bf16},
+ {{{bf16T, bf16T, bf16T, bf16T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.max_pool2d",
- {{{Extension::int16}, {{i16T, i16T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.rfft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T}}}}},
+ {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.rfft2d",
+ {{{Extension::fft},
+ {{{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.transpose_conv2d",
- {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}},
+ {{{Extension::int4},
+ {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16},
+ {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}},
+ {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}},
+ {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::bf16},
- {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+ {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.clamp",
- {{{Extension::int16}, {{i16T, i16T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.erf", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.sigmoid", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.tanh", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.add", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.maximum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.minimum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.mul", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.pow", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.sub", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.table", {{{Extension::int16}, {{i16T, i16T, i32T}}}}},
- {"tosa.abs", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.ceil", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.cos", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.exp", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.floor", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.log", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T}}}}},
- {"tosa.reciprocal", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.rsqrt", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.sin", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.select", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
- {"tosa.greater", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
- {"tosa.greater_equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
- {"tosa.reduce_max", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.reduce_min", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.reduce_product", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.reduce_sum", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.erf",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sigmoid",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.tanh",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.add",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.maximum",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.minimum",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.mul",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.pow",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sub",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.table",
+ {{{Extension::int16},
+ {{{i16T, i16T, i32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.abs",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.ceil",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.cos",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.exp",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.floor",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.log",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.negate",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reciprocal",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.rsqrt",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sin",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.select",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.equal",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.greater",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.greater_equal",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reduce_max",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reduce_min",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reduce_product",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reduce_sum",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.concat",
- {{{Extension::int16}, {{i16T, i16T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.pad",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reshape",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reverse",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.slice",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.tile",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.transpose",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.gather",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, i32T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, i32T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, i32T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, i32T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, i32T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16},
+ {{{bf16T, i32T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.scatter",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, i32T, fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, i32T, bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, i32T, fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16},
+ {{{bf16T, i32T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.resize",
- {{{Extension::int16}, {{i16T, i48T}, {i16T, i16T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::int16},
+ {{{i16T, i48T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.cast",
{{{Extension::bf16},
- {{i8T, bf16T},
- {i16T, bf16T},
- {i32T, bf16T},
- {bf16T, i8T},
- {bf16T, i16T},
- {bf16T, i32T},
- {bf16T, fp32T},
- {fp32T, bf16T}}},
+ {{{i8T, bf16T}, SpecificationVersion::V_1_0},
+ {{i16T, bf16T}, SpecificationVersion::V_1_0},
+ {{i32T, bf16T}, SpecificationVersion::V_1_0},
+ {{bf16T, i8T}, SpecificationVersion::V_1_0},
+ {{bf16T, i16T}, SpecificationVersion::V_1_0},
+ {{bf16T, i32T}, SpecificationVersion::V_1_0},
+ {{bf16T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp32T, bf16T}, SpecificationVersion::V_1_0}}},
{{Extension::bf16, Extension::fp8e4m3},
- {{bf16T, fp8e4m3T}, {fp8e4m3T, bf16T}},
+ {{{bf16T, fp8e4m3T}, SpecificationVersion::V_1_0},
+ {{fp8e4m3T, bf16T}, SpecificationVersion::V_1_0}},
allOf},
{{Extension::bf16, Extension::fp8e5m2},
- {{bf16T, fp8e5m2T}, {fp8e5m2T, bf16T}},
+ {{{bf16T, fp8e5m2T}, SpecificationVersion::V_1_0},
+ {{fp8e5m2T, bf16T}, SpecificationVersion::V_1_0}},
allOf},
{{Extension::fp8e4m3},
- {{fp8e4m3T, fp16T},
- {fp8e4m3T, fp32T},
- {fp16T, fp8e4m3T},
- {fp32T, fp8e4m3T}}},
+ {{{fp8e4m3T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp8e4m3T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp8e4m3T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
{{Extension::fp8e5m2},
- {{fp8e5m2T, fp16T},
- {fp8e5m2T, fp32T},
- {fp16T, fp8e5m2T},
- {fp32T, fp8e5m2T}}}}},
+ {{{fp8e5m2T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp8e5m2T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp8e5m2T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp8e5m2T}, SpecificationVersion::V_1_0}}}}},
{"tosa.rescale",
{{{Extension::int16},
- {{i48T, i48T, i8T, i8T},
- {i48T, i48T, i16T, i16T},
- {i48T, i48T, i32T, i32T}}}}},
+ {{{i48T, i48T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i48T, i48T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i48T, i48T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.const",
- {{{Extension::int4}, {{i4T}}},
- {{Extension::int16}, {{i48T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T}}}}},
+ {{{Extension::int4}, {{{i4T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16}, {{{i48T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3}, {{{fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2}, {{{fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.identity",
- {{{Extension::int4}, {{i4T, i4T}}},
- {{Extension::int16}, {{i48T, i48T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.variable", {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}},
+ {{{Extension::int4}, {{{i4T, i4T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16}, {{{i48T, i48T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.variable",
+ {{{Extension::variable},
+ {{{i8T}, SpecificationVersion::V_1_0},
+ {{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.variable_write",
- {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}},
+ {{{Extension::variable},
+ {{{i8T}, SpecificationVersion::V_1_0},
+ {{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.variable_read",
- {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}},
+ {{{Extension::variable},
+ {{{i8T}, SpecificationVersion::V_1_0},
+ {{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
};
+
// End of auto-generated metadata
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 38cb293..8376a4c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -221,7 +221,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
}
//===----------------------------------------------------------------------===//
-// TOSA Spec Section 1.5.
+// TOSA Profiles and extensions
//
// Profile:
// INT : Integer Inference. Integer operations, primarily 8 and 32-bit values.
@@ -293,12 +293,6 @@ def Tosa_ExtensionAttr
def Tosa_ExtensionArrayAttr
: TypedArrayAttrBase<Tosa_ExtensionAttr, "TOSA extension array attribute">;
-def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
-def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;
-
-def Tosa_LevelAttr
- : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;
-
// The base class for defining op availability dimensions.
class Availability {
// The following are fields for controlling the generated C++ OpInterface.
@@ -405,17 +399,40 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
}
//===----------------------------------------------------------------------===//
+// TOSA Levels
+//===----------------------------------------------------------------------===//
+
+def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
+def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;
+
+def Tosa_LevelAttr
+ : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;
+
+//===----------------------------------------------------------------------===//
+// TOSA Specification versions
+//===----------------------------------------------------------------------===//
+
+def Tosa_V_1_0 : I32EnumAttrCase<"V_1_0", 0, "1.0">;
+def Tosa_V_1_1_DRAFT : I32EnumAttrCase<"V_1_1_DRAFT", 1, "1.1.draft">;
+
+def Tosa_SpecificationVersion : Tosa_I32EnumAttr<
+ "SpecificationVersion", "TOSA specification version", "specification_version",
+ [Tosa_V_1_0, Tosa_V_1_1_DRAFT]>;
+
+//===----------------------------------------------------------------------===//
// TOSA target environment.
//===----------------------------------------------------------------------===//
def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> {
let summary = "Target environment information.";
let parameters = ( ins
+ "SpecificationVersion": $specification_version,
"Level": $level,
ArrayRefParameter<"Profile">: $profiles,
ArrayRefParameter<"Extension">: $extensions
);
- let assemblyFormat = "`<` `level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
+ let assemblyFormat = "`<` `specification_version` `=` $specification_version `,` "
+ "`level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
"`extensions` `=` `[` $extensions `]` `>`";
}
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 8f5c72b..7b946ad 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -36,12 +36,15 @@ enum CheckCondition {
allOf
};
+using VersionedTypeInfo =
+ std::pair<SmallVector<TypeInfo>, SpecificationVersion>;
+
template <typename T>
struct OpComplianceInfo {
// Certain operations require multiple modes enabled.
// e.g. cast bf16 to fp8e4m3 requires EXT-BF16 and EXT-FP8E4M3.
SmallVector<T> mode;
- SmallVector<SmallVector<TypeInfo>> operandTypeInfoSet;
+ SmallVector<VersionedTypeInfo> operandTypeInfoSet;
CheckCondition condition = CheckCondition::anyOf;
};
@@ -130,9 +133,8 @@ public:
// Find the required profiles or extensions from the compliance info according
// to the operand type combination.
template <typename T>
- SmallVector<T> findMatchedProfile(Operation *op,
- SmallVector<OpComplianceInfo<T>> compInfo,
- CheckCondition &condition);
+ OpComplianceInfo<T>
+ findMatchedEntry(Operation *op, SmallVector<OpComplianceInfo<T>> compInfo);
SmallVector<Profile> getCooperativeProfiles(Extension ext) {
switch (ext) {
@@ -168,8 +170,7 @@ public:
private:
template <typename T>
- FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op,
- CheckCondition &condition);
+ FailureOr<OpComplianceInfo<T>> getOperatorDefinition(Operation *op);
OperationProfileComplianceMap profileComplianceMap;
OperationExtensionComplianceMap extensionComplianceMap;
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 6ae19d8..14b00b0 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -137,6 +137,13 @@ def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> {
];
let options = [
+ Option<"specificationVersion", "specification_version", "mlir::tosa::SpecificationVersion",
+ /*default=*/"mlir::tosa::SpecificationVersion::V_1_0",
+ "The specification version that TOSA operators should conform to.",
+ [{::llvm::cl::values(
+ clEnumValN(mlir::tosa::SpecificationVersion::V_1_0, "1.0", "TOSA Specification version 1.0"),
+ clEnumValN(mlir::tosa::SpecificationVersion::V_1_1_DRAFT, "1.1.draft", "TOSA Specification version 1.1.draft")
+ )}]>,
Option<"level", "level", "mlir::tosa::Level",
/*default=*/"mlir::tosa::Level::eightK",
"The TOSA level that operators should conform to. A TOSA level defines "
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 6e79085..6e15b1e 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2999,6 +2999,7 @@ def Vector_StepOp : Vector_Op<"step", [
}];
let results = (outs VectorOfRankAndType<[1], [Index]>:$result);
let assemblyFormat = "attr-dict `:` type($result)";
+ let hasCanonicalizer = 1;
}
def Vector_YieldOp : Vector_Op<"yield", [
diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
index b80ee2c..e9425e8 100644
--- a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
@@ -43,9 +43,41 @@ class WasmSSA_BlockLikeOp<string mnemonic, string summaryStr> :
let assemblyFormat = "(`(`$inputs^`)` `:` type($inputs))? attr-dict `:` $body `>` $target";
}
-def WasmSSA_BlockOp : WasmSSA_BlockLikeOp<"block", "Create a nesting level"> {}
+def WasmSSA_BlockOp : WasmSSA_BlockLikeOp<
+ "block",
+ "Create a nesting level with a label at its exit."> {
+ let description = [{
+ Defines a Wasm block, creating a new nested scope.
+ A block contains a body region and an optional list of input values.
+ Control can enter the block and later branch out to the block target.
+ Example:
+
+ ```mlir
+
+ wasmssa.block {
+
+ // instructions
+
+ } > ^successor
+ }];
+}
+
+def WasmSSA_LoopOp : WasmSSA_BlockLikeOp<
+ "loop",
+ "Create a nesting level that define its entry as jump target."> {
+ let description = [{
+ Represents a Wasm loop construct. This defines a nesting level with
+ a label at the entry of the region.
-def WasmSSA_LoopOp : WasmSSA_BlockLikeOp<"loop", "Create a nesting level similar to Block Op, except that it has itself as a successor."> {}
+ Example:
+
+ ```mlir
+
+ wasmssa.loop {
+
+ } > ^successor
+ }];
+}
def WasmSSA_BlockReturnOp : WasmSSA_Op<"block_return", [Terminator,
DeclareOpInterfaceMethods<LabelBranchingOpInterface>]> {
@@ -55,9 +87,16 @@ def WasmSSA_BlockReturnOp : WasmSSA_Op<"block_return", [Terminator,
::mlir::Block* getTarget();
}];
let description = [{
- Marks a return from the current block.
+ Escape from the current nesting level and return the control flow to its successor.
+ Optionally, mark the arguments that should be transfered to the successor block.
- Example:
+ This shouldn't be confused with branch operations that targets the label defined
+ by the nesting level operation.
+
+ For instance, a `wasmssa.block_return` in a loop will give back control to the
+ successor of the loop, where a `branch` targeting the loop will flow back to the entry block of the loop.
+
+ Example:
```mlir
wasmssa.block_return
@@ -127,12 +166,18 @@ def WasmSSA_FuncOp : WasmSSA_Op<"func", [
- Arguments of the entry block of type `!wasm<local T>`, with T the corresponding type
in the function type.
+ By default, `wasmssa.func` have nested visibility. Functions exported by the module
+ are marked with the exported attribute. This gives them public visibility.
+
Example:
```mlir
- // A simple function with no arguments that returns a float32
+ // Internal function with no arguments that returns a float32
wasmssa.func @my_f32_func() -> f32
+ // Exported function with no arguments that returns a float32
+ wasmssa.func exported @my_f32_func() -> f32
+
// A function that takes a local ref argument
wasmssa.func @i64_wrap(%a: !wasmssa<local ref to i64>) -> i32
```
@@ -141,7 +186,7 @@ def WasmSSA_FuncOp : WasmSSA_Op<"func", [
WasmSSA_FuncTypeAttr: $functionType,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs,
- DefaultValuedAttr<StrAttr, "\"nested\"">:$sym_visibility);
+ UnitAttr: $exported);
let regions = (region AnyRegion: $body);
let extraClassDeclaration = [{
@@ -162,6 +207,12 @@ def WasmSSA_FuncOp : WasmSSA_Op<"func", [
/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }
+
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return getExported() ?
+ ::mlir::SymbolTable::Visibility::Public :
+ ::mlir::SymbolTable::Visibility::Nested;
+ };
}];
let builders = [
@@ -207,8 +258,7 @@ def WasmSSA_FuncImportOp : WasmSSA_Op<"import_func", [
StrAttr: $importName,
WasmSSA_FuncTypeAttr: $type,
OptionalAttr<DictArrayAttr>:$arg_attrs,
- OptionalAttr<DictArrayAttr>:$res_attrs,
- OptionalAttr<StrAttr>:$sym_visibility);
+ OptionalAttr<DictArrayAttr>:$res_attrs);
let extraClassDeclaration = [{
bool isDeclaration() const { return true; }
@@ -221,6 +271,10 @@ def WasmSSA_FuncImportOp : WasmSSA_Op<"import_func", [
::llvm::ArrayRef<Type> getResultTypes() {
return getType().getResults();
}
+
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return ::mlir::SymbolTable::Visibility::Nested;
+ };
}];
let builders = [
OpBuilder<(ins "StringRef":$symbol,
@@ -238,30 +292,41 @@ def WasmSSA_GlobalOp : WasmSSA_Op<"global", [
let arguments = (ins SymbolNameAttr: $sym_name,
WasmSSA_ValTypeAttr: $type,
UnitAttr: $isMutable,
- OptionalAttr<StrAttr>:$sym_visibility);
+ UnitAttr: $exported);
let description = [{
WebAssembly global variable.
Body contains the initialization instructions for the variable value.
The body must contain only instructions considered `const` in a webassembly context,
such as `wasmssa.const` or `global.get`.
+ By default, `wasmssa.global` have nested visibility. Global exported by the module
+ are marked with the exported attribute. This gives them public visibility.
+
Example:
```mlir
- // Define a global_var, a mutable i32 global variable equal to 10.
- wasmssa.global @global_var i32 mutable nested : {
+ // Define module_global_var, an internal mutable i32 global variable equal to 10.
+ wasmssa.global @module_global_var i32 mutable : {
%[[VAL_0:.*]] = wasmssa.const 10 : i32
wasmssa.return %[[VAL_0]] : i32
}
+
+ // Define global_var, an exported constant i32 global variable equal to 42.
+ wasmssa.global @global_var i32 : {
+ %[[VAL_0:.*]] = wasmssa.const 42 : i32
+ wasmssa.return %[[VAL_0]] : i32
+ }
```
}];
let regions = (region AnyRegion: $initializer);
- let builders = [
- OpBuilder<(ins "StringRef":$symbol,
- "Type": $type,
- "bool": $isMutable)>
- ];
+ let extraClassDeclaration = [{
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return getExported() ?
+ ::mlir::SymbolTable::Visibility::Public :
+ ::mlir::SymbolTable::Visibility::Nested;
+ };
+ }];
let hasCustomAssemblyFormat = 1;
}
@@ -283,18 +348,14 @@ def WasmSSA_GlobalImportOp : WasmSSA_Op<"import_global", [
StrAttr: $moduleName,
StrAttr: $importName,
WasmSSA_ValTypeAttr: $type,
- UnitAttr: $isMutable,
- OptionalAttr<StrAttr>:$sym_visibility);
+ UnitAttr: $isMutable);
let extraClassDeclaration = [{
bool isDeclaration() const { return true; }
+
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return ::mlir::SymbolTable::Visibility::Nested;
+ };
}];
- let builders = [
- OpBuilder<(ins "StringRef":$symbol,
- "StringRef":$moduleName,
- "StringRef":$importName,
- "Type": $type,
- "bool": $isMutable)>
- ];
let hasCustomAssemblyFormat = 1;
}
@@ -442,23 +503,33 @@ def WasmSSA_MemOp : WasmSSA_Op<"memory", [Symbol]> {
Define a memory to be used by the program.
Multiple memories can be defined in the same module.
+ By default, `wasmssa.memory` have nested visibility. Memory exported by
+ the module are marked with the exported attribute. This gives them public
+ visibility.
+
Example:
```mlir
- // Define the `mem_0` memory with defined bounds of 0 -> 65536
+ // Define the `mem_0` (internal) memory with defined size bounds of [0:65536]
wasmssa.memory @mem_0 !wasmssa<limit[0:65536]>
+
+ // Define the `mem_1` exported memory with minimal size of 512
+ wasmssa.memory exported @mem_1 !wasmssa<limit[512:]>
```
}];
let arguments = (ins SymbolNameAttr: $sym_name,
WasmSSA_LimitTypeAttr: $limits,
- OptionalAttr<StrAttr>:$sym_visibility);
- let builders = [
- OpBuilder<(ins
- "::llvm::StringRef":$symbol,
- "wasmssa::LimitType":$limit)>
- ];
+ UnitAttr: $exported);
- let assemblyFormat = "$sym_name custom<WasmVisibility>($sym_visibility) $limits attr-dict";
+ let extraClassDeclaration = [{
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return getExported() ?
+ ::mlir::SymbolTable::Visibility::Public :
+ ::mlir::SymbolTable::Visibility::Nested;
+ };
+ }];
+
+ let assemblyFormat = "(`exported` $exported^)? $sym_name $limits attr-dict";
}
def WasmSSA_MemImportOp : WasmSSA_Op<"import_mem", [Symbol, ImportOpInterface]> {
@@ -476,16 +547,13 @@ def WasmSSA_MemImportOp : WasmSSA_Op<"import_mem", [Symbol, ImportOpInterface]>
let arguments = (ins SymbolNameAttr: $sym_name,
StrAttr: $moduleName,
StrAttr: $importName,
- WasmSSA_LimitTypeAttr: $limits,
- OptionalAttr<StrAttr>:$sym_visibility);
+ WasmSSA_LimitTypeAttr: $limits);
let extraClassDeclaration = [{
- bool isDeclaration() const { return true; }
+ bool isDeclaration() const { return true; }
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return ::mlir::SymbolTable::Visibility::Nested;
+ };
}];
- let builders = [OpBuilder<(ins
- "::llvm::StringRef":$symbol,
- "::llvm::StringRef":$moduleName,
- "::llvm::StringRef":$importName,
- "wasmssa::LimitType":$limits)>];
let assemblyFormat = "$importName `from` $moduleName `as` $sym_name attr-dict";
}
@@ -493,11 +561,15 @@ def WasmSSA_TableOp : WasmSSA_Op<"table", [Symbol]> {
let summary= "WebAssembly table value";
let arguments = (ins SymbolNameAttr: $sym_name,
WasmSSA_TableTypeAttr: $type,
- OptionalAttr<StrAttr>:$sym_visibility);
- let builders = [OpBuilder<(ins
- "::llvm::StringRef":$symbol,
- "wasmssa::TableType":$type)>];
- let assemblyFormat = "$sym_name custom<WasmVisibility>($sym_visibility) $type attr-dict";
+ UnitAttr: $exported);
+ let extraClassDeclaration = [{
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return getExported() ?
+ ::mlir::SymbolTable::Visibility::Public :
+ ::mlir::SymbolTable::Visibility::Nested;
+ };
+ }];
+ let assemblyFormat = "(`exported` $exported^)? $sym_name $type attr-dict";
}
def WasmSSA_TableImportOp : WasmSSA_Op<"import_table", [Symbol, ImportOpInterface]> {
@@ -515,17 +587,14 @@ def WasmSSA_TableImportOp : WasmSSA_Op<"import_table", [Symbol, ImportOpInterfac
let arguments = (ins SymbolNameAttr: $sym_name,
StrAttr: $moduleName,
StrAttr: $importName,
- WasmSSA_TableTypeAttr: $type,
- OptionalAttr<StrAttr>:$sym_visibility);
+ WasmSSA_TableTypeAttr: $type);
let extraClassDeclaration = [{
bool isDeclaration() const { return true; }
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return ::mlir::SymbolTable::Visibility::Nested;
+ };
}];
let assemblyFormat = "$importName `from` $moduleName `as` $sym_name attr-dict";
- let builders = [OpBuilder<(ins
- "::llvm::StringRef":$symbol,
- "::llvm::StringRef":$moduleName,
- "::llvm::StringRef":$importName,
- "wasmssa::TableType":$type)>];
}
def WasmSSA_ReturnOp : WasmSSA_Op<"return", [Terminator]> {
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 5695d5d..19a5231 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -712,10 +712,14 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
return getAttrs().contains(name);
}
- ArrayAttr getStrides() {
+ ArrayAttr getStrideAttr() {
return getAttrs().getAs<ArrayAttr>("stride");
}
+ ArrayAttr getBlockAttr() {
+ return getAttrs().getAs<ArrayAttr>("block");
+ }
+
}];
}
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 73f9061..426377f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1298,14 +1298,14 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
}
def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
- AllElementTypesMatch<["mem_desc", "res"]>,
- AllRanksMatch<["mem_desc", "res"]>]> {
+ AllElementTypesMatch<["mem_desc", "res"]>]> {
let arguments = (ins XeGPU_MemDesc:$mem_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
+ OptionalAttr<UnitAttr>:$subgroup_block_io,
OptionalAttr<DistributeLayoutAttr>:$layout
);
- let results = (outs XeGPU_ValueType:$res);
+ let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$res);
let assemblyFormat = [{
$mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `` `:` type(operands) `->` type(results)
@@ -1319,6 +1319,9 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
Arguments:
- `mem_desc`: the memory descriptor identifying the SLM region.
- `offsets`: the coordinates within the matrix to read from.
+ - `subgroup_block_io`: [optional] An attribute indicating that the operation can be
+ lowered to a subgroup block load. When this attribute is present,
+ the offsets are subgroup-uniform across all lanes.
- `layout`: [optional] An attribute for guiding distributions among
subgroups and/or work-items. It currently can accept either
LayoutAttr or SliceAttr.
@@ -1336,7 +1339,10 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
}
ArrayRef<int64_t> getDataShape() {
- return getRes().getType().getShape();
+ auto resTy = getRes().getType();
+ if (auto vecTy = llvm::dyn_cast<VectorType>(resTy))
+ return vecTy.getShape();
+ return {};
}
}];
@@ -1344,13 +1350,13 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
}
def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
- AllElementTypesMatch<["mem_desc", "data"]>,
- AllRanksMatch<["mem_desc", "data"]>]> {
+ AllElementTypesMatch<["mem_desc", "data"]>]> {
let arguments = (ins
- XeGPU_ValueType:$data,
+ AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$data,
XeGPU_MemDesc:$mem_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
+ OptionalAttr<UnitAttr>:$subgroup_block_io,
OptionalAttr<DistributeLayoutAttr>:$layout
);
let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
@@ -1364,6 +1370,9 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
- `mem_desc`: the memory descriptor specifying the SLM region.
- `offsets`: the coordinates within the matrix where the data will be written.
- `data`: the values to be stored in the matrix.
+ - `subgroup_block_io`: [optional] An attribute indicating that the operation can be
+ lowered to a subgroup block store. When this attribute is present,
+ the offsets are subgroup-uniform across all lanes.
- `layout`: [optional] An attribute for guiding distributions among
subgroups and/or work-items. It currently can accept either
LayoutAttr or SliceAttr.
@@ -1378,7 +1387,10 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
}
ArrayRef<int64_t> getDataShape() {
- return getData().getType().getShape();
+ auto DataTy = getData().getType();
+ if (auto vecTy = llvm::dyn_cast<VectorType>(DataTy))
+ return vecTy.getShape();
+ return {};
}
}];
@@ -1386,41 +1398,4 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
let hasVerifier = 1;
}
-def XeGPU_MemDescSubviewOp: XeGPU_Op<"mem_desc_subview",
- [Pure, ViewLikeOpInterface, AllElementTypesMatch<["src", "res"]>]> {
- let description = [{
- Creates a subview of a memory descriptor. The resulting memory descriptor can have
- a lower rank than the source; in this case, the result dimensions correspond to the
- higher-order dimensions of the source memory descriptor.
-
- Arguments:
- - `src` : a memory descriptor.
- - `offsets` : the coordinates within the matrix the subview will be created from.
-
- Results:
- - `res` : a memory descriptor with smaller size.
-
- }];
- let arguments = (ins XeGPU_MemDesc:$src,
- Variadic<Index>:$offsets,
- DenseI64ArrayAttr:$const_offsets);
- let results = (outs XeGPU_MemDesc:$res);
- let assemblyFormat = [{$src `` custom<DynamicIndexList>($offsets, $const_offsets) prop-dict
- attr-dict `` `:` qualified(type($src)) `->` qualified(type($res))}];
- let builders = [
- OpBuilder<(ins "Type": $res, "Value":$src, "llvm::ArrayRef<OpFoldResult>": $offsets)>
- ];
-
- let extraClassDeclaration = [{
- mlir::Value getViewSource() { return getSrc(); }
-
- SmallVector<OpFoldResult> getMixedOffsets() {
- return getMixedValues(getConstOffsets(), getOffsets(), getContext());
- }
- }];
-
- let hasVerifier = 1;
-}
-
-
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 84902b2..b1196fb 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -237,12 +237,11 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout());
}
- ArrayAttr getStrides() {
+ ArrayAttr getStrideAttr() {
auto layout = getMemLayout();
if (layout && layout.hasAttr("stride")) {
- return layout.getStrides();
+ return layout.getStrideAttr();
}
-
// derive and return default strides
SmallVector<int64_t> defaultStrides;
llvm::append_range(defaultStrides, getShape().drop_front());
@@ -250,6 +249,63 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
Builder builder(getContext());
return builder.getI64ArrayAttr(defaultStrides);
}
+
+ ArrayAttr getBlockAttr() {
+ auto layout = getMemLayout();
+ if (layout && layout.hasAttr("block")) {
+ return layout.getBlockAttr();
+ }
+ Builder builder(getContext());
+ return builder.getI64ArrayAttr({});
+ }
+
+ /// Heuristic to determine if the MemDesc uses column-major layout,
+ /// based on the rank and the value of the first stride dimension.
+ bool isColMajor() {
+ auto dim0 = dyn_cast<IntegerAttr>(getStrideAttr()[0]);
+ return getRank() == 2 && dim0.getInt() == 1;
+ }
+
+ // Get the Blocking shape for a MemDescType, Which is represented
+ // as an attribute in MemDescType. By default it is the shape
+ // of the mdescTy
+ SmallVector<int64_t> getBlockShape() {
+ SmallVector<int64_t> size(getShape());
+ ArrayAttr blockAttr = getBlockAttr();
+ if (!blockAttr.empty()) {
+ size.clear();
+ for (auto attr : blockAttr.getValue()) {
+ size.push_back(cast<IntegerAttr>(attr).getInt());
+ }
+ }
+ return size;
+ }
+
+ // Get strides as vector of integer.
+ // If it contains block attribute, the strides are blocked strides.
+ //
+ // The blocking is applied to the base matrix shape derived from the
+ // memory descriptor's stride information. If the matrix described by
+ // the memory descriptor is not contiguous, it is assumed that the base
+ // matrix is contiguous and follows the same memory layout.
+ //
+ // It first computes the original matrix shape using the stride info,
+ // then computes the number of blocks in each dimension of original shape,
+ // then compute the outer block shape and stride,
+ // then combines the inner and outer block shape and stride
+ // e.g. for `mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>`
+ // its memory layout tuple is ([2,32,16,8],[128,256,1,16])
+ // for `mem_desc<256x32xf16, @block=[8, 16]>` with default @stride[32, 1]
+ // its memory layout tuple is ([32,2,8,16],[256,128,16,1])
+ SmallVector<int64_t> getStrideShape();
+
+ /// Generates instructions to compute the linearize offset
+ // if the memory descriptor is blocked, it returns linearize offset based on the blocked layout
+ // the strides of memory descriptor is always considered regardless of blocked or not
+ Value getLinearOffsets(OpBuilder &builder,
+ Location loc, ArrayRef<OpFoldResult> offsets);
+
+
}];
let hasCustomAssemblyFormat = true;
diff --git a/mlir/include/mlir/IR/Remarks.h b/mlir/include/mlir/IR/Remarks.h
index 20e84ec..9877926 100644
--- a/mlir/include/mlir/IR/Remarks.h
+++ b/mlir/include/mlir/IR/Remarks.h
@@ -18,7 +18,6 @@
#include "llvm/Remarks/Remark.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Regex.h"
-#include <optional>
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
@@ -144,7 +143,7 @@ public:
llvm::StringRef getCategoryName() const { return categoryName; }
- llvm::StringRef getFullCategoryName() const {
+ llvm::StringRef getCombinedCategoryName() const {
if (categoryName.empty() && subCategoryName.empty())
return {};
if (subCategoryName.empty())
@@ -318,7 +317,7 @@ private:
};
//===----------------------------------------------------------------------===//
-// MLIR Remark Streamer
+// Pluggable Remark Utilities
//===----------------------------------------------------------------------===//
/// Base class for MLIR remark streamers that is used to stream
@@ -338,6 +337,26 @@ public:
virtual void finalize() {} // optional
};
+using ReportFn = llvm::unique_function<void(const Remark &)>;
+
+/// Base class for MLIR remark emitting policies that is used to emit
+/// optimization remarks to the underlying remark streamer. The derived classes
+/// should implement the `reportRemark` method to provide the actual emitting
+/// implementation.
+class RemarkEmittingPolicyBase {
+protected:
+ ReportFn reportImpl;
+
+public:
+ RemarkEmittingPolicyBase() = default;
+ virtual ~RemarkEmittingPolicyBase() = default;
+
+ void initialize(ReportFn fn) { reportImpl = std::move(fn); }
+
+ virtual void reportRemark(const Remark &remark) = 0;
+ virtual void finalize() = 0;
+};
+
//===----------------------------------------------------------------------===//
// Remark Engine (MLIR Context will own this class)
//===----------------------------------------------------------------------===//
@@ -355,6 +374,8 @@ private:
std::optional<llvm::Regex> failedFilter;
/// The MLIR remark streamer that will be used to emit the remarks.
std::unique_ptr<MLIRRemarkStreamerBase> remarkStreamer;
+ /// The MLIR remark policy that will be used to emit the remarks.
+ std::unique_ptr<RemarkEmittingPolicyBase> remarkEmittingPolicy;
/// When is enabled, engine also prints remarks as mlir::emitRemarks.
bool printAsEmitRemarks = false;
@@ -392,6 +413,8 @@ private:
InFlightRemark emitIfEnabled(Location loc, RemarkOpts opts,
bool (RemarkEngine::*isEnabled)(StringRef)
const);
+ /// Report a remark.
+ void reportImpl(const Remark &remark);
public:
/// Default constructor is deleted, use the other constructor.
@@ -407,8 +430,15 @@ public:
~RemarkEngine();
/// Setup the remark engine with the given output path and format.
- LogicalResult initialize(std::unique_ptr<MLIRRemarkStreamerBase> streamer,
- std::string *errMsg);
+ LogicalResult
+ initialize(std::unique_ptr<MLIRRemarkStreamerBase> streamer,
+ std::unique_ptr<RemarkEmittingPolicyBase> remarkEmittingPolicy,
+ std::string *errMsg);
+
+ /// Get the remark emitting policy.
+ RemarkEmittingPolicyBase *getRemarkEmittingPolicy() const {
+ return remarkEmittingPolicy.get();
+ }
/// Report a remark.
void report(const Remark &&remark);
@@ -446,6 +476,46 @@ inline InFlightRemark withEngine(Fn fn, Location loc, Args &&...args) {
namespace mlir::remark {
+//===----------------------------------------------------------------------===//
+// Remark Emitting Policies
+//===----------------------------------------------------------------------===//
+
+/// Policy that emits all remarks.
+class RemarkEmittingPolicyAll : public detail::RemarkEmittingPolicyBase {
+public:
+ RemarkEmittingPolicyAll();
+
+ void reportRemark(const detail::Remark &remark) override {
+ assert(reportImpl && "reportImpl is not set");
+ reportImpl(remark);
+ }
+ void finalize() override {}
+};
+
+/// Policy that emits final remarks.
+class RemarkEmittingPolicyFinal : public detail::RemarkEmittingPolicyBase {
+private:
+ /// user can intercept them for custom processing via a registered callback,
+ /// otherwise they will be reported on engine destruction.
+ llvm::DenseSet<detail::Remark> postponedRemarks;
+
+public:
+ RemarkEmittingPolicyFinal();
+
+ void reportRemark(const detail::Remark &remark) override {
+ postponedRemarks.erase(remark);
+ postponedRemarks.insert(remark);
+ }
+
+ void finalize() override {
+ assert(reportImpl && "reportImpl is not set");
+ for (auto &remark : postponedRemarks) {
+ if (reportImpl)
+ reportImpl(remark);
+ }
+ }
+};
+
/// Create a Reason with llvm::formatv formatting.
template <class... Ts>
inline detail::LazyTextBuild reason(const char *fmt, Ts &&...ts) {
@@ -505,16 +575,72 @@ inline detail::InFlightRemark analysis(Location loc, RemarkOpts opts) {
/// Setup remarks for the context. This function will enable the remark engine
/// and set the streamer to be used for optimization remarks. The remark
-/// categories are used to filter the remarks that will be emitted by the remark
-/// engine. If a category is not specified, it will not be emitted. If
+/// categories are used to filter the remarks that will be emitted by the
+/// remark engine. If a category is not specified, it will not be emitted. If
/// `printAsEmitRemarks` is true, the remarks will be printed as
/// mlir::emitRemarks. 'streamer' must inherit from MLIRRemarkStreamerBase and
/// will be used to stream the remarks.
LogicalResult enableOptimizationRemarks(
MLIRContext &ctx,
std::unique_ptr<remark::detail::MLIRRemarkStreamerBase> streamer,
+ std::unique_ptr<remark::detail::RemarkEmittingPolicyBase>
+ remarkEmittingPolicy,
const remark::RemarkCategories &cats, bool printAsEmitRemarks = false);
} // namespace mlir::remark
+// DenseMapInfo specialization for Remark
+namespace llvm {
+template <>
+struct DenseMapInfo<mlir::remark::detail::Remark> {
+ static constexpr StringRef kEmptyKey = "<EMPTY_KEY>";
+ static constexpr StringRef kTombstoneKey = "<TOMBSTONE_KEY>";
+
+ /// Helper to provide a static dummy context for sentinel keys.
+ static mlir::MLIRContext *getStaticDummyContext() {
+ static mlir::MLIRContext dummyContext;
+ return &dummyContext;
+ }
+
+ /// Create an empty remark
+ static inline mlir::remark::detail::Remark getEmptyKey() {
+ return mlir::remark::detail::Remark(
+ mlir::remark::RemarkKind::RemarkUnknown, mlir::DiagnosticSeverity::Note,
+ mlir::UnknownLoc::get(getStaticDummyContext()),
+ mlir::remark::RemarkOpts::name(kEmptyKey));
+ }
+
+ /// Create a dead remark
+ static inline mlir::remark::detail::Remark getTombstoneKey() {
+ return mlir::remark::detail::Remark(
+ mlir::remark::RemarkKind::RemarkUnknown, mlir::DiagnosticSeverity::Note,
+ mlir::UnknownLoc::get(getStaticDummyContext()),
+ mlir::remark::RemarkOpts::name(kTombstoneKey));
+ }
+
+ /// Compute the hash value of the remark
+ static unsigned getHashValue(const mlir::remark::detail::Remark &remark) {
+ return llvm::hash_combine(
+ remark.getLocation().getAsOpaquePointer(),
+ llvm::hash_value(remark.getRemarkName()),
+ llvm::hash_value(remark.getCombinedCategoryName()));
+ }
+
+ static bool isEqual(const mlir::remark::detail::Remark &lhs,
+ const mlir::remark::detail::Remark &rhs) {
+ // Check for empty/tombstone keys first
+ if (lhs.getRemarkName() == kEmptyKey ||
+ lhs.getRemarkName() == kTombstoneKey ||
+ rhs.getRemarkName() == kEmptyKey ||
+ rhs.getRemarkName() == kTombstoneKey) {
+ return lhs.getRemarkName() == rhs.getRemarkName();
+ }
+
+ // For regular remarks, compare key identifying fields
+ return lhs.getLocation() == rhs.getLocation() &&
+ lhs.getRemarkName() == rhs.getRemarkName() &&
+ lhs.getCombinedCategoryName() == rhs.getCombinedCategoryName();
+ }
+};
+} // namespace llvm
#endif // MLIR_IR_REMARKS_H
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index a5feb59..72ed046 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_interface(DestinationStyleOpInterface)
add_mlir_interface(FunctionInterfaces)
add_mlir_interface(IndexingMapOpInterface)
add_mlir_interface(InferIntRangeInterface)
+add_mlir_interface(InferStridedMetadataInterface)
add_mlir_interface(InferTypeOpInterface)
add_mlir_interface(LoopLikeInterface)
add_mlir_interface(MemOpInterfaces)
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index 0e107e8..a6de3d1 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -117,7 +117,8 @@ public:
IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {}
/// Create an integer value range lattice value.
- IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt)
+ explicit IntegerValueRange(
+ std::optional<ConstantIntRanges> value = std::nullopt)
: value(std::move(value)) {}
/// Whether the range is uninitialized. This happens when the state hasn't
@@ -167,6 +168,15 @@ using SetIntRangeFn =
using SetIntLatticeFn =
llvm::function_ref<void(Value, const IntegerValueRange &)>;
+/// Helper callback type to get the integer range of a value.
+using GetIntRangeFn = function_ref<IntegerValueRange(Value)>;
+
+/// Helper function to collect the integer range values of an array of op fold
+/// results.
+SmallVector<IntegerValueRange> getIntValueRanges(ArrayRef<OpFoldResult> values,
+ GetIntRangeFn getIntRange,
+ int32_t indexBitwidth);
+
class InferIntRangeInterface;
namespace intrange::detail {
diff --git a/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h
new file mode 100644
index 0000000..0c572e0
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h
@@ -0,0 +1,145 @@
+//===- InferStridedMetadataInterface.h - Strided Metadata Inference -C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions of the strided metadata inference interface
+// defined in `InferStridedMetadataInterface.td`
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
+#define MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
+
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+
+namespace mlir {
+/// A class that represents the strided metadata range information, including
+/// offsets, sizes, and strides as integer ranges.
+class StridedMetadataRange {
+public:
+ /// Default constructor creates uninitialized ranges.
+ StridedMetadataRange() = default;
+
+ /// Returns a ranked strided metadata range.
+ static StridedMetadataRange
+ getRanked(SmallVectorImpl<ConstantIntRanges> &&offsets,
+ SmallVectorImpl<ConstantIntRanges> &&sizes,
+ SmallVectorImpl<ConstantIntRanges> &&strides) {
+ return StridedMetadataRange(std::move(offsets), std::move(sizes),
+ std::move(strides));
+ }
+
+ /// Returns a strided metadata range with maximum ranges.
+ static StridedMetadataRange getMaxRanges(int32_t indexBitwidth,
+ int32_t offsetsRank,
+ int32_t sizeRank,
+ int32_t stridedRank) {
+ return StridedMetadataRange(
+ SmallVector<ConstantIntRanges>(
+ offsetsRank, ConstantIntRanges::maxRange(indexBitwidth)),
+ SmallVector<ConstantIntRanges>(
+ sizeRank, ConstantIntRanges::maxRange(indexBitwidth)),
+ SmallVector<ConstantIntRanges>(
+ stridedRank, ConstantIntRanges::maxRange(indexBitwidth)));
+ }
+
+ static StridedMetadataRange getMaxRanges(int32_t indexBitwidth,
+ int32_t rank) {
+ return getMaxRanges(indexBitwidth, 1, rank, rank);
+ }
+
+ /// Returns whether the metadata is uninitialized.
+ bool isUninitialized() const { return !offsets.has_value(); }
+
+ /// Get the offsets range.
+ ArrayRef<ConstantIntRanges> getOffsets() const {
+ return offsets ? *offsets : ArrayRef<ConstantIntRanges>();
+ }
+ MutableArrayRef<ConstantIntRanges> getOffsets() {
+ return offsets ? *offsets : MutableArrayRef<ConstantIntRanges>();
+ }
+
+ /// Get the sizes ranges.
+ ArrayRef<ConstantIntRanges> getSizes() const { return sizes; }
+ MutableArrayRef<ConstantIntRanges> getSizes() { return sizes; }
+
+ /// Get the strides ranges.
+ ArrayRef<ConstantIntRanges> getStrides() const { return strides; }
+ MutableArrayRef<ConstantIntRanges> getStrides() { return strides; }
+
+ /// Compare two strided metadata ranges.
+ bool operator==(const StridedMetadataRange &other) const {
+ return offsets == other.offsets && sizes == other.sizes &&
+ strides == other.strides;
+ }
+
+ /// Print the strided metadata range.
+ void print(raw_ostream &os) const;
+
+ /// Join two strided metadata ranges, by taking the element-wise union of the
+ /// metadata.
+ static StridedMetadataRange join(const StridedMetadataRange &lhs,
+ const StridedMetadataRange &rhs) {
+ if (lhs.isUninitialized())
+ return rhs;
+ if (rhs.isUninitialized())
+ return lhs;
+
+ // Helper fuction to compute the range union of constant ranges.
+ auto rangeUnion =
+ +[](const std::tuple<ConstantIntRanges, ConstantIntRanges> &lhsRhs)
+ -> ConstantIntRanges {
+ return std::get<0>(lhsRhs).rangeUnion(std::get<1>(lhsRhs));
+ };
+
+ // Get the elementwise range union. Note, that `zip_equal` will assert if
+ // sizes are not equal.
+ SmallVector<ConstantIntRanges> offsets = llvm::map_to_vector(
+ llvm::zip_equal(*lhs.offsets, *rhs.offsets), rangeUnion);
+ SmallVector<ConstantIntRanges> sizes =
+ llvm::map_to_vector(llvm::zip_equal(lhs.sizes, rhs.sizes), rangeUnion);
+ SmallVector<ConstantIntRanges> strides = llvm::map_to_vector(
+ llvm::zip_equal(lhs.strides, rhs.strides), rangeUnion);
+
+ // Return the joined metadata.
+ return StridedMetadataRange(std::move(offsets), std::move(sizes),
+ std::move(strides));
+ }
+
+private:
+ /// Create a strided metadata range with the given offset, sizes, and strides.
+ StridedMetadataRange(SmallVectorImpl<ConstantIntRanges> &&offsets,
+ SmallVectorImpl<ConstantIntRanges> &&sizes,
+ SmallVectorImpl<ConstantIntRanges> &&strides)
+ : offsets(std::move(offsets)), sizes(std::move(sizes)),
+ strides(std::move(strides)) {}
+
+ /// The offsets range.
+ std::optional<SmallVector<ConstantIntRanges>> offsets;
+
+ /// The sizes ranges.
+ SmallVector<ConstantIntRanges> sizes;
+
+ /// The strides ranges.
+ SmallVector<ConstantIntRanges> strides;
+};
+
+/// Print the strided metadata to `os`.
+inline raw_ostream &operator<<(raw_ostream &os,
+ const StridedMetadataRange &range) {
+ range.print(os);
+ return os;
+}
+
+/// Callback function type for setting the strided metadata of a value.
+using SetStridedMetadataRangeFn =
+ function_ref<void(Value, const StridedMetadataRange &)>;
+} // end namespace mlir
+
+#include "mlir/Interfaces/InferStridedMetadataInterface.h.inc"
+
+#endif // MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
diff --git a/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td
new file mode 100644
index 0000000..ee5b094
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td
@@ -0,0 +1,45 @@
+//===- InferStridedMetadataInterface.td - Strided MD Inference ----------*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the interface for strided metadata range analysis
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE
+#define MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def InferStridedMetadataOpInterface :
+ OpInterface<"InferStridedMetadataOpInterface"> {
+ let description = [{
+ Allows operations to participate in strided metadata analysis by providing
+ methods that allow them to specify bounds on offsets, sizes, and strides
+ of their result(s) given bounds on their input(s) if known.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ InterfaceMethod<[{
+ Infer the strided metadata bounds on the results of this op given
+ the bounds on its operands.
+ For each result value or block argument of interest, the method should
+ call `setMetadata` with that `Value` as an argument.
+ The `operands` parameter contains the strided metadata ranges for all the
+ operands of the operation in order.
+ The `getIntRange` callback is provided for obtaining the int-range
+ analysis result for a given value.
+ }],
+ "void", "inferStridedMetadataRanges",
+ (ins "::llvm::ArrayRef<::mlir::StridedMetadataRange>":$operands,
+ "::mlir::GetIntRangeFn":$getIntRange,
+ "::mlir::SetStridedMetadataRangeFn":$setMetadata,
+ "int32_t":$indexBitwidth)>
+ ];
+}
+#endif // MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE
diff --git a/mlir/include/mlir/Remark/RemarkStreamer.h b/mlir/include/mlir/Remark/RemarkStreamer.h
index 170d6b4..19a70fa 100644
--- a/mlir/include/mlir/Remark/RemarkStreamer.h
+++ b/mlir/include/mlir/Remark/RemarkStreamer.h
@@ -45,6 +45,7 @@ namespace mlir::remark {
/// mlir::emitRemarks.
LogicalResult enableOptimizationRemarksWithLLVMStreamer(
MLIRContext &ctx, StringRef filePath, llvm::remarks::Format fmt,
+ std::unique_ptr<detail::RemarkEmittingPolicyBase> remarkEmittingPolicy,
const RemarkCategories &cat, bool printAsEmitRemarks = false);
} // namespace mlir::remark
diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h
index 252da21..997aef2 100644
--- a/mlir/include/mlir/TableGen/CodeGenHelpers.h
+++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h
@@ -88,7 +88,7 @@ public:
///
/// Constraints that do not meet the restriction that they can only reference
/// `$_self` and `$_op` are not uniqued.
- void emitOpConstraints(ArrayRef<const llvm::Record *> opDefs);
+ void emitOpConstraints();
/// Unique all compatible type and attribute constraints from a pattern file
/// and emit them at the top of the generated file.
diff --git a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
index 21adde8..cd9ef5b 100644
--- a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
+++ b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
@@ -19,6 +19,14 @@ namespace mlir {
struct WasmBinaryEncoding {
/// Byte encodings for Wasm instructions.
struct OpCode {
+ // Control instructions.
+ static constexpr std::byte block{0x02};
+ static constexpr std::byte loop{0x03};
+ static constexpr std::byte ifOpCode{0x04};
+ static constexpr std::byte elseOpCode{0x05};
+ static constexpr std::byte branchIf{0x0D};
+ static constexpr std::byte call{0x10};
+
// Locals, globals, constants.
static constexpr std::byte localGet{0x20};
static constexpr std::byte localSet{0x21};
@@ -29,6 +37,42 @@ struct WasmBinaryEncoding {
static constexpr std::byte constFP32{0x43};
static constexpr std::byte constFP64{0x44};
+ // Comparisons.
+ static constexpr std::byte eqzI32{0x45};
+ static constexpr std::byte eqI32{0x46};
+ static constexpr std::byte neI32{0x47};
+ static constexpr std::byte ltSI32{0x48};
+ static constexpr std::byte ltUI32{0x49};
+ static constexpr std::byte gtSI32{0x4A};
+ static constexpr std::byte gtUI32{0x4B};
+ static constexpr std::byte leSI32{0x4C};
+ static constexpr std::byte leUI32{0x4D};
+ static constexpr std::byte geSI32{0x4E};
+ static constexpr std::byte geUI32{0x4F};
+ static constexpr std::byte eqzI64{0x50};
+ static constexpr std::byte eqI64{0x51};
+ static constexpr std::byte neI64{0x52};
+ static constexpr std::byte ltSI64{0x53};
+ static constexpr std::byte ltUI64{0x54};
+ static constexpr std::byte gtSI64{0x55};
+ static constexpr std::byte gtUI64{0x56};
+ static constexpr std::byte leSI64{0x57};
+ static constexpr std::byte leUI64{0x58};
+ static constexpr std::byte geSI64{0x59};
+ static constexpr std::byte geUI64{0x5A};
+ static constexpr std::byte eqF32{0x5B};
+ static constexpr std::byte neF32{0x5C};
+ static constexpr std::byte ltF32{0x5D};
+ static constexpr std::byte gtF32{0x5E};
+ static constexpr std::byte leF32{0x5F};
+ static constexpr std::byte geF32{0x60};
+ static constexpr std::byte eqF64{0x61};
+ static constexpr std::byte neF64{0x62};
+ static constexpr std::byte ltF64{0x63};
+ static constexpr std::byte gtF64{0x64};
+ static constexpr std::byte leF64{0x65};
+ static constexpr std::byte geF64{0x66};
+
// Numeric operations.
static constexpr std::byte clzI32{0x67};
static constexpr std::byte ctzI32{0x68};
@@ -93,6 +137,33 @@ struct WasmBinaryEncoding {
static constexpr std::byte maxF64{0xA5};
static constexpr std::byte copysignF64{0xA6};
static constexpr std::byte wrap{0xA7};
+
+ // Conversion operations
+ static constexpr std::byte extendS{0xAC};
+ static constexpr std::byte extendU{0xAD};
+ static constexpr std::byte convertSI32F32{0xB2};
+ static constexpr std::byte convertUI32F32{0xB3};
+ static constexpr std::byte convertSI64F32{0xB4};
+ static constexpr std::byte convertUI64F32{0xB5};
+
+ static constexpr std::byte demoteF64ToF32{0xB6};
+
+ static constexpr std::byte convertSI32F64{0xB7};
+ static constexpr std::byte convertUI32F64{0xB8};
+ static constexpr std::byte convertSI64F64{0xB9};
+ static constexpr std::byte convertUI64F64{0xBA};
+
+ static constexpr std::byte promoteF32ToF64{0xBB};
+ static constexpr std::byte reinterpretF32AsI32{0xBC};
+ static constexpr std::byte reinterpretF64AsI64{0xBD};
+ static constexpr std::byte reinterpretI32AsF32{0xBE};
+ static constexpr std::byte reinterpretI64AsF64{0xBF};
+
+ static constexpr std::byte extendI328S{0xC0};
+ static constexpr std::byte extendI3216S{0xC1};
+ static constexpr std::byte extendI648S{0xC2};
+ static constexpr std::byte extendI6416S{0xC3};
+ static constexpr std::byte extendI6432S{0xC4};
};
/// Byte encodings of types in Wasm binaries
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index 0fbe15f..b739438 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -44,6 +44,11 @@ enum class RemarkFormat {
REMARK_FORMAT_BITSTREAM,
};
+enum class RemarkPolicy {
+ REMARK_POLICY_ALL,
+ REMARK_POLICY_FINAL,
+};
+
/// Configuration options for the mlir-opt tool.
/// This is intended to help building tools like mlir-opt by collecting the
/// supported options.
@@ -242,6 +247,8 @@ public:
/// Set the reproducer output filename
RemarkFormat getRemarkFormat() const { return remarkFormatFlag; }
+ /// Set the remark policy to use.
+ RemarkPolicy getRemarkPolicy() const { return remarkPolicyFlag; }
/// Set the remark format to use.
std::string getRemarksAllFilter() const { return remarksAllFilterFlag; }
/// Set the remark output file.
@@ -265,6 +272,8 @@ protected:
/// Remark format
RemarkFormat remarkFormatFlag = RemarkFormat::REMARK_FORMAT_STDOUT;
+ /// Remark policy
+ RemarkPolicy remarkPolicyFlag = RemarkPolicy::REMARK_POLICY_ALL;
/// Remark file to output to
std::string remarksOutputFileFlag = "";
/// Remark filters
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index 609cb34..db10ebc 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -40,6 +40,7 @@ add_mlir_library(MLIRAnalysis
DataFlow/IntegerRangeAnalysis.cpp
DataFlow/LivenessAnalysis.cpp
DataFlow/SparseAnalysis.cpp
+ DataFlow/StridedMetadataRangeAnalysis.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Analysis
@@ -53,6 +54,7 @@ add_mlir_library(MLIRAnalysis
MLIRDataLayoutInterfaces
MLIRFunctionInterfaces
MLIRInferIntRangeInterface
+ MLIRInferStridedMetadataInterface
MLIRInferTypeOpInterface
MLIRLoopLikeInterface
MLIRPresburger
diff --git a/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp
new file mode 100644
index 0000000..01c9daf
--- /dev/null
+++ b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp
@@ -0,0 +1,127 @@
+//===- StridedMetadataRangeAnalysis.cpp - Integer range analysis --------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the dataflow analysis class for integer range inference
+// which is used in transformations over the `arith` dialect such as
+// branch elimination or signed->unsigned rewriting
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h"
+#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/DebugStringHelper.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
+
+#define DEBUG_TYPE "strided-metadata-range-analysis"
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+/// Get the entry state for a value. For any value that is not a ranked memref,
+/// this function sets the metadata to a top state with no offsets, sizes, or
+/// strides. For `memref` types, this function will use the metadata in the type
+/// to try to deduce as much informaiton as possible.
+static StridedMetadataRange getEntryStateImpl(Value v, int32_t indexBitwidth) {
+ // TODO: generalize this method with a type interface.
+ auto mTy = dyn_cast<BaseMemRefType>(v.getType());
+
+ // If not a memref or it's un-ranked, don't infer any metadata.
+ if (!mTy || !mTy.hasRank())
+ return StridedMetadataRange::getMaxRanges(indexBitwidth, 0, 0, 0);
+
+ // Get the top state.
+ auto metadata =
+ StridedMetadataRange::getMaxRanges(indexBitwidth, mTy.getRank());
+
+ // Compute the offset and strides.
+ int64_t offset;
+ SmallVector<int64_t> strides;
+ if (failed(cast<MemRefType>(mTy).getStridesAndOffset(strides, offset)))
+ return metadata;
+
+ // Refine the metadata if we know it from the type.
+ if (!ShapedType::isDynamic(offset)) {
+ metadata.getOffsets()[0] =
+ ConstantIntRanges::constant(APInt(indexBitwidth, offset));
+ }
+ for (auto &&[size, range] :
+ llvm::zip_equal(mTy.getShape(), metadata.getSizes())) {
+ if (ShapedType::isDynamic(size))
+ continue;
+ range = ConstantIntRanges::constant(APInt(indexBitwidth, size));
+ }
+ for (auto &&[stride, range] :
+ llvm::zip_equal(strides, metadata.getStrides())) {
+ if (ShapedType::isDynamic(stride))
+ continue;
+ range = ConstantIntRanges::constant(APInt(indexBitwidth, stride));
+ }
+
+ return metadata;
+}
+
+StridedMetadataRangeAnalysis::StridedMetadataRangeAnalysis(
+ DataFlowSolver &solver, int32_t indexBitwidth)
+ : SparseForwardDataFlowAnalysis(solver), indexBitwidth(indexBitwidth) {
+ assert(indexBitwidth > 0 && "invalid bitwidth");
+}
+
+void StridedMetadataRangeAnalysis::setToEntryState(
+ StridedMetadataRangeLattice *lattice) {
+ propagateIfChanged(lattice, lattice->join(getEntryStateImpl(
+ lattice->getAnchor(), indexBitwidth)));
+}
+
+LogicalResult StridedMetadataRangeAnalysis::visitOperation(
+ Operation *op, ArrayRef<const StridedMetadataRangeLattice *> operands,
+ ArrayRef<StridedMetadataRangeLattice *> results) {
+ auto inferrable = dyn_cast<InferStridedMetadataOpInterface>(op);
+
+ // Bail if we cannot reason about the op.
+ if (!inferrable) {
+ setAllToEntryStates(results);
+ return success();
+ }
+
+ LDBG() << "Inferring metadata for: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
+
+ // Helper function to retrieve int range values.
+ auto getIntRange = [&](Value value) -> IntegerValueRange {
+ auto lattice = getOrCreateFor<IntegerValueRangeLattice>(
+ getProgramPointAfter(op), value);
+ return lattice ? lattice->getValue() : IntegerValueRange();
+ };
+
+ // Convert the arguments lattices to a vector.
+ SmallVector<StridedMetadataRange> argRanges = llvm::map_to_vector(
+ operands, [](const StridedMetadataRangeLattice *lattice) {
+ return lattice->getValue();
+ });
+
+ // Callback to set metadata on a result.
+ auto joinCallback = [&](Value v, const StridedMetadataRange &md) {
+ auto result = cast<OpResult>(v);
+ assert(llvm::is_contained(op->getResults(), result));
+ LDBG() << "- Inferred metadata: " << md;
+ StridedMetadataRangeLattice *lattice = results[result.getResultNumber()];
+ ChangeResult changed = lattice->join(md);
+ LDBG() << "- Joined metadata: " << lattice->getValue();
+ propagateIfChanged(lattice, changed);
+ };
+
+ // Infer the metadata.
+ inferrable.inferStridedMetadataRanges(argRanges, getIntRange, joinCallback,
+ indexBitwidth);
+ return success();
+}
diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index 30ce1fb..6588b53 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -1244,8 +1244,9 @@ bool FlatLinearValueConstraints::areVarsAlignedWithOther(
/// Checks if the SSA values associated with `cst`'s variables in range
/// [start, end) are unique.
-static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique(
- const FlatLinearValueConstraints &cst, unsigned start, unsigned end) {
+[[maybe_unused]] static bool
+areVarsUnique(const FlatLinearValueConstraints &cst, unsigned start,
+ unsigned end) {
assert(start <= cst.getNumDimAndSymbolVars() &&
"Start position out of bounds");
@@ -1267,14 +1268,14 @@ static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique(
}
/// Checks if the SSA values associated with `cst`'s variables are unique.
-static bool LLVM_ATTRIBUTE_UNUSED
+[[maybe_unused]] static bool
areVarsUnique(const FlatLinearValueConstraints &cst) {
return areVarsUnique(cst, 0, cst.getNumDimAndSymbolVars());
}
/// Checks if the SSA values associated with `cst`'s variables of kind `kind`
/// are unique.
-static bool LLVM_ATTRIBUTE_UNUSED
+[[maybe_unused]] static bool
areVarsUnique(const FlatLinearValueConstraints &cst, VarKind kind) {
if (kind == VarKind::SetDim)
diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp
index a1cbe29..547a4c2 100644
--- a/mlir/lib/Analysis/Presburger/Simplex.cpp
+++ b/mlir/lib/Analysis/Presburger/Simplex.cpp
@@ -34,7 +34,7 @@ using Direction = Simplex::Direction;
const int nullIndex = std::numeric_limits<int>::max();
// Return a + scale*b;
-LLVM_ATTRIBUTE_UNUSED
+[[maybe_unused]]
static SmallVector<DynamicAPInt, 8>
scaleAndAddForAssert(ArrayRef<DynamicAPInt> a, const DynamicAPInt &scale,
ArrayRef<DynamicAPInt> b) {
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7b17106..06d0256 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2730,6 +2730,17 @@ public:
operation->get(), toMlirStringRef(name)));
}
+ static void
+ forEachAttr(MlirOperation op,
+ llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn) {
+ intptr_t n = mlirOperationGetNumAttributes(op);
+ for (intptr_t i = 0; i < n; ++i) {
+ MlirNamedAttribute na = mlirOperationGetAttribute(op, i);
+ MlirStringRef name = mlirIdentifierStr(na.name);
+ fn(name, na.attribute);
+ }
+ }
+
static void bind(nb::module_ &m) {
nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
.def("__contains__", &PyOpAttributeMap::dunderContains)
@@ -2737,7 +2748,50 @@ public:
.def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
.def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
.def("__setitem__", &PyOpAttributeMap::dunderSetItem)
- .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
+ .def("__delitem__", &PyOpAttributeMap::dunderDelItem)
+ .def("__iter__",
+ [](PyOpAttributeMap &self) {
+ nb::list keys;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef name, MlirAttribute) {
+ keys.append(nb::str(name.data, name.length));
+ });
+ return nb::iter(keys);
+ })
+ .def("keys",
+ [](PyOpAttributeMap &self) {
+ nb::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef name, MlirAttribute) {
+ out.append(nb::str(name.data, name.length));
+ });
+ return out;
+ })
+ .def("values",
+ [](PyOpAttributeMap &self) {
+ nb::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef, MlirAttribute attr) {
+ out.append(PyAttribute(self.operation->getContext(), attr)
+ .maybeDownCast());
+ });
+ return out;
+ })
+ .def("items", [](PyOpAttributeMap &self) {
+ nb::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef name, MlirAttribute attr) {
+ out.append(nb::make_tuple(
+ nb::str(name.data, name.length),
+ PyAttribute(self.operation->getContext(), attr)
+ .maybeDownCast()));
+ });
+ return out;
+ });
}
private:
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 5ddb3fb..0f0ed22 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -205,7 +205,7 @@ public:
nb::object res = f(opView, PyPatternRewriter(rewriter));
return logicalResultFromObject(res);
};
- MlirRewritePattern pattern = mlirOpRewritePattenCreate(
+ MlirRewritePattern pattern = mlirOpRewritePatternCreate(
rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
/* nGeneratedNames */ 0,
/* generatedNames */ nullptr);
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 46c329d..41ceb15 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -341,7 +341,7 @@ private:
} // namespace mlir
-MlirRewritePattern mlirOpRewritePattenCreate(
+MlirRewritePattern mlirOpRewritePatternCreate(
MlirStringRef rootName, unsigned benefit, MlirContext context,
MlirRewritePatternCallbacks callbacks, void *userData,
size_t nGeneratedNames, MlirStringRef *generatedNames) {
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 71986f8..bebf1b8 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -40,6 +40,7 @@ add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)
add_subdirectory(MathToROCDL)
add_subdirectory(MathToSPIRV)
+add_subdirectory(MathToXeVM)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index b215211..c03f3a5 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -484,5 +484,5 @@ void mlir::populateGpuToROCDLConversionPatterns(
GPUSubgroupBroadcastOpToROCDL>(converter);
patterns.add<GPUSubgroupSizeOpToROCDL>(converter, chipset);
- populateMathToROCDLConversionPatterns(converter, patterns);
+ populateMathToROCDLConversionPatterns(converter, patterns, chipset);
}
diff --git a/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
index 2771955a..8cc3fde 100644
--- a/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
+++ b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRMathToROCDL
Core
LINK_LIBS PUBLIC
+ MLIRAMDGPUUtils
MLIRDialectUtils
MLIRFuncDialect
MLIRGPUToGPURuntimeTransforms
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index df219f3..a2dfc12 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -10,6 +10,8 @@
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -19,6 +21,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/DebugLog.h"
#include "../GPUCommon/GPUOpsLowering.h"
#include "../GPUCommon/OpToFuncCallLowering.h"
@@ -42,8 +45,46 @@ static void populateOpPatterns(const LLVMTypeConverter &converter,
f32ApproxFunc, f16Func);
}
+struct ClampFOpConversion final
+ : public ConvertOpToLLVMPattern<math::ClampFOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only f16 and f32 types are supported by fmed3
+ Type opTy = op.getType();
+ Type resultType = getTypeConverter()->convertType(opTy);
+
+ if (auto vectorType = dyn_cast<VectorType>(opTy))
+ opTy = vectorType.getElementType();
+
+ if (!isa<Float16Type, Float32Type>(opTy))
+ return rewriter.notifyMatchFailure(
+ op, "fmed3 only supports f16 and f32 types");
+
+ // Handle multi-dimensional vectors (converted to LLVM arrays)
+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType))
+ return LLVM::detail::handleMultidimensionalVectors(
+ op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
+ [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
+ typename math::ClampFOp::Adaptor adaptor(operands);
+ return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy,
+ adaptor.getValue(), adaptor.getMin(),
+ adaptor.getMax());
+ },
+ rewriter);
+
+ // Handle 1D vectors and scalars directly
+ rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(),
+ op.getMin(), op.getMax());
+ return success();
+ }
+};
+
void mlir::populateMathToROCDLConversionPatterns(
- const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ std::optional<amdgpu::Chipset> chipset) {
// Handled by mathToLLVM: math::AbsIOp
// Handled by mathToLLVM: math::AbsFOp
// Handled by mathToLLVM: math::CopySignOp
@@ -118,15 +159,21 @@ void mlir::populateMathToROCDLConversionPatterns(
// worth creating a separate pass for it.
populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
"__ocml_fmod_f64", "__ocml_fmod_f16");
+
+ if (chipset.has_value() && chipset->majorVersion >= 9) {
+ patterns.add<ClampFOpConversion>(converter);
+ } else {
+ LDBG() << "Chipset dependent patterns were not added";
+ }
}
-namespace {
-struct ConvertMathToROCDLPass
- : public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
- ConvertMathToROCDLPass() = default;
+struct ConvertMathToROCDLPass final
+ : impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
+ using impl::ConvertMathToROCDLBase<
+ ConvertMathToROCDLPass>::ConvertMathToROCDLBase;
+
void runOnOperation() override;
};
-} // namespace
void ConvertMathToROCDLPass::runOnOperation() {
auto m = getOperation();
@@ -135,10 +182,21 @@ void ConvertMathToROCDLPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
LowerToLLVMOptions options(ctx, DataLayout(m));
LLVMTypeConverter converter(ctx, options);
- populateMathToROCDLConversionPatterns(converter, patterns);
+
+ FailureOr<amdgpu::Chipset> maybeChipset;
+ if (!chipset.empty()) {
+ maybeChipset = amdgpu::Chipset::parse(chipset);
+ if (failed(maybeChipset))
+ return signalPassFailure();
+ }
+ populateMathToROCDLConversionPatterns(
+ converter, patterns,
+ succeeded(maybeChipset) ? std::optional(*maybeChipset) : std::nullopt);
+
ConversionTarget target(getContext());
- target.addLegalDialect<BuiltinDialect, func::FuncDialect,
- vector::VectorDialect, LLVM::LLVMDialect>();
+ target
+ .addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect,
+ LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index f0d8b78..610ce1f 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -407,11 +407,11 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
if (auto vectorType = dyn_cast<VectorType>(operandType))
nanAttr = DenseElementsAttr::get(vectorType, nan);
- Value NanValue =
+ Value nanValue =
spirv::ConstantOp::create(rewriter, loc, operandType, nanAttr);
Value lhs =
spirv::SelectOp::create(rewriter, loc, cmpNegativeWithFractionalExp,
- NanValue, adaptor.getLhs());
+ nanValue, adaptor.getLhs());
Value abs = spirv::GLFAbsOp::create(rewriter, loc, lhs);
// TODO: The following just forcefully casts y into an integer value in
diff --git a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
new file mode 100644
index 0000000..050c0ed
--- /dev/null
+++ b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
@@ -0,0 +1,22 @@
+add_mlir_conversion_library(MLIRMathToXeVM
+ MathToXeVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToXeVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRArithAttrToLLVMConversion
+ MLIRArithDialect
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRMathDialect
+ MLIRXeVMDialect
+ MLIRPass
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
new file mode 100644
index 0000000..0fe31d0
--- /dev/null
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -0,0 +1,167 @@
+//===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
+#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/FormatVariadic.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTOXEVM
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "math-to-xevm"
+
+/// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics.
+template <typename Op>
+struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
+
+ ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {}
+
+ LogicalResult
+ matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!isSPIRVCompatibleFloatOrVec(op.getType()))
+ return failure();
+
+ arith::FastMathFlags fastFlags = op.getFastmath();
+ if (!arith::bitEnumContainsAll(fastFlags, arith::FastMathFlags::afn))
+ return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation");
+
+ SmallVector<Type, 1> operandTypes;
+ for (auto operand : adaptor.getOperands()) {
+ Type opTy = operand.getType();
+ // This pass only supports operations on vectors that are already in SPIRV
+ // supported vector sizes: Distributing unsupported vector sizes to SPIRV
+ // supported vector sizes are done in other blocking optimization passes.
+ if (!isSPIRVCompatibleFloatOrVec(opTy))
+ return rewriter.notifyMatchFailure(
+ op, llvm::formatv("incompatible operand type: '{0}'", opTy));
+ operandTypes.push_back(opTy);
+ }
+
+ auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
+ auto funcOpRes = LLVM::lookupOrCreateFn(
+ rewriter, moduleOp, getMangledNativeFuncName(operandTypes),
+ operandTypes, op.getType());
+ assert(!failed(funcOpRes));
+ LLVM::LLVMFuncOp funcOp = funcOpRes.value();
+
+ auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+ op, funcOp, adaptor.getOperands());
+ // Preserve fastmath flags in our MLIR op when converting to llvm function
+ // calls, in order to allow further fastmath optimizations: We thus need to
+ // convert arith fastmath attrs into attrs recognized by llvm.
+ arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op);
+ mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0];
+ callOp->setAttr(fastAttr.getName(), fastAttr.getValue());
+ return success();
+ }
+
+ inline bool isSPIRVCompatibleFloatOrVec(Type type) const {
+ if (type.isFloat())
+ return true;
+ if (auto vecType = dyn_cast<VectorType>(type)) {
+ if (!vecType.getElementType().isFloat())
+ return false;
+ // SPIRV distinguishes between vectors and matrices: OpenCL native math
+ // intrsinics are not compatible with matrices.
+ ArrayRef<int64_t> shape = vecType.getShape();
+ if (shape.size() != 1)
+ return false;
+ // SPIRV only allows vectors of size 2, 3, 4, 8, 16.
+ if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 ||
+ shape[0] == 16)
+ return true;
+ }
+ return false;
+ }
+
+ inline std::string
+ getMangledNativeFuncName(const ArrayRef<Type> operandTypes) const {
+ std::string mangledFuncName =
+ "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str();
+
+ auto appendFloatToMangledFunc = [&mangledFuncName](Type type) {
+ if (type.isF32())
+ mangledFuncName += "f";
+ else if (type.isF16())
+ mangledFuncName += "Dh";
+ else if (type.isF64())
+ mangledFuncName += "d";
+ };
+
+ for (auto type : operandTypes) {
+ if (auto vecType = dyn_cast<VectorType>(type)) {
+ mangledFuncName += "Dv" + std::to_string(vecType.getShape()[0]) + "_";
+ appendFloatToMangledFunc(vecType.getElementType());
+ } else
+ appendFloatToMangledFunc(type);
+ }
+
+ return mangledFuncName;
+ }
+
+ const StringRef nativeFunc;
+};
+
+void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns,
+ bool convertArith) {
+ patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(patterns.getContext(),
+ "__spirv_ocl_native_exp");
+ patterns.add<ConvertNativeFuncPattern<math::CosOp>>(patterns.getContext(),
+ "__spirv_ocl_native_cos");
+ patterns.add<ConvertNativeFuncPattern<math::Exp2Op>>(
+ patterns.getContext(), "__spirv_ocl_native_exp2");
+ patterns.add<ConvertNativeFuncPattern<math::LogOp>>(patterns.getContext(),
+ "__spirv_ocl_native_log");
+ patterns.add<ConvertNativeFuncPattern<math::Log2Op>>(
+ patterns.getContext(), "__spirv_ocl_native_log2");
+ patterns.add<ConvertNativeFuncPattern<math::Log10Op>>(
+ patterns.getContext(), "__spirv_ocl_native_log10");
+ patterns.add<ConvertNativeFuncPattern<math::PowFOp>>(
+ patterns.getContext(), "__spirv_ocl_native_powr");
+ patterns.add<ConvertNativeFuncPattern<math::RsqrtOp>>(
+ patterns.getContext(), "__spirv_ocl_native_rsqrt");
+ patterns.add<ConvertNativeFuncPattern<math::SinOp>>(patterns.getContext(),
+ "__spirv_ocl_native_sin");
+ patterns.add<ConvertNativeFuncPattern<math::SqrtOp>>(
+ patterns.getContext(), "__spirv_ocl_native_sqrt");
+ patterns.add<ConvertNativeFuncPattern<math::TanOp>>(patterns.getContext(),
+ "__spirv_ocl_native_tan");
+ if (convertArith)
+ patterns.add<ConvertNativeFuncPattern<arith::DivFOp>>(
+ patterns.getContext(), "__spirv_ocl_native_divide");
+}
+
+namespace {
+struct ConvertMathToXeVMPass
+ : public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> {
+ using Base::Base;
+ void runOnOperation() override;
+};
+} // namespace
+
+void ConvertMathToXeVMPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ populateMathToXeVMConversionPatterns(patterns, convertArith);
+ ConversionTarget target(getContext());
+ target.addLegalDialect<BuiltinDialect, LLVM::LLVMDialect>();
+ if (failed(
+ applyPartialConversion(getOperation(), target, std::move(patterns))))
+ signalPassFailure();
+}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 5355909..41d8d53 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1723,17 +1723,18 @@ struct VectorBroadcastScalarToLowRankLowering
return success();
}
- // For 1-d vector, we additionally do a `vectorshuffle`.
auto v =
LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType,
poison, adaptor.getSource(), zero);
+ // For 1-d vector, we additionally do a `shufflevector`.
int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
SmallVector<int32_t> zeroValues(width, 0);
// Shuffle the value across the desired number of elements.
- rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison,
- zeroValues);
+ auto shuffle = rewriter.createOrFold<LLVM::ShuffleVectorOp>(
+ broadcast.getLoc(), v, poison, zeroValues);
+ rewriter.replaceOp(broadcast, shuffle);
return success();
}
};
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
index 84b2580..dd9edc4 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
@@ -21,6 +21,7 @@ add_mlir_conversion_library(MLIRXeGPUToXeVM
MLIRIndexDialect
MLIRSCFDialect
MLIRXeGPUDialect
+ MLIRXeGPUUtils
MLIRPass
MLIRTransforms
MLIRSCFTransforms
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 71687b1..fcbf66d 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -20,7 +20,9 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
@@ -62,6 +64,7 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
case xegpu::MemorySpace::SLM:
return static_cast<int>(xevm::AddrSpace::SHARED);
}
+ llvm_unreachable("Unknown XeGPU memory space");
}
// Get same bitwidth flat vector type of new element type.
@@ -185,6 +188,7 @@ class CreateNdDescToXeVMPattern
int64_t rank = mixedSizes.size();
if (rank != 2)
return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
+
auto sourceTy = source.getType();
auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
// If source is a memref, we need to extract the aligned pointer as index.
@@ -363,10 +367,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
// Add a builder that creates
// offset * elemByteSize + baseAddr
-static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
- Value baseAddr, Value offset, int64_t elemByteSize) {
+static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
+ Location loc, Value baseAddr, Value offset,
+ int64_t elemByteSize) {
Value byteSize = arith::ConstantIntOp::create(
- rewriter, loc, rewriter.getI64Type(), elemByteSize);
+ rewriter, loc, baseAddr.getType(), elemByteSize);
Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
return newAddr;
@@ -390,7 +395,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
// Load result or Store valye Type can be vector or scalar.
Type valOrResTy;
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
- valOrResTy = op.getResult().getType();
+ valOrResTy =
+ this->getTypeConverter()->convertType(op.getResult().getType());
else
valOrResTy = adaptor.getValue().getType();
VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
@@ -441,7 +447,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
// If offset is provided, we add them to the base pointer.
// Offset is in number of elements, we need to multiply by
// element byte size.
- basePtrI64 = addOffset(rewriter, loc, basePtrI64, offset, elemByteSize);
+ basePtrI64 =
+ addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
}
// Convert base pointer (i64) to LLVM pointer type.
Value basePtrLLVM =
@@ -504,6 +511,147 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
}
};
+// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions
+// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than
+// 32 bits will be converted to 32 bits.
+class CreateMemDescOpPattern final
+ : public OpConversionPattern<xegpu::CreateMemDescOp> {
+public:
+ using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto resTy = op.getMemDesc();
+
+ // Create the result MemRefType with the same shape, element type, and
+ // memory space
+ auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
+
+ Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
+ auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
+ op.getSource(), zero, ValueRange());
+ rewriter.replaceOp(op, viewOp);
+ return success();
+ }
+};
+
+template <typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
+class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
+
+ auto loc = op.getLoc();
+ auto ctxt = rewriter.getContext();
+ Value basePtrStruct = adaptor.getMemDesc();
+ Value mdescVal = op.getMemDesc();
+ // Load result or Store value Type can be vector or scalar.
+ Value data;
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>)
+ data = op.getResult();
+ else
+ data = adaptor.getData();
+ VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
+ if (!valOrResVecTy)
+ valOrResVecTy = VectorType::get(1, data.getType());
+
+ int64_t elemBitWidth =
+ valOrResVecTy.getElementType().getIntOrFloatBitWidth();
+ // Element type must be multiple of 8 bits.
+ if (elemBitWidth % 8 != 0)
+ return rewriter.notifyMatchFailure(
+ op, "Expected element type bit width to be multiple of 8.");
+ int64_t elemByteSize = elemBitWidth / 8;
+
+ // Default memory space is SLM.
+ LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
+
+ auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
+
+ Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, loc, basePtrStruct);
+
+ // Convert base pointer (ptr) to i32
+ Value basePtrI32 = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), basePtrLLVM);
+
+ Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
+ linearOffset = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), linearOffset);
+ basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset,
+ elemByteSize);
+
+ // convert base pointer (i32) to LLVM pointer type
+ basePtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
+
+ if (op.getSubgroupBlockIoAttr()) {
+ // if the attribute 'subgroup_block_io' is set to true, it lowers to
+ // xevm.blockload
+
+ Type intElemTy = rewriter.getIntegerType(elemBitWidth);
+ VectorType intVecTy =
+ VectorType::get(valOrResVecTy.getShape(), intElemTy);
+
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+ Value loadOp =
+ xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
+ if (intVecTy != valOrResVecTy) {
+ loadOp =
+ vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
+ }
+ rewriter.replaceOp(op, loadOp);
+ } else {
+ Value dataToStore = adaptor.getData();
+ if (valOrResVecTy != intVecTy) {
+ dataToStore =
+ vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
+ }
+ xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
+ nullptr);
+ rewriter.eraseOp(op);
+ }
+ return success();
+ }
+
+ if (valOrResVecTy.getNumElements() >= 1) {
+ auto chipOpt = xegpu::getChipStr(op);
+ if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) {
+ // the lowering for chunk load only works for pvc and bmg
+ return rewriter.notifyMatchFailure(
+ op, "The lowering is specific to pvc or bmg.");
+ }
+ }
+
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+ // if the size of valOrResVecTy is 1, it lowers to a scalar load/store
+ // operation. LLVM load/store does not support vector of size 1, so we
+ // need to handle this case separately.
+ auto scalarTy = valOrResVecTy.getElementType();
+ LLVM::LoadOp loadOp;
+ if (valOrResVecTy.getNumElements() == 1)
+ loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
+ else
+ loadOp =
+ LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
+ rewriter.replaceOp(op, loadOp);
+ } else {
+ LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
+ rewriter.eraseOp(op);
+ }
+ return success();
+ }
+};
+
class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@@ -546,8 +694,8 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
op, "Expected element type bit width to be multiple of 8.");
elemByteSize = elemBitWidth / 8;
}
- basePtrI64 =
- addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
+ basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
+ elemByteSize);
}
}
// Default memory space is global.
@@ -784,6 +932,13 @@ struct ConvertXeGPUToXeVMPass
auto i32Type = IntegerType::get(&getContext(), 32);
return VectorType::get(8, i32Type);
});
+ // Convert MemDescType into flattened MemRefType for SLM
+ typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
+ Type elemTy = type.getElementType();
+ int numElems = type.getNumElements();
+ return MemRefType::get(numElems, elemTy, AffineMap(), 3);
+ });
+
typeConverter.addConversion([&](MemRefType type) -> Type {
// Convert MemRefType to i64 type.
return IntegerType::get(&getContext(), 64);
@@ -878,10 +1033,30 @@ struct ConvertXeGPUToXeVMPass
}
return {};
};
- typeConverter.addSourceMaterialization(memrefMaterializationCast);
- typeConverter.addSourceMaterialization(ui64MaterializationCast);
- typeConverter.addSourceMaterialization(ui32MaterializationCast);
- typeConverter.addSourceMaterialization(vectorMaterializationCast);
+
+ // If result type of original op is single element vector and lowered type
+ // is scalar. This materialization cast creates a single element vector by
+ // broadcasting the scalar value.
+ auto singleElementVectorMaterializationCast =
+ [](OpBuilder &builder, Type type, ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1)
+ return {};
+ auto input = inputs.front();
+ if (input.getType().isIntOrIndexOrFloat()) {
+ // If the input is a scalar, and the target type is a vector of single
+ // element, create a single element vector by broadcasting.
+ if (auto vecTy = dyn_cast<VectorType>(type)) {
+ if (vecTy.getNumElements() == 1) {
+ return vector::BroadcastOp::create(builder, loc, vecTy, input)
+ .getResult();
+ }
+ }
+ }
+ return {};
+ };
+ typeConverter.addSourceMaterialization(
+ singleElementVectorMaterializationCast);
typeConverter.addTargetMaterialization(memrefMaterializationCast);
typeConverter.addTargetMaterialization(ui32MaterializationCast);
typeConverter.addTargetMaterialization(ui64MaterializationCast);
@@ -918,6 +1093,9 @@ void mlir::populateXeGPUToXeVMConversionPatterns(
LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
typeConverter, patterns.getContext());
+ patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
+ LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
+ CreateMemDescOpPattern>(typeConverter, patterns.getContext());
patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
patterns.getContext());
}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index f405d0c..c798adb 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -757,13 +757,13 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
offset = numElements - 4l;
}
Type scaleSrcElemType = scaleSrcType.getElementType();
- auto newSrcType = VectorType::get(SmallVector<int64_t>({numElements}),
- scaleSrcElemType);
+ auto newSrcType =
+ VectorType::get(ArrayRef{numElements}, scaleSrcElemType);
Value newScaleSrc =
vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
auto extract = vector::ExtractStridedSliceOp::create(
- rewriter, loc, newScaleSrc, ArrayRef<int64_t>{offset},
- ArrayRef<int64_t>{size}, ArrayRef<int64_t>{1});
+ rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size},
+ ArrayRef{int64_t(1)});
rewriter.modifyOpInPlace(op, [&] {
op->setOperand(opIdx, extract);
setOpsel(opIdx, opsel);
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 7e5ce26..749e2ba 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -125,9 +125,9 @@ static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
// Use "unused attribute" marker to silence clang-tidy warning stemming from
// the inability to see through "llvm::TypeSwitch".
template <>
-bool LLVM_ATTRIBUTE_UNUSED remainsLegalAfterInline(AffineApplyOp op,
- Region *src, Region *dest,
- const IRMapping &mapping) {
+[[maybe_unused]] bool remainsLegalAfterInline(AffineApplyOp op, Region *src,
+ Region *dest,
+ const IRMapping &mapping) {
// If it's a valid dimension, we need to check that it remains so.
if (isValidDim(op.getResult(), src))
return remainsLegalAfterInline(
@@ -1032,8 +1032,8 @@ static void simplifyMinOrMaxExprWithOperands(AffineMap &map,
/// Simplify the map while exploiting information on the values in `operands`.
// Use "unused attribute" marker to silence warning stemming from the inability
// to see through the template expansion.
-static void LLVM_ATTRIBUTE_UNUSED
-simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) {
+[[maybe_unused]] static void simplifyMapWithOperands(AffineMap &map,
+ ArrayRef<Value> operands) {
assert(map.getNumInputs() == operands.size() && "invalid operands for map");
SmallVector<AffineExpr> newResults;
newResults.reserve(map.getNumResults());
@@ -1125,6 +1125,141 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
return success(*map != initialMap);
}
+/// Recursively traverse `e`. If `e` or one of its sub-expressions has the form
+/// e1 + e2 + ... + eK, where the e_i are a super(multi)set of `exprsToRemove`,
+/// place a map between e and `newVal` + sum({e1, e2, .. eK} - exprsToRemove)
+/// into `replacementsMap`. If no entries were added to `replacementsMap`,
+/// nothing was found.
+static void shortenAddChainsContainingAll(
+ AffineExpr e, const llvm::SmallDenseSet<AffineExpr, 4> &exprsToRemove,
+ AffineExpr newVal, DenseMap<AffineExpr, AffineExpr> &replacementsMap) {
+ auto binOp = dyn_cast<AffineBinaryOpExpr>(e);
+ if (!binOp)
+ return;
+ AffineExpr lhs = binOp.getLHS();
+ AffineExpr rhs = binOp.getRHS();
+ if (binOp.getKind() != AffineExprKind::Add) {
+ shortenAddChainsContainingAll(lhs, exprsToRemove, newVal, replacementsMap);
+ shortenAddChainsContainingAll(rhs, exprsToRemove, newVal, replacementsMap);
+ return;
+ }
+ SmallVector<AffineExpr> toPreserve;
+ llvm::SmallDenseSet<AffineExpr, 4> ourTracker(exprsToRemove);
+ AffineExpr thisTerm = rhs;
+ AffineExpr nextTerm = lhs;
+
+ while (thisTerm) {
+ if (!ourTracker.erase(thisTerm)) {
+ toPreserve.push_back(thisTerm);
+ shortenAddChainsContainingAll(thisTerm, exprsToRemove, newVal,
+ replacementsMap);
+ }
+ auto nextBinOp = dyn_cast_if_present<AffineBinaryOpExpr>(nextTerm);
+ if (!nextBinOp || nextBinOp.getKind() != AffineExprKind::Add) {
+ thisTerm = nextTerm;
+ nextTerm = AffineExpr();
+ } else {
+ thisTerm = nextBinOp.getRHS();
+ nextTerm = nextBinOp.getLHS();
+ }
+ }
+ if (!ourTracker.empty())
+ return;
+ // We reverse the terms to be preserved here in order to preserve
+ // associativity between them.
+ AffineExpr newExpr = newVal;
+ for (AffineExpr preserved : llvm::reverse(toPreserve))
+ newExpr = newExpr + preserved;
+ replacementsMap.insert({e, newExpr});
+}
+
+/// If this map contains of the expression `x_1 + x_1 * C_1 + ... x_n * C_N +
+/// ...` (not necessarily in order) where the set of the `x_i` is the set of
+/// outputs of an `affine.delinearize_index` whos inverse is that expression,
+/// replace that expression with the input of that delinearize_index op.
+///
+/// `unitDimInput` is the input that was detected as the potential start to this
+/// replacement chain - if it isn't the rightmost result of the delinearization,
+/// this method fails. (This is intended to ensure we don't have redundant scans
+/// over the same expression).
+///
+/// While this currently only handles delinearizations with a constant basis,
+/// that isn't a fundamental limitation.
+///
+/// This is a utility function for `replaceDimOrSym` below.
+static LogicalResult replaceAffineDelinearizeIndexInverseExpression(
+ AffineDelinearizeIndexOp delinOp, Value resultToReplace, AffineMap *map,
+ SmallVectorImpl<Value> &dims, SmallVectorImpl<Value> &syms) {
+ if (!delinOp.getDynamicBasis().empty())
+ return failure();
+ if (resultToReplace != delinOp.getMultiIndex().back())
+ return failure();
+
+ MLIRContext *ctx = delinOp.getContext();
+ SmallVector<AffineExpr> resToExpr(delinOp.getNumResults(), AffineExpr());
+ for (auto [pos, dim] : llvm::enumerate(dims)) {
+ auto asResult = dyn_cast_if_present<OpResult>(dim);
+ if (!asResult)
+ continue;
+ if (asResult.getOwner() == delinOp.getOperation())
+ resToExpr[asResult.getResultNumber()] = getAffineDimExpr(pos, ctx);
+ }
+ for (auto [pos, sym] : llvm::enumerate(syms)) {
+ auto asResult = dyn_cast_if_present<OpResult>(sym);
+ if (!asResult)
+ continue;
+ if (asResult.getOwner() == delinOp.getOperation())
+ resToExpr[asResult.getResultNumber()] = getAffineSymbolExpr(pos, ctx);
+ }
+ if (llvm::is_contained(resToExpr, AffineExpr()))
+ return failure();
+
+ bool isDimReplacement = llvm::all_of(resToExpr, llvm::IsaPred<AffineDimExpr>);
+ int64_t stride = 1;
+ llvm::SmallDenseSet<AffineExpr, 4> expectedExprs;
+ // This isn't zip_equal since sometimes the delinearize basis is missing a
+ // size for the first result.
+ for (auto [binding, size] : llvm::zip(
+ llvm::reverse(resToExpr), llvm::reverse(delinOp.getStaticBasis()))) {
+ expectedExprs.insert(binding * getAffineConstantExpr(stride, ctx));
+ stride *= size;
+ }
+ if (resToExpr.size() != delinOp.getStaticBasis().size())
+ expectedExprs.insert(resToExpr[0] * stride);
+
+ DenseMap<AffineExpr, AffineExpr> replacements;
+ AffineExpr delinInExpr = isDimReplacement
+ ? getAffineDimExpr(dims.size(), ctx)
+ : getAffineSymbolExpr(syms.size(), ctx);
+
+ for (AffineExpr e : map->getResults())
+ shortenAddChainsContainingAll(e, expectedExprs, delinInExpr, replacements);
+ if (replacements.empty())
+ return failure();
+
+ AffineMap origMap = *map;
+ if (isDimReplacement)
+ dims.push_back(delinOp.getLinearIndex());
+ else
+ syms.push_back(delinOp.getLinearIndex());
+ *map = origMap.replace(replacements, dims.size(), syms.size());
+
+ // Blank out dead dimensions and symbols
+ for (AffineExpr e : resToExpr) {
+ if (auto d = dyn_cast<AffineDimExpr>(e)) {
+ unsigned pos = d.getPosition();
+ if (!map->isFunctionOfDim(pos))
+ dims[pos] = nullptr;
+ }
+ if (auto s = dyn_cast<AffineSymbolExpr>(e)) {
+ unsigned pos = s.getPosition();
+ if (!map->isFunctionOfSymbol(pos))
+ syms[pos] = nullptr;
+ }
+ }
+ return success();
+}
+
/// Replace all occurrences of AffineExpr at position `pos` in `map` by the
/// defining AffineApplyOp expression and operands.
/// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
@@ -1157,6 +1292,11 @@ static LogicalResult replaceDimOrSym(AffineMap *map,
syms);
}
+ if (auto delinOp = v.getDefiningOp<affine::AffineDelinearizeIndexOp>()) {
+ return replaceAffineDelinearizeIndexInverseExpression(delinOp, v, map, dims,
+ syms);
+ }
+
auto affineApply = v.getDefiningOp<AffineApplyOp>();
if (!affineApply)
return failure();
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index cd216ef..4743941 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -1357,7 +1357,7 @@ bool mlir::affine::isValidLoopInterchangePermutation(
/// Returns true if `loops` is a perfectly nested loop nest, where loops appear
/// in it from outermost to innermost.
-bool LLVM_ATTRIBUTE_UNUSED
+[[maybe_unused]] bool
mlir::affine::isPerfectlyNested(ArrayRef<AffineForOp> loops) {
assert(!loops.empty() && "no loops provided");
@@ -1920,8 +1920,7 @@ generatePointWiseCopy(Location loc, Value memref, Value fastMemRef,
return copyNestRoot;
}
-static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED
-emitRemarkForBlock(Block &block) {
+[[maybe_unused]] static InFlightDiagnostic emitRemarkForBlock(Block &block) {
return block.getParentOp()->emitRemark();
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
index 624519f..bc17990 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
@@ -41,18 +41,37 @@ namespace bufferization {
using namespace mlir;
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
- func::ReturnOp returnOp;
+/// Get all the ReturnOp in the funcOp.
+static SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp) {
+ SmallVector<func::ReturnOp> returnOps;
for (Block &b : funcOp.getBody()) {
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
- if (returnOp)
- return nullptr;
- returnOp = candidateOp;
+ returnOps.push_back(candidateOp);
}
}
- return returnOp;
+ return returnOps;
+}
+
+/// Get the operands at the specified position for all returnOps.
+static SmallVector<Value>
+getReturnOpsOperandInPos(ArrayRef<func::ReturnOp> returnOps, size_t pos) {
+ return llvm::map_to_vector(returnOps, [&](func::ReturnOp returnOp) {
+ return returnOp.getOperand(pos);
+ });
+}
+
+/// Check if all given values are the same buffer as the block argument (modulo
+/// cast ops).
+static bool operandsEqualFuncArgument(ArrayRef<Value> operands,
+ BlockArgument argument) {
+ for (Value val : operands) {
+ while (auto castOp = val.getDefiningOp<memref::CastOp>())
+ val = castOp.getSource();
+
+ if (val != argument)
+ return false;
+ }
+ return true;
}
LogicalResult
@@ -64,47 +83,53 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
module.walk([&](func::CallOp callOp) {
if (func::FuncOp calledFunc =
dyn_cast_or_null<func::FuncOp>(callOp.resolveCallable())) {
- callerMap[calledFunc].insert(callOp);
+ if (!calledFunc.isPublic() && !calledFunc.isExternal())
+ callerMap[calledFunc].insert(callOp);
}
});
for (auto funcOp : module.getOps<func::FuncOp>()) {
- if (funcOp.isExternal())
+ if (funcOp.isExternal() || funcOp.isPublic())
continue;
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- // TODO: Support functions with multiple blocks.
- if (!returnOp)
+ SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+ if (returnOps.empty())
continue;
// Compute erased results.
- SmallVector<Value> newReturnValues;
- BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
+ size_t numReturnOps = returnOps.size();
+ size_t numReturnValues = funcOp.getFunctionType().getNumResults();
+ SmallVector<SmallVector<Value>> newReturnValues(numReturnOps);
+ BitVector erasedResultIndices(numReturnValues);
DenseMap<int64_t, int64_t> resultToArgs;
- for (const auto &it : llvm::enumerate(returnOp.getOperands())) {
+ for (size_t i = 0; i < numReturnValues; ++i) {
bool erased = false;
+ SmallVector<Value> returnOperands =
+ getReturnOpsOperandInPos(returnOps, i);
for (BlockArgument bbArg : funcOp.getArguments()) {
- Value val = it.value();
- while (auto castOp = val.getDefiningOp<memref::CastOp>())
- val = castOp.getSource();
-
- if (val == bbArg) {
- resultToArgs[it.index()] = bbArg.getArgNumber();
+ if (operandsEqualFuncArgument(returnOperands, bbArg)) {
+ resultToArgs[i] = bbArg.getArgNumber();
erased = true;
break;
}
}
if (erased) {
- erasedResultIndices.set(it.index());
+ erasedResultIndices.set(i);
} else {
- newReturnValues.push_back(it.value());
+ for (auto [newReturnValue, operand] :
+ llvm::zip(newReturnValues, returnOperands)) {
+ newReturnValue.push_back(operand);
+ }
}
}
// Update function.
if (failed(funcOp.eraseResults(erasedResultIndices)))
return failure();
- returnOp.getOperandsMutable().assign(newReturnValues);
+
+ for (auto [returnOp, newReturnValue] :
+ llvm::zip(returnOps, newReturnValues))
+ returnOp.getOperandsMutable().assign(newReturnValue);
// Update function calls.
for (func::CallOp callOp : callerMap[funcOp]) {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 7ca09d9..3eae67f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2826,6 +2826,20 @@ LogicalResult ShuffleVectorOp::verify() {
return success();
}
+// Folding for shufflevector op when v1 is single element 1D vector
+// and the mask is a single zero. OpFoldResult will be v1 in this case.
+OpFoldResult ShuffleVectorOp::fold(FoldAdaptor adaptor) {
+ // Check if operand 0 is a single element vector.
+ auto vecType = llvm::dyn_cast<VectorType>(getV1().getType());
+ if (!vecType || vecType.getRank() != 1 || vecType.getNumElements() != 1)
+ return {};
+ // Check if the mask is a single zero.
+ // Note: The mask is guaranteed to be non-empty.
+ if (getMask().size() != 1 || getMask()[0] != 0)
+ return {};
+ return getV1();
+}
+
//===----------------------------------------------------------------------===//
// Implementations for LLVM::LLVMFuncOp.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 01a16ce..ac35eea 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -134,10 +134,10 @@ static void printExtTypeParams(AsmPrinter &p, ArrayRef<Type> typeParams,
/// These are unused for now.
/// TODO: Move over to these once more types have been migrated to TypeDef.
-LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
+[[maybe_unused]] static OptionalParseResult
generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
-LLVM_ATTRIBUTE_UNUSED static LogicalResult
-generatedTypePrinter(Type def, AsmPrinter &printer);
+[[maybe_unused]] static LogicalResult generatedTypePrinter(Type def,
+ AsmPrinter &printer);
#include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc"
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 5edcc40b..2a8c330 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -309,6 +309,17 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() {
return success();
}
+LogicalResult ConvertF32x2ToF4x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from f32x2 to f4x2.";
+
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -787,6 +798,26 @@ LogicalResult MmaOp::verify() {
" attribute");
}
+ // Validate layout combinations. According to the operation description, most
+ // MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16
+ // can use other layout combinations.
+ bool isM8N8K4_F16 =
+ (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
+ getMultiplicandAPtxType() == MMATypes::f16);
+
+ if (!isM8N8K4_F16) {
+ // For all other shapes/types, layoutA must be row and layoutB must be col
+ if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
+ return emitOpError("requires layoutA = #nvvm.mma_layout<row> and "
+ "layoutB = #nvvm.mma_layout<col> for shape <")
+ << mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2]
+ << "> with element types "
+ << stringifyEnum(*getMultiplicandAPtxType()) << " and "
+ << stringifyEnum(*getMultiplicandBPtxType())
+ << ". Only m8n8k4 with f16 supports other layouts.";
+ }
+ }
+
return success();
}
@@ -2047,6 +2078,23 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
}
+NVVM::IDArgPair
+ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getA()));
+ args.push_back(mt.lookupValue(op.getB()));
+
+ bool hasRelu = op.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
+
+ return {intId, std::move(args)};
+}
+
#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
@@ -2306,6 +2354,32 @@ static void nvvmInferResultRanges(Operation *op, Value result,
}
}
+/// Verify the range attribute satisfies LLVM ConstantRange constructor
+/// requirements for NVVM SpecialRangeableRegisterOp.
+static LogicalResult
+verifyConstantRangeAttr(Operation *op,
+ std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
+ if (!rangeAttr)
+ return success();
+
+ const llvm::APInt &lower = rangeAttr->getLower();
+ const llvm::APInt &upper = rangeAttr->getUpper();
+
+ // Check LLVM ConstantRange constructor condition
+ if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
+ unsigned bitWidth = lower.getBitWidth();
+ llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
+ llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
+ return op->emitOpError(
+ "invalid range attribute: Lower == Upper, but they aren't min (")
+ << llvm::toString(minVal, 10, false) << ") or max ("
+ << llvm::toString(maxVal, 10, false)
+ << ") value! This is an invalid constant range.";
+ }
+
+ return success();
+}
+
static llvm::Value *getAsPackedI32(llvm::Value *arg,
llvm::IRBuilderBase &builder) {
return builder.CreateBitCast(arg,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index c477c6c..dcc1ef9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -315,7 +315,8 @@ bool mlir::linalg::detail::isContractionBody(
Value yielded = getSourceSkipUnary(terminator->getOperand(0));
Operation *reductionOp = yielded.getDefiningOp();
- if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) {
+ if (!reductionOp || reductionOp->getNumResults() != 1 ||
+ reductionOp->getNumOperands() != 2) {
errs << "expected reduction op to be binary";
return false;
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d8f983f..6192d79 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3024,10 +3024,10 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
}
if (dynamicPointParseResult.has_value()) {
- Type ChunkSizesType;
+ Type chunkSizesType;
if (failed(*dynamicPointParseResult) || parser.parseComma() ||
- parser.parseType(ChunkSizesType) ||
- parser.resolveOperand(dynamicChunkSizes, ChunkSizesType,
+ parser.parseType(chunkSizesType) ||
+ parser.resolveOperand(dynamicChunkSizes, chunkSizesType,
result.operands)) {
return failure();
}
@@ -3399,9 +3399,9 @@ void transform::ContinuousTileSizesOp::getEffects(
}
static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op,
- Type targetType, Type tile_sizes,
+ Type targetType, Type tileSizes,
Type) {
- printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes});
+ printer.printFunctionalType(TypeRange{targetType}, TypeRange{tileSizes});
}
static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index e25a012..1382c7ac 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -5,7 +5,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
ValueBoundsOpInterfaceImpl.cpp
ADDITIONAL_HEADER_DIRS
- ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect
+ ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRef/IR
DEPENDS
MLIRMemRefOpsIncGen
@@ -18,6 +18,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
MLIRDialectUtils
MLIRInferIntRangeCommon
MLIRInferIntRangeInterface
+ MLIRInferStridedMetadataInterface
MLIRInferTypeOpInterface
MLIRIR
MLIRMemOpInterfaces
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e9bdcda..94947b7 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2158,11 +2158,45 @@ public:
return success();
}
};
+
+struct ReinterpretCastOpConstantFolder
+ : public OpRewritePattern<ReinterpretCastOp> {
+public:
+ using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ReinterpretCastOp op,
+ PatternRewriter &rewriter) const override {
+ unsigned srcStaticCount = llvm::count_if(
+ llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
+ op.getMixedStrides()),
+ [](OpFoldResult ofr) { return isa<Attribute>(ofr); });
+
+ SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
+ SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
+ SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
+
+ // TODO: Using counting comparison instead of direct comparison because
+ // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
+ // IntegerAttrs, while constifyIndexValues (and therefore
+ // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
+ if (srcStaticCount ==
+ llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
+ [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
+ return failure();
+
+ auto newReinterpretCast = ReinterpretCastOp::create(
+ rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
+
+ rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
+ return success();
+ }
+};
} // namespace
void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
+ results.add<ReinterpretCastOpExtractStridedMetadataFolder,
+ ReinterpretCastOpConstantFolder>(context);
}
FailureOr<std::optional<SmallVector<Value>>>
@@ -3437,6 +3471,65 @@ SubViewOp::bubbleDownCasts(OpBuilder &builder) {
return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
}
+void SubViewOp::inferStridedMetadataRanges(
+ ArrayRef<StridedMetadataRange> ranges, GetIntRangeFn getIntRange,
+ SetStridedMetadataRangeFn setMetadata, int32_t indexBitwidth) {
+ auto isUninitialized =
+ +[](IntegerValueRange range) { return range.isUninitialized(); };
+
+ // Bail early if any of the operands metadata is not ready:
+ SmallVector<IntegerValueRange> offsetOperands =
+ getIntValueRanges(getMixedOffsets(), getIntRange, indexBitwidth);
+ if (llvm::any_of(offsetOperands, isUninitialized))
+ return;
+
+ SmallVector<IntegerValueRange> sizeOperands =
+ getIntValueRanges(getMixedSizes(), getIntRange, indexBitwidth);
+ if (llvm::any_of(sizeOperands, isUninitialized))
+ return;
+
+ SmallVector<IntegerValueRange> stridesOperands =
+ getIntValueRanges(getMixedStrides(), getIntRange, indexBitwidth);
+ if (llvm::any_of(stridesOperands, isUninitialized))
+ return;
+
+ StridedMetadataRange sourceRange =
+ ranges[getSourceMutable().getOperandNumber()];
+ if (sourceRange.isUninitialized())
+ return;
+
+ ArrayRef<ConstantIntRanges> srcStrides = sourceRange.getStrides();
+
+ // Get the dropped dims.
+ llvm::SmallBitVector droppedDims = getDroppedDims();
+
+ // Compute the new offset, strides and sizes.
+ ConstantIntRanges offset = sourceRange.getOffsets()[0];
+ SmallVector<ConstantIntRanges> strides, sizes;
+
+ for (size_t i = 0, e = droppedDims.size(); i < e; ++i) {
+ bool dropped = droppedDims.test(i);
+ // Compute the new offset.
+ ConstantIntRanges off =
+ intrange::inferMul({offsetOperands[i].getValue(), srcStrides[i]});
+ offset = intrange::inferAdd({offset, off});
+
+ // Skip dropped dimensions.
+ if (dropped)
+ continue;
+ // Multiply the strides.
+ strides.push_back(
+ intrange::inferMul({stridesOperands[i].getValue(), srcStrides[i]}));
+ // Get the sizes.
+ sizes.push_back(sizeOperands[i].getValue());
+ }
+
+ setMetadata(getResult(),
+ StridedMetadataRange::getRanked(
+ SmallVector<ConstantIntRanges>({std::move(offset)}),
+ std::move(sizes), std::move(strides)));
+}
+
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
index 49b7162..6f815ae 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
@@ -121,7 +121,7 @@ struct EmulateWideIntPass final
[&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
RewritePatternSet patterns(ctx);
- // Add common pattenrs to support contants, functions, etc.
+ // Add common patterns to support contants, functions, etc.
arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
memref::populateMemRefWideIntEmulationPatterns(typeConverter, patterns);
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 6564a4e..90cbbd8 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/SymbolTable.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallSet.h"
@@ -39,6 +40,16 @@ static bool isScalarLikeType(Type type) {
return type.isIntOrIndexOrFloat() || isa<ComplexType>(type);
}
+/// Helper function to attach the `VarName` attribute to an operation
+/// if a variable name is provided.
+static void attachVarNameAttr(Operation *op, OpBuilder &builder,
+ StringRef varName) {
+ if (!varName.empty()) {
+ auto varNameAttr = acc::VarNameAttr::get(builder.getContext(), varName);
+ op->setAttr(acc::getVarNameAttrName(), varNameAttr);
+ }
+}
+
struct MemRefPointerLikeModel
: public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
MemRefType> {
@@ -74,14 +85,18 @@ struct MemRefPointerLikeModel
}
mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
- StringRef varName, Type varType,
- Value originalVar) const {
+ StringRef varName, Type varType, Value originalVar,
+ bool &needsFree) const {
auto memrefTy = cast<MemRefType>(pointer);
// Check if this is a static memref (all dimensions are known) - if yes
// then we can generate an alloca operation.
- if (memrefTy.hasStaticShape())
- return memref::AllocaOp::create(builder, loc, memrefTy).getResult();
+ if (memrefTy.hasStaticShape()) {
+ needsFree = false; // alloca doesn't need deallocation
+ auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
+ attachVarNameAttr(allocaOp, builder, varName);
+ return allocaOp.getResult();
+ }
// For dynamic memrefs, extract sizes from the original variable if
// provided. Otherwise they cannot be handled.
@@ -99,8 +114,11 @@ struct MemRefPointerLikeModel
// Note: We only add dynamic sizes to the dynamicSizes array
// Static dimensions are handled automatically by AllocOp
}
- return memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes)
- .getResult();
+ needsFree = true; // alloc needs deallocation
+ auto allocOp =
+ memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
+ attachVarNameAttr(allocOp, builder, varName);
+ return allocOp.getResult();
}
// TODO: Unranked not yet supported.
@@ -108,10 +126,14 @@ struct MemRefPointerLikeModel
}
bool genFree(Type pointer, OpBuilder &builder, Location loc,
- TypedValue<PointerLikeType> varPtr, Type varType) const {
- if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varPtr)) {
+ TypedValue<PointerLikeType> varToFree, Value allocRes,
+ Type varType) const {
+ if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varToFree)) {
+ // Use allocRes if provided to determine the allocation type
+ Value valueToInspect = allocRes ? allocRes : memrefValue;
+
// Walk through casts to find the original allocation
- Value currentValue = memrefValue;
+ Value currentValue = valueToInspect;
Operation *originalAlloc = nullptr;
// Follow the chain of operations to find the original allocation
@@ -150,7 +172,7 @@ struct MemRefPointerLikeModel
return true;
}
if (isa<memref::AllocOp>(originalAlloc)) {
- // This is an alloc - generate dealloc
+ // This is an alloc - generate dealloc on varToFree
memref::DeallocOp::create(builder, loc, memrefValue);
return true;
}
@@ -1003,6 +1025,142 @@ struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
}
};
+//===----------------------------------------------------------------------===//
+// Recipe Region Helpers
+//===----------------------------------------------------------------------===//
+
+/// Create and populate an init region for privatization recipes.
+/// Returns the init block on success, or nullptr on failure.
+/// Sets needsFree to indicate if the allocated memory requires deallocation.
+static std::unique_ptr<Block> createInitRegion(OpBuilder &builder, Location loc,
+ Type varType, StringRef varName,
+ ValueRange bounds,
+ bool &needsFree) {
+ // Create init block with arguments: original value + bounds
+ SmallVector<Type> argTypes{varType};
+ SmallVector<Location> argLocs{loc};
+ for (Value bound : bounds) {
+ argTypes.push_back(bound.getType());
+ argLocs.push_back(loc);
+ }
+
+ auto initBlock = std::make_unique<Block>();
+ initBlock->addArguments(argTypes, argLocs);
+ builder.setInsertionPointToStart(initBlock.get());
+
+ Value privatizedValue;
+
+ // Get the block argument that represents the original variable
+ Value blockArgVar = initBlock->getArgument(0);
+
+ // Generate init region body based on variable type
+ if (isa<MappableType>(varType)) {
+ auto mappableTy = cast<MappableType>(varType);
+ auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
+ privatizedValue = mappableTy.generatePrivateInit(
+ builder, loc, typedVar, varName, bounds, {}, needsFree);
+ if (!privatizedValue)
+ return nullptr;
+ } else {
+ assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
+ auto pointerLikeTy = cast<PointerLikeType>(varType);
+ // Use PointerLikeType's allocation API with the block argument
+ privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
+ blockArgVar, needsFree);
+ if (!privatizedValue)
+ return nullptr;
+ }
+
+ // Add yield operation to init block
+ acc::YieldOp::create(builder, loc, privatizedValue);
+
+ return initBlock;
+}
+
+/// Create and populate a copy region for firstprivate recipes.
+/// Returns the copy block on success, or nullptr on failure.
+/// TODO: Handle MappableType - it does not yet have a copy API.
+static std::unique_ptr<Block> createCopyRegion(OpBuilder &builder, Location loc,
+ Type varType,
+ ValueRange bounds) {
+ // Create copy block with arguments: original value + privatized value +
+ // bounds
+ SmallVector<Type> copyArgTypes{varType, varType};
+ SmallVector<Location> copyArgLocs{loc, loc};
+ for (Value bound : bounds) {
+ copyArgTypes.push_back(bound.getType());
+ copyArgLocs.push_back(loc);
+ }
+
+ auto copyBlock = std::make_unique<Block>();
+ copyBlock->addArguments(copyArgTypes, copyArgLocs);
+ builder.setInsertionPointToStart(copyBlock.get());
+
+ bool isMappable = isa<MappableType>(varType);
+ bool isPointerLike = isa<PointerLikeType>(varType);
+ // TODO: Handle MappableType - it does not yet have a copy API.
+ // Otherwise, for now just fallback to pointer-like behavior.
+ if (isMappable && !isPointerLike)
+ return nullptr;
+
+ // Generate copy region body based on variable type
+ if (isPointerLike) {
+ auto pointerLikeTy = cast<PointerLikeType>(varType);
+ Value originalArg = copyBlock->getArgument(0);
+ Value privatizedArg = copyBlock->getArgument(1);
+
+ // Generate copy operation using PointerLikeType interface
+ if (!pointerLikeTy.genCopy(
+ builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg),
+ cast<TypedValue<PointerLikeType>>(originalArg), varType))
+ return nullptr;
+ }
+
+ // Add terminator to copy block
+ acc::TerminatorOp::create(builder, loc);
+
+ return copyBlock;
+}
+
+/// Create and populate a destroy region for privatization recipes.
+/// Returns the destroy block on success, or nullptr if not needed.
+static std::unique_ptr<Block> createDestroyRegion(OpBuilder &builder,
+ Location loc, Type varType,
+ Value allocRes,
+ ValueRange bounds) {
+ // Create destroy block with arguments: original value + privatized value +
+ // bounds
+ SmallVector<Type> destroyArgTypes{varType, varType};
+ SmallVector<Location> destroyArgLocs{loc, loc};
+ for (Value bound : bounds) {
+ destroyArgTypes.push_back(bound.getType());
+ destroyArgLocs.push_back(loc);
+ }
+
+ auto destroyBlock = std::make_unique<Block>();
+ destroyBlock->addArguments(destroyArgTypes, destroyArgLocs);
+ builder.setInsertionPointToStart(destroyBlock.get());
+
+ bool isMappable = isa<MappableType>(varType);
+ bool isPointerLike = isa<PointerLikeType>(varType);
+ // TODO: Handle MappableType - it does not yet have a deallocation API.
+ // Otherwise, for now just fallback to pointer-like behavior.
+ if (isMappable && !isPointerLike)
+ return nullptr;
+
+ assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
+ auto pointerLikeTy = cast<PointerLikeType>(varType);
+ auto privatizedArg =
+ cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1));
+ // Pass allocRes to help determine the allocation type
+ if (!pointerLikeTy.genFree(builder, loc, privatizedArg, allocRes, varType))
+ return nullptr;
+
+ acc::TerminatorOp::create(builder, loc);
+
+ return destroyBlock;
+}
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -1050,6 +1208,55 @@ LogicalResult acc::PrivateRecipeOp::verifyRegions() {
return success();
}
+std::optional<PrivateRecipeOp>
+PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
+ StringRef recipeName, Type varType,
+ StringRef varName, ValueRange bounds) {
+ // First, validate that we can handle this variable type
+ bool isMappable = isa<MappableType>(varType);
+ bool isPointerLike = isa<PointerLikeType>(varType);
+
+ // Unsupported type
+ if (!isMappable && !isPointerLike)
+ return std::nullopt;
+
+ // Create init and destroy blocks using shared helpers
+ OpBuilder::InsertionGuard guard(builder);
+
+ // Save the original insertion point for creating the recipe operation later
+ auto originalInsertionPoint = builder.saveInsertionPoint();
+
+ bool needsFree = false;
+ auto initBlock =
+ createInitRegion(builder, loc, varType, varName, bounds, needsFree);
+ if (!initBlock)
+ return std::nullopt;
+
+ // Only create destroy region if the allocation needs deallocation
+ std::unique_ptr<Block> destroyBlock;
+ if (needsFree) {
+ // Extract the allocated value from the init block's yield operation
+ auto yieldOp = cast<acc::YieldOp>(initBlock->getTerminator());
+ Value allocRes = yieldOp.getOperand(0);
+
+ destroyBlock = createDestroyRegion(builder, loc, varType, allocRes, bounds);
+ if (!destroyBlock)
+ return std::nullopt;
+ }
+
+ // Now create the recipe operation at the original insertion point and attach
+ // the blocks
+ builder.restoreInsertionPoint(originalInsertionPoint);
+ auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
+
+ // Move the blocks into the recipe's regions
+ recipe.getInitRegion().push_back(initBlock.release());
+ if (destroyBlock)
+ recipe.getDestroyRegion().push_back(destroyBlock.release());
+
+ return recipe;
+}
+
//===----------------------------------------------------------------------===//
// FirstprivateRecipeOp
//===----------------------------------------------------------------------===//
@@ -1080,6 +1287,60 @@ LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
return success();
}
+std::optional<FirstprivateRecipeOp>
+FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
+ StringRef recipeName, Type varType,
+ StringRef varName, ValueRange bounds) {
+ // First, validate that we can handle this variable type
+ bool isMappable = isa<MappableType>(varType);
+ bool isPointerLike = isa<PointerLikeType>(varType);
+
+ // Unsupported type
+ if (!isMappable && !isPointerLike)
+ return std::nullopt;
+
+ // Create init, copy, and destroy blocks using shared helpers
+ OpBuilder::InsertionGuard guard(builder);
+
+ // Save the original insertion point for creating the recipe operation later
+ auto originalInsertionPoint = builder.saveInsertionPoint();
+
+ bool needsFree = false;
+ auto initBlock =
+ createInitRegion(builder, loc, varType, varName, bounds, needsFree);
+ if (!initBlock)
+ return std::nullopt;
+
+ auto copyBlock = createCopyRegion(builder, loc, varType, bounds);
+ if (!copyBlock)
+ return std::nullopt;
+
+ // Only create destroy region if the allocation needs deallocation
+ std::unique_ptr<Block> destroyBlock;
+ if (needsFree) {
+ // Extract the allocated value from the init block's yield operation
+ auto yieldOp = cast<acc::YieldOp>(initBlock->getTerminator());
+ Value allocRes = yieldOp.getOperand(0);
+
+ destroyBlock = createDestroyRegion(builder, loc, varType, allocRes, bounds);
+ if (!destroyBlock)
+ return std::nullopt;
+ }
+
+ // Now create the recipe operation at the original insertion point and attach
+ // the blocks
+ builder.restoreInsertionPoint(originalInsertionPoint);
+ auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
+
+ // Move the blocks into the recipe's regions
+ recipe.getInitRegion().push_back(initBlock.release());
+ recipe.getCopyRegion().push_back(copyBlock.release());
+ if (destroyBlock)
+ recipe.getDestroyRegion().push_back(destroyBlock.release());
+
+ return recipe;
+}
+
//===----------------------------------------------------------------------===//
// ReductionRecipeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
index 135c033..645cbff 100644
--- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
+++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
@@ -158,7 +158,7 @@ static FailureOr<GridOp> getGridAndVerify(Operation *op,
}
template <typename It>
-bool isUnique(It begin, It end) {
+static bool isUnique(It begin, It end) {
if (begin == end) {
return true;
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index a1711a6..069191c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -143,8 +143,8 @@ void VarInfo::setNum(Var::Num n) {
/// Helper function for `assertUsageConsistency` to better handle SMLoc
/// mismatches.
-LLVM_ATTRIBUTE_UNUSED static llvm::SMLoc
-minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) {
+[[maybe_unused]] static llvm::SMLoc minSMLoc(AsmParser &parser, llvm::SMLoc sm1,
+ llvm::SMLoc sm2) {
const auto loc1 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm1));
assert(loc1 && "Could not get `FileLineColLoc` for first `SMLoc`");
const auto loc2 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm2));
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index f539502..684c088 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -43,8 +43,8 @@ using namespace mlir::sparse_tensor;
//===----------------------------------------------------------------------===//
#ifndef NDEBUG
-LLVM_ATTRIBUTE_UNUSED static void dumpIndexMemRef(OpBuilder &builder,
- Location loc, Value memref) {
+[[maybe_unused]] static void dumpIndexMemRef(OpBuilder &builder, Location loc,
+ Value memref) {
memref = memref::CastOp::create(
builder, loc, UnrankedMemRefType::get(builder.getIndexType(), 0), memref);
createFuncCall(builder, loc, "printMemrefInd", TypeRange{},
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
index 5aad671..1cba1bb 100644
--- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
+#include "llvm/Support/FormatVariadic.h"
namespace mlir {
namespace tosa {
@@ -27,7 +28,7 @@ TargetEnvAttr lookupTargetEnv(Operation *op) {
}
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) {
- return TargetEnvAttr::get(context, Level::eightK,
+ return TargetEnvAttr::get(context, SpecificationVersion::V_1_0, Level::eightK,
{Profile::pro_int, Profile::pro_fp}, {});
}
@@ -38,5 +39,9 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
return getDefaultTargetEnv(op->getContext());
}
+llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
+ return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor());
+}
+
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
index bcb880a..a0661e4 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
@@ -61,8 +61,8 @@ public:
ModuleOp mod = getOperation();
MLIRContext *ctx = &getContext();
- const auto targetEnvAttr =
- TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions);
+ const auto targetEnvAttr = TargetEnvAttr::get(
+ ctx, specificationVersion, level, selectedProfiles, selectedExtensions);
mod->setAttr(TargetEnvAttr::name, targetEnvAttr);
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 20f9333..f072e3e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -335,16 +335,15 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
//===----------------------------------------------------------------------===//
template <typename T>
-FailureOr<SmallVector<T>>
-TosaProfileCompliance::getOperatorDefinition(Operation *op,
- CheckCondition &condition) {
+FailureOr<OpComplianceInfo<T>>
+TosaProfileCompliance::getOperatorDefinition(Operation *op) {
const std::string opName = op->getName().getStringRef().str();
const auto complianceMap = getProfileComplianceMap<T>();
const auto it = complianceMap.find(opName);
if (it == complianceMap.end())
return {};
- return findMatchedProfile<T>(op, it->second, condition);
+ return findMatchedEntry<T>(op, it->second);
}
template <typename T>
@@ -356,22 +355,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
if (specRequiredModeSet.size() == 0)
return success();
- CheckCondition condition = CheckCondition::invalid;
- const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
- if (failed(maybeOpRequiredMode)) {
+ const auto maybeOpDefinition = getOperatorDefinition<T>(op);
+ if (failed(maybeOpDefinition)) {
// Operators such as control-flow and shape ops do not have an operand type
// restriction. When the profile compliance information of operation is not
// found, confirm if the target have enabled the profile required from the
// specification.
- int mode_count = 0;
+ int modeCount = 0;
for (const auto &cands : specRequiredModeSet) {
if (targetEnv.allowsAnyOf(cands))
return success();
- mode_count += cands.size();
+ modeCount += cands.size();
}
op->emitOpError() << "illegal: requires"
- << (mode_count > 1 ? " any of " : " ") << "["
+ << (modeCount > 1 ? " any of " : " ") << "["
<< llvm::join(stringifyProfile<T>(specRequiredModeSet),
", ")
<< "] but not enabled in target\n";
@@ -381,7 +379,10 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
// Find the required profiles or extensions according to the operand type
// combination.
- const auto opRequiredMode = maybeOpRequiredMode.value();
+ const auto opDefinition = maybeOpDefinition.value();
+ const SmallVector<T> opRequiredMode = opDefinition.mode;
+ const CheckCondition condition = opDefinition.condition;
+
if (opRequiredMode.size() == 0) {
// No matched restriction found.
return success();
@@ -437,6 +438,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
}
}
+ // Ensure the matched op compliance version does not exceed the target
+ // specification version.
+ const VersionedTypeInfo versionedTypeInfo =
+ opDefinition.operandTypeInfoSet[0];
+ const TosaSpecificationVersion complianceVersion{versionedTypeInfo.second};
+ const TosaSpecificationVersion targetVersion{targetEnv.getSpecVersion()};
+ if (!targetVersion.isBackwardsCompatibleWith(complianceVersion)) {
+ op->emitOpError() << "illegal: the target specification version ("
+ << stringifyVersion(targetVersion)
+ << ") is not backwards compatible with the op compliance "
+ "specification version ("
+ << stringifyVersion(complianceVersion) << ")\n";
+ return failure();
+ }
+
return success();
}
@@ -461,14 +477,14 @@ TosaProfileCompliance::checkExtension(Operation *op,
}
LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
- CheckCondition condition = CheckCondition::invalid;
- const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
- const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
+ const auto maybeProfDef = getOperatorDefinition<Profile>(op);
+ const auto maybeExtDef = getOperatorDefinition<Extension>(op);
if (failed(maybeProfDef) && failed(maybeExtDef))
return success();
- const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) ||
- (succeeded(maybeExtDef) && !maybeExtDef->empty());
+ const bool hasEntry =
+ (succeeded(maybeProfDef) && !maybeProfDef->mode.empty()) ||
+ (succeeded(maybeExtDef) && !maybeExtDef->mode.empty());
if (!hasEntry) {
std::string message;
llvm::raw_string_ostream os(message);
@@ -488,7 +504,9 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
SmallVector<TypeInfo> bestTypeInfo;
const auto searchBestMatch = [&](auto map) {
for (const auto &complianceInfos : map[opName]) {
- for (const auto &typeInfos : complianceInfos.operandTypeInfoSet) {
+ for (const auto &versionedTypeInfos :
+ complianceInfos.operandTypeInfoSet) {
+ const SmallVector<TypeInfo> typeInfos = versionedTypeInfos.first;
const int matches = llvm::count_if(
llvm::zip_equal(current, typeInfos), [&](const auto zipType) {
return isSameTypeInfo(std::get<0>(zipType),
@@ -520,9 +538,8 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
// Find the profiles or extensions requirement according to the signature of
// type of the operand list.
template <typename T>
-SmallVector<T> TosaProfileCompliance::findMatchedProfile(
- Operation *op, SmallVector<OpComplianceInfo<T>> compInfo,
- CheckCondition &condition) {
+OpComplianceInfo<T> TosaProfileCompliance::findMatchedEntry(
+ Operation *op, SmallVector<OpComplianceInfo<T>> compInfo) {
assert(compInfo.size() != 0 &&
"profile-based compliance information is empty");
@@ -533,27 +550,30 @@ SmallVector<T> TosaProfileCompliance::findMatchedProfile(
return {};
for (size_t i = 0; i < compInfo.size(); i++) {
- SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;
- for (SmallVector<TypeInfo> expected : sets) {
+ SmallVector<VersionedTypeInfo> sets = compInfo[i].operandTypeInfoSet;
+ for (const auto &set : sets) {
+ SmallVector<TypeInfo> expected = set.first;
assert(present.size() == expected.size() &&
"the entries for profile-based compliance do not match between "
"the generated metadata and the type definition retrieved from "
" the operation");
- bool is_found = true;
+ bool isFound = true;
// Compare the type signature between the given operation and the
// compliance metadata.
for (size_t j = 0; j < expected.size(); j++) {
if (!isSameTypeInfo(present[j], expected[j])) {
// Verify the next mode set from the list.
- is_found = false;
+ isFound = false;
break;
}
}
- if (is_found == true) {
- condition = compInfo[i].condition;
- return compInfo[i].mode;
+ if (isFound == true) {
+ SmallVector<VersionedTypeInfo> typeInfoSet{set};
+ OpComplianceInfo<T> info{compInfo[i].mode, typeInfoSet,
+ compInfo[i].condition};
+ return info;
}
}
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
index 9a24c2b..a2cff6a 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
@@ -21,10 +21,10 @@ using namespace mlir;
// These are automatically generated by ODS but are not used as the Transform
// dialect uses a different dispatch mechanism to support dialect extensions.
-LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
+[[maybe_unused]] static OptionalParseResult
generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
-LLVM_ATTRIBUTE_UNUSED static LogicalResult
-generatedTypePrinter(Type def, AsmPrinter &printer);
+[[maybe_unused]] static LogicalResult generatedTypePrinter(Type def,
+ AsmPrinter &printer);
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc"
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 58256b0..45c54c7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -7601,6 +7601,111 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), result);
}
+namespace {
+
+/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
+/// constant large enough such that the result is the same at all indices.
+///
+/// For example, rewrite the 'greater than' comparison below,
+///
+/// ```mlir
+/// %cst = arith.constant dense<7> : vector<3xindex>
+/// %stp = vector.step : vector<3xindex>
+/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
+/// ```
+///
+/// as,
+///
+/// ```mlir
+/// %out = arith.constant dense<false> : vector<3xi1>.
+/// ```
+///
+/// Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result
+/// is false at ALL indices we fold. If the constant was 1, then
+/// `[0, 1, 2] > [1, 1, 1]` => `[false, false, true]` and we do fold,
+/// conservatively preferring the 'compact' vector.step representation.
+///
+/// Note: this folder only works for the case where the constant (`%cst` above)
+/// is the second operand of the comparison. The arith.cmpi canonicalizer will
+/// ensure that constants are always second (on the right).
+struct StepCompareFolder : public OpRewritePattern<StepOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(StepOp stepOp,
+ PatternRewriter &rewriter) const override {
+ const int64_t stepSize = stepOp.getResult().getType().getNumElements();
+
+ for (OpOperand &use : stepOp.getResult().getUses()) {
+ auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
+ if (!cmpiOp)
+ continue;
+
+ // arith.cmpi canonicalizer makes constants final operands.
+ const unsigned stepOperandNumber = use.getOperandNumber();
+ if (stepOperandNumber != 0)
+ continue;
+
+ // Check that operand 1 is a constant.
+ unsigned constOperandNumber = 1;
+ Value otherOperand = cmpiOp.getOperand(constOperandNumber);
+ std::optional<int64_t> maybeConstValue =
+ getConstantIntValue(otherOperand);
+ if (!maybeConstValue.has_value())
+ continue;
+
+ int64_t constValue = maybeConstValue.value();
+ arith::CmpIPredicate pred = cmpiOp.getPredicate();
+
+ auto maybeSplat = [&]() -> std::optional<bool> {
+ // Handle ult (unsigned less than) and uge (unsigned greater equal).
+ if ((pred == arith::CmpIPredicate::ult ||
+ pred == arith::CmpIPredicate::uge) &&
+ stepSize <= constValue)
+ return pred == arith::CmpIPredicate::ult;
+
+ // Handle ule and ugt.
+ if ((pred == arith::CmpIPredicate::ule ||
+ pred == arith::CmpIPredicate::ugt) &&
+ stepSize - 1 <= constValue) {
+ return pred == arith::CmpIPredicate::ule;
+ }
+
+ // Handle eq and ne.
+ if ((pred == arith::CmpIPredicate::eq ||
+ pred == arith::CmpIPredicate::ne) &&
+ stepSize <= constValue)
+ return pred == arith::CmpIPredicate::ne;
+
+ return std::nullopt;
+ }();
+
+ if (!maybeSplat.has_value())
+ continue;
+
+ rewriter.setInsertionPointAfter(cmpiOp);
+
+ auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
+ if (!type)
+ continue;
+
+ auto boolAttr = DenseElementsAttr::get(type, maybeSplat.value());
+ Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
+ type, boolAttr);
+
+ rewriter.replaceOp(cmpiOp, splat);
+ return success();
+ }
+
+ return failure();
+ }
+};
+} // namespace
+
+void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<StepCompareFolder>(context);
+}
+
//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e95338f..12e6475 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -928,17 +928,20 @@ struct WarpOpDeadResult : public WarpDistributionPattern {
// Some values may be yielded multiple times and correspond to multiple
// results. Deduplicating occurs by taking each result with its matching
// yielded value, and:
- // 1. recording the unique first position at which the value is yielded.
+ // 1. recording the unique first position at which the value with uses is
+ // yielded.
// 2. recording for the result, the first position at which the dedup'ed
// value is yielded.
// 3. skipping from the new result types / new yielded values any result
// that has no use or whose yielded value has already been seen.
for (OpResult result : warpOp.getResults()) {
+ if (result.use_empty())
+ continue;
Value yieldOperand = yield.getOperand(result.getResultNumber());
auto it = dedupYieldOperandPositionMap.insert(
std::make_pair(yieldOperand, newResultTypes.size()));
dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
- if (result.use_empty() || !it.second)
+ if (!it.second)
continue;
newResultTypes.push_back(result.getType());
newYieldValues.push_back(yieldOperand);
@@ -1843,16 +1846,16 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
escapingValueDistTypesElse.end());
- llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx;
for (auto [idx, val] :
llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
- origToNewYieldIdx[idx] = newWarpOpYieldValues.size();
newWarpOpYieldValues.push_back(val);
newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
}
- // Create the new `WarpOp` with the updated yield values and types.
- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
- rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+ // Replace the old `WarpOp` with the new one that has additional yield
+ // values and types.
+ SmallVector<size_t> newIndices;
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
// `ifOp` returns the result of the inner warp op.
SmallVector<Type> newIfOpDistResTypes;
for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
@@ -1870,8 +1873,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
auto newIfOp = scf::IfOp::create(
- rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0),
- static_cast<bool>(ifOp.thenBlock()),
+ rewriter, ifOp.getLoc(), newIfOpDistResTypes,
+ newWarpOp.getResult(newIndices[0]), static_cast<bool>(ifOp.thenBlock()),
static_cast<bool>(ifOp.elseBlock()));
auto encloseRegionInWarpOp =
[&](Block *oldIfBranch, Block *newIfBranch,
@@ -1888,7 +1891,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
for (size_t i = 0; i < escapingValues.size();
++i, ++warpResRangeStart) {
innerWarpInputVals.push_back(
- newWarpOp.getResult(warpResRangeStart));
+ newWarpOp.getResult(newIndices[warpResRangeStart]));
escapeValToBlockArgIndex[escapingValues[i]] =
innerWarpInputTypes.size();
innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
@@ -1936,17 +1939,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
// Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
// result.
for (auto [origIdx, newIdx] : ifResultMapping)
- rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
+ rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
newIfOp.getResult(newIdx), newIfOp);
- // Similarly, update any users of the `WarpOp` results that were not
- // results of the `IfOp`.
- for (auto [origIdx, newIdx] : origToNewYieldIdx)
- rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
- newWarpOp.getResult(newIdx));
- // Remove the original `WarpOp` and `IfOp`, they should not have any uses
- // at this point.
- rewriter.eraseOp(ifOp);
- rewriter.eraseOp(warpOp);
return success();
}
@@ -2065,19 +2059,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
escapingValueDistTypes.begin(),
escapingValueDistTypes.end());
// Next, we insert all non-`ForOp` yielded values and their distributed
- // types. We also create a mapping between the non-`ForOp` yielded value
- // index and the corresponding new `WarpOp` yield value index (needed to
- // update users later).
- llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
+ // types.
for (auto [i, v] :
llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
- nonForResultMapping[i] = newWarpOpYieldValues.size();
newWarpOpYieldValues.push_back(v);
newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
}
// Create the new `WarpOp` with the updated yield values and types.
- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
- rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+ SmallVector<size_t> newIndices;
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
// Next, we create a new `ForOp` with the init args yielded by the new
// `WarpOp`.
@@ -2086,7 +2077,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
// escaping values in the new `WarpOp`.
SmallVector<Value> newForOpOperands;
for (size_t i = 0; i < escapingValuesStartIdx; ++i)
- newForOpOperands.push_back(newWarpOp.getResult(i));
+ newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));
// Create a new `ForOp` outside the new `WarpOp` region.
OpBuilder::InsertionGuard g(rewriter);
@@ -2110,7 +2101,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
for (size_t i = escapingValuesStartIdx;
i < escapingValuesStartIdx + escapingValues.size(); ++i) {
- innerWarpInput.push_back(newWarpOp.getResult(i));
+ innerWarpInput.push_back(newWarpOp.getResult(newIndices[i]));
argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
innerWarpInputType.size();
innerWarpInputType.push_back(
@@ -2146,20 +2137,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
if (!innerWarp.getResults().empty())
scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());
- // Update the users of original `WarpOp` results that were coming from the
+ // Update the users of the new `WarpOp` results that were coming from the
// original `ForOp` to the corresponding new `ForOp` result.
for (auto [origIdx, newIdx] : forResultMapping)
- rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
+ rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
newForOp.getResult(newIdx), newForOp);
- // Similarly, update any users of the `WarpOp` results that were not
- // results of the `ForOp`.
- for (auto [origIdx, newIdx] : nonForResultMapping)
- rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
- newWarpOp.getResult(newIdx));
- // Remove the original `WarpOp` and `ForOp`, they should not have any uses
- // at this point.
- rewriter.eraseOp(forOp);
- rewriter.eraseOp(warpOp);
// Update any users of escaping values that were forwarded to the
// inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
newForOp.walk([&](Operation *op) {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 14639c5..fbae098 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -465,26 +465,33 @@ struct UnrollElementwisePattern : public RewritePattern {
auto targetShape = getTargetShape(options, op);
if (!targetShape)
return failure();
+ int64_t targetShapeRank = targetShape->size();
auto dstVecType = cast<VectorType>(op->getResult(0).getType());
SmallVector<int64_t> originalSize =
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
- // Bail-out if rank(source) != rank(target). The main limitation here is the
- // fact that `ExtractStridedSlice` requires the rank for the input and
- // output to match. If needed, we can relax this later.
- if (originalSize.size() != targetShape->size())
- return rewriter.notifyMatchFailure(
- op, "expected input vector rank to match target shape rank");
+ int64_t originalShapeRank = originalSize.size();
+
Location loc = op->getLoc();
+
+ // Handle rank mismatch by adding leading unit dimensions to targetShape
+ SmallVector<int64_t> adjustedTargetShape(originalShapeRank);
+ int64_t rankDiff = originalShapeRank - targetShapeRank;
+ std::fill(adjustedTargetShape.begin(),
+ adjustedTargetShape.begin() + rankDiff, 1);
+ std::copy(targetShape->begin(), targetShape->end(),
+ adjustedTargetShape.begin() + rankDiff);
+
+ int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
// Prepare the result vector.
Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
rewriter.getZeroAttr(dstVecType));
- SmallVector<int64_t> strides(targetShape->size(), 1);
- VectorType newVecType =
+ SmallVector<int64_t> strides(adjustedTargetShapeRank, 1);
+ VectorType unrolledVecType =
VectorType::get(*targetShape, dstVecType.getElementType());
// Create the unrolled computation.
for (SmallVector<int64_t> offsets :
- StaticTileOffsetRange(originalSize, *targetShape)) {
+ StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
SmallVector<Value> extractOperands;
for (OpOperand &operand : op->getOpOperands()) {
auto vecType = dyn_cast<VectorType>(operand.get().getType());
@@ -492,14 +499,31 @@ struct UnrollElementwisePattern : public RewritePattern {
extractOperands.push_back(operand.get());
continue;
}
- extractOperands.push_back(
- rewriter.createOrFold<vector::ExtractStridedSliceOp>(
- loc, operand.get(), offsets, *targetShape, strides));
+ Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, operand.get(), offsets, adjustedTargetShape, strides);
+
+ // Reshape to remove leading unit dims if needed
+ if (adjustedTargetShapeRank > targetShapeRank) {
+ extracted = rewriter.createOrFold<vector::ShapeCastOp>(
+ loc, VectorType::get(*targetShape, vecType.getElementType()),
+ extracted);
+ }
+ extractOperands.push_back(extracted);
}
+
Operation *newOp = cloneOpWithOperandsAndTypes(
- rewriter, loc, op, extractOperands, newVecType);
+ rewriter, loc, op, extractOperands, unrolledVecType);
+
+ Value computeResult = newOp->getResult(0);
+
+ // Use strides sized to targetShape for proper insertion
+ SmallVector<int64_t> insertStrides =
+ (adjustedTargetShapeRank > targetShapeRank)
+ ? SmallVector<int64_t>(targetShapeRank, 1)
+ : strides;
+
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
- loc, newOp->getResult(0), result, offsets, strides);
+ loc, computeResult, result, offsets, insertStrides);
}
rewriter.replaceOp(op, result);
return success();
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 025ee9a..c809c502 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -91,7 +91,7 @@ mlir::vector::isTranspose2DSlice(vector::TransposeOp op) {
// Check whether the two source vector dimensions that are greater than one
// must be transposed with each other so that we can apply one of the 2-D
- // transpose pattens. Otherwise, these patterns are not applicable.
+ // transpose patterns. Otherwise, these patterns are not applicable.
if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1],
op.getPermutation()))
return failure();
diff --git a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
index 89b62a2..a514ea9 100644
--- a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
+++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/SymbolTable.h"
@@ -39,28 +40,6 @@ void printElseRegion(OpAsmPrinter &opPrinter, Operation *op,
opPrinter.printKeywordOrString("else ");
opPrinter.printRegion(elseRegion);
}
-
-ParseResult parseWasmVisibility(OpAsmParser &opParser, StringAttr &visibility) {
- std::string keyword;
- auto initLocation = opParser.getCurrentLocation();
- std::ignore = opParser.parseOptionalKeywordOrString(&keyword);
- if (keyword == "nested" or keyword == "") {
- visibility = StringAttr::get(opParser.getContext(), "nested");
- return ParseResult::success();
- }
-
- if (keyword == "public" || keyword == "private") {
- visibility = StringAttr::get(opParser.getContext(), keyword);
- return ParseResult::success();
- }
- opParser.emitError(initLocation, "expecting symbol visibility");
- return ParseResult::failure();
-}
-
-void printWasmVisibility(OpAsmPrinter &opPrinter, Operation *op,
- Attribute visibility) {
- opPrinter.printKeywordOrString(cast<StringAttr>(visibility).strref());
-}
} // namespace
#define GET_OP_CLASSES
@@ -167,10 +146,23 @@ Block *FuncOp::addEntryBlock() {
void FuncOp::build(OpBuilder &odsBuilder, OperationState &odsState,
StringRef symbol, FunctionType funcType) {
- FuncOp::build(odsBuilder, odsState, symbol, funcType, {}, {}, "nested");
+ FuncOp::build(odsBuilder, odsState, symbol, funcType, {}, {});
}
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
+ auto *ctx = parser.getContext();
+ std::string visibilityString;
+ auto loc = parser.getNameLoc();
+ ParseResult res = parser.parseOptionalKeywordOrString(&visibilityString);
+ bool exported{false};
+ if (res.succeeded()) {
+ if (visibilityString != "exported")
+ return parser.emitError(
+ loc, "expecting either `exported` or symbol name. got ")
+ << visibilityString;
+ exported = true;
+ }
+
auto buildFuncType = [&parser](Builder &builder, ArrayRef<Type> argTypes,
ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
@@ -191,11 +183,13 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
return builder.getFunctionType(argTypesWithoutLocal, results);
};
-
- return function_interface_impl::parseFunctionOp(
+ auto funcParseRes = function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false,
getFunctionTypeAttrName(result.name), buildFuncType,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+ if (exported)
+ result.addAttribute(getExportedAttrName(result.name), UnitAttr::get(ctx));
+ return funcParseRes;
}
LogicalResult FuncOp::verifyBody() {
@@ -224,9 +218,18 @@ LogicalResult FuncOp::verifyBody() {
}
void FuncOp::print(OpAsmPrinter &p) {
+ /// If exported, print it before and mask it before printing
+ /// using generic interface.
+ auto exported = getExported();
+ if (exported) {
+ p << " exported";
+ removeExportedAttr();
+ }
function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
+ if (exported)
+ setExported(true);
}
//===----------------------------------------------------------------------===//
@@ -237,38 +240,37 @@ void FuncImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
StringRef symbol, StringRef moduleName,
StringRef importName, FunctionType type) {
FuncImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
- type, {}, {}, odsBuilder.getStringAttr("nested"));
+ type, {}, {});
}
//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
-
-void GlobalOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef symbol, Type type, bool isMutable) {
- GlobalOp::build(odsBuilder, odsState, symbol, type, isMutable,
- odsBuilder.getStringAttr("nested"));
-}
-
// Custom formats
ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
StringAttr symbolName;
Type globalType;
auto *ctx = parser.getContext();
- ParseResult res = parser.parseSymbolName(
- symbolName, SymbolTable::getSymbolAttrName(), result.attributes);
+ std::string visibilityString;
+ auto loc = parser.getNameLoc();
+ ParseResult res = parser.parseOptionalKeywordOrString(&visibilityString);
+ if (res.succeeded()) {
+ if (visibilityString != "exported")
+ return parser.emitError(
+ loc, "expecting either `exported` or symbol name. got ")
+ << visibilityString;
+ result.addAttribute(getExportedAttrName(result.name), UnitAttr::get(ctx));
+ }
+ res = parser.parseSymbolName(symbolName, SymbolTable::getSymbolAttrName(),
+ result.attributes);
res = parser.parseType(globalType);
result.addAttribute(getTypeAttrName(result.name), TypeAttr::get(globalType));
std::string mutableString;
res = parser.parseOptionalKeywordOrString(&mutableString);
if (res.succeeded() && mutableString == "mutable")
result.addAttribute("isMutable", UnitAttr::get(ctx));
- std::string visibilityString;
- res = parser.parseOptionalKeywordOrString(&visibilityString);
- if (res.succeeded())
- result.addAttribute("sym_visibility",
- StringAttr::get(ctx, visibilityString));
+
res = parser.parseColon();
Region *globalInitRegion = result.addRegion();
res = parser.parseRegion(*globalInitRegion);
@@ -276,11 +278,11 @@ ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
}
void GlobalOp::print(OpAsmPrinter &printer) {
+ if (getExported())
+ printer << " exported";
printer << " @" << getSymName().str() << " " << getType();
if (getIsMutable())
printer << " mutable";
- if (auto vis = getSymVisibility())
- printer << " " << *vis;
printer << " :";
Region &body = getRegion();
if (!body.empty()) {
@@ -319,13 +321,6 @@ GlobalGetOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// GlobalImportOp
//===----------------------------------------------------------------------===//
-void GlobalImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef symbol, StringRef moduleName,
- StringRef importName, Type type, bool isMutable) {
- GlobalImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
- type, isMutable, odsBuilder.getStringAttr("nested"));
-}
-
ParseResult GlobalImportOp::parse(OpAsmParser &parser, OperationState &result) {
auto *ctx = parser.getContext();
ParseResult res = parseImportOp(parser, result);
@@ -335,12 +330,8 @@ ParseResult GlobalImportOp::parse(OpAsmParser &parser, OperationState &result) {
res = parser.parseOptionalKeywordOrString(&mutableOrSymVisString);
if (res.succeeded() && mutableOrSymVisString == "mutable") {
result.addAttribute("isMutable", UnitAttr::get(ctx));
- res = parser.parseOptionalKeywordOrString(&mutableOrSymVisString);
}
- if (res.succeeded())
- result.addAttribute("sym_visibility",
- StringAttr::get(ctx, mutableOrSymVisString));
res = parser.parseColon();
Type importedType;
@@ -356,8 +347,6 @@ void GlobalImportOp::print(OpAsmPrinter &printer) {
<< "\" as @" << getSymName();
if (getIsMutable())
printer << " mutable";
- if (auto vis = getSymVisibility())
- printer << " " << *vis;
printer << " : " << getType();
}
@@ -431,27 +420,6 @@ LogicalResult LocalTeeOp::verify() {
Block *LoopOp::getLabelTarget() { return &getBody().front(); }
//===----------------------------------------------------------------------===//
-// MemOp
-//===----------------------------------------------------------------------===//
-
-void MemOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef symbol, LimitType limit) {
- MemOp::build(odsBuilder, odsState, symbol, limit,
- odsBuilder.getStringAttr("nested"));
-}
-
-//===----------------------------------------------------------------------===//
-// MemImportOp
-//===----------------------------------------------------------------------===//
-
-void MemImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef symbol, StringRef moduleName,
- StringRef importName, LimitType limits) {
- MemImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
- limits, odsBuilder.getStringAttr("nested"));
-}
-
-//===----------------------------------------------------------------------===//
// ReinterpretOp
//===----------------------------------------------------------------------===//
@@ -471,24 +439,3 @@ LogicalResult ReinterpretOp::verify() {
//===----------------------------------------------------------------------===//
void ReturnOp::build(OpBuilder &odsBuilder, OperationState &odsState) {}
-
-//===----------------------------------------------------------------------===//
-// TableOp
-//===----------------------------------------------------------------------===//
-
-void TableOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef symbol, TableType type) {
- TableOp::build(odsBuilder, odsState, symbol, type,
- odsBuilder.getStringAttr("nested"));
-}
-
-//===----------------------------------------------------------------------===//
-// TableImportOp
-//===----------------------------------------------------------------------===//
-
-void TableImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef symbol, StringRef moduleName,
- StringRef importName, TableType type) {
- TableImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
- type, odsBuilder.getStringAttr("nested"));
-}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 9beb22d..1599ae9 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -727,6 +727,152 @@ void MemLayoutAttr::print(AsmPrinter &printer) const {
}
printer << ">";
}
+// a helper utility to perform binary operation on OpFoldResult.
+// If both a and b are attributes, it will simply return the result.
+// Otherwise, the corresponding arith op will be generated, and an
+// contant op will be created if one of them is an attribute.
+template <typename ArithOp>
+OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc,
+ OpBuilder &builder) {
+ auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a);
+ auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b);
+ return builder.create<ArithOp>(loc, aVal, bVal).getResult();
+}
+
+// a helper utility to perform division operation on OpFoldResult and int64_t.
+#define div(a, b) \
+ genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform reminder operation on OpFoldResult and int64_t.
+#define rem(a, b) \
+ genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform multiply operation on OpFoldResult and int64_t.
+#define mul(a, b) \
+ genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform addition operation on two OpFoldResult.
+#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder)
+
+// block the given offsets according to the block shape
+// say the original offset is [y, x], and the block shape is [By, Bx],
+// then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
+SmallVector<OpFoldResult> getBlockedOffsets(OpBuilder &builder, Location loc,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<int64_t> blockShape) {
+
+ assert(offsets.size() == blockShape.size() &&
+ "offsets and blockShape must have the same size");
+ SmallVector<OpFoldResult> blockedOffsets;
+ SmallVector<OpFoldResult> divs, rems;
+
+ for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
+ divs.push_back(div(offset, block));
+ rems.push_back(rem(offset, block));
+ }
+ blockedOffsets.append(divs.begin(), divs.end());
+ blockedOffsets.append(rems.begin(), rems.end());
+
+ return blockedOffsets;
+}
+
+// Get strides as vector of integer for MemDesc.
+SmallVector<int64_t> MemDescType::getStrideShape() {
+
+ SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
+
+ ArrayAttr strideAttr = getStrideAttr();
+ SmallVector<int64_t> strides;
+ for (Attribute attr : strideAttr.getValue()) {
+ strides.push_back(cast<IntegerAttr>(attr).getInt());
+ }
+
+ SmallVector<int64_t> innerBlkShape = getBlockShape();
+
+ // get perm from FCD to LCD
+ // perm[i] = the dim with i-th smallest stride
+ SmallVector<int, 4> perm =
+ llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
+ llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
+
+ assert(strides[perm[0]] == 1 && "inner most dim must have stride 1");
+
+ SmallVector<int64_t> innerBlkStride(innerBlkShape.size());
+ innerBlkStride[perm[0]] = 1;
+ for (size_t i = 1; i < perm.size(); ++i)
+ innerBlkStride[perm[i]] =
+ innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]];
+
+ // compute the original matrix shape using the stride info
+ // and compute the number of blocks in each dimension
+ // The shape of highest dim can't be derived from stride info,
+ // but doesn't impact the stride computation for blocked layout.
+ SmallVector<int64_t> matrixShapeOrig(matrixShape.size());
+ SmallVector<int64_t> BlkShapeOrig(matrixShape.size());
+ for (size_t i = 0; i < perm.size() - 1; ++i) {
+ matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
+ BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
+ }
+
+ int64_t innerBlkSize = 1;
+ for (auto s : innerBlkShape)
+ innerBlkSize *= s;
+
+ SmallVector<int64_t> outerBlkStride(matrixShape.size());
+ outerBlkStride[perm[0]] = innerBlkSize;
+ for (size_t i = 0; i < perm.size() - 1; ++i) {
+ outerBlkStride[perm[i + 1]] =
+ outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
+ }
+
+ // combine the inner and outer strides
+ SmallVector<int64_t> blockedStrides;
+ blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
+ blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
+
+ return blockedStrides;
+}
+
+// Calculate the linear offset using the blocked offsets and stride
+Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
+ ArrayRef<OpFoldResult> offsets) {
+
+ SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
+ SmallVector<int64_t> blockShape = getBlockShape();
+ SmallVector<int64_t> strides = getStrideShape();
+ SmallVector<OpFoldResult> blockedOffsets;
+
+ // blockshape equal to matrixshape means no blocking
+ if (llvm::equal(blockShape, matrixShape)) {
+ // remove the outer dims from strides
+ strides.erase(strides.begin(), strides.begin() + matrixShape.size());
+ } else {
+ assert(offsets.size() == blockShape.size() &&
+ "offsets and blockShape must have the same size");
+ // say the original offset is [y, x], and the block shape is [By, Bx],
+ // then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
+
+ SmallVector<OpFoldResult> divs, rems;
+
+ for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
+ divs.push_back(div(offset, block));
+ rems.push_back(rem(offset, block));
+ }
+ blockedOffsets.append(divs.begin(), divs.end());
+ blockedOffsets.append(rems.begin(), rems.end());
+ offsets = blockedOffsets;
+ }
+
+ // Start with initial value as matrix descriptor's base offset.
+ Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0);
+ for (size_t i = 0; i < offsets.size(); ++i) {
+ OpFoldResult mulResult = mul(offsets[i], strides[i]);
+ Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult);
+ linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
+ }
+
+ return linearOffset;
+}
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 81b5788..abd12e2 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -20,8 +20,8 @@
#define DEBUG_TYPE "xegpu"
-namespace mlir {
-namespace xegpu {
+using namespace mlir;
+using namespace mlir::xegpu;
static bool isSharedMemory(const MemRefType &memrefTy) {
Attribute attr = memrefTy.getMemorySpace();
@@ -173,6 +173,49 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
return success();
}
+LogicalResult
+IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
+ UnitAttr subgroup_block_io,
+ function_ref<InFlightDiagnostic()> emitError) {
+
+ if (!dataTy) {
+ if (subgroup_block_io)
+ return emitError() << "subgroup_block_io "
+ "are only allowed when result is a 1D VectorType.";
+ else
+ return success();
+ }
+
+ if (mdescTy.getRank() != 2)
+ return emitError() << "mem_desc must be 2D.";
+
+ ArrayRef<int64_t> dataShape = dataTy.getShape();
+ ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+
+ if (dataShape.size() == 2) {
+ if (subgroup_block_io)
+ return emitError() << "subgroup_block_io "
+ "are only allowed when result is a 1D VectorType.";
+ if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
+ [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+ return emitError() << "data shape must not exceed mem_desc shape.";
+ } else {
+ SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
+ // if the subgroup_block_io attribute is set, mdescTy must have block
+ // attribute
+ if (subgroup_block_io && !blockShape.size())
+ return emitError() << "mem_desc must have block attribute when "
+ "subgroup_block_io is set.";
+ // if the subgroup_block_io attribute is set, the memdesc should be row
+ // major
+ if (subgroup_block_io && mdescTy.isColMajor())
+ return emitError() << "mem_desc should be row major when "
+ "subgroup_block_io is set.";
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_CreateNdDescOp
//===----------------------------------------------------------------------===//
@@ -1049,23 +1092,20 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
llvm::SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+ // Call the generated builder with all parameters (including optional ones as
+ // nullptr/empty)
build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
- layout);
+ /*subgroup_block_io=*/nullptr, layout);
}
LogicalResult LoadMatrixOp::verify() {
- VectorType resTy = getRes().getType();
- MemDescType mdescTy = getMemDesc().getType();
- if (mdescTy.getRank() != 2)
- return emitOpError("mem_desc must be 2D.");
+ auto resTy = dyn_cast<VectorType>(getRes().getType());
+ UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
+ MemDescType mdescTy = getMemDesc().getType();
- ArrayRef<int64_t> valueShape = resTy.getShape();
- ArrayRef<int64_t> mdescShape = mdescTy.getShape();
- if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape),
- [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
- return emitOpError("result shape must not exceed mem_desc shape.");
- return success();
+ return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
+ [&]() { return emitError(); });
}
//===----------------------------------------------------------------------===//
@@ -1080,62 +1120,18 @@ void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
- layout);
+ /*subgroup_block_io=*/nullptr, layout);
}
LogicalResult StoreMatrixOp::verify() {
- VectorType dataTy = getData().getType();
- MemDescType mdescTy = getMemDesc().getType();
-
- if (mdescTy.getRank() != 2)
- return emitOpError("mem_desc must be 2D.");
-
- ArrayRef<int64_t> dataShape = dataTy.getShape();
- ArrayRef<int64_t> mdescShape = mdescTy.getShape();
- if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
- [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
- return emitOpError("data shape must not exceed mem_desc shape.");
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// XeGPU_MemDescSubviewOp
-//===----------------------------------------------------------------------===//
-
-void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state,
- Type resTy, Value src,
- llvm::ArrayRef<OpFoldResult> offsets) {
- llvm::SmallVector<Value> dynamicOffsets;
- llvm::SmallVector<int64_t> staticOffsets;
- dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
- auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
- build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr);
-}
-
-LogicalResult MemDescSubviewOp::verify() {
- MemDescType srcTy = getSrc().getType();
- MemDescType resTy = getRes().getType();
- ArrayRef<int64_t> srcShape = srcTy.getShape();
- ArrayRef<int64_t> resShape = resTy.getShape();
-
- if (srcTy.getRank() < resTy.getRank())
- return emitOpError("result rank must not exceed source rank.");
- if (llvm::any_of(
- llvm::zip_equal(resShape, srcShape.take_back(resShape.size())),
- [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
- return emitOpError("result shape must not exceed source shape.");
-
- if (srcTy.getStrides() != resTy.getStrides())
- return emitOpError("result must inherit the source strides.");
-
- return success();
+ auto dataTy = dyn_cast<VectorType>(getData().getType());
+ UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
+ MemDescType mdescTy = getMemDesc().getType();
+ return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
+ [&]() { return emitError(); });
}
-} // namespace xegpu
-} // namespace mlir
-
namespace mlir {
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index a178d0f..aafa1b7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -941,7 +941,9 @@ struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- VectorType valueTy = op.getType();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
+ assert(valueTy && "the value type must be vector type!");
+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
return failure();
@@ -984,7 +986,8 @@ struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> {
return failure();
Location loc = op.getLoc();
- VectorType valueTy = op.getData().getType();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType());
+ assert(valueTy && "the value type must be vector type!");
ArrayRef<int64_t> shape = valueTy.getShape();
auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index c28d2fc..31a967d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -991,7 +991,8 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
return failure();
ArrayRef<int64_t> wgShape = op.getDataShape();
- VectorType valueTy = op.getRes().getType();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
+ assert(valueTy && "the value type must be vector type!");
Type elemTy = valueTy.getElementType();
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp
index 776b5c6..4d81918 100644
--- a/mlir/lib/IR/Diagnostics.cpp
+++ b/mlir/lib/IR/Diagnostics.cpp
@@ -378,8 +378,10 @@ struct SourceMgrDiagnosticHandlerImpl {
}
// Otherwise, try to load the source file.
- std::string ignored;
- unsigned id = mgr.AddIncludeFile(std::string(filename), SMLoc(), ignored);
+ auto bufferOrErr = llvm::MemoryBuffer::getFile(filename);
+ if (!bufferOrErr)
+ return 0;
+ unsigned id = mgr.AddNewSourceBuffer(std::move(*bufferOrErr), SMLoc());
filenameToBufId[filename] = id;
return id;
}
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 1fa04ed..5f63fe6 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -121,6 +121,11 @@ namespace mlir {
class MLIRContextImpl {
public:
//===--------------------------------------------------------------------===//
+ // Remark
+ //===--------------------------------------------------------------------===//
+ std::unique_ptr<remark::detail::RemarkEngine> remarkEngine;
+
+ //===--------------------------------------------------------------------===//
// Debugging
//===--------------------------------------------------------------------===//
@@ -135,11 +140,6 @@ public:
DiagnosticEngine diagEngine;
//===--------------------------------------------------------------------===//
- // Remark
- //===--------------------------------------------------------------------===//
- std::unique_ptr<remark::detail::RemarkEngine> remarkEngine;
-
- //===--------------------------------------------------------------------===//
// Options
//===--------------------------------------------------------------------===//
@@ -357,7 +357,10 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
impl->affineUniquer.registerParametricStorageType<IntegerSetStorage>();
}
-MLIRContext::~MLIRContext() = default;
+MLIRContext::~MLIRContext() {
+ // finalize remark engine before destroying anything else.
+ impl->remarkEngine.reset();
+}
/// Copy the specified array of elements into memory managed by the provided
/// bump pointer allocator. This assumes the elements are all PODs.
@@ -1201,7 +1204,7 @@ AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount,
/// present in result expressions is less than `dimCount` and the highest index
/// of symbolic identifier present in result expressions is less than
/// `symbolCount`.
-LLVM_ATTRIBUTE_UNUSED static bool
+[[maybe_unused]] static bool
willBeValidAffineMap(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr> results) {
int64_t maxDimPosition = -1;
diff --git a/mlir/lib/IR/Remarks.cpp b/mlir/lib/IR/Remarks.cpp
index a55f61a..031eae2 100644
--- a/mlir/lib/IR/Remarks.cpp
+++ b/mlir/lib/IR/Remarks.cpp
@@ -16,7 +16,7 @@
#include "llvm/ADT/StringRef.h"
using namespace mlir::remark::detail;
-
+using namespace mlir::remark;
//------------------------------------------------------------------------------
// Remark
//------------------------------------------------------------------------------
@@ -70,7 +70,7 @@ static void printArgs(llvm::raw_ostream &os, llvm::ArrayRef<Remark::Arg> args) {
void Remark::print(llvm::raw_ostream &os, bool printLocation) const {
// Header: [Type] pass:remarkName
StringRef type = getRemarkTypeString();
- StringRef categoryName = getFullCategoryName();
+ StringRef categoryName = getCombinedCategoryName();
StringRef name = remarkName;
os << '[' << type << "] ";
@@ -81,9 +81,10 @@ void Remark::print(llvm::raw_ostream &os, bool printLocation) const {
os << "Function=" << getFunction() << " | ";
if (printLocation) {
- if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(getLocation()))
+ if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(getLocation())) {
os << " @" << flc.getFilename() << ":" << flc.getLine() << ":"
<< flc.getColumn();
+ }
}
printArgs(os, getArgs());
@@ -140,7 +141,7 @@ llvm::remarks::Remark Remark::generateRemark() const {
r.RemarkType = getRemarkType();
r.RemarkName = getRemarkName();
// MLIR does not use passes; instead, it has categories and sub-categories.
- r.PassName = getFullCategoryName();
+ r.PassName = getCombinedCategoryName();
r.FunctionName = getFunction();
r.Loc = locLambda();
for (const Remark::Arg &arg : getArgs()) {
@@ -225,26 +226,42 @@ InFlightRemark RemarkEngine::emitOptimizationRemarkAnalysis(Location loc,
// RemarkEngine
//===----------------------------------------------------------------------===//
-void RemarkEngine::report(const Remark &&remark) {
+void RemarkEngine::reportImpl(const Remark &remark) {
// Stream the remark
- if (remarkStreamer)
+ if (remarkStreamer) {
remarkStreamer->streamOptimizationRemark(remark);
+ }
// Print using MLIR's diagnostic
if (printAsEmitRemarks)
emitRemark(remark.getLocation(), remark.getMsg());
}
+void RemarkEngine::report(const Remark &&remark) {
+ if (remarkEmittingPolicy)
+ remarkEmittingPolicy->reportRemark(remark);
+}
+
RemarkEngine::~RemarkEngine() {
+ if (remarkEmittingPolicy)
+ remarkEmittingPolicy->finalize();
+
if (remarkStreamer)
remarkStreamer->finalize();
}
-llvm::LogicalResult
-RemarkEngine::initialize(std::unique_ptr<MLIRRemarkStreamerBase> streamer,
- std::string *errMsg) {
- // If you need to validate categories/filters, do so here and set errMsg.
+llvm::LogicalResult RemarkEngine::initialize(
+ std::unique_ptr<MLIRRemarkStreamerBase> streamer,
+ std::unique_ptr<RemarkEmittingPolicyBase> remarkEmittingPolicy,
+ std::string *errMsg) {
+
remarkStreamer = std::move(streamer);
+
+ auto reportFunc =
+ std::bind(&RemarkEngine::reportImpl, this, std::placeholders::_1);
+ remarkEmittingPolicy->initialize(ReportFn(std::move(reportFunc)));
+
+ this->remarkEmittingPolicy = std::move(remarkEmittingPolicy);
return success();
}
@@ -301,14 +318,15 @@ RemarkEngine::RemarkEngine(bool printAsEmitRemarks,
}
llvm::LogicalResult mlir::remark::enableOptimizationRemarks(
- MLIRContext &ctx,
- std::unique_ptr<remark::detail::MLIRRemarkStreamerBase> streamer,
- const remark::RemarkCategories &cats, bool printAsEmitRemarks) {
+ MLIRContext &ctx, std::unique_ptr<detail::MLIRRemarkStreamerBase> streamer,
+ std::unique_ptr<detail::RemarkEmittingPolicyBase> remarkEmittingPolicy,
+ const RemarkCategories &cats, bool printAsEmitRemarks) {
auto engine =
- std::make_unique<remark::detail::RemarkEngine>(printAsEmitRemarks, cats);
+ std::make_unique<detail::RemarkEngine>(printAsEmitRemarks, cats);
std::string errMsg;
- if (failed(engine->initialize(std::move(streamer), &errMsg))) {
+ if (failed(engine->initialize(std::move(streamer),
+ std::move(remarkEmittingPolicy), &errMsg))) {
llvm::report_fatal_error(
llvm::Twine("Failed to initialize remark engine. Error: ") + errMsg);
}
@@ -316,3 +334,12 @@ llvm::LogicalResult mlir::remark::enableOptimizationRemarks(
return success();
}
+
+//===----------------------------------------------------------------------===//
+// Remark emitting policies
+//===----------------------------------------------------------------------===//
+
+namespace mlir::remark {
+RemarkEmittingPolicyAll::RemarkEmittingPolicyAll() = default;
+RemarkEmittingPolicyFinal::RemarkEmittingPolicyFinal() = default;
+} // namespace mlir::remark
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index 388de1c..f96af02 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -9,6 +9,7 @@ set(LLVM_OPTIONAL_SOURCES
FunctionInterfaces.cpp
IndexingMapOpInterface.cpp
InferIntRangeInterface.cpp
+ InferStridedMetadataInterface.cpp
InferTypeOpInterface.cpp
LoopLikeInterface.cpp
MemOpInterfaces.cpp
@@ -64,6 +65,21 @@ add_mlir_library(MLIRFunctionInterfaces
add_mlir_interface_library(IndexingMapOpInterface)
add_mlir_interface_library(InferIntRangeInterface)
+
+add_mlir_library(MLIRInferStridedMetadataInterface
+ InferStridedMetadataInterface.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces
+
+ DEPENDS
+ MLIRInferStridedMetadataInterfaceIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRInferIntRangeInterface
+ MLIRIR
+)
+
add_mlir_interface_library(InferTypeOpInterface)
add_mlir_library(MLIRLoopLikeInterface
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index 9f3e97d..84fc9b8 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -146,6 +146,25 @@ raw_ostream &mlir::operator<<(raw_ostream &os, const IntegerValueRange &range) {
return os;
}
+SmallVector<IntegerValueRange>
+mlir::getIntValueRanges(ArrayRef<OpFoldResult> values,
+ GetIntRangeFn getIntRange, int32_t indexBitwidth) {
+ SmallVector<IntegerValueRange> ranges;
+ ranges.reserve(values.size());
+ for (OpFoldResult ofr : values) {
+ if (auto value = dyn_cast<Value>(ofr)) {
+ ranges.push_back(getIntRange(value));
+ continue;
+ }
+
+ // Create a constant range.
+ auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
+ ranges.emplace_back(ConstantIntRanges::constant(
+ attr.getValue().sextOrTrunc(indexBitwidth)));
+ }
+ return ranges;
+}
+
void mlir::intrange::detail::defaultInferResultRanges(
InferIntRangeInterface interface, ArrayRef<IntegerValueRange> argRanges,
SetIntLatticeFn setResultRanges) {
diff --git a/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp b/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp
new file mode 100644
index 0000000..483e9f1
--- /dev/null
+++ b/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp
@@ -0,0 +1,36 @@
+//===- InferStridedMetadataInterface.cpp - Strided md inference interface -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/InferStridedMetadataInterface.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include <optional>
+
+using namespace mlir;
+
+#include "mlir/Interfaces/InferStridedMetadataInterface.cpp.inc"
+
+void StridedMetadataRange::print(raw_ostream &os) const {
+ if (isUninitialized()) {
+ os << "strided_metadata<None>";
+ return;
+ }
+ os << "strided_metadata<offset = [";
+ llvm::interleaveComma(*offsets, os, [&](const ConstantIntRanges &range) {
+ os << "{" << range << "}";
+ });
+ os << "], sizes = [";
+ llvm::interleaveComma(sizes, os, [&](const ConstantIntRanges &range) {
+ os << "{" << range << "}";
+ });
+ os << "], strides = [";
+ llvm::interleaveComma(strides, os, [&](const ConstantIntRanges &range) {
+ os << "{" << range << "}";
+ });
+ os << "]>";
+}
diff --git a/mlir/lib/Remark/RemarkStreamer.cpp b/mlir/lib/Remark/RemarkStreamer.cpp
index d213a1a..bf36286 100644
--- a/mlir/lib/Remark/RemarkStreamer.cpp
+++ b/mlir/lib/Remark/RemarkStreamer.cpp
@@ -60,6 +60,7 @@ void LLVMRemarkStreamer::finalize() {
namespace mlir::remark {
LogicalResult enableOptimizationRemarksWithLLVMStreamer(
MLIRContext &ctx, StringRef path, llvm::remarks::Format fmt,
+ std::unique_ptr<detail::RemarkEmittingPolicyBase> remarkEmittingPolicy,
const RemarkCategories &cat, bool printAsEmitRemarks) {
FailureOr<std::unique_ptr<detail::MLIRRemarkStreamerBase>> sOr =
@@ -67,7 +68,8 @@ LogicalResult enableOptimizationRemarksWithLLVMStreamer(
if (failed(sOr))
return failure();
- return remark::enableOptimizationRemarks(ctx, std::move(*sOr), cat,
+ return remark::enableOptimizationRemarks(ctx, std::move(*sOr),
+ std::move(remarkEmittingPolicy), cat,
printAsEmitRemarks);
}
diff --git a/mlir/lib/TableGen/CodeGenHelpers.cpp b/mlir/lib/TableGen/CodeGenHelpers.cpp
index cb90ef8..d52d5e7 100644
--- a/mlir/lib/TableGen/CodeGenHelpers.cpp
+++ b/mlir/lib/TableGen/CodeGenHelpers.cpp
@@ -49,9 +49,7 @@ StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
raw_ostream &os, const RecordKeeper &records, StringRef tag)
: os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {}
-void StaticVerifierFunctionEmitter::emitOpConstraints(
- ArrayRef<const Record *> opDefs) {
- NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace());
+void StaticVerifierFunctionEmitter::emitOpConstraints() {
emitTypeConstraints();
emitAttrConstraints();
emitPropConstraints();
diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index 4bbcd8e..db39c70 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -34,11 +34,9 @@ Location DebugImporter::translateFuncLocation(llvm::Function *func) {
return UnknownLoc::get(context);
// Add a fused location to link the subprogram information.
- StringAttr funcName = StringAttr::get(context, subprogram->getName());
StringAttr fileName = StringAttr::get(context, subprogram->getFilename());
return FusedLocWith<DISubprogramAttr>::get(
- {NameLoc::get(funcName),
- FileLineColLoc::get(fileName, subprogram->getLine(), /*column=*/0)},
+ {FileLineColLoc::get(fileName, subprogram->getLine(), /*column=*/0)},
translate(subprogram), context);
}
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 9603813..857e31b 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -2604,6 +2604,7 @@ static constexpr std::array kExplicitLLVMFuncOpAttributes{
StringLiteral("denormal-fp-math-f32"),
StringLiteral("fp-contract"),
StringLiteral("frame-pointer"),
+ StringLiteral("inlinehint"),
StringLiteral("instrument-function-entry"),
StringLiteral("instrument-function-exit"),
StringLiteral("memory"),
@@ -2643,6 +2644,8 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
funcOp.setNoInline(true);
if (func->hasFnAttribute(llvm::Attribute::AlwaysInline))
funcOp.setAlwaysInline(true);
+ if (func->hasFnAttribute(llvm::Attribute::InlineHint))
+ funcOp.setInlineHint(true);
if (func->hasFnAttribute(llvm::Attribute::OptimizeNone))
funcOp.setOptimizeNone(true);
if (func->hasFnAttribute(llvm::Attribute::Convergent))
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 845a14f..147613f 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1652,6 +1652,8 @@ static void convertFunctionAttributes(LLVMFuncOp func,
llvmFunc->addFnAttr(llvm::Attribute::NoInline);
if (func.getAlwaysInlineAttr())
llvmFunc->addFnAttr(llvm::Attribute::AlwaysInline);
+ if (func.getInlineHintAttr())
+ llvmFunc->addFnAttr(llvm::Attribute::InlineHint);
if (func.getOptimizeNoneAttr())
llvmFunc->addFnAttr(llvm::Attribute::OptimizeNone);
if (func.getConvergentAttr())
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 0c3e87a..d9ad8fb 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2619,6 +2619,11 @@ LogicalResult ControlFlowStructurizer::structurize() {
// region. We cannot handle such cases given that once a value is sinked into
// the SelectionOp/LoopOp's region, there is no escape for it.
for (auto *block : constructBlocks) {
+ if (!block->use_empty())
+ return emitError(block->getParent()->getLoc(),
+ "failed control flow structurization: "
+ "block has uses outside of the "
+ "enclosing selection/loop construct");
for (Operation &op : *block)
if (!op.use_empty())
return op.emitOpError("failed control flow structurization: value has "
diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
index 132be4e..366ba8f 100644
--- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
+++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
@@ -14,6 +14,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/Support/LLVM.h"
@@ -138,6 +139,10 @@ using ImportDesc =
using parsed_inst_t = FailureOr<SmallVector<Value>>;
+struct EmptyBlockMarker {};
+using BlockTypeParseResult =
+ std::variant<EmptyBlockMarker, TypeIdxRecord, Type>;
+
struct WasmModuleSymbolTables {
SmallVector<FunctionSymbolRefContainer> funcSymbols;
SmallVector<GlobalSymbolRefContainer> globalSymbols;
@@ -175,6 +180,9 @@ class ParserHead;
/// Wrapper around SmallVector to only allow access as push and pop on the
/// stack. Makes sure that there are no "free accesses" on the stack to preserve
/// its state.
+/// This class also keep tracks of the Wasm labels defined by different ops,
+/// which can be targeted by control flow ops. This can be modeled as part of
+/// the Value Stack as Wasm control flow ops can only target enclosing labels.
class ValueStack {
private:
struct LabelLevel {
@@ -206,6 +214,16 @@ public:
/// if an error occurs.
LogicalResult pushResults(ValueRange results, Location *opLoc);
+ void addLabelLevel(LabelLevelOpInterface levelOp) {
+ labelLevel.push_back({values.size(), levelOp});
+ LDBG() << "Adding a new frame context to ValueStack";
+ }
+
+ void dropLabelLevel() {
+ assert(!labelLevel.empty() && "Trying to drop a frame from empty context");
+ auto newSize = labelLevel.pop_back_val().stackIdx;
+ values.truncate(newSize);
+ }
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
/// A simple dump function for debugging.
/// Writes output to llvm::dbgs().
@@ -214,6 +232,7 @@ public:
private:
SmallVector<Value> values;
+ SmallVector<LabelLevel> labelLevel;
};
using local_val_t = TypedValue<wasmssa::LocalRefType>;
@@ -248,6 +267,19 @@ private:
buildNumericOp(OpBuilder &builder,
std::enable_if_t<std::is_arithmetic_v<valueType>> * = nullptr);
+ /// Construct a conversion operation of type \p opType that takes a value from
+ /// type \p inputType on the stack and will produce a value of type
+ /// \p outputType.
+ ///
+ /// \p opType - The WASM dialect operation to build.
+ /// \p inputType - The operand type for the built instruction.
+ /// \p outputType - The result type for the built instruction.
+ ///
+ /// \returns The parsed instruction result, or failure.
+ template <typename opType, typename inputType, typename outputType,
+ typename... extraArgsT>
+ inline parsed_inst_t buildConvertOp(OpBuilder &builder, extraArgsT...);
+
/// This function generates a dispatch tree to associate an opcode with a
/// parser. Parsers are registered by specialising the
/// `parseSpecificInstruction` function for the op code to handle.
@@ -280,11 +312,105 @@ private:
}
}
+ ///
+ /// RAII guard class for creating a nesting level
+ ///
+ struct NestingContextGuard {
+ NestingContextGuard(ExpressionParser &parser, LabelLevelOpInterface levelOp)
+ : parser{parser} {
+ parser.addNestingContextLevel(levelOp);
+ }
+ NestingContextGuard(NestingContextGuard &&other) : parser{other.parser} {
+ other.shouldDropOnDestruct = false;
+ }
+ NestingContextGuard(NestingContextGuard const &) = delete;
+ ~NestingContextGuard() {
+ if (shouldDropOnDestruct)
+ parser.dropNestingContextLevel();
+ }
+ ExpressionParser &parser;
+ bool shouldDropOnDestruct = true;
+ };
+
+ void addNestingContextLevel(LabelLevelOpInterface levelOp) {
+ valueStack.addLabelLevel(levelOp);
+ }
+
+ void dropNestingContextLevel() {
+ // Should always succeed as we are droping the frame that was previously
+ // created.
+ valueStack.dropLabelLevel();
+ }
+
+ llvm::FailureOr<FunctionType> getFuncTypeFor(OpBuilder &builder,
+ EmptyBlockMarker) {
+ return builder.getFunctionType({}, {});
+ }
+
+ llvm::FailureOr<FunctionType> getFuncTypeFor(OpBuilder &builder,
+ TypeIdxRecord type) {
+ if (type.id >= symbols.moduleFuncTypes.size())
+ return emitError(*currentOpLoc,
+ "type index references nonexistent type (")
+ << type.id << "). Only " << symbols.moduleFuncTypes.size()
+ << " types are registered";
+ return symbols.moduleFuncTypes[type.id];
+ }
+
+ llvm::FailureOr<FunctionType> getFuncTypeFor(OpBuilder &builder,
+ Type valType) {
+ return builder.getFunctionType({}, {valType});
+ }
+
+ llvm::FailureOr<FunctionType>
+ getFuncTypeFor(OpBuilder &builder, BlockTypeParseResult parseResult) {
+ return std::visit(
+ [this, &builder](auto value) { return getFuncTypeFor(builder, value); },
+ parseResult);
+ }
+
+ llvm::FailureOr<FunctionType>
+ getFuncTypeFor(OpBuilder &builder,
+ llvm::FailureOr<BlockTypeParseResult> parseResult) {
+ if (llvm::failed(parseResult))
+ return failure();
+ return getFuncTypeFor(builder, *parseResult);
+ }
+
+ llvm::FailureOr<FunctionType> parseBlockFuncType(OpBuilder &builder);
+
struct ParseResultWithInfo {
SmallVector<Value> opResults;
std::byte endingByte;
};
+ template <typename FilterT = ByteSequence<WasmBinaryEncoding::endByte>>
+ /// @param blockToFill: the block which content will be populated
+ /// @param resType: the type that this block is supposed to return
+ llvm::FailureOr<std::byte>
+ parseBlockContent(OpBuilder &builder, Block *blockToFill, TypeRange resTypes,
+ Location opLoc, LabelLevelOpInterface levelOp,
+ FilterT parseEndBytes = {}) {
+ OpBuilder::InsertionGuard guard{builder};
+ builder.setInsertionPointToStart(blockToFill);
+ LDBG() << "parsing a block of type "
+ << builder.getFunctionType(blockToFill->getArgumentTypes(),
+ resTypes);
+ auto nC = addNesting(levelOp);
+
+ if (failed(pushResults(blockToFill->getArguments())))
+ return failure();
+ auto bodyParsingRes = parse(builder, parseEndBytes);
+ if (failed(bodyParsingRes))
+ return failure();
+ auto returnOperands = popOperands(resTypes);
+ if (failed(returnOperands))
+ return failure();
+ builder.create<BlockReturnOp>(opLoc, *returnOperands);
+ LDBG() << "end of parsing of a block";
+ return bodyParsingRes->endingByte;
+ }
+
public:
template <std::byte ParseEndByte = WasmBinaryEncoding::endByte>
parsed_inst_t parse(OpBuilder &builder, UniqueByte<ParseEndByte> = {});
@@ -294,7 +420,11 @@ public:
parse(OpBuilder &builder,
ByteSequence<ExpressionParseEnd...> parsingEndFilters);
- FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes) {
+ NestingContextGuard addNesting(LabelLevelOpInterface levelOp) {
+ return NestingContextGuard{*this, levelOp};
+ }
+
+ FailureOr<llvm::SmallVector<Value>> popOperands(TypeRange operandTypes) {
return valueStack.popOperands(operandTypes, &currentOpLoc.value());
}
@@ -308,6 +438,12 @@ public:
template <typename OpToCreate>
parsed_inst_t parseSetOrTee(OpBuilder &);
+ /// Blocks and Loops have a similar format and differ only in how their exit
+ /// is handled which doesn´t matter at parsing time. Factorizes in one
+ /// function.
+ template <typename OpToCreate>
+ parsed_inst_t parseBlockLikeOp(OpBuilder &);
+
private:
std::optional<Location> currentOpLoc;
ParserHead &parser;
@@ -586,6 +722,29 @@ public:
return success();
}
+ llvm::FailureOr<BlockTypeParseResult> parseBlockType(MLIRContext *ctx) {
+ auto loc = getLocation();
+ auto blockIndicator = peek();
+ if (failed(blockIndicator))
+ return failure();
+ if (*blockIndicator == WasmBinaryEncoding::Type::emptyBlockType) {
+ offset += 1;
+ return {EmptyBlockMarker{}};
+ }
+ if (isValueOneOf(*blockIndicator, valueTypesEncodings))
+ return parseValueType(ctx);
+ /// Block type idx is a 32 bit positive integer encoded as a 33 bit signed
+ /// value
+ auto typeIdx = parseI64();
+ if (failed(typeIdx))
+ return failure();
+ if (*typeIdx < 0 || *typeIdx > std::numeric_limits<uint32_t>::max())
+ return emitError(loc, "type ID should be representable with an unsigned "
+ "32 bits integer. Got ")
+ << *typeIdx;
+ return {TypeIdxRecord{static_cast<uint32_t>(*typeIdx)}};
+ }
+
bool end() const { return curHead().empty(); }
ParserHead copy() const { return *this; }
@@ -701,17 +860,41 @@ inline parsed_inst_t ExpressionParser::parseSpecificInstruction(OpBuilder &) {
void ValueStack::dump() const {
llvm::dbgs() << "================= Wasm ValueStack =======================\n";
llvm::dbgs() << "size: " << size() << "\n";
+ llvm::dbgs() << "nbFrames: " << labelLevel.size() << '\n';
llvm::dbgs() << "<Top>"
<< "\n";
// Stack is pushed to via push_back. Therefore the top of the stack is the
// end of the vector. Iterate in reverse so that the first thing we print
// is the top of the stack.
+ auto indexGetter = [this]() {
+ size_t idx = labelLevel.size();
+ return [this, idx]() mutable -> std::optional<std::pair<size_t, size_t>> {
+ llvm::dbgs() << "IDX: " << idx << '\n';
+ if (idx == 0)
+ return std::nullopt;
+ auto frameId = idx - 1;
+ auto frameLimit = labelLevel[frameId].stackIdx;
+ idx -= 1;
+ return {{frameId, frameLimit}};
+ };
+ };
+ auto getNextFrameIndex = indexGetter();
+ auto nextFrameIdx = getNextFrameIndex();
size_t stackSize = size();
- for (size_t idx = 0; idx < stackSize; idx++) {
+ for (size_t idx = 0; idx < stackSize; ++idx) {
size_t actualIdx = stackSize - 1 - idx;
+ while (nextFrameIdx && (nextFrameIdx->second > actualIdx)) {
+ llvm::dbgs() << " --------------- Frame (" << nextFrameIdx->first
+ << ")\n";
+ nextFrameIdx = getNextFrameIndex();
+ }
llvm::dbgs() << " ";
values[actualIdx].dump();
}
+ while (nextFrameIdx) {
+ llvm::dbgs() << " --------------- Frame (" << nextFrameIdx->first << ")\n";
+ nextFrameIdx = getNextFrameIndex();
+ }
llvm::dbgs() << "<Bottom>"
<< "\n";
llvm::dbgs() << "=========================================================\n";
@@ -726,7 +909,7 @@ parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) {
return emitError(*opLoc,
"stack doesn't contain enough values. trying to get ")
<< operandTypes.size() << " operands on a stack containing only "
- << values.size() << " values.";
+ << values.size() << " values";
size_t stackIdxOffset = values.size() - operandTypes.size();
SmallVector<Value> res{};
res.reserve(operandTypes.size());
@@ -735,8 +918,7 @@ parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) {
Type stackType = operand.getType();
if (stackType != operandTypes[i])
return emitError(*opLoc, "invalid operand type on stack. expecting ")
- << operandTypes[i] << ", value on stack is of type " << stackType
- << ".";
+ << operandTypes[i] << ", value on stack is of type " << stackType;
LDBG() << " POP: " << operand;
res.push_back(operand);
}
@@ -792,6 +974,151 @@ ExpressionParser::parse(OpBuilder &builder,
}
}
+llvm::FailureOr<FunctionType>
+ExpressionParser::parseBlockFuncType(OpBuilder &builder) {
+ return getFuncTypeFor(builder, parser.parseBlockType(builder.getContext()));
+}
+
+template <typename OpToCreate>
+parsed_inst_t ExpressionParser::parseBlockLikeOp(OpBuilder &builder) {
+ auto opLoc = currentOpLoc;
+ auto funcType = parseBlockFuncType(builder);
+ if (failed(funcType))
+ return failure();
+
+ auto inputTypes = funcType->getInputs();
+ auto inputOps = popOperands(inputTypes);
+ if (failed(inputOps))
+ return failure();
+
+ Block *curBlock = builder.getBlock();
+ Region *curRegion = curBlock->getParent();
+ auto resTypes = funcType->getResults();
+ llvm::SmallVector<Location> locations{};
+ locations.resize(resTypes.size(), *currentOpLoc);
+ auto *successor =
+ builder.createBlock(curRegion, curRegion->end(), resTypes, locations);
+ builder.setInsertionPointToEnd(curBlock);
+ auto blockOp =
+ builder.create<OpToCreate>(*currentOpLoc, *inputOps, successor);
+ auto *blockBody = blockOp.createBlock();
+ if (failed(parseBlockContent(builder, blockBody, resTypes, *opLoc, blockOp)))
+ return failure();
+ builder.setInsertionPointToStart(successor);
+ return {ValueRange{successor->getArguments()}};
+}
+
+template <>
+inline parsed_inst_t
+ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::block>(
+ OpBuilder &builder) {
+ return parseBlockLikeOp<BlockOp>(builder);
+}
+
+template <>
+inline parsed_inst_t
+ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::loop>(
+ OpBuilder &builder) {
+ return parseBlockLikeOp<LoopOp>(builder);
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::ifOpCode>(OpBuilder &builder) {
+ auto opLoc = currentOpLoc;
+ auto funcType = parseBlockFuncType(builder);
+ if (failed(funcType))
+ return failure();
+
+ LDBG() << "Parsing an if instruction of type " << *funcType;
+ auto inputTypes = funcType->getInputs();
+ auto conditionValue = popOperands(builder.getI32Type());
+ if (failed(conditionValue))
+ return failure();
+ auto inputOps = popOperands(inputTypes);
+ if (failed(inputOps))
+ return failure();
+
+ Block *curBlock = builder.getBlock();
+ Region *curRegion = curBlock->getParent();
+ auto resTypes = funcType->getResults();
+ llvm::SmallVector<Location> locations{};
+ locations.resize(resTypes.size(), *currentOpLoc);
+ auto *successor =
+ builder.createBlock(curRegion, curRegion->end(), resTypes, locations);
+ builder.setInsertionPointToEnd(curBlock);
+ auto ifOp = builder.create<IfOp>(*currentOpLoc, conditionValue->front(),
+ *inputOps, successor);
+ auto *ifEntryBlock = ifOp.createIfBlock();
+ constexpr auto ifElseFilter =
+ ByteSequence<WasmBinaryEncoding::endByte,
+ WasmBinaryEncoding::OpCode::elseOpCode>{};
+ auto parseIfRes = parseBlockContent(builder, ifEntryBlock, resTypes, *opLoc,
+ ifOp, ifElseFilter);
+ if (failed(parseIfRes))
+ return failure();
+ if (*parseIfRes == WasmBinaryEncoding::OpCode::elseOpCode) {
+ LDBG() << " else block is present.";
+ Block *elseEntryBlock = ifOp.createElseBlock();
+ auto parseElseRes =
+ parseBlockContent(builder, elseEntryBlock, resTypes, *opLoc, ifOp);
+ if (failed(parseElseRes))
+ return failure();
+ }
+ builder.setInsertionPointToStart(successor);
+ return {ValueRange{successor->getArguments()}};
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::branchIf>(OpBuilder &builder) {
+ auto level = parser.parseLiteral<uint32_t>();
+ if (failed(level))
+ return failure();
+ Block *curBlock = builder.getBlock();
+ Region *curRegion = curBlock->getParent();
+ auto sip = builder.saveInsertionPoint();
+ Block *elseBlock = builder.createBlock(curRegion, curRegion->end());
+ auto condition = popOperands(builder.getI32Type());
+ if (failed(condition))
+ return failure();
+ builder.restoreInsertionPoint(sip);
+ auto targetOp =
+ LabelBranchingOpInterface::getTargetOpFromBlock(curBlock, *level);
+ if (failed(targetOp))
+ return failure();
+ auto inputTypes = targetOp->getLabelTarget()->getArgumentTypes();
+ auto branchArgs = popOperands(inputTypes);
+ if (failed(branchArgs))
+ return failure();
+ builder.create<BranchIfOp>(*currentOpLoc, condition->front(),
+ builder.getUI32IntegerAttr(*level), *branchArgs,
+ elseBlock);
+ builder.setInsertionPointToStart(elseBlock);
+ return {*branchArgs};
+}
+
+template <>
+inline parsed_inst_t
+ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::call>(
+ OpBuilder &builder) {
+ auto loc = *currentOpLoc;
+ auto funcIdx = parser.parseLiteral<uint32_t>();
+ if (failed(funcIdx))
+ return failure();
+ if (*funcIdx >= symbols.funcSymbols.size())
+ return emitError(loc, "Invalid function index: ") << *funcIdx;
+ auto callee = symbols.funcSymbols[*funcIdx];
+ llvm::ArrayRef<Type> inTypes = callee.functionType.getInputs();
+ llvm::ArrayRef<Type> resTypes = callee.functionType.getResults();
+ parsed_inst_t inOperands = popOperands(inTypes);
+ if (failed(inOperands))
+ return failure();
+ auto callOp =
+ builder.create<FuncCallOp>(loc, resTypes, callee.symbol, *inOperands);
+ return {callOp.getResults()};
+}
+
template <>
inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
WasmBinaryEncoding::OpCode::localGet>(OpBuilder &builder) {
@@ -834,7 +1161,7 @@ parsed_inst_t ExpressionParser::parseSetOrTee(OpBuilder &builder) {
if (valueStack.empty())
return emitError(
*currentOpLoc,
- "invalid stack access, trying to access a value on an empty stack.");
+ "invalid stack access, trying to access a value on an empty stack");
parsed_inst_t poppedOp = popOperands(locals[*id].getType().getElementType());
if (failed(poppedOp))
@@ -956,7 +1283,7 @@ inline parsed_inst_t ExpressionParser::buildNumericOp(
<< ", type = " << ty << " ***";
auto tysToPop = SmallVector<Type, numOperands>();
tysToPop.resize(numOperands);
- std::fill(tysToPop.begin(), tysToPop.end(), ty);
+ llvm::fill(tysToPop, ty);
auto operands = popOperands(tysToPop);
if (failed(operands))
return failure();
@@ -1000,11 +1327,23 @@ inline parsed_inst_t ExpressionParser::buildNumericOp(
BUILD_NUMERIC_BINOP_FP(CopySignOp, copysign)
BUILD_NUMERIC_BINOP_FP(DivOp, div)
+BUILD_NUMERIC_BINOP_FP(GeOp, ge)
+BUILD_NUMERIC_BINOP_FP(GtOp, gt)
+BUILD_NUMERIC_BINOP_FP(LeOp, le)
+BUILD_NUMERIC_BINOP_FP(LtOp, lt)
BUILD_NUMERIC_BINOP_FP(MaxOp, max)
BUILD_NUMERIC_BINOP_FP(MinOp, min)
BUILD_NUMERIC_BINOP_INT(AndOp, and)
BUILD_NUMERIC_BINOP_INT(DivSIOp, divS)
BUILD_NUMERIC_BINOP_INT(DivUIOp, divU)
+BUILD_NUMERIC_BINOP_INT(GeSIOp, geS)
+BUILD_NUMERIC_BINOP_INT(GeUIOp, geU)
+BUILD_NUMERIC_BINOP_INT(GtSIOp, gtS)
+BUILD_NUMERIC_BINOP_INT(GtUIOp, gtU)
+BUILD_NUMERIC_BINOP_INT(LeSIOp, leS)
+BUILD_NUMERIC_BINOP_INT(LeUIOp, leU)
+BUILD_NUMERIC_BINOP_INT(LtSIOp, ltS)
+BUILD_NUMERIC_BINOP_INT(LtUIOp, ltU)
BUILD_NUMERIC_BINOP_INT(OrOp, or)
BUILD_NUMERIC_BINOP_INT(RemSIOp, remS)
BUILD_NUMERIC_BINOP_INT(RemUIOp, remU)
@@ -1015,7 +1354,9 @@ BUILD_NUMERIC_BINOP_INT(ShRSOp, shrS)
BUILD_NUMERIC_BINOP_INT(ShRUOp, shrU)
BUILD_NUMERIC_BINOP_INT(XOrOp, xor)
BUILD_NUMERIC_BINOP_INTFP(AddOp, add)
+BUILD_NUMERIC_BINOP_INTFP(EqOp, eq)
BUILD_NUMERIC_BINOP_INTFP(MulOp, mul)
+BUILD_NUMERIC_BINOP_INTFP(NeOp, ne)
BUILD_NUMERIC_BINOP_INTFP(SubOp, sub)
BUILD_NUMERIC_UNARY_OP_FP(AbsOp, abs)
BUILD_NUMERIC_UNARY_OP_FP(CeilOp, ceil)
@@ -1025,6 +1366,7 @@ BUILD_NUMERIC_UNARY_OP_FP(SqrtOp, sqrt)
BUILD_NUMERIC_UNARY_OP_FP(TruncOp, trunc)
BUILD_NUMERIC_UNARY_OP_INT(ClzOp, clz)
BUILD_NUMERIC_UNARY_OP_INT(CtzOp, ctz)
+BUILD_NUMERIC_UNARY_OP_INT(EqzOp, eqz)
BUILD_NUMERIC_UNARY_OP_INT(PopCntOp, popcnt)
// Don't need these anymore so let's undef them.
@@ -1036,6 +1378,105 @@ BUILD_NUMERIC_UNARY_OP_INT(PopCntOp, popcnt)
#undef BUILD_NUMERIC_OP
#undef BUILD_NUMERIC_CAST_OP
+template <typename opType, typename inputType, typename outputType,
+ typename... extraArgsT>
+inline parsed_inst_t ExpressionParser::buildConvertOp(OpBuilder &builder,
+ extraArgsT... extraArgs) {
+ static_assert(std::is_arithmetic_v<inputType>,
+ "InputType should be an arithmetic type");
+ static_assert(std::is_arithmetic_v<outputType>,
+ "OutputType should be an arithmetic type");
+ auto intype = buildLiteralType<inputType>(builder);
+ auto outType = buildLiteralType<outputType>(builder);
+ auto operand = popOperands(intype);
+ if (failed(operand))
+ return failure();
+ auto op = builder.create<opType>(*currentOpLoc, outType, operand->front(),
+ extraArgs...);
+ LDBG() << "Built operation: " << op;
+ return {{op.getResult()}};
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::demoteF64ToF32>(OpBuilder &builder) {
+ return buildConvertOp<DemoteOp, double, float>(builder);
+}
+
+template <>
+inline parsed_inst_t
+ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::wrap>(
+ OpBuilder &builder) {
+ return buildConvertOp<WrapOp, int64_t, int32_t>(builder);
+}
+
+#define BUILD_CONVERSION_OP(IN_T, OUT_T, SOURCE_OP, TARGET_OP) \
+ template <> \
+ inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \
+ WasmBinaryEncoding::OpCode::SOURCE_OP>(OpBuilder & builder) { \
+ return buildConvertOp<TARGET_OP, IN_T, OUT_T>(builder); \
+ }
+
+#define BUILD_CONVERT_OP_FOR(DEST_T, WIDTH) \
+ BUILD_CONVERSION_OP(uint32_t, DEST_T, convertUI32F##WIDTH, ConvertUOp) \
+ BUILD_CONVERSION_OP(int32_t, DEST_T, convertSI32F##WIDTH, ConvertSOp) \
+ BUILD_CONVERSION_OP(uint64_t, DEST_T, convertUI64F##WIDTH, ConvertUOp) \
+ BUILD_CONVERSION_OP(int64_t, DEST_T, convertSI64F##WIDTH, ConvertSOp)
+
+BUILD_CONVERT_OP_FOR(float, 32)
+BUILD_CONVERT_OP_FOR(double, 64)
+
+#undef BUILD_CONVERT_OP_FOR
+
+BUILD_CONVERSION_OP(int32_t, int64_t, extendS, ExtendSI32Op)
+BUILD_CONVERSION_OP(int32_t, int64_t, extendU, ExtendUI32Op)
+
+#undef BUILD_CONVERSION_OP
+
+#define BUILD_SLICE_EXTEND_PARSER(IT_WIDTH, EXTRACT_WIDTH) \
+ template <> \
+ parsed_inst_t ExpressionParser::parseSpecificInstruction< \
+ WasmBinaryEncoding::OpCode::extendI##IT_WIDTH##EXTRACT_WIDTH##S>( \
+ OpBuilder & builder) { \
+ using inout_t = int##IT_WIDTH##_t; \
+ auto attr = builder.getUI32IntegerAttr(EXTRACT_WIDTH); \
+ return buildConvertOp<ExtendLowBitsSOp, inout_t, inout_t>(builder, attr); \
+ }
+
+BUILD_SLICE_EXTEND_PARSER(32, 8)
+BUILD_SLICE_EXTEND_PARSER(32, 16)
+BUILD_SLICE_EXTEND_PARSER(64, 8)
+BUILD_SLICE_EXTEND_PARSER(64, 16)
+BUILD_SLICE_EXTEND_PARSER(64, 32)
+
+#undef BUILD_SLICE_EXTEND_PARSER
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::promoteF32ToF64>(OpBuilder &builder) {
+ return buildConvertOp<PromoteOp, float, double>(builder);
+}
+
+#define BUILD_REINTERPRET_PARSER(WIDTH, FP_TYPE) \
+ template <> \
+ inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \
+ WasmBinaryEncoding::OpCode::reinterpretF##WIDTH##AsI##WIDTH>(OpBuilder & \
+ builder) { \
+ return buildConvertOp<ReinterpretOp, FP_TYPE, int##WIDTH##_t>(builder); \
+ } \
+ \
+ template <> \
+ inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \
+ WasmBinaryEncoding::OpCode::reinterpretI##WIDTH##AsF##WIDTH>(OpBuilder & \
+ builder) { \
+ return buildConvertOp<ReinterpretOp, int##WIDTH##_t, FP_TYPE>(builder); \
+ }
+
+BUILD_REINTERPRET_PARSER(32, float)
+BUILD_REINTERPRET_PARSER(64, double)
+
+#undef BUILD_REINTERPRET_PARSER
+
class WasmBinaryParser {
private:
struct SectionRegistry {
@@ -1153,7 +1594,7 @@ private:
if (tid.id >= symbols.moduleFuncTypes.size())
return emitError(loc, "invalid type id: ")
<< tid.id << ". Only " << symbols.moduleFuncTypes.size()
- << " type registration.";
+ << " type registrations";
FunctionType type = symbols.moduleFuncTypes[tid.id];
std::string symbol = symbols.getNewFuncSymbolName();
auto funcOp = FuncImportOp::create(builder, loc, symbol, moduleName,
@@ -1221,7 +1662,7 @@ public:
FileLineColLoc magicLoc = parser.getLocation();
FailureOr<StringRef> magic = parser.consumeNBytes(wasmHeader.size());
if (failed(magic) || magic->compare(wasmHeader)) {
- emitError(magicLoc, "source file does not contain valid Wasm header.");
+ emitError(magicLoc, "source file does not contain valid Wasm header");
return;
}
auto const expectedVersionString = StringRef{"\1\0\0\0", 4};
@@ -1391,7 +1832,7 @@ WasmBinaryParser::parseSectionItem<WasmSectionType::EXPORT>(ParserHead &ph,
return failure();
Operation *op = SymbolTable::lookupSymbolIn(mOp, *currentSymbol);
- SymbolTable::setSymbolVisibility(op, SymbolTable::Visibility::Public);
+ op->setAttr("exported", UnitAttr::get(op->getContext()));
StringAttr symName = SymbolTable::getSymbolName(op);
return SymbolTable{mOp}.rename(symName, *exportName);
}
diff --git a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
index 9670285..3fda5a7 100644
--- a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
+++ b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
@@ -93,7 +93,7 @@ void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
// Emit function to add the generated matchers to the pattern list.
os << "template <typename... ConfigsT>\n"
- "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
+ "[[maybe_unused]] static void populateGeneratedPDLLPatterns("
"::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n";
for (const auto &name : patternNames)
os << " patterns.add<" << name
diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index c883baa..3236b4f 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -27,6 +27,7 @@
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/ScopedPrinter.h"
+#include "llvm/Support/VirtualFileSystem.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Parser.h"
#include <optional>
@@ -828,6 +829,7 @@ LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
llvm::SourceMgr tdSrcMgr;
tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
+ tdSrcMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem());
// This class provides a context argument for the llvm::SourceMgr diagnostic
// handler.
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 30fd384..9ef405d 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -37,6 +37,7 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/Remarks/RemarkFormat.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/ManagedStatic.h"
@@ -226,6 +227,18 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
"bitstream", "Print bitstream file")),
llvm::cl::cat(remarkCategory)};
+ static llvm::cl::opt<RemarkPolicy, /*ExternalStorage=*/true> remarkPolicy{
+ "remark-policy",
+ llvm::cl::desc("Specify the policy for remark output."),
+ cl::location(remarkPolicyFlag),
+ llvm::cl::value_desc("format"),
+ llvm::cl::init(RemarkPolicy::REMARK_POLICY_ALL),
+ llvm::cl::values(clEnumValN(RemarkPolicy::REMARK_POLICY_ALL, "all",
+ "Print all remarks"),
+ clEnumValN(RemarkPolicy::REMARK_POLICY_FINAL, "final",
+ "Print final remarks")),
+ llvm::cl::cat(remarkCategory)};
+
static cl::opt<std::string, /*ExternalStorage=*/true> remarksAll(
"remarks-filter",
cl::desc("Show all remarks: passed, missed, failed, analysis"),
@@ -517,18 +530,28 @@ performActions(raw_ostream &os,
return failure();
context->enableMultithreading(wasThreadingEnabled);
-
+ // Set the remark categories and policy.
remark::RemarkCategories cats{
config.getRemarksAllFilter(), config.getRemarksPassedFilter(),
config.getRemarksMissedFilter(), config.getRemarksAnalyseFilter(),
config.getRemarksFailedFilter()};
mlir::MLIRContext &ctx = *context;
+ // Helper to create the appropriate policy based on configuration
+ auto createPolicy = [&config]()
+ -> std::unique_ptr<mlir::remark::detail::RemarkEmittingPolicyBase> {
+ if (config.getRemarkPolicy() == RemarkPolicy::REMARK_POLICY_ALL)
+ return std::make_unique<mlir::remark::RemarkEmittingPolicyAll>();
+ if (config.getRemarkPolicy() == RemarkPolicy::REMARK_POLICY_FINAL)
+ return std::make_unique<mlir::remark::RemarkEmittingPolicyFinal>();
+
+ llvm_unreachable("Invalid remark policy");
+ };
switch (config.getRemarkFormat()) {
case RemarkFormat::REMARK_FORMAT_STDOUT:
if (failed(mlir::remark::enableOptimizationRemarks(
- ctx, nullptr, cats, true /*printAsEmitRemarks*/)))
+ ctx, nullptr, createPolicy(), cats, true /*printAsEmitRemarks*/)))
return failure();
break;
@@ -537,7 +560,7 @@ performActions(raw_ostream &os,
? "mlir-remarks.yaml"
: config.getRemarksOutputFile();
if (failed(mlir::remark::enableOptimizationRemarksWithLLVMStreamer(
- ctx, file, llvm::remarks::Format::YAML, cats)))
+ ctx, file, llvm::remarks::Format::YAML, createPolicy(), cats)))
return failure();
break;
}
@@ -547,7 +570,7 @@ performActions(raw_ostream &os,
? "mlir-remarks.bitstream"
: config.getRemarksOutputFile();
if (failed(mlir::remark::enableOptimizationRemarksWithLLVMStreamer(
- ctx, file, llvm::remarks::Format::Bitstream, cats)))
+ ctx, file, llvm::remarks::Format::Bitstream, createPolicy(), cats)))
return failure();
break;
}
@@ -593,6 +616,12 @@ performActions(raw_ostream &os,
AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr,
&fallbackResourceMap);
os << OpWithState(op.get(), asmState) << '\n';
+
+ // This is required if the remark policy is final. Otherwise, the remarks are
+ // not emitted.
+ if (remark::detail::RemarkEngine *engine = ctx.getRemarkEngine())
+ engine->getRemarkEmittingPolicy()->finalize();
+
return success();
}
diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
index 60b9567..1dbe7eca 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
@@ -31,6 +31,7 @@
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/LSP/Logging.h"
#include "llvm/Support/Path.h"
+#include "llvm/Support/VirtualFileSystem.h"
#include <optional>
using namespace mlir;
@@ -402,6 +403,7 @@ PDLDocument::PDLDocument(const llvm::lsp::URIForFile &uri, StringRef contents,
llvm::append_range(includeDirs, extraDirs);
sourceMgr.setIncludeDirs(includeDirs);
+ sourceMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem());
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
astContext.getDiagEngine().setHandlerFn([&](const ast::Diagnostic &diag) {
diff --git a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp
index 3080b78..2d817be 100644
--- a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp
+++ b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp
@@ -17,6 +17,7 @@
#include "llvm/Support/LSP/Logging.h"
#include "llvm/Support/LSP/Protocol.h"
#include "llvm/Support/Path.h"
+#include "llvm/Support/VirtualFileSystem.h"
#include "llvm/TableGen/Parser.h"
#include "llvm/TableGen/Record.h"
#include <optional>
@@ -448,6 +449,7 @@ void TableGenTextFile::initialize(
return;
}
sourceMgr.setIncludeDirs(includeDirs);
+ sourceMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem());
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
// This class provides a context argument for the SourceMgr diagnostic
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 111f58e..5f3b04a 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -66,7 +66,9 @@ size_t mlir::moveLoopInvariantCode(
size_t numMoved = 0;
for (Region *region : regions) {
- LDBG() << "Original loop:\n" << *region->getParentOp();
+ LDBG() << "Original loop:\n"
+ << OpWithFlags(region->getParentOp(),
+ OpPrintingFlags().skipRegions());
std::queue<Operation *> worklist;
// Add top-level operations in the loop body to the worklist.
@@ -90,7 +92,8 @@ size_t mlir::moveLoopInvariantCode(
!canBeHoisted(op, definedOutside))
continue;
- LDBG() << "Moving loop-invariant op: " << *op;
+ LDBG() << "Moving loop-invariant op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
moveOutOfRegion(op, region);
++numMoved;
@@ -111,9 +114,7 @@ size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
[&](Value value, Region *) {
return loopLike.isDefinedOutsideOfLoop(value);
},
- [&](Operation *op, Region *) {
- return isMemoryEffectFree(op) && isSpeculatable(op);
- },
+ [&](Operation *op, Region *) { return isPure(op); },
[&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 9f5246d..ffa96ad 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -137,6 +137,16 @@ declare_mlir_dialect_python_bindings(
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/OpenACCOps.td
+ SOURCES
+ dialects/openacc.py
+ DIALECT_NAME acc
+ DEPENDS acc_common_td
+ )
+
+declare_mlir_dialect_python_bindings(
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/GPUOps.td
SOURCES_GLOB dialects/gpu/*.py
DIALECT_NAME gpu
diff --git a/mlir/python/mlir/dialects/OpenACCOps.td b/mlir/python/mlir/dialects/OpenACCOps.td
new file mode 100644
index 0000000..69a3002
--- /dev/null
+++ b/mlir/python/mlir/dialects/OpenACCOps.td
@@ -0,0 +1,14 @@
+//===-- OpenACCOps.td - Entry point for OpenACCOps bind ------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_OPENACC_OPS
+#define PYTHON_BINDINGS_OPENACC_OPS
+
+include "mlir/Dialect/OpenACC/OpenACCOps.td"
+
+#endif
diff --git a/mlir/python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py
index 4cd80aa..b14ea68 100644
--- a/mlir/python/mlir/dialects/gpu/__init__.py
+++ b/mlir/python/mlir/dialects/gpu/__init__.py
@@ -3,5 +3,151 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._gpu_ops_gen import *
+from .._gpu_ops_gen import _Dialect
from .._gpu_enum_gen import *
from ..._mlir_libs._mlirDialectsGPU import *
+from typing import Callable, Sequence, Union, Optional, List
+
+try:
+ from ...ir import (
+ FunctionType,
+ TypeAttr,
+ StringAttr,
+ UnitAttr,
+ Block,
+ InsertionPoint,
+ ArrayAttr,
+ Type,
+ DictAttr,
+ Attribute,
+ DenseI32ArrayAttr,
+ )
+ from .._ods_common import (
+ get_default_loc_context as _get_default_loc_context,
+ _cext as _ods_cext,
+ )
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class GPUFuncOp(GPUFuncOp):
+ __doc__ = GPUFuncOp.__doc__
+
+ KERNEL_ATTR_NAME = "gpu.kernel"
+ KNOWN_BLOCK_SIZE_ATTR_NAME = "known_block_size"
+ KNOWN_GRID_SIZE_ATTR_NAME = "known_grid_size"
+
+ FUNCTION_TYPE_ATTR_NAME = "function_type"
+ SYM_NAME_ATTR_NAME = "sym_name"
+ ARGUMENT_ATTR_NAME = "arg_attrs"
+ RESULT_ATTR_NAME = "res_attrs"
+
+ def __init__(
+ self,
+ function_type: Union[FunctionType, TypeAttr],
+ sym_name: Optional[Union[str, StringAttr]] = None,
+ kernel: Optional[bool] = None,
+ workgroup_attrib_attrs: Optional[Sequence[dict]] = None,
+ private_attrib_attrs: Optional[Sequence[dict]] = None,
+ known_block_size: Optional[Union[Sequence[int], DenseI32ArrayAttr]] = None,
+ known_grid_size: Optional[Union[Sequence[int], DenseI32ArrayAttr]] = None,
+ loc=None,
+ ip=None,
+ body_builder: Optional[Callable[[GPUFuncOp], None]] = None,
+ ):
+ """
+ Create a GPUFuncOp with the provided `function_type`, `sym_name`,
+ `kernel`, `workgroup_attrib_attrs`, `private_attrib_attrs`, `known_block_size`,
+ `known_grid_size`, and `body_builder`.
+ - `function_type` is a FunctionType or a TypeAttr.
+ - `sym_name` is a string or a StringAttr representing the function name.
+ - `kernel` is a boolean representing whether the function is a kernel.
+ - `workgroup_attrib_attrs` is an optional list of dictionaries.
+ - `private_attrib_attrs` is an optional list of dictionaries.
+ - `known_block_size` is an optional list of integers or a DenseI32ArrayAttr representing the known block size.
+ - `known_grid_size` is an optional list of integers or a DenseI32ArrayAttr representing the known grid size.
+ - `body_builder` is an optional callback. When provided, a new entry block
+ is created and the callback is invoked with the new op as argument within
+ an InsertionPoint context already set for the block. The callback is
+ expected to insert a terminator in the block.
+ """
+ function_type = (
+ TypeAttr.get(function_type)
+ if not isinstance(function_type, TypeAttr)
+ else function_type
+ )
+ super().__init__(
+ function_type,
+ workgroup_attrib_attrs=workgroup_attrib_attrs,
+ private_attrib_attrs=private_attrib_attrs,
+ loc=loc,
+ ip=ip,
+ )
+
+ if isinstance(sym_name, str):
+ self.attributes[self.SYM_NAME_ATTR_NAME] = StringAttr.get(sym_name)
+ elif isinstance(sym_name, StringAttr):
+ self.attributes[self.SYM_NAME_ATTR_NAME] = sym_name
+ else:
+ raise ValueError("sym_name must be a string or a StringAttr")
+
+ if kernel:
+ self.attributes[self.KERNEL_ATTR_NAME] = UnitAttr.get()
+
+ if known_block_size is not None:
+ if isinstance(known_block_size, Sequence):
+ block_size = DenseI32ArrayAttr.get(known_block_size)
+ self.attributes[self.KNOWN_BLOCK_SIZE_ATTR_NAME] = block_size
+ elif isinstance(known_block_size, DenseI32ArrayAttr):
+ self.attributes[self.KNOWN_BLOCK_SIZE_ATTR_NAME] = known_block_size
+ else:
+ raise ValueError(
+ "known_block_size must be a list of integers or a DenseI32ArrayAttr"
+ )
+
+ if known_grid_size is not None:
+ if isinstance(known_grid_size, Sequence):
+ grid_size = DenseI32ArrayAttr.get(known_grid_size)
+ self.attributes[self.KNOWN_GRID_SIZE_ATTR_NAME] = grid_size
+ elif isinstance(known_grid_size, DenseI32ArrayAttr):
+ self.attributes[self.KNOWN_GRID_SIZE_ATTR_NAME] = known_grid_size
+ else:
+ raise ValueError(
+ "known_grid_size must be a list of integers or a DenseI32ArrayAttr"
+ )
+
+ if body_builder is not None:
+ with InsertionPoint(self.add_entry_block()):
+ body_builder(self)
+
+ @property
+ def name(self) -> StringAttr:
+ return StringAttr(self.attributes[self.SYM_NAME_ATTR_NAME])
+
+ @property
+ def is_kernel(self) -> bool:
+ return self.KERNEL_ATTR_NAME in self.attributes
+
+ def add_entry_block(self) -> Block:
+ if len(self.body.blocks) > 0:
+ raise RuntimeError(f"Entry block already exists for {self.name.value}")
+
+ function_type = self.function_type.value
+ return self.body.blocks.append(
+ *function_type.inputs,
+ arg_locs=[self.location for _ in function_type.inputs],
+ )
+
+ @property
+ def entry_block(self) -> Block:
+ if len(self.body.blocks) == 0:
+ raise RuntimeError(
+ f"Entry block does not exist for {self.name.value}."
+ + " Do you need to call the add_entry_block() method on this GPUFuncOp?"
+ )
+ return self.body.blocks[0]
+
+ @property
+ def arguments(self) -> Sequence[Type]:
+ return self.function_type.value.inputs
diff --git a/mlir/python/mlir/dialects/openacc.py b/mlir/python/mlir/dialects/openacc.py
new file mode 100644
index 0000000..057f71a
--- /dev/null
+++ b/mlir/python/mlir/dialects/openacc.py
@@ -0,0 +1,5 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._acc_ops_gen import *
diff --git a/mlir/test/Analysis/DataFlow/test-strided-metadata-range-analysis.mlir b/mlir/test/Analysis/DataFlow/test-strided-metadata-range-analysis.mlir
new file mode 100644
index 0000000..808c1c2
--- /dev/null
+++ b/mlir/test/Analysis/DataFlow/test-strided-metadata-range-analysis.mlir
@@ -0,0 +1,67 @@
+// RUN: mlir-opt -test-strided-metadata-range-analysis %s 2>&1 | FileCheck %s
+
+func.func @memref_subview(%arg0: memref<8x16x4xf32, strided<[64, 4, 1]>>, %arg1: memref<1x128x1x32x1xf32, strided<[4096, 32, 32, 1, 1]>>, %arg2: memref<8x16x4xf32, strided<[1, 64, 8], offset: 16>>, %arg3: index, %arg4: index, %arg5: index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %0 = test.with_bounds {smax = 13 : index, smin = 11 : index, umax = 13 : index, umin = 11 : index} : index
+ %1 = test.with_bounds {smax = 7 : index, smin = 5 : index, umax = 7 : index, umin = 5 : index} : index
+
+ // Test subview with unknown sizes, and constant offsets and strides.
+ // CHECK: Op: %[[SV0:.*]] = memref.subview
+ // CHECK-NEXT: result[0]: strided_metadata<
+ // CHECK-SAME: offset = [{unsigned : [1, 1] signed : [1, 1]}]
+ // CHECK-SAME: sizes = [{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}, {unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}, {unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}]
+ // CHECK-SAME: strides = [{unsigned : [64, 64] signed : [64, 64]}, {unsigned : [4, 4] signed : [4, 4]}, {unsigned : [1, 1] signed : [1, 1]}]
+ %subview = memref.subview %arg0[%c0, %c0, %c1] [%arg3, %arg4, %arg5] [%c1, %c1, %c1] : memref<8x16x4xf32, strided<[64, 4, 1]>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+
+ // Test a subview of a subview, with bounded dynamic offsets.
+ // CHECK: Op: %[[SV1:.*]] = memref.subview
+ // CHECK-NEXT: result[0]: strided_metadata<
+ // CHECK-SAME: offset = [{unsigned : [346, 484] signed : [346, 484]}]
+ // CHECK-SAME: sizes = [{unsigned : [2, 2] signed : [2, 2]}, {unsigned : [2, 2] signed : [2, 2]}, {unsigned : [2, 2] signed : [2, 2]}]
+ // CHECK-SAME: strides = [{unsigned : [704, 832] signed : [704, 832]}, {unsigned : [44, 52] signed : [44, 52]}, {unsigned : [11, 13] signed : [11, 13]}]
+ %subview_0 = memref.subview %subview[%1, %1, %1] [%c2, %c2, %c2] [%0, %0, %0] : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+
+ // Test a subview of a subview, with constant operands.
+ // CHECK: Op: %[[SV2:.*]] = memref.subview
+ // CHECK-NEXT: result[0]: strided_metadata<
+ // CHECK-SAME: offset = [{unsigned : [368, 510] signed : [368, 510]}]
+ // CHECK-SAME: sizes = [{unsigned : [2, 2] signed : [2, 2]}, {unsigned : [2, 2] signed : [2, 2]}, {unsigned : [2, 2] signed : [2, 2]}]
+ // CHECK-SAME: strides = [{unsigned : [704, 832] signed : [704, 832]}, {unsigned : [44, 52] signed : [44, 52]}, {unsigned : [11, 13] signed : [11, 13]}]
+ %subview_1 = memref.subview %subview_0[%c0, %c0, %c2] [%c2, %c2, %c2] [%c1, %c1, %c1] : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+
+ // Test a rank-reducing subview.
+ // CHECK: Op: %[[SV3:.*]] = memref.subview
+ // CHECK-NEXT: result[0]: strided_metadata<
+ // CHECK-SAME: offset = [{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}]
+ // CHECK-SAME: sizes = [{unsigned : [64, 64] signed : [64, 64]}, {unsigned : [16, 16] signed : [16, 16]}]
+ // CHECK-SAME: strides = [{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}, {unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}]
+ %subview_2 = memref.subview %arg1[%arg4, %arg4, %arg4, %arg4, %arg4] [1, 64, 1, 16, 1] [%arg5, %arg5, %arg5, %arg5, %arg5] : memref<1x128x1x32x1xf32, strided<[4096, 32, 32, 1, 1]>> to memref<64x16xf32, strided<[?, ?], offset: ?>>
+
+ // Test a subview of a rank-reducing subview
+ // CHECK: Op: %[[SV4:.*]] = memref.subview
+ // CHECK-NEXT: result[0]: strided_metadata<
+ // CHECK-SAME: offset = [{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}]
+ // CHECK-SAME: sizes = [{unsigned : [5, 7] signed : [5, 7]}]
+ // CHECK-SAME: strides = [{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}]
+ %subview_3 = memref.subview %subview_2[%c0, %0] [1, %1] [%c1, %c2] : memref<64x16xf32, strided<[?, ?], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
+
+ // Test a subview with mixed bounded and unbound dynamic sizes.
+ // CHECK: Op: %[[SV5:.*]] = memref.subview
+ // CHECK-NEXT: result[0]: strided_metadata<
+ // CHECK-SAME: offset = [{unsigned : [32, 32] signed : [32, 32]}]
+ // CHECK-SAME: sizes = [{unsigned : [11, 13] signed : [11, 13]}, {unsigned : [5, 7] signed : [5, 7]}, {unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}]
+ // CHECK-SAME: strides = [{unsigned : [1, 1] signed : [1, 1]}, {unsigned : [64, 64] signed : [64, 64]}, {unsigned : [8, 8] signed : [8, 8]}]
+ %subview_4 = memref.subview %arg2[%c0, %c0, %c2] [%0, %1, %arg5] [%c1, %c1, %c1] : memref<8x16x4xf32, strided<[1, 64, 8], offset: 16>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+ return
+}
+
+// CHECK: func.func @memref_subview
+// CHECK: %[[A0:.*]]: memref<8x16x4xf32, strided<[64, 4, 1]>>
+// CHECK: %[[SV0]] = memref.subview %[[A0]]
+// CHECK-NEXT: %[[SV1]] = memref.subview
+// CHECK-NEXT: %[[SV2]] = memref.subview
+// CHECK-NEXT: %[[SV3]] = memref.subview
+// CHECK-NEXT: %[[SV4]] = memref.subview
+// CHECK-NEXT: %[[SV5]] = memref.subview
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index dbff233..455f886 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -convert-math-to-rocdl -allow-unregistered-dialect -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx803})' | FileCheck %s --check-prefix=PRE9
+// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx942})' | FileCheck %s --check-prefix=POST9
module @test_module {
// CHECK: llvm.func @__ocml_fmod_f16(f16, f16) -> f16
@@ -596,3 +597,76 @@ module @test_module {
func.return %result : vector<2x2xf16>
}
}
+
+// -----
+
+// f16 clamp → rocdl.fmed3 on gfx9+
+// CHECK-LABEL: func.func @clampf_f16
+func.func @clampf_f16(%x: f16, %lo: f16, %hi: f16) -> f16 {
+ %r = math.clampf %x to [%lo, %hi] : f16
+ return %r : f16
+ // POST9: rocdl.fmed3 {{.*}} : f16
+ // PRE9-NOT: rocdl.fmed3
+ // PRE9: math.clampf {{.*}} : f16
+}
+
+// f32 clamp → rocdl.fmed3 on gfx9+
+// CHECK-LABEL: func.func @clampf_f32
+func.func @clampf_f32(%x: f32, %lo: f32, %hi: f32) -> f32 {
+ %r = math.clampf %x to [%lo, %hi] : f32
+ return %r : f32
+ // POST9: rocdl.fmed3 {{.*}} : f32
+ // PRE9-NOT: rocdl.fmed3
+ // PRE9: math.clampf {{.*}} : f32
+}
+
+// -----
+
+// Vector f16 clamp → rocdl.fmed3 on gfx9+
+// CHECK-LABEL: func.func @clampf_vector_f16
+func.func @clampf_vector_f16(%x: vector<2xf16>, %lo: vector<2xf16>, %hi: vector<2xf16>) -> vector<2xf16> {
+ %r = math.clampf %x to [%lo, %hi] : vector<2xf16>
+ return %r : vector<2xf16>
+ // POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
+ // PRE9-NOT: rocdl.fmed3
+ // PRE9: math.clampf {{.*}} : vector<2xf16>
+}
+
+// -----
+
+// Vector f32 clamp → rocdl.fmed3 on gfx9+
+// CHECK-LABEL: func.func @clampf_vector_f32
+func.func @clampf_vector_f32(%x: vector<2xf32>, %lo: vector<2xf32>, %hi: vector<2xf32>) -> vector<2xf32> {
+ %r = math.clampf %x to [%lo, %hi] : vector<2xf32>
+ return %r : vector<2xf32>
+ // POST9: rocdl.fmed3 {{.*}} : vector<2xf32>
+ // PRE9-NOT: rocdl.fmed3
+ // PRE9: math.clampf {{.*}} : vector<2xf32>
+}
+
+// -----
+
+// Multi-dimensional vector f16 clamp → rocdl.fmed3 on gfx9+ (unrolled to 1D vectors)
+// CHECK-LABEL: func.func @clampf_vector_2d_f16
+func.func @clampf_vector_2d_f16(%x: vector<2x2xf16>, %lo: vector<2x2xf16>, %hi: vector<2x2xf16>) -> vector<2x2xf16> {
+ %r = math.clampf %x to [%lo, %hi] : vector<2x2xf16>
+ return %r : vector<2x2xf16>
+ // POST9: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>>
+ // POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ // POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
+ // POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ // POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ // POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
+ // POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ // PRE9-NOT: rocdl.fmed3
+ // PRE9: math.clampf {{.*}} : vector<2x2xf16>
+}
+
+// -----
+// CHECK-LABEL: func.func @clampf_bf16
+func.func @clampf_bf16(%x: bf16, %lo: bf16, %hi: bf16) -> bf16 {
+ %r = math.clampf %x to [%lo, %hi] : bf16
+ return %r : bf16
+ // CHECK: math.clampf {{.*}} : bf16
+ // CHECK-NOT: rocdl.fmed3
+}
diff --git a/mlir/test/Conversion/MathToXeVM/lit.local.cfg b/mlir/test/Conversion/MathToXeVM/lit.local.cfg
new file mode 100644
index 0000000..cc1ce35
--- /dev/null
+++ b/mlir/test/Conversion/MathToXeVM/lit.local.cfg
@@ -0,0 +1,7 @@
+spirv_backend_tests = [
+ 'native-spirv-builtins.mlir',
+]
+
+# Exclude SPIRV backend tests if SPIRV target is disabled:
+if(not config.run_xevm_tests):
+ config.excludes.update(spirv_backend_tests)
diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
new file mode 100644
index 0000000..d76627b
--- /dev/null
+++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
@@ -0,0 +1,155 @@
+// RUN: mlir-opt %s -convert-math-to-xevm \
+// RUN: | FileCheck %s -check-prefixes='CHECK,CHECK-ARITH'
+// RUN: mlir-opt %s -convert-math-to-xevm='convert-arith=false' \
+// RUN: | FileCheck %s -check-prefixes='CHECK,CHECK-NO-ARITH'
+
+module @test_module {
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
+ //
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv2_d(vector<2xf64>) -> vector<2xf64>
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv3_d(vector<3xf64>) -> vector<3xf64>
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv4_d(vector<4xf64>) -> vector<4xf64>
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv8_d(vector<8xf64>) -> vector<8xf64>
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv16_d(vector<16xf64>) -> vector<16xf64>
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv16_f(vector<16xf32>) -> vector<16xf32>
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv4_Dh(vector<4xf16>) -> vector<4xf16>
+ //
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_cosDh(f16) -> f16
+ // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_exp2f(f32) -> f32
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_logDh(f16) -> f16
+ // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_log2f(f32) -> f32
+ // CHECK-DAG: llvm.func @_Z24__spirv_ocl_native_log10d(f64) -> f64
+ // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_powrDhDh(f16, f16) -> f16
+ // CHECK-DAG: llvm.func @_Z24__spirv_ocl_native_rsqrtd(f64) -> f64
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16
+ // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64
+ // CHECK-ARITH-DAG: llvm.func @_Z25__spirv_ocl_native_divideff(f32, f32) -> f32
+
+ // CHECK-LABEL: func @math_ops
+ func.func @math_ops() {
+
+ %c1_f16 = arith.constant 1. : f16
+ %c1_f32 = arith.constant 1. : f32
+ %c1_f64 = arith.constant 1. : f64
+
+ // CHECK: math.exp
+ %exp_normal_f16 = math.exp %c1_f16 : f16
+ // CHECK: math.exp
+ %exp_normal_f32 = math.exp %c1_f32 : f32
+ // CHECK: math.exp
+ %exp_normal_f64 = math.exp %c1_f64 : f64
+
+ // Check float operations are converted properly:
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f16) -> f16
+ %exp_fast_f16 = math.exp %c1_f16 fastmath<fast> : f16
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f32) -> f32
+ %exp_fast_f32 = math.exp %c1_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f64) -> f64
+ %exp_fast_f64 = math.exp %c1_f64 fastmath<fast> : f64
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+ %exp_afn_f16 = math.exp %c1_f16 fastmath<afn> : f16
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+ %exp_afn_f32 = math.exp %c1_f32 fastmath<afn> : f32
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
+ %exp_afn_f64 = math.exp %c1_f64 fastmath<afn> : f64
+
+ // CHECK: math.exp
+ %exp_none_f16 = math.exp %c1_f16 fastmath<none> : f16
+ // CHECK: math.exp
+ %exp_none_f32 = math.exp %c1_f32 fastmath<none> : f32
+ // CHECK: math.exp
+ %exp_none_f64 = math.exp %c1_f64 fastmath<none> : f64
+
+ // Check vector operations:
+
+ %v2_c1_f64 = arith.constant dense<1.> : vector<2xf64>
+ %v3_c1_f64 = arith.constant dense<1.> : vector<3xf64>
+ %v4_c1_f64 = arith.constant dense<1.> : vector<4xf64>
+ %v8_c1_f64 = arith.constant dense<1.> : vector<8xf64>
+ %v16_c1_f64 = arith.constant dense<1.> : vector<16xf64>
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv2_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<2xf64>) -> vector<2xf64>
+ %exp_v2_f64 = math.exp %v2_c1_f64 fastmath<afn> : vector<2xf64>
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv3_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<3xf64>) -> vector<3xf64>
+ %exp_v3_f64 = math.exp %v3_c1_f64 fastmath<afn> : vector<3xf64>
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<4xf64>) -> vector<4xf64>
+ %exp_v4_f64 = math.exp %v4_c1_f64 fastmath<afn> : vector<4xf64>
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv8_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<8xf64>) -> vector<8xf64>
+ %exp_v8_f64 = math.exp %v8_c1_f64 fastmath<afn> : vector<8xf64>
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<16xf64>) -> vector<16xf64>
+ %exp_v16_f64 = math.exp %v16_c1_f64 fastmath<afn> : vector<16xf64>
+
+ %v16_c1_f32 = arith.constant dense<1.> : vector<16xf32>
+ %v4_c1_f16 = arith.constant dense<1.> : vector<4xf16>
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_f(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (vector<16xf32>) -> vector<16xf32>
+ %exp_v16_f32 = math.exp %v16_c1_f32 fastmath<fast> : vector<16xf32>
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_Dh(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (vector<4xf16>) -> vector<4xf16>
+ %exp_v4_f16 = math.exp %v4_c1_f16 fastmath<fast> : vector<4xf16>
+
+ // Check unsupported vector sizes are not converted:
+
+ %v5_c1_f64 = arith.constant dense<1.> : vector<5xf64>
+ %v32_c1_f64 = arith.constant dense<1.> : vector<32xf64>
+
+ // CHECK: math.exp
+ %exp_v5_f64 = math.exp %v5_c1_f64 fastmath<afn> : vector<5xf64>
+ // CHECK: math.exp
+ %exp_v32_f64 = math.exp %v32_c1_f64 fastmath<afn> : vector<32xf64>
+
+ // Check fastmath flags propagate properly:
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f16) -> f16
+ %exp_fastmath_all_f16 = math.exp %c1_f16 fastmath<reassoc,nnan,ninf,nsz,arcp,contract,afn> : f16
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<nnan, ninf, nsz, arcp, contract, afn>} : (f32) -> f32
+ %exp_fastmath_most_f32 = math.exp %c1_f32 fastmath<nnan,ninf,nsz,arcp,contract,afn> : f32
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<nnan, afn, reassoc>} : (f32) -> f32
+ %exp_afn_reassoc_nnan_f32 = math.exp %c1_f32 fastmath<afn,reassoc,nnan> : f32
+
+ // Check all other math operations:
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_cosDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+ %cos_afn_f16 = math.cos %c1_f16 fastmath<afn> : f16
+
+ // CHECK: llvm.call @_Z23__spirv_ocl_native_exp2f(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+ %exp2_afn_f32 = math.exp2 %c1_f32 fastmath<afn> : f32
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_logDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+ %log_afn_f16 = math.log %c1_f16 fastmath<afn> : f16
+
+ // CHECK: llvm.call @_Z23__spirv_ocl_native_log2f(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+ %log2_afn_f32 = math.log2 %c1_f32 fastmath<afn> : f32
+
+ // CHECK: llvm.call @_Z24__spirv_ocl_native_log10d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
+ %log10_afn_f64 = math.log10 %c1_f64 fastmath<afn> : f64
+
+ // CHECK: llvm.call @_Z23__spirv_ocl_native_powrDhDh(%{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16, f16) -> f16
+ %powr_afn_f16 = math.powf %c1_f16, %c1_f16 fastmath<afn> : f16
+
+ // CHECK: llvm.call @_Z24__spirv_ocl_native_rsqrtd(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
+ %rsqrt_afn_f64 = math.rsqrt %c1_f64 fastmath<afn> : f64
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_sinDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+ %sin_afn_f16 = math.sin %c1_f16 fastmath<afn> : f16
+
+ // CHECK: llvm.call @_Z23__spirv_ocl_native_sqrtf(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+ %sqrt_afn_f32 = math.sqrt %c1_f32 fastmath<afn> : f32
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_tand(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
+ %tan_afn_f64 = math.tan %c1_f64 fastmath<afn> : f64
+
+ %c6_9_f32 = arith.constant 6.9 : f32
+ %c7_f32 = arith.constant 7. : f32
+
+ // CHECK-ARITH: llvm.call @_Z25__spirv_ocl_native_divideff(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32, f32) -> f32
+ // CHECK-NO-ARITH: arith.divf
+ %divf_afn_f32 = arith.divf %c6_9_f32, %c7_f32 fastmath<afn> : f32
+
+ return
+ }
+}
diff --git a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
new file mode 100644
index 0000000..82426c4
--- /dev/null
+++ b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
@@ -0,0 +1,119 @@
+// RUN: mlir-opt %s -gpu-module-to-binary="format=isa" \
+// RUN: -debug-only=serialize-to-isa 2> %t
+// RUN: FileCheck --input-file=%t %s
+// REQUIRES: asserts
+//
+// MathToXeVM pass generates OpenCL intrinsics function calls when converting
+// Math ops with `fastmath` attr to native function calls. It is assumed that
+// the SPIRV backend would correctly convert these intrinsics calls to OpenCL
+// ExtInst instructions in SPIRV (See llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp).
+//
+// To ensure this assumption holds, this test verifies that the SPIRV backend
+// behaves as expected.
+
+module @test_ocl_intrinsics attributes {gpu.container_module} {
+ gpu.module @kernel [#xevm.target] {
+ llvm.func spir_kernelcc @native_fcns() attributes {gpu.kernel} {
+ // CHECK-DAG: %[[F16T:.+]] = OpTypeFloat 16
+ // CHECK-DAG: %[[ZERO_F16:.+]] = OpConstantNull %[[F16T]]
+ %c0_f16 = llvm.mlir.constant(0. : f16) : f16
+ // CHECK-DAG: %[[F32T:.+]] = OpTypeFloat 32
+ // CHECK-DAG: %[[ZERO_F32:.+]] = OpConstantNull %[[F32T]]
+ %c0_f32 = llvm.mlir.constant(0. : f32) : f32
+ // CHECK-DAG: %[[F64T:.+]] = OpTypeFloat 64
+ // CHECK-DAG: %[[ZERO_F64:.+]] = OpConstantNull %[[F64T]]
+ %c0_f64 = llvm.mlir.constant(0. : f64) : f64
+
+ // CHECK-DAG: %[[V2F64T:.+]] = OpTypeVector %[[F64T]] 2
+ // CHECK-DAG: %[[V2_ZERO_F64:.+]] = OpConstantNull %[[V2F64T]]
+ %v2_c0_f64 = llvm.mlir.constant(dense<0.> : vector<2xf64>) : vector<2xf64>
+ // CHECK-DAG: %[[V3F32T:.+]] = OpTypeVector %[[F32T]] 3
+ // CHECK-DAG: %[[V3_ZERO_F32:.+]] = OpConstantNull %[[V3F32T]]
+ %v3_c0_f32 = llvm.mlir.constant(dense<0.> : vector<3xf32>) : vector<3xf32>
+ // CHECK-DAG: %[[V4F64T:.+]] = OpTypeVector %[[F64T]] 4
+ // CHECK-DAG: %[[V4_ZERO_F64:.+]] = OpConstantNull %[[V4F64T]]
+ %v4_c0_f64 = llvm.mlir.constant(dense<0.> : vector<4xf64>) : vector<4xf64>
+ // CHECK-DAG: %[[V8F64T:.+]] = OpTypeVector %[[F64T]] 8
+ // CHECK-DAG: %[[V8_ZERO_F64:.+]] = OpConstantNull %[[V8F64T]]
+ %v8_c0_f64 = llvm.mlir.constant(dense<0.> : vector<8xf64>) : vector<8xf64>
+ // CHECK-DAG: %[[V16F16T:.+]] = OpTypeVector %[[F16T]] 16
+ // CHECK-DAG: %[[V16_ZERO_F16:.+]] = OpConstantNull %[[V16F16T]]
+ %v16_c0_f16 = llvm.mlir.constant(dense<0.> : vector<16xf16>) : vector<16xf16>
+
+ // CHECK: OpExtInst %[[F16T]] %{{.+}} native_exp %[[ZERO_F16]]
+ %exp_f16 = llvm.call @_Z22__spirv_ocl_native_expDh(%c0_f16) : (f16) -> f16
+ // CHECK: OpExtInst %[[F32T]] %{{.+}} native_exp %[[ZERO_F32]]
+ %exp_f32 = llvm.call @_Z22__spirv_ocl_native_expf(%c0_f32) : (f32) -> f32
+ // CHECK: OpExtInst %[[F64T]] %{{.+}} native_exp %[[ZERO_F64]]
+ %exp_f64 = llvm.call @_Z22__spirv_ocl_native_expd(%c0_f64) : (f64) -> f64
+
+ // CHECK: OpExtInst %[[V2F64T]] %{{.+}} native_exp %[[V2_ZERO_F64]]
+ %exp_v2_f64 = llvm.call @_Z22__spirv_ocl_native_expDv2_f64(%v2_c0_f64) : (vector<2xf64>) -> vector<2xf64>
+ // CHECK: OpExtInst %[[V3F32T]] %{{.+}} native_exp %[[V3_ZERO_F32]]
+ %exp_v3_f32 = llvm.call @_Z22__spirv_ocl_native_expDv3_f32(%v3_c0_f32) : (vector<3xf32>) -> vector<3xf32>
+ // CHECK: OpExtInst %[[V4F64T]] %{{.+}} native_exp %[[V4_ZERO_F64]]
+ %exp_v4_f64 = llvm.call @_Z22__spirv_ocl_native_expDv4_f64(%v4_c0_f64) : (vector<4xf64>) -> vector<4xf64>
+ // CHECK: OpExtInst %[[V8F64T]] %{{.+}} native_exp %[[V8_ZERO_F64]]
+ %exp_v8_f64 = llvm.call @_Z22__spirv_ocl_native_expDv8_f64(%v8_c0_f64) : (vector<8xf64>) -> vector<8xf64>
+ // CHECK: OpExtInst %[[V16F16T]] %{{.+}} native_exp %[[V16_ZERO_F16]]
+ %exp_v16_f16 = llvm.call @_Z22__spirv_ocl_native_expDv16_f16(%v16_c0_f16) : (vector<16xf16>) -> vector<16xf16>
+
+ // SPIRV backend does not currently handle fastmath flags: The SPIRV
+ // backend would need to generate OpDecorate calls to decorate math ops
+ // with FPFastMathMode/FPFastMathModeINTEL decorations.
+ //
+ // FIXME: When support for fastmath flags in the SPIRV backend is added,
+ // add tests here to ensure fastmath flags are converted to the correct
+ // OpDecorate calls.
+ //
+ // See:
+ // - https://registry.khronos.org/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html#_math_extended_instructions
+ // - https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpDecorate
+
+ // CHECK: OpExtInst %[[F16T]] %{{.+}} native_cos %[[ZERO_F16]]
+ %cos_afn_f16 = llvm.call @_Z22__spirv_ocl_native_cosDh(%c0_f16) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+ // CHECK: OpExtInst %[[F32T]] %{{.+}} native_exp2 %[[ZERO_F32]]
+ %exp2_afn_f32 = llvm.call @_Z23__spirv_ocl_native_exp2f(%c0_f32) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+ // CHECK: OpExtInst %[[F16T]] %{{.+}} native_log %[[ZERO_F16]]
+ %log_afn_f16 = llvm.call @_Z22__spirv_ocl_native_logDh(%c0_f16) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+ // CHECK: OpExtInst %[[F32T]] %{{.+}} native_log2 %[[ZERO_F32]]
+ %log2_afn_f32 = llvm.call @_Z23__spirv_ocl_native_log2f(%c0_f32) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+ // CHECK: OpExtInst %[[V8F64T]] %{{.+}} native_log10 %[[V8_ZERO_F64]]
+ %log10_afn_f64 = llvm.call @_Z24__spirv_ocl_native_log10Dv8_d(%v8_c0_f64) {fastmathFlags = #llvm.fastmath<afn>} : (vector<8xf64>) -> vector<8xf64>
+ // CHECK: OpExtInst %[[V16F16T]] %{{.+}} native_powr %[[V16_ZERO_F16]] %[[V16_ZERO_F16]]
+ %powr_afn_f16 = llvm.call @_Z23__spirv_ocl_native_powrDv16_DhS_(%v16_c0_f16, %v16_c0_f16) {fastmathFlags = #llvm.fastmath<afn>} : (vector<16xf16>, vector<16xf16>) -> vector<16xf16>
+ // CHECK: OpExtInst %[[F64T]] %{{.+}} native_rsqrt %[[ZERO_F64]]
+ %rsqrt_afn_f64 = llvm.call @_Z24__spirv_ocl_native_rsqrtd(%c0_f64) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
+ // CHECK: OpExtInst %[[F16T]] %{{.+}} native_sin %[[ZERO_F16]]
+ %sin_afn_f16 = llvm.call @_Z22__spirv_ocl_native_sinDh(%c0_f16) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+ // CHECK: OpExtInst %[[F32T]] %{{.+}} native_sqrt %[[ZERO_F32]]
+ %sqrt_afn_f32 = llvm.call @_Z23__spirv_ocl_native_sqrtf(%c0_f32) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+ // CHECK: OpExtInst %[[F64T]] %{{.+}} native_tan %[[ZERO_F64]]
+ %tan_afn_f64 = llvm.call @_Z22__spirv_ocl_native_tand(%c0_f64) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
+ // CHECK: OpExtInst %[[F32T]] %{{.+}} native_divide %[[ZERO_F32]] %[[ZERO_F32]]
+ %divide_afn_f32 = llvm.call @_Z25__spirv_ocl_native_divideff(%c0_f32, %c0_f32) {fastmathFlags = #llvm.fastmath<afn>} : (f32, f32) -> f32
+
+ llvm.return
+ }
+
+ llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
+ llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32
+ llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
+ llvm.func @_Z22__spirv_ocl_native_expDv2_f64(vector<2xf64>) -> vector<2xf64>
+ llvm.func @_Z22__spirv_ocl_native_expDv3_f32(vector<3xf32>) -> vector<3xf32>
+ llvm.func @_Z22__spirv_ocl_native_expDv4_f64(vector<4xf64>) -> vector<4xf64>
+ llvm.func @_Z22__spirv_ocl_native_expDv8_f64(vector<8xf64>) -> vector<8xf64>
+ llvm.func @_Z22__spirv_ocl_native_expDv16_f16(vector<16xf16>) -> vector<16xf16>
+ llvm.func @_Z22__spirv_ocl_native_cosDh(f16) -> f16
+ llvm.func @_Z23__spirv_ocl_native_exp2f(f32) -> f32
+ llvm.func @_Z22__spirv_ocl_native_logDh(f16) -> f16
+ llvm.func @_Z23__spirv_ocl_native_log2f(f32) -> f32
+ llvm.func @_Z24__spirv_ocl_native_log10Dv8_d(vector<8xf64>) -> vector<8xf64>
+ llvm.func @_Z23__spirv_ocl_native_powrDv16_DhS_(vector<16xf16>, vector<16xf16>) -> vector<16xf16>
+ llvm.func @_Z24__spirv_ocl_native_rsqrtd(f64) -> f64
+ llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16
+ llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32
+ llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64
+ llvm.func @_Z25__spirv_ocl_native_divideff(f32, f32) -> f32
+ }
+}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 2d33888..d669a3b 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -76,6 +76,18 @@ func.func @broadcast_vec1d_from_f32(%arg0: f32) -> vector<2xf32> {
// -----
+func.func @broadcast_single_elem_vec1d_from_f32(%arg0: f32) -> vector<1xf32> {
+ %0 = vector.broadcast %arg0 : f32 to vector<1xf32>
+ return %0 : vector<1xf32>
+}
+// CHECK-LABEL: @broadcast_single_elem_vec1d_from_f32
+// CHECK-SAME: %[[A:.*]]: f32)
+// CHECK: %[[T0:.*]] = llvm.insertelement %[[A]]
+// CHECK-NOT: llvm.shufflevector
+// CHECK: return %[[T0]] : vector<1xf32>
+
+// -----
+
func.func @broadcast_vec1d_from_f32_scalable(%arg0: f32) -> vector<[2]xf32> {
%0 = vector.broadcast %arg0 : f32 to vector<[2]xf32>
return %0 : vector<[2]xf32>
diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
index e6f22f0..a9ab0be 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
@@ -1,17 +1,13 @@
// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
-#sg_map_a_f16 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
-#sg_map_b_f16 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
-#sg_map_c_f32 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
-
-gpu.module @load_store_check {
+gpu.module @test_kernel {
// CHECK-LABEL: func.func @dpas(
// CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32>
func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> {
// Loads are checked in a separate test.
// CHECK: %[[D:.*]] = xevm.mma %[[ARG0]], %[[ARG1]], %[[ARG2]] {shape = <m = 8, n = 16, k = 16>, types = <d = f32, a = f16, b = f16, c = f32>}
// CHECK-SAME: : (vector<8xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
- %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded {a_layout = #sg_map_a_f16, b_layout = #sg_map_b_f16, c_layout = #sg_map_c_f32}
+ %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded
: vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
return %d : vector<8xf32>
}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
new file mode 100644
index 0000000..d4cb493
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -0,0 +1,201 @@
+// RUN: mlir-opt -split-input-file -convert-xegpu-to-xevm -cse %s | FileCheck %s
+
+gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
+
+ // e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
+ // its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1])
+ //CHECK-LABEL: load_store_matrix_1
+ gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 {
+ %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
+
+ //CHECK: %[[TID:.*]] = gpu.thread_id x
+ //CHECK: %[[C1:.*]] = arith.constant 1 : index
+ //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
+ //CHECK: %[[C4:.*]] = arith.constant 4 : i32
+ //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32
+ //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
+
+ %tid_x = gpu.thread_id x
+ %c0 = arith.constant 0 : index
+ %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
+
+ //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
+
+ xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
+
+ gpu.return %1: f32
+ }
+
+// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]>
+ // its memory layout tuple is ([2,4,16,16],[256,512,1,16])
+ //CHECK-LABEL: load_store_matrix_2
+ gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 {
+ %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
+ //CHECK: %[[tid_x:.*]] = gpu.thread_id x
+ //CHECK: %[[c13:.*]] = arith.constant 13 : index
+ //CHECK: %[[c16:.*]] = arith.constant 16 : index
+ //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index
+ //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
+ //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
+ //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+
+ //CHECK: %[[c256:.*]] = arith.constant 256 : index
+ //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
+ //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+ //CHECK: %[[c512:.*]] = arith.constant 512 : index
+ //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
+ //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+ //CHECK: %[[c1:.*]] = arith.constant 1 : index
+ //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
+ //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+ //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
+ //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+
+ //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16
+
+
+ %tid_x = gpu.thread_id x
+ %c13 = arith.constant 13 : index
+ %1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> f16
+
+ //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
+
+ xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
+ gpu.return %1: f16
+ }
+
+
+ // e.g. for mem_desc<32x64xf16, @block=[16, 16]>
+ // its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
+ //CHECK-LABEL: load_store_matrix_3
+ gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 {
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
+ //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
+ %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+
+ //CHECK: %[[tid_x:.*]] = gpu.thread_id x
+ //CHECK: %[[c19:.*]] = arith.constant 19 : index
+ %tid_x = gpu.thread_id x
+ %c19 = arith.constant 19: index
+
+ //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
+ //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
+ //CHECK: %[[c16:.*]] = arith.constant 16 : index
+ //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
+ //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
+ //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
+ //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+ //CHECK: %[[c1024:.*]] = arith.constant 1024 : index
+ //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
+ //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+ //CHECK: %[[c256:.*]] = arith.constant 256 : index
+ //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index
+ //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+ //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index
+ //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+ //CHECK: %[[c1:.*]] = arith.constant 1 : index
+ //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
+ //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+
+ //CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
+ %1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16
+
+ //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
+ xegpu.store_matrix %1, %0[%c19, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index
+
+ //CHECK: gpu.return %[[loaded]] : f16
+ gpu.return %1: f16
+ }
+
+ // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
+ // its memory layout tuple is ([2,4,16,16],[256,512,1,16])
+ //CHECK-LABEL: load_store_matrix_4
+ gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
+ %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
+
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
+ //CHECK: %[[tid_x:.*]] = gpu.thread_id x
+
+ //CHECK: %[[c16:.*]] = arith.constant 16 : index
+ //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
+ //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
+ //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
+ //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+
+ //CHECK: %[[c256:.*]] = arith.constant 256 : index
+ //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
+ //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+ //CHECK: %[[c512:.*]] = arith.constant 512 : index
+ //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
+ //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+ //CHECK: %[[c1:.*]] = arith.constant 1 : index
+ //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
+ //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+ //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
+ //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+
+ //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16>
+
+ %tid_x = gpu.thread_id x
+ %c16 = arith.constant 16 : index
+ %1 = xegpu.load_matrix %0[%c16, %tid_x] : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<8xf16>
+
+ //CHECK: llvm.store %[[loaded]], {{.*}} : vector<8xf16>, !llvm.ptr<3>
+ xegpu.store_matrix %1, %0[%c16, %tid_x] : vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
+
+ gpu.return %1: vector<8xf16>
+ }
+
+
+ // e.g. for mem_desc<32x64xf16, @block=[16, 16]>
+ // its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
+ //CHECK-LABEL: load_store_matrix_5
+ gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
+ //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
+
+ %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+
+ //CHECK: %[[c16:.*]] = arith.constant 16 : index
+ //CHECK: %[[c48:.*]] = arith.constant 48 : index
+
+ %c16 = arith.constant 16 : index
+ %c48 = arith.constant 48 : index
+
+ //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
+ //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
+ //CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
+ //CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
+ //CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index
+ //CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index
+ //CHECK: %[[c1024:.*]] = arith.constant 1024 : index
+ //CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index
+ //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+ //CHECK: %[[c256:.*]] = arith.constant 256 : index
+ //CHECK: %[[mul1:.*]] = arith.muli %[[offset2]], %[[c256]] : index
+ //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+ //CHECK: %[[mul2:.*]] = arith.muli %[[offset1]], %[[c16]] : index
+ //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+ //CHECK: %[[c1:.*]] = arith.constant 1 : index
+ //CHECK: %[[mul3:.*]] = arith.muli %[[offset3]], %[[c1]] : index
+ //CHECK: %[[linearOffset:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+ //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32
+ //CHECK: %[[c2:.*]] = arith.constant 2 : i32
+ //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32
+ //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i32
+ //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3>
+ //CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16>
+ //CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16>
+
+ %1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<8xf16>
+
+ //CHECK: %[[storeDataI16:.*]] = vector.bitcast %[[loaded]] : vector<8xf16> to vector<8xi16>
+ //CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>)
+
+ xegpu.store_matrix %1, %0[%c16, %c48] {subgroup_block_io}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index
+
+ gpu.return %1: vector<8xf16>
+ }
+
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
index 0b150e9..9c552d8 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
@@ -14,19 +14,36 @@ gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: vector<1xindex>)
// CHECK: %[[VAR4:.*]] = arith.addi %[[ARG0]], %[[VAR3]] : i64
// CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1>
// CHECK: %[[VAR6:.*]] = scf.if %[[VAR2]] -> (f16) {
- // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR5]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} : !llvm.ptr<1> -> vector<1xf16>
- // CHECK: %[[VAR8:.*]] = vector.extract %[[VAR7]][0] : f16 from vector<1xf16>
- // CHECK: scf.yield %[[VAR8]] : f16
- // CHECK: } else {
- // CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf16>
- // CHECK: %[[VAR7:.*]] = vector.extract %[[CST_0]][0] : f16 from vector<1xf16>
+ // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR5]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} : !llvm.ptr<1> -> f16
// CHECK: scf.yield %[[VAR7]] : f16
+ // CHECK: } else {
+ // CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f16
+ // CHECK: scf.yield %[[CST_0]] : f16
// CHECK: }
%3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
: i64, vector<1xindex>, vector<1xi1> -> vector<1xf16>
gpu.return
}
}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: @source_materialize_single_elem_vec
+// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: memref<1xf16>
+gpu.func @source_materialize_single_elem_vec(%src: i64, %offset: vector<1xindex>, %dst: memref<1xf16>) {
+ %1 = arith.constant dense<1>: vector<1xi1>
+ %3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : i64, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+ // CHECK: %[[VAR_IF:.*]] = scf.if
+ // CHECK: %[[VAR_RET:.*]] = vector.broadcast %[[VAR_IF]] : f16 to vector<1xf16>
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: vector.store %[[VAR_RET]], %[[ARG2]][%[[C0]]] : memref<1xf16>, vector<1xf16>
+ %c0 = arith.constant 0 : index
+ vector.store %3, %dst[%c0] : memref<1xf16>, vector<1xf16>
+ gpu.return
+}
+}
+
// -----
gpu.module @test {
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index e56079c..1169cd1 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -2235,6 +2235,136 @@ func.func @affine_leading_zero_no_outer_bound(%arg0: index, %arg1: index) -> ind
// -----
+// CHECK-LABEL: func @delin_apply_cancel_exact
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
+// CHECK-COUNT-6: memref.store %[[ARG0]], %[[ARG1]][%[[ARG0]]]
+// CHECK-NOT: memref.store
+// CHECK: return
+func.func @delin_apply_cancel_exact(%arg0: index, %arg1: memref<?xindex>) {
+ %a:3 = affine.delinearize_index %arg0 into (4, 5) : index, index, index
+ %b:3 = affine.delinearize_index %arg0 into (3, 4, 5) : index, index, index
+ %c:2 = affine.delinearize_index %arg0 into (20) : index, index
+
+ %t1 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 5 + s2 * 20)>()[%a#2, %a#1, %a#0]
+ memref.store %t1, %arg1[%t1] : memref<?xindex>
+
+ %t2 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s2 * 20 + s1 * 5)>()[%a#2, %a#1, %a#0]
+ memref.store %t2, %arg1[%t2] : memref<?xindex>
+
+ %t3 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 20 + s2 * 5 + s0)>()[%a#2, %a#0, %a#1]
+ memref.store %t3, %arg1[%t3] : memref<?xindex>
+
+ %t4 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 5 + s2 * 20)>()[%b#2, %b#1, %b#0]
+ memref.store %t4, %arg1[%t4] : memref<?xindex>
+
+ %t5 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 20)>()[%c#1, %c#0]
+ memref.store %t5, %arg1[%t5] : memref<?xindex>
+
+ %t6 = affine.apply affine_map<()[s0, s1] -> (s1 * 20 + s0)>()[%c#1, %c#0]
+ memref.store %t6, %arg1[%t5] : memref<?xindex>
+
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @delin_apply_cancel_exact_dim
+// CHECK: affine.for %[[arg1:.+]] = 0 to 256
+// CHECK: memref.store %[[arg1]]
+// CHECK: return
+func.func @delin_apply_cancel_exact_dim(%arg0: memref<?xindex>) {
+ affine.for %arg1 = 0 to 256 {
+ %a:3 = affine.delinearize_index %arg1 into (2, 2, 64) : index, index, index
+ %i = affine.apply affine_map<(d0, d1, d2) -> (d0 + d1 * 128 + d2 * 64)>(%a#2, %a#0, %a#1)
+ memref.store %i, %arg0[%i] : memref<?xindex>
+ }
+ return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 512)>
+// CHECK-LABEL: func @delin_apply_cancel_const_term
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
+// CHECK: affine.apply #[[$MAP]]()[%[[ARG0]]]
+// CHECK: return
+func.func @delin_apply_cancel_const_term(%arg0: index, %arg1: memref<?xindex>) {
+ %a:3 = affine.delinearize_index %arg0 into (2, 2, 64) : index, index, index
+
+ %t1 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 128 + s2 * 64 + 512)>()[%a#2, %a#0, %a#1]
+ memref.store %t1, %arg1[%t1] : memref<?xindex>
+
+ return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 512)>
+// CHECK-LABEL: func @delin_apply_cancel_var_term
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>, %[[ARG2:.+]]: index)
+// CHECK: affine.apply #[[$MAP]]()[%[[ARG2]], %[[ARG0]]]
+// CHECK: return
+func.func @delin_apply_cancel_var_term(%arg0: index, %arg1: memref<?xindex>, %arg2: index) {
+ %a:3 = affine.delinearize_index %arg0 into (2, 2, 64) : index, index, index
+
+ %t1 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s1 * 128 + s2 * 64 + s3 + 512)>()[%a#2, %a#0, %a#1, %arg2]
+ memref.store %t1, %arg1[%t1] : memref<?xindex>
+
+ return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2 + s0 ceildiv 4)>
+// CHECK-LABEL: func @delin_apply_cancel_nested_exprs
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
+// CHECK: affine.apply #[[$MAP]]()[%[[ARG0]]]
+// CHECK: return
+func.func @delin_apply_cancel_nested_exprs(%arg0: index, %arg1: memref<?xindex>) {
+ %a:2 = affine.delinearize_index %arg0 into (20) : index, index
+
+ %t1 = affine.apply affine_map<()[s0, s1] -> ((s0 + s1 * 20) ceildiv 4 + (s1 * 20 + s0) * 2)>()[%a#1, %a#0]
+ memref.store %t1, %arg1[%t1] : memref<?xindex>
+
+ return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func @delin_apply_cancel_preserve_rotation
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
+// CHECK: %[[A:.+]]:2 = affine.delinearize_index %[[ARG0]] into (20)
+// CHECK: affine.apply #[[$MAP]]()[%[[A]]#1, %[[ARG0]]]
+// CHECK: return
+func.func @delin_apply_cancel_preserve_rotation(%arg0: index, %arg1: memref<?xindex>) {
+ %a:2 = affine.delinearize_index %arg0 into (20) : index, index
+
+ %t1 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 20 + s0)>()[%a#1, %a#0]
+ memref.store %t1, %arg1[%t1] : memref<?xindex>
+
+ return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 5)>
+// CHECK-LABEL: func @delin_apply_dont_cancel_partial
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
+// CHECK: %[[A:.+]]:3 = affine.delinearize_index %[[ARG0]] into (3, 4, 5)
+// CHECK: affine.apply #[[$MAP]]()[%[[A]]#2, %[[A]]#1]
+// CHECK: return
+func.func @delin_apply_dont_cancel_partial(%arg0: index, %arg1: memref<?xindex>) {
+ %a:3 = affine.delinearize_index %arg0 into (3, 4, 5) : index, index, index
+
+ %t1 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 5)>()[%a#2, %a#1]
+ memref.store %t1, %arg1[%t1] : memref<?xindex>
+
+ return
+}
+
+// -----
+
// CHECK-LABEL: @cst_value_to_cst_attr_basis_delinearize_index
// CHECK-SAME: (%[[ARG0:.*]]: index)
// CHECK: %[[RET:.*]]:3 = affine.delinearize_index %[[ARG0]] into (3, 4, 2) : index, index
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir
index c58b153..21b508e 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir
@@ -65,13 +65,13 @@ func.func @main(%t: tensor<?xf32>, %sz: index, %idx: index) -> (f32, f32) {
// -----
-func.func @return_arg(%A: tensor<?xf32>) -> tensor<?xf32> {
+func.func private @return_arg(%A: tensor<?xf32>) -> tensor<?xf32> {
func.return %A : tensor<?xf32>
}
-// CHECK-LABEL: func @return_arg
+// CHECK-LABEL: func private @return_arg
// CHECK-SAME: %[[A:.*]]: memref<?xf32
// CHECK-NOT: return %[[A]]
-// NO-DROP-LABEL: func @return_arg
+// NO-DROP-LABEL: func private @return_arg
// NO-DROP-SAME: %[[A:.*]]: memref<?xf32
// NO-DROP: return %[[A]]
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 6054a61..d5f834b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -171,9 +171,9 @@ func.func @func_without_tensor_args(%v : vector<10xf32>) -> () {
// Bufferization of a function that is reading and writing. %t0 is writable, so
// no copy should be inserted.
-// CHECK-LABEL: func @inner_func(
+// CHECK-LABEL: func private @inner_func(
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
-func.func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) {
+func.func private @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) {
// CHECK-NOT: copy
%f = arith.constant 1.0 : f32
%c0 = arith.constant 0 : index
@@ -186,9 +186,9 @@ func.func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) {
return %0, %1 : tensor<?xf32>, f32
}
-// CHECK-LABEL: func @call_func_with_non_tensor_return(
+// CHECK-LABEL: func private @call_func_with_non_tensor_return(
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
-func.func @call_func_with_non_tensor_return(
+func.func private @call_func_with_non_tensor_return(
%t0: tensor<?xf32> {bufferization.writable = true}) -> (f32, tensor<?xf32>) {
// CHECK-NOT: alloc
// CHECK-NOT: copy
@@ -203,9 +203,9 @@ func.func @call_func_with_non_tensor_return(
// Bufferization of a function that is reading and writing. %t0 is not writable,
// so a copy is needed.
-// CHECK-LABEL: func @inner_func(
+// CHECK-LABEL: func private @inner_func(
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
-func.func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) {
+func.func private @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) {
// CHECK-NOT: copy
%f = arith.constant 1.0 : f32
%c0 = arith.constant 0 : index
@@ -276,10 +276,10 @@ func.func @main(%t: tensor<?xf32> {bufferization.writable = false}) -> (f32) {
// This function does not read, just write. We need an alloc, but no copy.
-// CHECK-LABEL: func @does_not_read(
+// CHECK-LABEL: func private @does_not_read(
// CHECK-NOT: alloc
// CHECK-NOT: copy
-func.func @does_not_read(%t: tensor<?xf32>) -> tensor<?xf32> {
+func.func private @does_not_read(%t: tensor<?xf32>) -> tensor<?xf32> {
%f0 = arith.constant 0.0 : f32
%r = linalg.fill ins(%f0 : f32) outs(%t : tensor<?xf32>) -> tensor<?xf32>
return %r : tensor<?xf32>
@@ -354,9 +354,9 @@ func.func @main() {
// A write inside an scf.execute_region. An equivalent tensor is yielded.
-// CHECK-LABEL: func @execute_region_test(
+// CHECK-LABEL: func private @execute_region_test(
// CHECK-SAME: %[[m1:.*]]: memref<?xf32
-func.func @execute_region_test(%t1 : tensor<?xf32>)
+func.func private @execute_region_test(%t1 : tensor<?xf32>)
-> (f32, tensor<?xf32>, f32)
{
%f1 = arith.constant 0.0 : f32
@@ -397,11 +397,11 @@ func.func @no_inline_execute_region_not_canonicalized() {
// CHECK: func private @some_external_func(memref<?xf32, strided<[?], offset: ?>>)
func.func private @some_external_func(tensor<?xf32>)
-// CHECK: func @scf_for_with_tensor_insert_slice(
+// CHECK: func private @scf_for_with_tensor_insert_slice(
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref<4xf32, strided<[?], offset: ?>>
-func.func @scf_for_with_tensor_insert_slice(
+func.func private @scf_for_with_tensor_insert_slice(
%A : tensor<?xf32>, %B : tensor<?xf32>, %C : tensor<4xf32>,
%lb : index, %ub : index, %step : index)
-> (tensor<?xf32>, tensor<?xf32>)
@@ -456,11 +456,11 @@ func.func @bar(
// -----
-// CHECK: func @init_and_dot(
+// CHECK: func private @init_and_dot(
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<64xf32, strided<[?], offset: ?>>
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref<64xf32, strided<[?], offset: ?>>
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref<f32, strided<[], offset: ?>>
-func.func @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor<f32>) -> tensor<f32> {
+func.func private @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor<f32>) -> tensor<f32> {
// CHECK-NEXT: %[[C0:.*]] = arith.constant 0{{.*}} : f32
%v0 = arith.constant 0.0 : f32
@@ -574,9 +574,9 @@ func.func @entry(%A : tensor<?xf32> {bufferization.buffer_layout = affine_map<(i
// No alloc or copy inside of the loop.
-// CHECK-LABEL: func @inner_func(
+// CHECK-LABEL: func private @inner_func(
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
-func.func @inner_func(%t: tensor<?xf32>) -> tensor<?xf32> {
+func.func private @inner_func(%t: tensor<?xf32>) -> tensor<?xf32> {
%f = arith.constant 1.0 : f32
%c0 = arith.constant 0 : index
// CHECK: memref.store %{{.*}}, %[[arg0]]
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir
index e2ab876..b52612d 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir
@@ -24,10 +24,46 @@
// CHECK-NOT: copy
// CHECK: %[[call:.*]]:2 = call @inner_func(%[[arg0]])
%0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32)
- // CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref<?xf32,{{.*}}>
+ // CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref<?xf32{{.*}}>
return %1, %0 : f32, tensor<?xf32>
}
"test.finish" () : () -> ()
}) : () -> ()
+// -----
+#enc1 = #test.tensor_encoding<"hello">
+#enc2 = #test.tensor_encoding<"not hello">
+
+"test.symbol_scope_isolated"() ({
+ // CHECK: func @inner_func(
+ // CHECK-SAME: %[[arg0:.*]]: memref<?xf32, #test.memref_layout<"hello">>)
+ // CHECK-SAME: -> memref<?xf32, #test.memref_layout<"hello">>
+ func.func @inner_func(%t: tensor<?xf32, #enc1>)
+ -> tensor<?xf32, #enc1> {
+ // CHECK: return %[[arg0]]
+ return %t : tensor<?xf32, #enc1>
+ }
+
+ // CHECK: func @outer_func(
+ // CHECK-SAME: %[[arg0:.*]]: memref<?xf32, #test.memref_layout<"hello">>)
+ // CHECK-SAME: -> (memref<?xf32, #test.memref_layout<"hello">>,
+ // CHECK-SAME: memref<?xf32, #test.memref_layout<"not hello">>)
+ func.func @outer_func(%t0: tensor<?xf32, #enc1>)
+ -> (tensor<?xf32, #enc1>, tensor<?xf32, #enc2>) {
+ // CHECK: %[[call:.*]] = call @inner_func(%[[arg0]])
+ %0 = call @inner_func(%t0)
+ : (tensor<?xf32, #enc1>) -> (tensor<?xf32, #enc1>)
+
+ // CHECK: %[[local:.*]] = "test.create_memref_op"() : ()
+ // CHECK-SAME: -> memref<?xf32, #test.memref_layout<"not hello">>
+ %local = "test.create_tensor_op"() : () -> tensor<?xf32, #enc2>
+ // CHECK: %[[dummy:.*]] = "test.dummy_memref_op"(%[[local]])
+ %1 = "test.dummy_tensor_op"(%local) : (tensor<?xf32, #enc2>)
+ -> tensor<?xf32, #enc2>
+
+ // CHECK: return %[[call]], %[[dummy]]
+ return %0, %1 : tensor<?xf32, #enc1>, tensor<?xf32, #enc2>
+ }
+ "test.finish" () : () -> ()
+}) : () -> ()
diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
index 8accf6e..755e3a3 100644
--- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir
+++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
@@ -235,6 +235,17 @@ llvm.func @fold_gep_canon(%x : !llvm.ptr) -> !llvm.ptr {
// -----
+// CHECK-LABEL: fold_shufflevector
+// CHECK-SAME: %[[ARG1:[[:alnum:]]+]]: vector<1xf32>, %[[ARG2:[[:alnum:]]+]]: vector<1xf32>
+llvm.func @fold_shufflevector(%v1 : vector<1xf32>, %v2 : vector<1xf32>) -> vector<1xf32> {
+ // CHECK-NOT: llvm.shufflevector
+ %c = llvm.shufflevector %v1, %v2 [0] : vector<1xf32>
+ // CHECK: llvm.return %[[ARG1]]
+ llvm.return %c : vector<1xf32>
+}
+
+// -----
+
// Check that LLVM constants participate in cross-dialect constant folding. The
// resulting constant is created in the arith dialect because the last folded
// operation belongs to it.
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 358bd33..242c04f 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -1035,6 +1035,20 @@ llvm.func @rocdl.s.wait.expcnt() {
llvm.return
}
+llvm.func @rocdl.s.wait.asynccnt() {
+ // CHECK-LABEL: rocdl.s.wait.asynccnt
+ // CHECK: rocdl.s.wait.asynccnt 0
+ rocdl.s.wait.asynccnt 0
+ llvm.return
+}
+
+llvm.func @rocdl.s.wait.tensorcnt() {
+ // CHECK-LABEL: rocdl.s.wait.tensorcnt
+ // CHECK: rocdl.s.wait.tensorcnt 0
+ rocdl.s.wait.tensorcnt 0
+ llvm.return
+}
+
// -----
llvm.func @rocdl.readfirstlane(%src : f32) -> f32 {
diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
index 618ba34..66cae5c 100644
--- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
@@ -1011,6 +1011,20 @@ module attributes { transform.target_tag = "start_here" } {
} -> tensor<1x1x4xf32>
return
}
+
+ func.func @generic_none(%arg0: tensor<128x128xi32>, %arg1: tensor<128x128xi32>, %arg2: tensor<128x128xi32>) {
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0, %arg1 : tensor<128x128xi32>, tensor<128x128xi32>)
+ outs(%arg2 : tensor<128x128xi32>) {
+ ^bb0(%in: i32, %in_0: i32, %out: i32):
+ linalg.yield %out : i32
+ } -> tensor<128x128xi32>
+ return
+ }
}
// -----
diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
index 9616a3e..1df15e8 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
@@ -10,10 +10,10 @@
// TODO: Some test cases from this file should be moved to other dialects.
-// CHECK-LABEL: func @fill_inplace(
+// CHECK-LABEL: func private @fill_inplace(
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>
-// CHECK-NO-LAYOUT-MAP-LABEL: func @fill_inplace(%{{.*}}: memref<?xf32>) {
-func.func @fill_inplace(
+// CHECK-NO-LAYOUT-MAP-LABEL: func private @fill_inplace(%{{.*}}: memref<?xf32>) {
+func.func private @fill_inplace(
%A : tensor<?xf32> {bufferization.writable = true})
-> tensor<?xf32>
{
@@ -56,10 +56,10 @@ func.func @not_inplace(
// -----
-// CHECK-LABEL: func @not_inplace
+// CHECK-LABEL: func private @not_inplace
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<?x?xf32, strided<[?, ?], offset: ?>>) {
-// CHECK-NO-LAYOUT-MAP-LABEL: func @not_inplace(%{{.*}}: memref<?x?xf32>) {
-func.func @not_inplace(
+// CHECK-NO-LAYOUT-MAP-LABEL: func private @not_inplace(%{{.*}}: memref<?x?xf32>) {
+func.func private @not_inplace(
%A : tensor<?x?xf32> {bufferization.writable = true})
-> tensor<?x?xf32>
{
@@ -235,7 +235,7 @@ func.func @dominance_violation_bug_1(
// -----
-func.func @gather_like(
+func.func private @gather_like(
%arg0 : tensor<?x?xf32> {bufferization.writable = false},
%arg1 : tensor<?xi32> {bufferization.writable = false},
%arg2 : tensor<?x?xf32> {bufferization.writable = true})
@@ -254,7 +254,7 @@ func.func @gather_like(
} -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
-// CHECK-LABEL: func @gather_like(
+// CHECK-LABEL: func private @gather_like(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32,
// CHECK-SAME: %[[ARG1:.+]]: memref<?xi32
// CHECK-SAME: %[[ARG2:.+]]: memref<?x?xf32
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
index 35f520a..93a0336 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
@@ -1,5 +1,9 @@
// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.dot
+///----------------------------------------------------------------------------------------
+
// CHECK-LABEL: contraction_dot
func.func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
@@ -20,6 +24,10 @@ module attributes {transform.with_named_sequence} {
// -----
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.matvec
+///----------------------------------------------------------------------------------------
+
// CHECK-LABEL: contraction_matvec
func.func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
@@ -41,6 +49,10 @@ module attributes {transform.with_named_sequence} {
// -----
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.matmul
+///----------------------------------------------------------------------------------------
+
// CHECK-LABEL: contraction_matmul
func.func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32>
@@ -138,6 +150,10 @@ module attributes {transform.with_named_sequence} {
// -----
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.batch_matmul
+///----------------------------------------------------------------------------------------
+
// CHECK-LABEL: contraction_batch_matmul
func.func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584x1584xf32>
@@ -159,6 +175,10 @@ module attributes {transform.with_named_sequence} {
// -----
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.cantract
+///----------------------------------------------------------------------------------------
+
// CHECK-LABEL: @matmul_as_contract
// CHECK-SAME: %[[A:.*]]: tensor<24x12xf32>
// CHECK-SAME: %[[B:.*]]: tensor<12x25xf32>
@@ -220,6 +240,10 @@ module attributes {transform.with_named_sequence} {
// -----
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.fill
+///----------------------------------------------------------------------------------------
+
// CHECK-LABEL: func @test_vectorize_fill
func.func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
// CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
@@ -259,70 +283,14 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: func @test_vectorize_copy
-func.func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
- // CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32>
- // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
- memref.copy %A, %B : memref<8x16xf32> to memref<8x16xf32>
- return
-}
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.pack
+///----------------------------------------------------------------------------------------
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
+// Note, see a similar test in:
+// * vectorization.mlir.
-// -----
-
-// CHECK-LABEL: func @test_vectorize_copy_0d
-func.func @test_vectorize_copy_0d(%A : memref<f32>, %B : memref<f32>) {
- // CHECK-SAME: (%[[A:.*]]: memref<f32>, %[[B:.*]]: memref<f32>)
- // CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<f32>
- // CHECK: %[[val:.*]] = vector.extract %[[V]][] : f32 from vector<f32>
- // CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<f32>
- // CHECK: vector.transfer_write %[[VV]], %[[B]][] : vector<f32>, memref<f32>
- memref.copy %A, %B : memref<f32> to memref<f32>
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: func @test_vectorize_copy_complex
-// CHECK-NOT: vector<
-func.func @test_vectorize_copy_complex(%A : memref<8x16xcomplex<f32>>, %B : memref<8x16xcomplex<f32>>) {
- memref.copy %A, %B : memref<8x16xcomplex<f32>> to memref<8x16xcomplex<f32>>
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// Input identical as the test in vectorization.mlir. Output is different -
-// vector sizes are inferred (rather than user-specified) and hence _no_
-// masking was used.
-
-func.func @test_vectorize_pack(%arg0: tensor<32x8x16xf32>, %arg1: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> {
+func.func @pack_no_padding(%arg0: tensor<32x8x16xf32>, %arg1: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> {
%pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x8x16xf32> -> tensor<4x1x32x16x2xf32>
return %pack : tensor<4x1x32x16x2xf32>
}
@@ -336,7 +304,7 @@ module attributes {transform.with_named_sequence} {
}
}
-// CHECK-LABEL: func.func @test_vectorize_pack(
+// CHECK-LABEL: func.func @pack_no_padding(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x8x16xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> {
// CHECK-DAG: %[[VAL_2:.*]] = ub.poison : f32
@@ -349,13 +317,16 @@ module attributes {transform.with_named_sequence} {
// -----
-func.func @test_vectorize_padded_pack(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
+// Note, see a similar test in:
+// * vectorization.mlir.
+
+func.func @pack_with_padding(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
%pad = arith.constant 0.000000e+00 : f32
%pack = linalg.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32>
return %pack : tensor<32x4x1x16x2xf32>
}
-// CHECK-LABEL: func.func @test_vectorize_padded_pack(
+// CHECK-LABEL: func.func @pack_with_padding(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x7x15xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
// CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
@@ -377,6 +348,10 @@ module attributes {transform.with_named_sequence} {
// -----
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.map
+///----------------------------------------------------------------------------------------
+
func.func @vectorize_map(%arg0: memref<64xf32>,
%arg1: memref<64xf32>, %arg2: memref<64xf32>) {
linalg.map ins(%arg0, %arg1 : memref<64xf32>, memref<64xf32>)
@@ -403,6 +378,10 @@ module attributes {transform.with_named_sequence} {
// -----
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.transpose
+///----------------------------------------------------------------------------------------
+
func.func @vectorize_transpose(%arg0: memref<16x32x64xf32>,
%arg1: memref<32x64x16xf32>) {
linalg.transpose ins(%arg0 : memref<16x32x64xf32>)
@@ -424,6 +403,10 @@ module attributes {transform.with_named_sequence} {
// -----
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.reduce
+///----------------------------------------------------------------------------------------
+
func.func @vectorize_reduce(%arg0: memref<16x32x64xf32>,
%arg1: memref<16x64xf32>) {
linalg.reduce ins(%arg0 : memref<16x32x64xf32>)
@@ -449,6 +432,10 @@ module attributes {transform.with_named_sequence} {
// -----
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.generic
+///----------------------------------------------------------------------------------------
+
#matmul_trait = {
indexing_maps = [
affine_map<(m, n, k) -> (m, k)>,
@@ -1446,6 +1433,8 @@ module attributes {transform.with_named_sequence} {
// -----
+// TODO: Two Linalg Ops in one tests - either split or document "why".
+
// CHECK-DAG: #[[$M6:.*]] = affine_map<(d0, d1) -> (d0, 0)>
// CHECK-LABEL: func @fused_broadcast_red_2d
@@ -1896,3 +1885,65 @@ module attributes {transform.with_named_sequence} {
}
}
+// -----
+
+///----------------------------------------------------------------------------------------
+/// Tests for memref.copy
+///----------------------------------------------------------------------------------------
+
+// CHECK-LABEL: func @test_vectorize_copy
+func.func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
+ // CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32>
+ // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
+ memref.copy %A, %B : memref<8x16xf32> to memref<8x16xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_copy_0d
+func.func @test_vectorize_copy_0d(%A : memref<f32>, %B : memref<f32>) {
+ // CHECK-SAME: (%[[A:.*]]: memref<f32>, %[[B:.*]]: memref<f32>)
+ // CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<f32>
+ // CHECK: %[[val:.*]] = vector.extract %[[V]][] : f32 from vector<f32>
+ // CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<f32>
+ // CHECK: vector.transfer_write %[[VV]], %[[B]][] : vector<f32>, memref<f32>
+ memref.copy %A, %B : memref<f32> to memref<f32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_copy_complex
+// CHECK-NOT: vector<
+func.func @test_vectorize_copy_complex(%A : memref<8x16xcomplex<f32>>, %B : memref<8x16xcomplex<f32>>) {
+ memref.copy %A, %B : memref<8x16xcomplex<f32>> to memref<8x16xcomplex<f32>>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 11bea8d..1304a90 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -1307,14 +1307,17 @@ func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf
/// Tests for linalg.pack
///----------------------------------------------------------------------------------------
-// Input identical as the test in vectorization-with-patterns.mlir. Output is
-// different - vector sizes are inferred (rather than user-specified) and hence
-// masking was used.
+// This packing requires no padding, so no out-of-bounds read/write vector Ops.
-// CHECK-LABEL: func @test_vectorize_pack
+// Note, see a similar test in:
+// * vectorization-with-patterns.mlir
+// The output is identical (the input vector sizes == the inferred vector
+// sizes, i.e. the tensor sizes).
+
+// CHECK-LABEL: func @pack_no_padding
// CHECK-SAME: %[[SRC:.*]]: tensor<32x8x16xf32>,
// CHECK-SAME: %[[DEST:.*]]: tensor<4x1x32x16x2xf32>
-func.func @test_vectorize_pack(%src: tensor<32x8x16xf32>, %dest: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> {
+func.func @pack_no_padding(%src: tensor<32x8x16xf32>, %dest: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> {
%pack = linalg.pack %src outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %dest : tensor<32x8x16xf32> -> tensor<4x1x32x16x2xf32>
return %pack : tensor<4x1x32x16x2xf32>
}
@@ -1325,9 +1328,9 @@ func.func @test_vectorize_pack(%src: tensor<32x8x16xf32>, %dest: tensor<4x1x32x1
// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
// CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [1, 3, 0, 4, 2] : vector<32x4x2x1x16xf32> to vector<4x1x32x16x2xf32>
// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index
-// CHECK: %[[write:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]]
+// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]]
// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<4x1x32x16x2xf32>, tensor<4x1x32x16x2xf32>
-// CHECK: return %[[write]] : tensor<4x1x32x16x2xf32>
+// CHECK: return %[[WRITE]] : tensor<4x1x32x16x2xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%src: !transform.any_op {transform.readonly}) {
@@ -1339,14 +1342,18 @@ module attributes {transform.with_named_sequence} {
// -----
-// Input identical as the test in vectorization-with-patterns.mlir. Output is
-// different - vector sizes are inferred (rather than user-specified) and hence
-// masking was used.
+// This packing does require padding, so there are out-of-bounds read/write
+// vector Ops.
+
+// Note, see a similar test in:
+// * vectorization-with-patterns.mlir.
+// The output is different (the input vector sizes != inferred vector sizes,
+// i.e. the tensor sizes).
-// CHECK-LABEL: func @test_vectorize_padded_pack
+// CHECK-LABEL: func @pack_with_padding
// CHECK-SAME: %[[SRC:.*]]: tensor<32x7x15xf32>,
// CHECK-SAME: %[[DEST:.*]]: tensor<32x4x1x16x2xf32>
-func.func @test_vectorize_padded_pack(%src: tensor<32x7x15xf32>, %dest: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
+func.func @pack_with_padding(%src: tensor<32x7x15xf32>, %dest: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
%pad = arith.constant 0.000000e+00 : f32
%pack = linalg.pack %src padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %dest : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32>
return %pack : tensor<32x4x1x16x2xf32>
@@ -1364,9 +1371,9 @@ func.func @test_vectorize_padded_pack(%src: tensor<32x7x15xf32>, %dest: tensor<3
// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
// CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index
-// CHECK: %[[write:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]]
+// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]]
// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
-// CHECK: return %[[write]] : tensor<32x4x1x16x2xf32>
+// CHECK: return %[[WRITE]] : tensor<32x4x1x16x2xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -1378,10 +1385,46 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: func @test_vectorize_dynamic_pack
+// This packing does require padding, so there are out-of-bounds read/write
+// vector Ops.
+
+// Note, see a similar test in:
+// * vectorization-with-patterns.mlir.
+// The output is identical (in both cases the vector sizes are inferred).
+
+// CHECK-LABEL: func @pack_with_padding_no_vector_sizes
+// CHECK-SAME: %[[SRC:.*]]: tensor<32x7x15xf32>,
+// CHECK-SAME: %[[DEST:.*]]: tensor<32x4x1x16x2xf32>
+func.func @pack_with_padding_no_vector_sizes(%src: tensor<32x7x15xf32>, %dest: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
+ %pad = arith.constant 0.000000e+00 : f32
+ %pack = linalg.pack %src padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %dest : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32>
+ return %pack : tensor<32x4x1x16x2xf32>
+}
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[CST]]
+// CHECK-SAME: {in_bounds = [true, false, false]} : tensor<32x7x15xf32>, vector<32x8x16xf32>
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
+// CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
+// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]]
+// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
+// CHECK: return %[[WRITE]] : tensor<32x4x1x16x2xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @pack_with_dynamic_dims
// CHECK-SAME: %[[SRC:.*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?x16x2xf32>
-func.func @test_vectorize_dynamic_pack(%src: tensor<?x?xf32>, %dest: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> {
+func.func @pack_with_dynamic_dims(%src: tensor<?x?xf32>, %dest: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> {
%pack = linalg.pack %src inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %dest : tensor<?x?xf32> -> tensor<?x?x16x2xf32>
return %pack : tensor<?x?x16x2xf32>
}
@@ -1418,64 +1461,6 @@ module attributes {transform.with_named_sequence} {
}
}
-// -----
-
-// CHECK-LABEL: func @test_vectorize_pack_no_vector_sizes
-// CHECK-SAME: %[[SRC:.*]]: tensor<64x4xf32>,
-// CHECK-SAME: %[[DEST:.*]]: tensor<2x4x16x2xf32>
-func.func @test_vectorize_pack_no_vector_sizes(%src: tensor<64x4xf32>, %dest: tensor<2x4x16x2xf32>) -> tensor<2x4x16x2xf32> {
- %pack = linalg.pack %src outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, 2] into %dest : tensor<64x4xf32> -> tensor<2x4x16x2xf32>
- return %pack : tensor<2x4x16x2xf32>
-}
-// CHECK-DAG: %[[CST:.*]] = ub.poison : f32
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CST]]
-// CHECK-SAME: {in_bounds = [true, true]} : tensor<64x4xf32>, vector<64x4xf32>
-// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<64x4xf32> to vector<4x16x2x2xf32>
-// CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [2, 0, 1, 3] : vector<4x16x2x2xf32> to vector<2x4x16x2xf32>
-// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index
-// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]]
-// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<2x4x16x2xf32>, tensor<2x4x16x2xf32>
-// CHECK: return %[[WRITE]] : tensor<2x4x16x2xf32>
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %0 : !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: test_vectorize_padded_pack_no_vector_sizes
-// CHECK-SAME: %[[SRC:.*]]: tensor<32x7x15xf32>,
-// CHECK-SAME: %[[DEST:.*]]: tensor<32x4x1x16x2xf32>
-func.func @test_vectorize_padded_pack_no_vector_sizes(%src: tensor<32x7x15xf32>, %dest: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
- %pad = arith.constant 0.000000e+00 : f32
- %pack = linalg.pack %src padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %dest : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32>
- return %pack : tensor<32x4x1x16x2xf32>
-}
-// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[CST]]
-// CHECK-SAME: {in_bounds = [true, false, false]} : tensor<32x7x15xf32>, vector<32x8x16xf32>
-// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
-// CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
-// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index
-// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]]
-// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
-// CHECK: return %[[WRITE]] : tensor<32x4x1x16x2xf32>
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %0 : !transform.any_op
- transform.yield
- }
-}
-
-
///----------------------------------------------------------------------------------------
/// Tests for other Ops
///----------------------------------------------------------------------------------------
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 16b7a5c..7160b52 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -911,6 +911,21 @@ func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> {
// -----
+// CHECK-LABEL: func @reinterpret_constant_fold
+// CHECK-SAME: (%[[ARG:.*]]: memref<f32>)
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [100, 100], strides: [100, 1]
+// CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+// CHECK: return %[[CAST]]
+func.func @reinterpret_constant_fold(%arg0: memref<f32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c100 = arith.constant 100 : index
+ %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [%c100, %c100], strides: [%c100, %c1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ return %reinterpret_cast : memref<?x?xf32, strided<[?, ?], offset: ?>>
+}
+
+// -----
+
// CHECK-LABEL: func @reinterpret_of_reinterpret
// CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index)
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]
@@ -996,10 +1011,9 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
// when the strides don't match.
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
-// CHECK: return %[[RES]]
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [4, 2, 2], strides: [1, 1, 1]
+// CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+// CHECK: return %[[CAST]]
func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
%m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
@@ -1011,11 +1025,9 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me
// when the offset doesn't match.
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
-// CHECK: return %[[RES]]
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [8, 2], strides: [2, 1]
+// CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
+// CHECK: return %[[CAST]]
func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
%m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
diff --git a/mlir/test/Dialect/OpenACC/pointer-like-interface-alloc.mlir b/mlir/test/Dialect/OpenACC/pointer-like-interface-alloc.mlir
index 603ace8..3d4bec7 100644
--- a/mlir/test/Dialect/OpenACC/pointer-like-interface-alloc.mlir
+++ b/mlir/test/Dialect/OpenACC/pointer-like-interface-alloc.mlir
@@ -3,7 +3,7 @@
func.func @test_static_memref_alloc() {
%0 = memref.alloca() {test.ptr} : memref<10x20xf32>
// CHECK: Successfully generated alloc for operation: %[[ORIG:.*]] = memref.alloca() {test.ptr} : memref<10x20xf32>
- // CHECK: Generated: %{{.*}} = memref.alloca() : memref<10x20xf32>
+ // CHECK: Generated: %{{.*}} = memref.alloca() {acc.var_name = #acc.var_name<"test_alloc">} : memref<10x20xf32>
return
}
@@ -19,6 +19,6 @@ func.func @test_dynamic_memref_alloc() {
// CHECK: Generated: %[[DIM0:.*]] = memref.dim %[[ORIG]], %[[C0]] : memref<?x?xf32>
// CHECK: Generated: %[[C1:.*]] = arith.constant 1 : index
// CHECK: Generated: %[[DIM1:.*]] = memref.dim %[[ORIG]], %[[C1]] : memref<?x?xf32>
- // CHECK: Generated: %{{.*}} = memref.alloc(%[[DIM0]], %[[DIM1]]) : memref<?x?xf32>
+ // CHECK: Generated: %{{.*}} = memref.alloc(%[[DIM0]], %[[DIM1]]) {acc.var_name = #acc.var_name<"test_alloc">} : memref<?x?xf32>
return
}
diff --git a/mlir/test/Dialect/OpenACC/recipe-populate-firstprivate.mlir b/mlir/test/Dialect/OpenACC/recipe-populate-firstprivate.mlir
new file mode 100644
index 0000000..8846c9e
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/recipe-populate-firstprivate.mlir
@@ -0,0 +1,102 @@
+// RUN: mlir-opt %s --split-input-file --pass-pipeline="builtin.module(test-acc-recipe-populate{recipe-type=firstprivate})" | FileCheck %s
+
+// CHECK: acc.firstprivate.recipe @firstprivate_scalar : memref<f32> init {
+// CHECK: ^bb0(%{{.*}}: memref<f32>):
+// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"scalar">} : memref<f32>
+// CHECK: acc.yield %[[ALLOC]] : memref<f32>
+// CHECK: } copy {
+// CHECK: ^bb0(%[[SRC:.*]]: memref<f32>, %[[DST:.*]]: memref<f32>):
+// CHECK: memref.copy %[[SRC]], %[[DST]] : memref<f32> to memref<f32>
+// CHECK: acc.terminator
+// CHECK: }
+// CHECK-NOT: destroy
+
+func.func @test_scalar() {
+ %0 = memref.alloca() {test.var = "scalar"} : memref<f32>
+ return
+}
+
+// -----
+
+// CHECK: acc.firstprivate.recipe @firstprivate_static_2d : memref<10x20xf32> init {
+// CHECK: ^bb0(%{{.*}}: memref<10x20xf32>):
+// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"static_2d">} : memref<10x20xf32>
+// CHECK: acc.yield %[[ALLOC]] : memref<10x20xf32>
+// CHECK: } copy {
+// CHECK: ^bb0(%[[SRC:.*]]: memref<10x20xf32>, %[[DST:.*]]: memref<10x20xf32>):
+// CHECK: memref.copy %[[SRC]], %[[DST]] : memref<10x20xf32> to memref<10x20xf32>
+// CHECK: acc.terminator
+// CHECK: }
+// CHECK-NOT: destroy
+
+func.func @test_static_2d() {
+ %0 = memref.alloca() {test.var = "static_2d"} : memref<10x20xf32>
+ return
+}
+
+// -----
+
+// CHECK: acc.firstprivate.recipe @firstprivate_dynamic_2d : memref<?x?xf32> init {
+// CHECK: ^bb0(%[[ARG:.*]]: memref<?x?xf32>):
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG]], %[[C0]] : memref<?x?xf32>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<?x?xf32>
+// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) {acc.var_name = #acc.var_name<"dynamic_2d">} : memref<?x?xf32>
+// CHECK: acc.yield %[[ALLOC]] : memref<?x?xf32>
+// CHECK: } copy {
+// CHECK: ^bb0(%[[SRC:.*]]: memref<?x?xf32>, %[[DST:.*]]: memref<?x?xf32>):
+// CHECK: memref.copy %[[SRC]], %[[DST]] : memref<?x?xf32> to memref<?x?xf32>
+// CHECK: acc.terminator
+// CHECK: } destroy {
+// CHECK: ^bb0(%{{.*}}: memref<?x?xf32>, %[[VAL:.*]]: memref<?x?xf32>):
+// CHECK: memref.dealloc %[[VAL]] : memref<?x?xf32>
+// CHECK: acc.terminator
+// CHECK: }
+
+func.func @test_dynamic_2d(%arg0: index, %arg1: index) {
+ %0 = memref.alloc(%arg0, %arg1) {test.var = "dynamic_2d"} : memref<?x?xf32>
+ return
+}
+
+// -----
+
+// CHECK: acc.firstprivate.recipe @firstprivate_mixed_dims : memref<10x?xf32> init {
+// CHECK: ^bb0(%[[ARG:.*]]: memref<10x?xf32>):
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<10x?xf32>
+// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM1]]) {acc.var_name = #acc.var_name<"mixed_dims">} : memref<10x?xf32>
+// CHECK: acc.yield %[[ALLOC]] : memref<10x?xf32>
+// CHECK: } copy {
+// CHECK: ^bb0(%[[SRC:.*]]: memref<10x?xf32>, %[[DST:.*]]: memref<10x?xf32>):
+// CHECK: memref.copy %[[SRC]], %[[DST]] : memref<10x?xf32> to memref<10x?xf32>
+// CHECK: acc.terminator
+// CHECK: } destroy {
+// CHECK: ^bb0(%{{.*}}: memref<10x?xf32>, %[[VAL:.*]]: memref<10x?xf32>):
+// CHECK: memref.dealloc %[[VAL]] : memref<10x?xf32>
+// CHECK: acc.terminator
+// CHECK: }
+
+func.func @test_mixed_dims(%arg0: index) {
+ %0 = memref.alloc(%arg0) {test.var = "mixed_dims"} : memref<10x?xf32>
+ return
+}
+
+// -----
+
+// CHECK: acc.firstprivate.recipe @firstprivate_scalar_int : memref<i32> init {
+// CHECK: ^bb0(%{{.*}}: memref<i32>):
+// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"scalar_int">} : memref<i32>
+// CHECK: acc.yield %[[ALLOC]] : memref<i32>
+// CHECK: } copy {
+// CHECK: ^bb0(%[[SRC:.*]]: memref<i32>, %[[DST:.*]]: memref<i32>):
+// CHECK: memref.copy %[[SRC]], %[[DST]] : memref<i32> to memref<i32>
+// CHECK: acc.terminator
+// CHECK: }
+// CHECK-NOT: destroy
+
+func.func @test_scalar_int() {
+ %0 = memref.alloca() {test.var = "scalar_int"} : memref<i32>
+ return
+}
+
diff --git a/mlir/test/Dialect/OpenACC/recipe-populate-private.mlir b/mlir/test/Dialect/OpenACC/recipe-populate-private.mlir
new file mode 100644
index 0000000..3d5a918
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/recipe-populate-private.mlir
@@ -0,0 +1,82 @@
+// RUN: mlir-opt %s --split-input-file --pass-pipeline="builtin.module(test-acc-recipe-populate{recipe-type=private})" | FileCheck %s
+
+// CHECK: acc.private.recipe @private_scalar : memref<f32> init {
+// CHECK: ^bb0(%{{.*}}: memref<f32>):
+// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"scalar">} : memref<f32>
+// CHECK: acc.yield %[[ALLOC]] : memref<f32>
+// CHECK: }
+// CHECK-NOT: destroy
+
+func.func @test_scalar() {
+ %0 = memref.alloca() {test.var = "scalar"} : memref<f32>
+ return
+}
+
+// -----
+
+// CHECK: acc.private.recipe @private_static_2d : memref<10x20xf32> init {
+// CHECK: ^bb0(%{{.*}}: memref<10x20xf32>):
+// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"static_2d">} : memref<10x20xf32>
+// CHECK: acc.yield %[[ALLOC]] : memref<10x20xf32>
+// CHECK: }
+// CHECK-NOT: destroy
+
+func.func @test_static_2d() {
+ %0 = memref.alloca() {test.var = "static_2d"} : memref<10x20xf32>
+ return
+}
+
+// -----
+
+// CHECK: acc.private.recipe @private_dynamic_2d : memref<?x?xf32> init {
+// CHECK: ^bb0(%[[ARG:.*]]: memref<?x?xf32>):
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG]], %[[C0]] : memref<?x?xf32>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<?x?xf32>
+// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) {acc.var_name = #acc.var_name<"dynamic_2d">} : memref<?x?xf32>
+// CHECK: acc.yield %[[ALLOC]] : memref<?x?xf32>
+// CHECK: } destroy {
+// CHECK: ^bb0(%{{.*}}: memref<?x?xf32>, %[[VAL:.*]]: memref<?x?xf32>):
+// CHECK: memref.dealloc %[[VAL]] : memref<?x?xf32>
+// CHECK: acc.terminator
+// CHECK: }
+
+func.func @test_dynamic_2d(%arg0: index, %arg1: index) {
+ %0 = memref.alloc(%arg0, %arg1) {test.var = "dynamic_2d"} : memref<?x?xf32>
+ return
+}
+
+// -----
+
+// CHECK: acc.private.recipe @private_mixed_dims : memref<10x?xf32> init {
+// CHECK: ^bb0(%[[ARG:.*]]: memref<10x?xf32>):
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<10x?xf32>
+// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM1]]) {acc.var_name = #acc.var_name<"mixed_dims">} : memref<10x?xf32>
+// CHECK: acc.yield %[[ALLOC]] : memref<10x?xf32>
+// CHECK: } destroy {
+// CHECK: ^bb0(%{{.*}}: memref<10x?xf32>, %[[VAL:.*]]: memref<10x?xf32>):
+// CHECK: memref.dealloc %[[VAL]] : memref<10x?xf32>
+// CHECK: acc.terminator
+// CHECK: }
+
+func.func @test_mixed_dims(%arg0: index) {
+ %0 = memref.alloc(%arg0) {test.var = "mixed_dims"} : memref<10x?xf32>
+ return
+}
+
+// -----
+
+// CHECK: acc.private.recipe @private_scalar_int : memref<i32> init {
+// CHECK: ^bb0(%{{.*}}: memref<i32>):
+// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"scalar_int">} : memref<i32>
+// CHECK: acc.yield %[[ALLOC]] : memref<i32>
+// CHECK: }
+// CHECK-NOT: destroy
+
+func.func @test_scalar_int() {
+ %0 = memref.alloca() {test.var = "scalar_int"} : memref<i32>
+ return
+}
+
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index a1067ec..af09dc8 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -8,11 +8,11 @@
// Test bufferization using memref types that have no layout map.
// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs-from-loops unknown-type-conversion=identity-layout-map function-boundary-type-conversion=identity-layout-map bufferize-function-boundaries" -split-input-file -o /dev/null
-// CHECK-LABEL: func @scf_for_yield_only(
+// CHECK-LABEL: func private @scf_for_yield_only(
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>,
// CHECK-SAME: %[[t:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>
// CHECK-SAME: ) -> memref<?xf32> {
-func.func @scf_for_yield_only(
+func.func private @scf_for_yield_only(
%A : tensor<?xf32> {bufferization.writable = false},
%B : tensor<?xf32> {bufferization.writable = true},
%lb : index, %ub : index, %step : index)
@@ -85,11 +85,11 @@ func.func @nested_scf_for(%A : tensor<?xf32> {bufferization.writable = true},
// -----
-// CHECK-LABEL: func @scf_for_with_tensor.insert_slice
+// CHECK-LABEL: func private @scf_for_with_tensor.insert_slice
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref<4xf32, strided<[?], offset: ?>>
-func.func @scf_for_with_tensor.insert_slice(
+func.func private @scf_for_with_tensor.insert_slice(
%A : tensor<?xf32> {bufferization.writable = false},
%B : tensor<?xf32> {bufferization.writable = true},
%C : tensor<4xf32> {bufferization.writable = false},
@@ -471,11 +471,11 @@ func.func @scf_while_iter_arg_result_mismatch(%arg0: tensor<5xi1>,
// -----
-// CHECK-LABEL: func.func @parallel_insert_slice_no_conflict(
+// CHECK-LABEL: func private @parallel_insert_slice_no_conflict(
// CHECK-SAME: %[[idx:.*]]: index, %[[idx2:.*]]: index,
// CHECK-SAME: %[[arg1:.*]]: memref<?xf32, strided{{.*}}>,
// CHECK-SAME: %[[arg2:.*]]: memref<?xf32, strided{{.*}}>
-func.func @parallel_insert_slice_no_conflict(
+func.func private @parallel_insert_slice_no_conflict(
%idx: index,
%idx2: index,
%arg1: tensor<?xf32> {bufferization.writable = true},
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index 5f95da2..f66cf7a 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -8,12 +8,12 @@
// Test bufferization using memref types that have no layout map.
// RUN: mlir-opt %s -one-shot-bufferize="unknown-type-conversion=identity-layout-map bufferize-function-boundaries" -split-input-file -o /dev/null
-// CHECK-LABEL: func @insert_slice_fun
+// CHECK-LABEL: func private @insert_slice_fun
// CHECK-SAME: %[[A0:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>,
// CHECK-SAME: %[[A1:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>,
// CHECK-SAME: %[[t0:[a-zA-Z0-9]*]]: memref<4xf32, strided<[?], offset: ?>>,
// CHECK-SAME: %[[t1:[a-zA-Z0-9]*]]: memref<4xf32, strided<[?], offset: ?>>
-func.func @insert_slice_fun(
+func.func private @insert_slice_fun(
%A0 : tensor<?xf32> {bufferization.writable = false},
%A1 : tensor<?xf32> {bufferization.writable = true},
%t0 : tensor<4xf32> {bufferization.writable = false},
@@ -331,12 +331,12 @@ func.func @dim_not_reading(%t: tensor<?xf32>, %f: f32, %pos: index)
// -----
// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0 + 5)>
-// CHECK-LABEL: func.func @cast_retains_buffer_layout(
+// CHECK-LABEL: func.func private @cast_retains_buffer_layout(
// CHECK-SAME: %[[t:.*]]: memref<?xf32, #[[$map]]>, %[[sz:.*]]: index) -> memref<?xf32, strided<[1], offset: 7>> {
// CHECK: %[[casted:.*]] = memref.cast %[[t]] : memref<?xf32, #[[$map]]> to memref<10xf32, #[[$map]]>
// CHECK: %[[slice:.*]] = memref.subview %[[casted]][2] [%[[sz]]] [1] : memref<10xf32, #[[$map]]> to memref<?xf32, strided<[1], offset: 7>>
// CHECK: return %[[slice]]
-func.func @cast_retains_buffer_layout(
+func.func private @cast_retains_buffer_layout(
%t: tensor<?xf32>
{bufferization.buffer_layout = affine_map<(d0) -> (d0 + 5)>},
%sz: index)
@@ -353,12 +353,12 @@ func.func @cast_retains_buffer_layout(
// -----
-// CHECK-LABEL: func.func @cast_retains_buffer_layout_strided(
+// CHECK-LABEL: func private @cast_retains_buffer_layout_strided(
// CHECK-SAME: %[[t:.*]]: memref<?xf32, strided<[1], offset: 5>>, %[[sz:.*]]: index) -> memref<?xf32, strided<[1], offset: 7>> {
// CHECK: %[[casted:.*]] = memref.cast %[[t]] : memref<?xf32, strided<[1], offset: 5>> to memref<10xf32, strided<[1], offset: 5>>
// CHECK: %[[slice:.*]] = memref.subview %[[casted]][2] [%[[sz]]] [1] : memref<10xf32, strided<[1], offset: 5>> to memref<?xf32, strided<[1], offset: 7>>
// CHECK: return %[[slice]]
-func.func @cast_retains_buffer_layout_strided(
+func.func private @cast_retains_buffer_layout_strided(
%t: tensor<?xf32>
{bufferization.buffer_layout = strided<[1], offset: 5>},
%sz: index)
@@ -490,3 +490,32 @@ func.func @collapse_shape_regression(
tensor.collapse_shape %0[[0, 1]] : tensor<5x6xf32> into tensor<30xf32>
return
}
+
+// -----
+
+// CHECK-LABEL: func private @mult_return_callee(
+// CHECK-SAME: %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1,
+// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index) -> index {
+// CHECK: cf.cond_br %[[COND]], ^bb1, ^bb2
+// CHECK: ^bb1:
+// CHECK: return %[[A]] : index
+// CHECK: ^bb2:
+// CHECK: return %[[B]] : index
+func.func private @mult_return_callee(%t: tensor<?xf32>, %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) {
+ %casted = tensor.cast %t : tensor<?xf32> to tensor<10xf32>
+ cf.cond_br %cond,^a, ^b
+^a:
+ return %casted, %a : tensor<10xf32>, index
+^b:
+ return %casted, %b : tensor<10xf32>, index
+}
+
+// CHECK-LABEL: func @mult_return(
+// CHECK-SAME: %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1,
+// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index) -> (memref<?xf32, strided<[?], offset: ?>>, index) {
+func.func @mult_return(%t: tensor<?xf32>, %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) {
+ // CHECK: %[[RET:.*]] = call @mult_return_callee(%[[T]], %[[COND]], %[[A]], %[[B]]) : (memref<?xf32, strided<[?], offset: ?>>, i1, index, index) -> index
+ // CHECK: return %[[T]], %[[RET]] : memref<?xf32, strided<[?], offset: ?>>, index
+ %t_res, %v = func.call @mult_return_callee(%t, %cond, %a, %b) : (tensor<?xf32>, i1, index, index) -> (tensor<10xf32>, index)
+ return %t_res, %v : tensor<10xf32>, index
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-attach-target.mlir b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir
index d6c886c..a0c59c0 100644
--- a/mlir/test/Dialect/Tosa/tosa-attach-target.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir
@@ -1,12 +1,14 @@
// RUN: mlir-opt %s -split-input-file -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,dynamic level=none" | FileCheck %s --check-prefix=CHECK-ALL
// RUN: mlir-opt %s -split-input-file -tosa-attach-target="level=8k" | FileCheck %s --check-prefix=CHECK-LVL-8K
// RUN: mlir-opt %s -split-input-file -tosa-attach-target | FileCheck %s --check-prefix=CHECK-DEFAULT
+// RUN: mlir-opt %s -split-input-file -tosa-attach-target="specification_version=1.1.draft" | FileCheck %s --check-prefix=CHECK-VERSION-1P1
// -----
-// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env<level = none, profiles = [pro_int, pro_fp], extensions = [int16, int4, bf16, fp8e4m3, fp8e5m2, fft, variable, controlflow, doubleround, inexactround, dynamic]>}
-// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>}
-// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>}
+// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = none, profiles = [pro_int, pro_fp], extensions = [int16, int4, bf16, fp8e4m3, fp8e5m2, fft, variable, controlflow, doubleround, inexactround, dynamic]>}
+// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = "8k", profiles = [], extensions = []>}
+// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = "8k", profiles = [], extensions = []>}
+// CHECK-VERSION-1P1: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.1.draft", level = "8k", profiles = [], extensions = []>}
// CHECK-LABEL: test_simple
func.func @test_simple(%arg0 : tensor<1x1x1x1xf32>, %arg1 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> {
%1 = tosa.add %arg0, %arg1 : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
new file mode 100644
index 0000000..51089df
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.0 profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
+
+// -----
+
+func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> {
+ %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ // expected-error@+1 {{'tosa.matmul' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E5M2>, tensor<1xf8E4M3FN>, tensor<1xf8E5M2>) -> tensor<1x14x28xf16>
+ return %0 : tensor<1x14x28xf16>
+}
+
+// -----
+
+func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf32> {
+ %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ // expected-error@+1 {{'tosa.matmul' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf32>
+ return %0 : tensor<1x14x28xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
new file mode 100644
index 0000000..8164509
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" | FileCheck %s
+
+// -----
+
+func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> {
+ %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E5M2>, tensor<1xf8E4M3FN>, tensor<1xf8E5M2>) -> tensor<1x14x28xf16>
+ return %0 : tensor<1x14x28xf16>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_fp8_input_fp32_acc_type
+func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf32> {
+ %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf32>
+ return %0 : tensor<1x14x28xf32>
+}
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir
new file mode 100644
index 0000000..023a0e5
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir
@@ -0,0 +1,311 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s
+
+///===----------------------------------------------===//
+/// Tests of `StepCompareFolder`
+///===----------------------------------------------===//
+
+
+///===------------------------------------===//
+/// Tests of `ugt` (unsigned greater than)
+///===------------------------------------===//
+
+// CHECK-LABEL: @ugt_constant_3_lhs
+// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @ugt_constant_3_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // 3 > [0, 1, 2] => [true, true, true] => true for all indices => fold
+ %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_ugt_constant_2_lhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @negative_ugt_constant_2_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // 2 > [0, 1, 2] => [true, true, false] => not same for all indices => don't fold
+ %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @ugt_constant_3_rhs
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @ugt_constant_3_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // [0, 1, 2] > 3 => [false, false, false] => false for all indices => fold
+ %1 = arith.cmpi ugt, %0, %cst : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @ugt_constant_max_rhs
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @ugt_constant_max_rhs() -> vector<3xi1> {
+ // The largest i64 possible:
+ %cst = arith.constant dense<0x7fffffffffffffff> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ugt, %0, %cst: vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+
+// -----
+
+// CHECK-LABEL: @ugt_constant_2_rhs
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @ugt_constant_2_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // [0, 1, 2] > 2 => [false, false, false] => false for all indices => fold
+ %1 = arith.cmpi ugt, %0, %cst : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_ugt_constant_1_rhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @negative_ugt_constant_1_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<1> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // [0, 1, 2] > 1 => [false, false, true] => not same for all indices => don't fold
+ %1 = arith.cmpi ugt, %0, %cst: vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+///===------------------------------------===//
+/// Tests of `uge` (unsigned greater than or equal)
+///===------------------------------------===//
+
+
+// CHECK-LABEL: @uge_constant_2_lhs
+// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @uge_constant_2_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // 2 >= [0, 1, 2] => [true, true, true] => true for all indices => fold
+ %1 = arith.cmpi uge, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_uge_constant_1_lhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @negative_uge_constant_1_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<1> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // 1 >= [0, 1, 2] => [true, false, false] => not same for all indices => don't fold
+ %1 = arith.cmpi uge, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @uge_constant_3_rhs
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @uge_constant_3_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // [0, 1, 2] >= 3 => [false, false, false] => false for all indices => fold
+ %1 = arith.cmpi uge, %0, %cst : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_uge_constant_2_rhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @negative_uge_constant_2_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // [0, 1, 2] >= 2 => [false, false, true] => not same for all indices => don't fold
+ %1 = arith.cmpi uge, %0, %cst : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+
+///===------------------------------------===//
+/// Tests of `ult` (unsigned less than)
+///===------------------------------------===//
+
+
+// CHECK-LABEL: @ult_constant_2_lhs
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @ult_constant_2_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // 2 < [0, 1, 2] => [false, false, false] => false for all indices => fold
+ %1 = arith.cmpi ult, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_ult_constant_1_lhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @negative_ult_constant_1_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<1> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // 1 < [0, 1, 2] => [false, false, true] => not same for all indices => don't fold
+ %1 = arith.cmpi ult, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @ult_constant_3_rhs
+// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @ult_constant_3_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // [0, 1, 2] < 3 => [true, true, true] => true for all indices => fold
+ %1 = arith.cmpi ult, %0, %cst : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_ult_constant_2_rhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @negative_ult_constant_2_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // [0, 1, 2] < 2 => [true, true, false] => not same for all indices => don't fold
+ %1 = arith.cmpi ult, %0, %cst : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+///===------------------------------------===//
+/// Tests of `ule` (unsigned less than or equal)
+///===------------------------------------===//
+
+// CHECK-LABEL: @ule_constant_3_lhs
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @ule_constant_3_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ule, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_ule_constant_2_lhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @negative_ule_constant_2_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ule, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @ule_constant_2_rhs
+// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @ule_constant_2_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ule, %0, %cst : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_ule_constant_1_rhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @negative_ule_constant_1_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<1> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ule, %0, %cst: vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+///===------------------------------------===//
+/// Tests of `eq` (equal)
+///===------------------------------------===//
+
+// CHECK-LABEL: @eq_constant_3
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @eq_constant_3() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi eq, %0, %cst: vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_eq_constant_2
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @negative_eq_constant_2() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi eq, %0, %cst: vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+///===------------------------------------===//
+/// Tests of `ne` (not equal)
+///===------------------------------------===//
+
+// CHECK-LABEL: @ne_constant_3
+// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @ne_constant_3() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ne, %0, %cst: vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_ne_constant_2
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @negative_ne_constant_2() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ne, %0, %cst: vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 35db14e..e5a98b5 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -188,15 +188,38 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
// CHECK-LABEL: func @vector_fma
// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
-// TODO: We should be able to unroll this like the example above - this will require extending UnrollElementwisePattern.
-func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
+func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
%0 = vector.fma %a, %a, %a : vector<3x2x2xf32>
return %0 : vector<3x2x2xf32>
}
-// CHECK-LABEL: func @negative_vector_fma_3d
-// CHECK-NOT: vector.extract_strided_slice
-// CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32>
-// CHECK: return
+// CHECK-LABEL: func @vector_fma_3d
+// CHECK-SAME: (%[[SRC:.*]]: vector<3x2x2xf32>) -> vector<3x2x2xf32> {
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<3x2x2xf32>
+// CHECK: %[[E_LHS_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_LHS_0:.*]] = vector.shape_cast %[[E_LHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E_RHS_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_RHS_0:.*]] = vector.shape_cast %[[E_RHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E_OUT_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_OUT_0:.*]] = vector.shape_cast %[[E_OUT_0]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[FMA0:.*]] = vector.fma %[[S_LHS_0]], %[[S_RHS_0]], %[[S_OUT_0]] : vector<2x2xf32>
+// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[FMA0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
+// CHECK: %[[E_LHS_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_LHS_1:.*]] = vector.shape_cast %[[E_LHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E_RHS_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_RHS_1:.*]] = vector.shape_cast %[[E_RHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E_OUT_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_OUT_1:.*]] = vector.shape_cast %[[E_OUT_1]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[FMA1:.*]] = vector.fma %[[S_LHS_1]], %[[S_RHS_1]], %[[S_OUT_1]] : vector<2x2xf32>
+// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[FMA1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
+// CHECK: %[[E_LHS_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_LHS_2:.*]] = vector.shape_cast %[[E_LHS_2]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E_RHS_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_RHS_2:.*]] = vector.shape_cast %[[E_RHS_2]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E_OUT_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_OUT_2:.*]] = vector.shape_cast %[[E_OUT_2]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[FMA2:.*]] = vector.fma %[[S_LHS_2]], %[[S_RHS_2]], %[[S_OUT_2]] : vector<2x2xf32>
+// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[FMA2]], %[[I1]] {offsets = [2, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
+// CHECK: return %[[I2]] : vector<3x2x2xf32>
func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
%0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
@@ -440,3 +463,36 @@ func.func @vector_step() -> vector<32xindex> {
// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP]], %[[CST]] : vector<8xindex>
// CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
// CHECK: return %[[INS3]] : vector<32xindex>
+
+
+func.func @elementwise_3D_to_2D(%v1: vector<2x2x2xf32>, %v2: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
+ %0 = arith.addf %v1, %v2 : vector<2x2x2xf32>
+ return %0 : vector<2x2x2xf32>
+}
+// CHECK-LABEL: func @elementwise_3D_to_2D
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2x2x2xf32>, %[[ARG1:.*]]: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x2xf32>
+// CHECK: %[[E_LHS_0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_LHS_0:.*]] = vector.shape_cast %[[E_LHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E_RHS_0:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_RHS_0:.*]] = vector.shape_cast %[[E_RHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[ADD0:.*]] = arith.addf %[[S_LHS_0]], %[[S_RHS_0]] : vector<2x2xf32>
+// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[ADD0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
+// CHECK: %[[E_LHS_1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_LHS_1:.*]] = vector.shape_cast %[[E_LHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E_RHS_1:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_RHS_1:.*]] = vector.shape_cast %[[E_RHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[ADD1:.*]] = arith.addf %[[S_LHS_1]], %[[S_RHS_1]] : vector<2x2xf32>
+// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
+// CHECK: return %[[I1]] : vector<2x2x2xf32>
+
+
+func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf32>) -> vector<2x2x2x2xf32> {
+ %0 = arith.addf %v1, %v2 : vector<2x2x2x2xf32>
+ return %0 : vector<2x2x2x2xf32>
+}
+
+// CHECK-LABEL: func @elementwise_4D_to_2D
+// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
+// CHECK-NOT: arith.addf
+// CHECK: return
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index bb76392..401cdd29 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1925,3 +1925,22 @@ func.func @warp_scf_if_distribute(%pred : i1) {
// CHECK-PROP: "some_use"(%[[IF_YIELD_DIST]]) : (vector<1xf32>) -> ()
// CHECK-PROP: return
// CHECK-PROP: }
+
+// -----
+func.func @dedup_unused_result(%laneid : index) -> (vector<1xf32>) {
+ %r:3 = gpu.warp_execute_on_lane_0(%laneid)[32] ->
+ (vector<1xf32>, vector<2xf32>, vector<1xf32>) {
+ %2 = "some_def"() : () -> (vector<32xf32>)
+ %3 = "some_def"() : () -> (vector<64xf32>)
+ gpu.yield %2, %3, %2 : vector<32xf32>, vector<64xf32>, vector<32xf32>
+ }
+ %r0 = "some_use"(%r#2, %r#2) : (vector<1xf32>, vector<1xf32>) -> (vector<1xf32>)
+ return %r0 : vector<1xf32>
+}
+
+// CHECK-PROP: func @dedup_unused_result
+// CHECK-PROP: %[[R:.*]] = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<1xf32>)
+// CHECK-PROP: %[[Y0:.*]] = "some_def"() : () -> vector<32xf32>
+// CHECK-PROP: %[[Y1:.*]] = "some_def"() : () -> vector<64xf32>
+// CHECK-PROP: gpu.yield %[[Y0]] : vector<32xf32>
+// CHECK-PROP: "some_use"(%[[R]], %[[R]]) : (vector<1xf32>, vector<1xf32>) -> vector<1xf32>
diff --git a/mlir/test/Dialect/WasmSSA/custom_parser/global.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/global.mlir
index b9b3420..a25abbd 100644
--- a/mlir/test/Dialect/WasmSSA/custom_parser/global.mlir
+++ b/mlir/test/Dialect/WasmSSA/custom_parser/global.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s | FileCheck %s
module {
- wasmssa.import_global "from_js" from "env" as @global_0 nested : i32
+ wasmssa.import_global "from_js" from "env" as @global_0 : i32
wasmssa.global @global_1 i32 : {
%0 = wasmssa.const 10 : i32
@@ -21,7 +21,7 @@ module {
}
}
-// CHECK-LABEL: wasmssa.import_global "from_js" from "env" as @global_0 nested : i32
+// CHECK-LABEL: wasmssa.import_global "from_js" from "env" as @global_0 : i32
// CHECK-LABEL: wasmssa.global @global_1 i32 : {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
diff --git a/mlir/test/Dialect/WasmSSA/custom_parser/if.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/if.mlir
index 01068cb..cee3c69 100644
--- a/mlir/test/Dialect/WasmSSA/custom_parser/if.mlir
+++ b/mlir/test/Dialect/WasmSSA/custom_parser/if.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s | FileCheck %s
-// CHECK-LABEL: wasmssa.func nested @func_0(
+// CHECK-LABEL: wasmssa.func @func_0(
// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32
// CHECK: wasmssa.if %[[VAL_0]] : {
@@ -12,7 +12,7 @@
// CHECK: }> ^bb1
// CHECK: ^bb1(%[[VAL_3:.*]]: f32):
// CHECK: wasmssa.return %[[VAL_3]] : f32
-wasmssa.func nested @func_0(%arg0 : !wasmssa<local ref to i32>) -> i32 {
+wasmssa.func @func_0(%arg0 : !wasmssa<local ref to i32>) -> i32 {
%cond = wasmssa.local_get %arg0 : ref to i32
wasmssa.if %cond : {
%c0 = wasmssa.const 0.5 : f32
@@ -25,7 +25,7 @@ wasmssa.func nested @func_0(%arg0 : !wasmssa<local ref to i32>) -> i32 {
wasmssa.return %retVal : f32
}
-// CHECK-LABEL: wasmssa.func nested @func_1(
+// CHECK-LABEL: wasmssa.func @func_1(
// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32
// CHECK: %[[VAL_1:.*]] = wasmssa.local of type i32
@@ -38,7 +38,7 @@ wasmssa.func nested @func_0(%arg0 : !wasmssa<local ref to i32>) -> i32 {
// CHECK: ^bb1:
// CHECK: %[[VAL_4:.*]] = wasmssa.local_get %[[VAL_1]] : ref to i32
// CHECK: wasmssa.return %[[VAL_4]] : i32
-wasmssa.func nested @func_1(%arg0 : !wasmssa<local ref to i32>) -> i32 {
+wasmssa.func @func_1(%arg0 : !wasmssa<local ref to i32>) -> i32 {
%cond = wasmssa.local_get %arg0 : ref to i32
%var = wasmssa.local of type i32
%zero = wasmssa.const 0
diff --git a/mlir/test/Dialect/WasmSSA/custom_parser/import.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/import.mlir
index 3cc0548..dc23229 100644
--- a/mlir/test/Dialect/WasmSSA/custom_parser/import.mlir
+++ b/mlir/test/Dialect/WasmSSA/custom_parser/import.mlir
@@ -5,13 +5,13 @@ module {
wasmssa.import_func "bar" from "my_module" as @func_1 {sym_visibility = "nested", type = (i32) -> ()}
wasmssa.import_table "table" from "my_module" as @table_0 {sym_visibility = "nested", type = !wasmssa<tabletype !wasmssa.funcref [2:]>}
wasmssa.import_mem "mem" from "my_module" as @mem_0 {limits = !wasmssa<limit[2:]>, sym_visibility = "nested"}
- wasmssa.import_global "glob" from "my_module" as @global_0 nested : i32
- wasmssa.import_global "glob_mut" from "my_other_module" as @global_1 mutable nested : i32
+ wasmssa.import_global "glob" from "my_module" as @global_0 : i32
+ wasmssa.import_global "glob_mut" from "my_other_module" as @global_1 mutable : i32
}
// CHECK-LABEL: wasmssa.import_func "foo" from "my_module" as @func_0 {sym_visibility = "nested", type = (i32) -> ()}
// CHECK: wasmssa.import_func "bar" from "my_module" as @func_1 {sym_visibility = "nested", type = (i32) -> ()}
// CHECK: wasmssa.import_table "table" from "my_module" as @table_0 {sym_visibility = "nested", type = !wasmssa<tabletype !wasmssa.funcref [2:]>}
// CHECK: wasmssa.import_mem "mem" from "my_module" as @mem_0 {limits = !wasmssa<limit[2:]>, sym_visibility = "nested"}
-// CHECK: wasmssa.import_global "glob" from "my_module" as @global_0 nested : i32
-// CHECK: wasmssa.import_global "glob_mut" from "my_other_module" as @global_1 mutable nested : i32
+// CHECK: wasmssa.import_global "glob" from "my_module" as @global_0 : i32
+// CHECK: wasmssa.import_global "glob_mut" from "my_other_module" as @global_1 mutable : i32
diff --git a/mlir/test/Dialect/WasmSSA/custom_parser/local.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/local.mlir
index 3f6423f..f613ebf 100644
--- a/mlir/test/Dialect/WasmSSA/custom_parser/local.mlir
+++ b/mlir/test/Dialect/WasmSSA/custom_parser/local.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s | FileCheck %s
module {
- wasmssa.func nested @func_0() -> f32 {
+ wasmssa.func @func_0() -> f32 {
%0 = wasmssa.local of type f32
%1 = wasmssa.local of type f32
%2 = wasmssa.const 8.000000e+00 : f32
@@ -9,7 +9,7 @@ module {
%4 = wasmssa.add %2 %3 : f32
wasmssa.return %4 : f32
}
- wasmssa.func nested @func_1() -> i32 {
+ wasmssa.func @func_1() -> i32 {
%0 = wasmssa.local of type i32
%1 = wasmssa.local of type i32
%2 = wasmssa.const 8 : i32
@@ -17,13 +17,13 @@ module {
%4 = wasmssa.add %2 %3 : i32
wasmssa.return %4 : i32
}
- wasmssa.func nested @func_2(%arg0: !wasmssa<local ref to i32>) -> i32 {
+ wasmssa.func @func_2(%arg0: !wasmssa<local ref to i32>) -> i32 {
%0 = wasmssa.const 3 : i32
wasmssa.return %0 : i32
}
}
-// CHECK-LABEL: wasmssa.func nested @func_0() -> f32 {
+// CHECK-LABEL: wasmssa.func @func_0() -> f32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.local of type f32
// CHECK: %[[VAL_1:.*]] = wasmssa.local of type f32
// CHECK: %[[VAL_2:.*]] = wasmssa.const 8.000000e+00 : f32
@@ -31,7 +31,7 @@ module {
// CHECK: %[[VAL_4:.*]] = wasmssa.add %[[VAL_2]] %[[VAL_3]] : f32
// CHECK: wasmssa.return %[[VAL_4]] : f32
-// CHECK-LABEL: wasmssa.func nested @func_1() -> i32 {
+// CHECK-LABEL: wasmssa.func @func_1() -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.local of type i32
// CHECK: %[[VAL_1:.*]] = wasmssa.local of type i32
// CHECK: %[[VAL_2:.*]] = wasmssa.const 8 : i32
@@ -39,7 +39,7 @@ module {
// CHECK: %[[VAL_4:.*]] = wasmssa.add %[[VAL_2]] %[[VAL_3]] : i32
// CHECK: wasmssa.return %[[VAL_4]] : i32
-// CHECK-LABEL: wasmssa.func nested @func_2(
+// CHECK-LABEL: wasmssa.func @func_2(
// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 3 : i32
// CHECK: wasmssa.return %[[VAL_0]] : i32
diff --git a/mlir/test/Dialect/WasmSSA/custom_parser/memory.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/memory.mlir
index 47551db..ca6ebe0 100644
--- a/mlir/test/Dialect/WasmSSA/custom_parser/memory.mlir
+++ b/mlir/test/Dialect/WasmSSA/custom_parser/memory.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s | FileCheck %s
-// CHECK: wasmssa.memory @mem0 public !wasmssa<limit[0: 65536]>
-wasmssa.memory @mem0 public !wasmssa<limit[0:65536]>
-
-// CHECK: wasmssa.memory @mem1 nested !wasmssa<limit[512:]>
+// CHECK: wasmssa.memory @mem1 !wasmssa<limit[512:]>
wasmssa.memory @mem1 !wasmssa<limit[512:]>
+
+// CHECK: wasmssa.memory exported @mem2 !wasmssa<limit[0: 65536]>
+wasmssa.memory exported @mem2 !wasmssa<limit[0:65536]>
diff --git a/mlir/test/Dialect/WasmSSA/custom_parser/table.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/table.mlir
index 5a874f4..ea630de 100644
--- a/mlir/test/Dialect/WasmSSA/custom_parser/table.mlir
+++ b/mlir/test/Dialect/WasmSSA/custom_parser/table.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s | FileCheck %s
-// CHECK: wasmssa.table @tab0 public !wasmssa<tabletype !wasmssa.externref [0: 65536]>
-wasmssa.table @tab0 public !wasmssa<tabletype !wasmssa.externref [0:65536]>
+// CHECK: wasmssa.table exported @tab0 !wasmssa<tabletype !wasmssa.externref [0: 65536]>
+wasmssa.table exported @tab0 !wasmssa<tabletype !wasmssa.externref [0:65536]>
-// CHECK: wasmssa.table @tab1 nested !wasmssa<tabletype !wasmssa.funcref [348:]>
+// CHECK: wasmssa.table @tab1 !wasmssa<tabletype !wasmssa.funcref [348:]>
wasmssa.table @tab1 !wasmssa<tabletype !wasmssa.funcref [348:]>
diff --git a/mlir/test/Dialect/WasmSSA/extend-invalid.mlir b/mlir/test/Dialect/WasmSSA/extend-invalid.mlir
index 8d78280..7687e5f 100644
--- a/mlir/test/Dialect/WasmSSA/extend-invalid.mlir
+++ b/mlir/test/Dialect/WasmSSA/extend-invalid.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
-wasmssa.func nested @extend_low_64() -> i32 {
+wasmssa.func @extend_low_64() -> i32 {
%0 = wasmssa.const 10 : i32
// expected-error@+1 {{extend op can only take 8, 16 or 32 bits. Got 64}}
%1 = wasmssa.extend 64 low bits from %0: i32
@@ -10,7 +10,7 @@ wasmssa.func nested @extend_low_64() -> i32 {
// -----
-wasmssa.func nested @extend_too_much() -> i32 {
+wasmssa.func @extend_too_much() -> i32 {
%0 = wasmssa.const 10 : i32
// expected-error@+1 {{trying to extend the 32 low bits from a 'i32' value is illegal}}
%1 = wasmssa.extend 32 low bits from %0: i32
diff --git a/mlir/test/Dialect/WasmSSA/global-invalid.mlir b/mlir/test/Dialect/WasmSSA/global-invalid.mlir
index b9cafd8..c5bc606 100644
--- a/mlir/test/Dialect/WasmSSA/global-invalid.mlir
+++ b/mlir/test/Dialect/WasmSSA/global-invalid.mlir
@@ -13,7 +13,7 @@ module {
// -----
module {
- wasmssa.import_global "glob" from "my_module" as @global_0 mutable nested : i32
+ wasmssa.import_global "glob" from "my_module" as @global_0 mutable : i32
wasmssa.global @global_1 i32 : {
// expected-error@+1 {{global.get op is considered constant if it's referring to a import.global symbol marked non-mutable}}
%0 = wasmssa.global_get @global_0 : i32
@@ -30,3 +30,13 @@ module {
wasmssa.return %0 : i32
}
}
+
+// -----
+
+module {
+ // expected-error@+1 {{expecting either `exported` or symbol name. got exproted}}
+ wasmssa.global exproted @global_1 i32 : {
+ %0 = wasmssa.const 17 : i32
+ wasmssa.return %0 : i32
+ }
+}
diff --git a/mlir/test/Dialect/WasmSSA/locals-invalid.mlir b/mlir/test/Dialect/WasmSSA/locals-invalid.mlir
index 35c590b..eaad80e 100644
--- a/mlir/test/Dialect/WasmSSA/locals-invalid.mlir
+++ b/mlir/test/Dialect/WasmSSA/locals-invalid.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
-wasmssa.func nested @local_set_err(%arg0: !wasmssa<local ref to i32>) -> i64 {
+wasmssa.func @local_set_err(%arg0: !wasmssa<local ref to i32>) -> i64 {
%0 = wasmssa.const 3 : i64
// expected-error@+1 {{input type and result type of local.set do not match}}
wasmssa.local_set %arg0 : ref to i32 to %0 : i64
@@ -9,7 +9,7 @@ wasmssa.func nested @local_set_err(%arg0: !wasmssa<local ref to i32>) -> i64 {
// -----
-wasmssa.func nested @local_tee_err(%arg0: !wasmssa<local ref to i32>) -> i32 {
+wasmssa.func @local_tee_err(%arg0: !wasmssa<local ref to i32>) -> i32 {
%0 = wasmssa.const 3 : i64
// expected-error@+1 {{input type and output type of local.tee do not match}}
%1 = wasmssa.local_tee %arg0 : ref to i32 to %0 : i64
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 228ef69d..ebbe3ce 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -858,7 +858,7 @@ func.func @load_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>
// -----
func.func @load_mem_desc_invalid_result_size(%arg0: !xegpu.mem_desc<16x64xf16>) {
- // expected-error@+1 {{result shape must not exceed mem_desc shape}}
+ // expected-error@+1 {{data shape must not exceed mem_desc shape}}
%data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<32x16xf16>
return
}
@@ -871,6 +871,14 @@ func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) {
}
// -----
+func.func @load_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
+ %data2 = xegpu.load_matrix %arg0[8, 8] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16>
+ return
+}
+
+
+// -----
func.func @store_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf32>) {
// expected-error@+1 {{failed to verify that all of {mem_desc, data} have same element type}}
xegpu.store_matrix %arg1, %arg0[8, 8] : vector<16x16xf32>, !xegpu.mem_desc<16x64xf16>
@@ -892,30 +900,16 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve
}
// -----
-func.func @mem_desc_subview_size_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
- // expected-error@+1 {{result shape must not exceed source shape}}
- %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<32x16xf16>
- return
-}
-
-// -----
-func.func @mem_desc_subview_layout_mismatch(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride =[1, 16]>>) {
- // expected-error@+1 {{result must inherit the source strides}}
- %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride =[1, 16]>> -> !xegpu.mem_desc<8x16xf16>
- return
-}
-
-// -----
-func.func @mem_desc_subview_element_type_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
- // expected-error@+1 {{failed to verify that all of {src, res} have same element type}}
- %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf32, #xegpu.mem_layout<stride =[64, 1]>>
+func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
+ // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
+ xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
return
}
// -----
-func.func @mem_desc_subview_rank_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
- // expected-error@+1 {{result rank must not exceed source rank}}
- %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<4x8x16xf16>
+func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
+ // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
+ xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
return
}
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index bb37902..0a10f68 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -825,53 +825,73 @@ gpu.func @create_mem_desc_with_stride() {
gpu.return
}
-// CHECK: gpu.func @load_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
-gpu.func @load_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>) {
+// CHECK: gpu.func @load_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
+gpu.func @load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
// CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16>
%data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16>
gpu.return
}
-// CHECK: gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
-gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
+// CHECK: gpu.func @load_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
+gpu.func @load_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
// CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8x16xf16>
%data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8x16xf16>
gpu.return
}
+// CHECK: gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>)
+gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ // CHECK: xegpu.load_matrix [[ARG0]][8, 16] : !xegpu.mem_desc<16x64xf16> -> vector<1xf16>
+ %data = xegpu.load_matrix %arg0[8, 16]: !xegpu.mem_desc<16x64xf16> -> vector<1xf16>
+ gpu.return
+}
+
+// CHECK: gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>)
+gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>) {
+ // CHECK: xegpu.load_matrix [[ARG0]][8, 16] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<8xf16>
+ %data = xegpu.load_matrix %arg0[8, 16] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<8xf16>
+ gpu.return
+}
+
+// CHECK: gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
+gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
+ // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16>
+ %data = xegpu.load_matrix %arg0[8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16>
+ gpu.return
+}
-// CHECK: gpu.func @store_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>)
-gpu.func @store_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) {
+// CHECK: gpu.func @store_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>)
+gpu.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) {
// CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
xegpu.store_matrix %arg1, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
gpu.return
}
-// CHECK: gpu.func @store_mem_desc_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, [[ARG1:%.+]]: vector<16x16xf16>)
-gpu.func @store_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<16x16xf16>) {
+// CHECK: gpu.func @store_matrix_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, [[ARG1:%.+]]: vector<16x16xf16>)
+gpu.func @store_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<16x16xf16>) {
// CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][0, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
xegpu.store_matrix %arg1, %arg0[0, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
gpu.return
}
-// CHECK: gpu.func @mem_desc_subview([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
-gpu.func @mem_desc_subview(%arg0: !xegpu.mem_desc<16x64xf16>) {
- //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>>
- %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>>
+// CHECK: gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf16>) {
+gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf16>) {
+ // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] : vector<1xf16>, !xegpu.mem_desc<16x64xf16>
+ xegpu.store_matrix %arg1, %arg0[8, 16]: vector<1xf16>, !xegpu.mem_desc<16x64xf16>
gpu.return
}
-// CHECK: gpu.func @mem_desc_subview_lower_rank([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
-gpu.func @mem_desc_subview_lower_rank(%arg0: !xegpu.mem_desc<16x64xf16>) {
- //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout<stride = [64, 1]>>
- %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout<stride = [64, 1]>>
+// CHECK: gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>, %arg1: vector<8xf16>)
+gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>, %arg1: vector<8xf16>) {
+ // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] <{subgroup_block_io}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+ xegpu.store_matrix %arg1, %arg0[8, 16] <{subgroup_block_io}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>
gpu.return
}
-// CHECK: gpu.func @mem_desc_subview_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
-gpu.func @mem_desc_subview_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
- //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [1, 16]>>
- %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [1, 16]>>
+// CHECK: gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<8xf16>) {
+gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<8xf16>) {
+ // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ xegpu.store_matrix %arg1, %arg0[8, 8] : vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
gpu.return
}
diff --git a/mlir/test/Integration/GPU/SPIRV/simple_add.mlir b/mlir/test/Integration/GPU/SPIRV/simple_add.mlir
index cb16c37..b3154d4 100644
--- a/mlir/test/Integration/GPU/SPIRV/simple_add.mlir
+++ b/mlir/test/Integration/GPU/SPIRV/simple_add.mlir
@@ -3,7 +3,16 @@
// RUN: | FileCheck %s
// CHECK: data =
-// CHECK-RAW: [[[7.7, 0, 0], [7.7, 0, 0], [7.7, 0, 0]], [[0, 7.7, 0], [0, 7.7, 0], [0, 7.7, 0]], [[0, 0, 7.7], [0, 0, 7.7], [0, 0, 7.7]]]
+// CHECK{LITERAL}: [[[7.7, 0, 0],
+// CHECK{LITERAL}: [7.7, 0, 0],
+// CHECK{LITERAL}: [7.7, 0, 0]],
+// CHECK{LITERAL}: [[0, 7.7, 0],
+// CHECK{LITERAL}: [0, 7.7, 0],
+// CHECK{LITERAL}: [0, 7.7, 0]],
+// CHECK{LITERAL}: [[0, 0, 7.7],
+// CHECK{LITERAL}: [0, 0, 7.7],
+// CHECK{LITERAL}: [0, 0, 7.7]]]
+
module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<
diff --git a/mlir/test/Pass/remark-final.mlir b/mlir/test/Pass/remark-final.mlir
new file mode 100644
index 0000000..325271e
--- /dev/null
+++ b/mlir/test/Pass/remark-final.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt %s --test-remark --remarks-filter="category.*" --remark-policy=final 2>&1 | FileCheck %s
+// RUN: mlir-opt %s --test-remark --remarks-filter="category.*" --remark-policy=final --remark-format=yaml --remarks-output-file=%t.yaml
+// RUN: FileCheck --check-prefix=CHECK-YAML %s < %t.yaml
+module @foo {
+ "test.op"() : () -> ()
+
+}
+
+// CHECK-YAML-NOT: This is a test passed remark (should be dropped)
+// CHECK-YAML-DAG: !Analysis
+// CHECK-YAML-DAG: !Failure
+// CHECK-YAML-DAG: !Passed
+
+// CHECK-NOT: This is a test passed remark (should be dropped)
+// CHECK-DAG: remark: [Analysis] test-remark
+// CHECK-DAG: remark: [Failure] test-remark | Category:category-2-failed
+// CHECK-DAG: remark: [Passed] test-remark | Category:category-1-passed
diff --git a/mlir/test/Target/LLVMIR/Import/debug-info.ll b/mlir/test/Target/LLVMIR/Import/debug-info.ll
index e056e43..61376b8 100644
--- a/mlir/test/Target/LLVMIR/Import/debug-info.ll
+++ b/mlir/test/Target/LLVMIR/Import/debug-info.ll
@@ -240,11 +240,10 @@ define void @subprogram() !dbg !3 {
define void @func_loc() !dbg !3 {
ret void
}
-; CHECK-DAG: #[[NAME_LOC:.+]] = loc("func_loc")
; CHECK-DAG: #[[FILE_LOC:.+]] = loc("debug-info.ll":42:0)
; CHECK-DAG: #[[SP:.+]] = #llvm.di_subprogram<id = distinct[{{.*}}]<>, compileUnit = #{{.*}}, scope = #{{.*}}, name = "func_loc", file = #{{.*}}, line = 42, subprogramFlags = Definition>
-; CHECK: loc(fused<#[[SP]]>[#[[NAME_LOC]], #[[FILE_LOC]]]
+; CHECK: loc(fused<#[[SP]]>[#[[FILE_LOC]]]
!llvm.dbg.cu = !{!1}
!llvm.module.flags = !{!0}
diff --git a/mlir/test/Target/LLVMIR/Import/function-attributes.ll b/mlir/test/Target/LLVMIR/Import/function-attributes.ll
index cc3d799..00d09ba 100644
--- a/mlir/test/Target/LLVMIR/Import/function-attributes.ll
+++ b/mlir/test/Target/LLVMIR/Import/function-attributes.ll
@@ -393,6 +393,12 @@ declare void @alwaysinline_attribute() alwaysinline
// -----
+; CHECK-LABEL: @inlinehint_attribute
+; CHECK-SAME: attributes {inline_hint}
+declare void @inlinehint_attribute() inlinehint
+
+// -----
+
; CHECK-LABEL: @optnone_attribute
; CHECK-SAME: attributes {no_inline, optimize_none}
declare void @optnone_attribute() noinline optnone
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 69814f2..cc243c8 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2555,6 +2555,17 @@ llvm.func @always_inline() attributes { always_inline } {
// -----
+// CHECK-LABEL: @inline_hint
+// CHECK-SAME: #[[ATTRS:[0-9]+]]
+llvm.func @inline_hint() attributes { inline_hint } {
+ llvm.return
+}
+
+// CHECK: #[[ATTRS]]
+// CHECK-SAME: inlinehint
+
+// -----
+
// CHECK-LABEL: @optimize_none
// CHECK-SAME: #[[ATTRS:[0-9]+]]
llvm.func @optimize_none() attributes { no_inline, optimize_none } {
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
new file mode 100644
index 0000000..04e2ddf
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @convert_f32x2_to_f4x2_e2m1
+llvm.func @convert_f32x2_to_f4x2_e2m1(%srcA : f32, %srcB : f32) {
+ // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+ // CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8
+ %res1 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB : i8 (f4E2M1FN)
+ // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
+ // CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8
+ %res2 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB {relu = true} : i8 (f4E2M1FN)
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 0b36154..6cccfe4 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -254,6 +254,14 @@ llvm.func @nvvm_cvt_f32x2_to_f6x2_invalid_type(%a : f32, %b : f32) {
// -----
+llvm.func @nvvm_cvt_f32x2_to_f4x2_invalid_type(%a : f32, %b : f32) {
+ // expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from f32x2 to f4x2.}}
+ %res = nvvm.convert.f32x2.to.f4x2 %a, %b : i8 (f8E4M3FN)
+ llvm.return
+}
+
+// -----
+
llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
// expected-error @below {{cache eviction priority supported only for cache level L2}}
nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>
@@ -559,3 +567,25 @@ llvm.func @clusterlaunchcontrol_query_cancel_get_first_cta_id_invalid_return_typ
%res = nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_x, %try_cancel_response : i1
llvm.return
}
+
+// -----
+
+// Test that ensures invalid row/col layouts for matrices A and B are not accepted
+llvm.func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32,i32,i32,i32)> {
+ // expected-error@+1 {{Only m8n8k4 with f16 supports other layouts.}}
+ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<col>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>,
+ intOverflowBehavior=#nvvm.mma_int_overflow<satfinite>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
+}
+
+// -----
+
+// Test for range validation - invalid range where lower == upper but not at extremes
+func.func @invalid_range_equal_bounds() {
+ // expected-error @below {{invalid range attribute: Lower == Upper, but they aren't min (0) or max (4294967295) value! This is an invalid constant range.}}
+ %0 = nvvm.read.ptx.sreg.warpsize range <i32, 32, 32> : i32
+ return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 00a479d..594ae48 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -152,6 +152,10 @@ llvm.func @nvvm_special_regs() -> i32 {
%74 = nvvm.read.ptx.sreg.lanemask.ge : i32
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.gt
%75 = nvvm.read.ptx.sreg.lanemask.gt : i32
+ // CHECK: %76 = call range(i32 0, 0) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+ %76 = nvvm.read.ptx.sreg.tid.x range <i32, 0, 0> : i32
+ // CHECK: %77 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+ %77 = nvvm.read.ptx.sreg.tid.x range <i32, 4294967295, 4294967295> : i32
llvm.return %1 : i32
}
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index fdd2c91..6536fac 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -276,6 +276,20 @@ llvm.func @rocdl.s.wait.expcnt() {
llvm.return
}
+llvm.func @rocdl.s.wait.asynccnt() {
+ // CHECK-LABEL: rocdl.s.wait.asynccnt
+ // CHECK-NEXT: call void @llvm.amdgcn.s.wait.asynccnt(i16 0)
+ rocdl.s.wait.asynccnt 0
+ llvm.return
+}
+
+llvm.func @rocdl.s.wait.tensorcnt() {
+ // CHECK-LABEL: rocdl.s.wait.tensorcnt
+ // CHECK-NEXT: call void @llvm.amdgcn.s.wait.tensorcnt(i16 0)
+ rocdl.s.wait.tensorcnt 0
+ llvm.return
+}
+
llvm.func @rocdl.setprio() {
// CHECK: call void @llvm.amdgcn.s.setprio(i16 0)
rocdl.s.setprio 0
diff --git a/mlir/test/Target/Wasm/abs.mlir b/mlir/test/Target/Wasm/abs.mlir
index 9c45ba7..fe3602a 100644
--- a/mlir/test/Target/Wasm/abs.mlir
+++ b/mlir/test/Target/Wasm/abs.mlir
@@ -12,12 +12,12 @@
)
*/
-// CHECK-LABEL: wasmssa.func @abs_f32() -> f32 {
+// CHECK-LABEL: wasmssa.func exported @abs_f32() -> f32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32
// CHECK: %[[VAL_1:.*]] = wasmssa.abs %[[VAL_0]] : f32
// CHECK: wasmssa.return %[[VAL_1]] : f32
-// CHECK-LABEL: wasmssa.func @abs_f64() -> f64 {
+// CHECK-LABEL: wasmssa.func exported @abs_f64() -> f64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64
// CHECK: %[[VAL_1:.*]] = wasmssa.abs %[[VAL_0]] : f64
// CHECK: wasmssa.return %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/add_div.mlir b/mlir/test/Target/Wasm/add_div.mlir
new file mode 100644
index 0000000..8a87c60
--- /dev/null
+++ b/mlir/test/Target/Wasm/add_div.mlir
@@ -0,0 +1,40 @@
+// RUN: yaml2obj %S/inputs/add_div.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+ (module $test.wasm
+ (type (;0;) (func (param i32) (result i32)))
+ (type (;1;) (func (param i32 i32) (result i32)))
+ (import "env" "twoTimes" (func $twoTimes (type 0)))
+ (func $add (type 1) (param i32 i32) (result i32)
+ local.get 0
+ call $twoTimes
+ local.get 1
+ call $twoTimes
+ i32.add
+ i32.const 2
+ i32.div_s)
+ (memory (;0;) 2)
+ (global $__stack_pointer (mut i32) (i32.const 66560))
+ (export "memory" (memory 0))
+ (export "add" (func $add)))
+*/
+
+// CHECK-LABEL: wasmssa.import_func "twoTimes" from "env" as @func_0 {type = (i32) -> i32}
+
+// CHECK-LABEL: wasmssa.func exported @add(
+// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>,
+// CHECK-SAME: %[[ARG1:.*]]: !wasmssa<local ref to i32>) -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.call @func_0(%[[VAL_0]]) : (i32) -> i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.local_get %[[ARG1]] : ref to i32
+// CHECK: %[[VAL_3:.*]] = wasmssa.call @func_0(%[[VAL_2]]) : (i32) -> i32
+// CHECK: %[[VAL_4:.*]] = wasmssa.add %[[VAL_1]] %[[VAL_3]] : i32
+// CHECK: %[[VAL_5:.*]] = wasmssa.const 2 : i32
+// CHECK: %[[VAL_6:.*]] = wasmssa.div_si %[[VAL_4]] %[[VAL_5]] : i32
+// CHECK: wasmssa.return %[[VAL_6]] : i32
+// CHECK: }
+// CHECK: wasmssa.memory exported @memory !wasmssa<limit[2:]>
+
+// CHECK-LABEL: wasmssa.global @global_0 i32 mutable : {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 66560 : i32
+// CHECK: wasmssa.return %[[VAL_0]] : i32
diff --git a/mlir/test/Target/Wasm/and.mlir b/mlir/test/Target/Wasm/and.mlir
index 4c0fea0..323d41a 100644
--- a/mlir/test/Target/Wasm/and.mlir
+++ b/mlir/test/Target/Wasm/and.mlir
@@ -14,13 +14,13 @@
)
*/
-// CHECK-LABEL: wasmssa.func @and_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @and_i32() -> i32 {
// CHECK: %0 = wasmssa.const 10 : i32
// CHECK: %1 = wasmssa.const 3 : i32
// CHECK: %2 = wasmssa.and %0 %1 : i32
// CHECK: wasmssa.return %2 : i32
-// CHECK-LABEL: wasmssa.func @and_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @and_i64() -> i64 {
// CHECK: %0 = wasmssa.const 10 : i64
// CHECK: %1 = wasmssa.const 3 : i64
// CHECK: %2 = wasmssa.and %0 %1 : i64
diff --git a/mlir/test/Target/Wasm/block.mlir b/mlir/test/Target/Wasm/block.mlir
new file mode 100644
index 0000000..c85fc1e
--- /dev/null
+++ b/mlir/test/Target/Wasm/block.mlir
@@ -0,0 +1,16 @@
+// RUN: yaml2obj %S/inputs/block.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module
+(func(export "i_am_a_block")
+(block $i_am_a_block)
+)
+)
+*/
+
+// CHECK-LABEL: wasmssa.func exported @i_am_a_block() {
+// CHECK: wasmssa.block : {
+// CHECK: wasmssa.block_return
+// CHECK: }> ^bb1
+// CHECK: ^bb1:
+// CHECK: wasmssa.return
diff --git a/mlir/test/Target/Wasm/block_complete_type.mlir b/mlir/test/Target/Wasm/block_complete_type.mlir
new file mode 100644
index 0000000..67df198
--- /dev/null
+++ b/mlir/test/Target/Wasm/block_complete_type.mlir
@@ -0,0 +1,24 @@
+// RUN: yaml2obj %S/inputs/block_complete_type.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module
+ (type (;0;) (func (param i32) (result i32)))
+ (type (;1;) (func (result i32)))
+ (func (;0;) (type 1) (result i32)
+ i32.const 14
+ block (param i32) (result i32) ;; label = @1
+ i32.const 1
+ i32.add
+ end))
+*/
+
+// CHECK-LABEL: wasmssa.func @func_0() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 14 : i32
+// CHECK: wasmssa.block(%[[VAL_0]]) : i32 : {
+// CHECK: ^bb0(%[[VAL_1:.*]]: i32):
+// CHECK: %[[VAL_2:.*]] = wasmssa.const 1 : i32
+// CHECK: %[[VAL_3:.*]] = wasmssa.add %[[VAL_1]] %[[VAL_2]] : i32
+// CHECK: wasmssa.block_return %[[VAL_3]] : i32
+// CHECK: }> ^bb1
+// CHECK: ^bb1(%[[VAL_4:.*]]: i32):
+// CHECK: wasmssa.return %[[VAL_4]] : i32
diff --git a/mlir/test/Target/Wasm/block_value_type.mlir b/mlir/test/Target/Wasm/block_value_type.mlir
new file mode 100644
index 0000000..fa30f08
--- /dev/null
+++ b/mlir/test/Target/Wasm/block_value_type.mlir
@@ -0,0 +1,19 @@
+// RUN: yaml2obj %S/inputs/block_value_type.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module
+ (type (;0;) (func (result i32)))
+ (func (;0;) (type 0) (result i32)
+ block (result i32) ;; label = @1
+ i32.const 17
+ end))
+*/
+
+
+// CHECK-LABEL: wasmssa.func @func_0() -> i32 {
+// CHECK: wasmssa.block : {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 17 : i32
+// CHECK: wasmssa.block_return %[[VAL_0]] : i32
+// CHECK: }> ^bb1
+// CHECK: ^bb1(%[[VAL_1:.*]]: i32):
+// CHECK: wasmssa.return %[[VAL_1]] : i32
diff --git a/mlir/test/Target/Wasm/branch_if.mlir b/mlir/test/Target/Wasm/branch_if.mlir
new file mode 100644
index 0000000..c91ff37
--- /dev/null
+++ b/mlir/test/Target/Wasm/branch_if.mlir
@@ -0,0 +1,29 @@
+// RUN: yaml2obj %S/inputs/branch_if.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module
+ (type $produce_i32 (func (result i32)))
+ (func (type $produce_i32)
+ (block $my_block (type $produce_i32)
+ i32.const 1
+ i32.const 2
+ br_if $my_block
+ i32.const 1
+ i32.add
+ )
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @func_0() -> i32 {
+// CHECK: wasmssa.block : {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 2 : i32
+// CHECK: wasmssa.branch_if %[[VAL_1]] to level 0 with args(%[[VAL_0]] : i32) else ^bb1
+// CHECK: ^bb1:
+// CHECK: %[[VAL_2:.*]] = wasmssa.const 1 : i32
+// CHECK: %[[VAL_3:.*]] = wasmssa.add %[[VAL_0]] %[[VAL_2]] : i32
+// CHECK: wasmssa.block_return %[[VAL_3]] : i32
+// CHECK: }> ^bb1
+// CHECK: ^bb1(%[[VAL_4:.*]]: i32):
+// CHECK: wasmssa.return %[[VAL_4]] : i32
diff --git a/mlir/test/Target/Wasm/call.mlir b/mlir/test/Target/Wasm/call.mlir
new file mode 100644
index 0000000..c0169aa
--- /dev/null
+++ b/mlir/test/Target/Wasm/call.mlir
@@ -0,0 +1,17 @@
+// RUN: yaml2obj %S/inputs/call.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module
+(func $forty_two (result i32)
+i32.const 42)
+(func(export "forty_two")(result i32)
+call $forty_two))
+*/
+
+// CHECK-LABEL: wasmssa.func @func_0() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 42 : i32
+// CHECK: wasmssa.return %[[VAL_0]] : i32
+
+// CHECK-LABEL: wasmssa.func exported @forty_two() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.call @func_0 : () -> i32
+// CHECK: wasmssa.return %[[VAL_0]] : i32
diff --git a/mlir/test/Target/Wasm/clz.mlir b/mlir/test/Target/Wasm/clz.mlir
index 3e6641d..858c09d 100644
--- a/mlir/test/Target/Wasm/clz.mlir
+++ b/mlir/test/Target/Wasm/clz.mlir
@@ -14,12 +14,12 @@
)
*/
-// CHECK-LABEL: wasmssa.func @clz_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @clz_i32() -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
// CHECK: %[[VAL_1:.*]] = wasmssa.clz %[[VAL_0]] : i32
// CHECK: wasmssa.return %[[VAL_1]] : i32
-// CHECK-LABEL: wasmssa.func @clz_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @clz_i64() -> i64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
// CHECK: %[[VAL_1:.*]] = wasmssa.clz %[[VAL_0]] : i64
// CHECK: wasmssa.return %[[VAL_1]] : i64
diff --git a/mlir/test/Target/Wasm/comparison_ops.mlir b/mlir/test/Target/Wasm/comparison_ops.mlir
new file mode 100644
index 0000000..91e3a6a
--- /dev/null
+++ b/mlir/test/Target/Wasm/comparison_ops.mlir
@@ -0,0 +1,269 @@
+// RUN: yaml2obj %S/inputs/comparison_ops.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+/* Source code used to create this test:
+(module
+ (func $lt_si32 (result i32)
+ i32.const 12
+ i32.const 50
+ i32.lt_s
+ )
+ (func $le_si32 (result i32)
+ i32.const 12
+ i32.const 50
+ i32.le_s
+ )
+ (func $lt_ui32 (result i32)
+ i32.const 12
+ i32.const 50
+ i32.lt_u
+ )
+ (func $le_ui32 (result i32)
+ i32.const 12
+ i32.const 50
+ i32.le_u
+ )
+ (func $gt_si32 (result i32)
+ i32.const 12
+ i32.const 50
+ i32.gt_s
+ )
+ (func $gt_ui32 (result i32)
+ i32.const 12
+ i32.const 50
+ i32.gt_u
+ )
+ (func $ge_si32 (result i32)
+ i32.const 12
+ i32.const 50
+ i32.ge_s
+ )
+ (func $ge_ui32 (result i32)
+ i32.const 12
+ i32.const 50
+ i32.ge_u
+ )
+ (func $lt_si64 (result i32)
+ i64.const 12
+ i64.const 50
+ i64.lt_s
+ )
+ (func $le_si64 (result i32)
+ i64.const 12
+ i64.const 50
+ i64.le_s
+ )
+ (func $lt_ui64 (result i32)
+ i64.const 12
+ i64.const 50
+ i64.lt_u
+ )
+ (func $le_ui64 (result i32)
+ i64.const 12
+ i64.const 50
+ i64.le_u
+ )
+ (func $gt_si64 (result i32)
+ i64.const 12
+ i64.const 50
+ i64.gt_s
+ )
+ (func $gt_ui64 (result i32)
+ i64.const 12
+ i64.const 50
+ i64.gt_u
+ )
+ (func $ge_si64 (result i32)
+ i64.const 12
+ i64.const 50
+ i64.ge_s
+ )
+ (func $ge_ui64 (result i32)
+ i64.const 12
+ i64.const 50
+ i64.ge_u
+ )
+ (func $lt_f32 (result i32)
+ f32.const 5
+ f32.const 14
+ f32.lt
+ )
+ (func $le_f32 (result i32)
+ f32.const 5
+ f32.const 14
+ f32.le
+ )
+ (func $gt_f32 (result i32)
+ f32.const 5
+ f32.const 14
+ f32.gt
+ )
+ (func $ge_f32 (result i32)
+ f32.const 5
+ f32.const 14
+ f32.ge
+ )
+ (func $lt_f64 (result i32)
+ f64.const 5
+ f64.const 14
+ f64.lt
+ )
+ (func $le_f64 (result i32)
+ f64.const 5
+ f64.const 14
+ f64.le
+ )
+ (func $gt_f64 (result i32)
+ f64.const 5
+ f64.const 14
+ f64.gt
+ )
+ (func $ge_f64 (result i32)
+ f64.const 5
+ f64.const 14
+ f64.ge
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @func_0() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.lt_si %[[VAL_0]] %[[VAL_1]] : i32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_1() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.le_si %[[VAL_0]] %[[VAL_1]] : i32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_2() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.lt_ui %[[VAL_0]] %[[VAL_1]] : i32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_3() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.le_ui %[[VAL_0]] %[[VAL_1]] : i32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_4() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.gt_si %[[VAL_0]] %[[VAL_1]] : i32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_5() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.gt_ui %[[VAL_0]] %[[VAL_1]] : i32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_6() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.ge_si %[[VAL_0]] %[[VAL_1]] : i32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_7() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.ge_ui %[[VAL_0]] %[[VAL_1]] : i32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_8() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.lt_si %[[VAL_0]] %[[VAL_1]] : i64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_9() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.le_si %[[VAL_0]] %[[VAL_1]] : i64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_10() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.lt_ui %[[VAL_0]] %[[VAL_1]] : i64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_11() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.le_ui %[[VAL_0]] %[[VAL_1]] : i64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_12() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.gt_si %[[VAL_0]] %[[VAL_1]] : i64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_13() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.gt_ui %[[VAL_0]] %[[VAL_1]] : i64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_14() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.ge_si %[[VAL_0]] %[[VAL_1]] : i64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_15() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.ge_ui %[[VAL_0]] %[[VAL_1]] : i64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_16() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f32
+// CHECK: %[[VAL_2:.*]] = wasmssa.lt %[[VAL_0]] %[[VAL_1]] : f32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_17() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f32
+// CHECK: %[[VAL_2:.*]] = wasmssa.le %[[VAL_0]] %[[VAL_1]] : f32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_18() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f32
+// CHECK: %[[VAL_2:.*]] = wasmssa.gt %[[VAL_0]] %[[VAL_1]] : f32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_19() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f32
+// CHECK: %[[VAL_2:.*]] = wasmssa.ge %[[VAL_0]] %[[VAL_1]] : f32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_20() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f64
+// CHECK: %[[VAL_2:.*]] = wasmssa.lt %[[VAL_0]] %[[VAL_1]] : f64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_21() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f64
+// CHECK: %[[VAL_2:.*]] = wasmssa.le %[[VAL_0]] %[[VAL_1]] : f64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_22() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f64
+// CHECK: %[[VAL_2:.*]] = wasmssa.gt %[[VAL_0]] %[[VAL_1]] : f64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_23() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f64
+// CHECK: %[[VAL_2:.*]] = wasmssa.ge %[[VAL_0]] %[[VAL_1]] : f64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
diff --git a/mlir/test/Target/Wasm/const.mlir b/mlir/test/Target/Wasm/const.mlir
index aa9e76f..adb792a 100644
--- a/mlir/test/Target/Wasm/const.mlir
+++ b/mlir/test/Target/Wasm/const.mlir
@@ -16,22 +16,22 @@
)
*/
-// CHECK-LABEL: wasmssa.func nested @func_0() -> i32 {
+// CHECK-LABEL: wasmssa.func @func_0() -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1 : i32
// CHECK: wasmssa.return %[[VAL_0]] : i32
// CHECK: }
-// CHECK-LABEL: wasmssa.func nested @func_1() -> i64 {
+// CHECK-LABEL: wasmssa.func @func_1() -> i64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 3 : i64
// CHECK: wasmssa.return %[[VAL_0]] : i64
// CHECK: }
-// CHECK-LABEL: wasmssa.func nested @func_2() -> f32 {
+// CHECK-LABEL: wasmssa.func @func_2() -> f32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 4.000000e+00 : f32
// CHECK: wasmssa.return %[[VAL_0]] : f32
// CHECK: }
-// CHECK-LABEL: wasmssa.func nested @func_3() -> f64 {
+// CHECK-LABEL: wasmssa.func @func_3() -> f64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 9.000000e+00 : f64
// CHECK: wasmssa.return %[[VAL_0]] : f64
// CHECK: }
diff --git a/mlir/test/Target/Wasm/convert.mlir b/mlir/test/Target/Wasm/convert.mlir
new file mode 100644
index 0000000..ddc29a7
--- /dev/null
+++ b/mlir/test/Target/Wasm/convert.mlir
@@ -0,0 +1,85 @@
+// RUN: yaml2obj %S/inputs/convert.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to generate this test:
+(module
+ (func (export "convert_i32_u_to_f32") (result f32)
+ i32.const 10
+ f32.convert_i32_u
+ )
+
+ (func (export "convert_i32_s_to_f32") (result f32)
+ i32.const 42
+ f32.convert_i32_s
+ )
+
+ (func (export "convert_i64_u_to_f32") (result f32)
+ i64.const 17
+ f32.convert_i64_u
+ )
+
+ (func (export "convert_i64s_to_f32") (result f32)
+ i64.const 10
+ f32.convert_i64_s
+ )
+
+ (func (export "convert_i32_u_to_f64") (result f64)
+ i32.const 10
+ f64.convert_i32_u
+ )
+
+ (func (export "convert_i32_s_to_f64") (result f64)
+ i32.const 42
+ f64.convert_i32_s
+ )
+
+ (func (export "convert_i64_u_to_f64") (result f64)
+ i64.const 17
+ f64.convert_i64_u
+ )
+
+ (func (export "convert_i64s_to_f64") (result f64)
+ i64.const 10
+ f64.convert_i64_s
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func exported @convert_i32_u_to_f32() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.convert_u %[[VAL_0]] : i32 to f32
+// CHECK: wasmssa.return %[[VAL_1]] : f32
+
+// CHECK-LABEL: wasmssa.func exported @convert_i32_s_to_f32() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 42 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.convert_s %[[VAL_0]] : i32 to f32
+// CHECK: wasmssa.return %[[VAL_1]] : f32
+
+// CHECK-LABEL: wasmssa.func exported @convert_i64_u_to_f32() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 17 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.convert_u %[[VAL_0]] : i64 to f32
+// CHECK: wasmssa.return %[[VAL_1]] : f32
+
+// CHECK-LABEL: wasmssa.func exported @convert_i64s_to_f32() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.convert_s %[[VAL_0]] : i64 to f32
+// CHECK: wasmssa.return %[[VAL_1]] : f32
+
+// CHECK-LABEL: wasmssa.func exported @convert_i32_u_to_f64() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.convert_u %[[VAL_0]] : i32 to f64
+// CHECK: wasmssa.return %[[VAL_1]] : f64
+
+// CHECK-LABEL: wasmssa.func exported @convert_i32_s_to_f64() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 42 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.convert_s %[[VAL_0]] : i32 to f64
+// CHECK: wasmssa.return %[[VAL_1]] : f64
+
+// CHECK-LABEL: wasmssa.func exported @convert_i64_u_to_f64() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 17 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.convert_u %[[VAL_0]] : i64 to f64
+// CHECK: wasmssa.return %[[VAL_1]] : f64
+
+// CHECK-LABEL: wasmssa.func exported @convert_i64s_to_f64() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.convert_s %[[VAL_0]] : i64 to f64
+// CHECK: wasmssa.return %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/copysign.mlir b/mlir/test/Target/Wasm/copysign.mlir
index 33d7a56..90c5b11 100644
--- a/mlir/test/Target/Wasm/copysign.mlir
+++ b/mlir/test/Target/Wasm/copysign.mlir
@@ -16,14 +16,14 @@
)
*/
-// CHECK-LABEL: wasmssa.func @copysign_f32() -> f32 {
+// CHECK-LABEL: wasmssa.func exported @copysign_f32() -> f32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32
// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.000000e+00 : f32
// CHECK: %[[VAL_2:.*]] = wasmssa.copysign %[[VAL_0]] %[[VAL_1]] : f32
// CHECK: wasmssa.return %[[VAL_2]] : f32
// CHECK: }
-// CHECK-LABEL: wasmssa.func @copysign_f64() -> f64 {
+// CHECK-LABEL: wasmssa.func exported @copysign_f64() -> f64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64
// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.000000e+00 : f64
// CHECK: %[[VAL_2:.*]] = wasmssa.copysign %[[VAL_0]] %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/ctz.mlir b/mlir/test/Target/Wasm/ctz.mlir
index 6c0806f..9e7cc5e 100644
--- a/mlir/test/Target/Wasm/ctz.mlir
+++ b/mlir/test/Target/Wasm/ctz.mlir
@@ -14,12 +14,12 @@
)
*/
-// CHECK-LABEL: wasmssa.func @ctz_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @ctz_i32() -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
// CHECK: %[[VAL_1:.*]] = wasmssa.ctz %[[VAL_0]] : i32
// CHECK: wasmssa.return %[[VAL_1]] : i32
-// CHECK-LABEL: wasmssa.func @ctz_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @ctz_i64() -> i64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
// CHECK: %[[VAL_1:.*]] = wasmssa.ctz %[[VAL_0]] : i64
// CHECK: wasmssa.return %[[VAL_1]] : i64
diff --git a/mlir/test/Target/Wasm/demote.mlir b/mlir/test/Target/Wasm/demote.mlir
new file mode 100644
index 0000000..3d2bc05
--- /dev/null
+++ b/mlir/test/Target/Wasm/demote.mlir
@@ -0,0 +1,15 @@
+// RUN: yaml2obj %S/inputs/demote.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module
+ (func $main (result f32)
+ f64.const 2.24
+ f32.demote_f64
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @func_0() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 2.240000e+00 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.demote %[[VAL_0]] : f64 to f32
+// CHECK: wasmssa.return %[[VAL_1]] : f32
diff --git a/mlir/test/Target/Wasm/div.mlir b/mlir/test/Target/Wasm/div.mlir
index c91f780..4967d96 100644
--- a/mlir/test/Target/Wasm/div.mlir
+++ b/mlir/test/Target/Wasm/div.mlir
@@ -66,61 +66,61 @@
)
*/
-// CHECK-LABEL: wasmssa.func @div_u_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @div_u_i32() -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
// CHECK: %[[VAL_1:.*]] = wasmssa.const 2 : i32
// CHECK: %[[VAL_2:.*]] = wasmssa.div_ui %[[VAL_0]] %[[VAL_1]] : i32
// CHECK: wasmssa.return %[[VAL_2]] : i32
-// CHECK-LABEL: wasmssa.func @div_u_i32_zero() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @div_u_i32_zero() -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
// CHECK: %[[VAL_1:.*]] = wasmssa.const 0 : i32
// CHECK: %[[VAL_2:.*]] = wasmssa.div_ui %[[VAL_0]] %[[VAL_1]] : i32
// CHECK: wasmssa.return %[[VAL_2]] : i32
-// CHECK-LABEL: wasmssa.func @div_s_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @div_s_i32() -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
// CHECK: %[[VAL_1:.*]] = wasmssa.const 2 : i32
// CHECK: %[[VAL_2:.*]] = wasmssa.div_si %[[VAL_0]] %[[VAL_1]] : i32
// CHECK: wasmssa.return %[[VAL_2]] : i32
-// CHECK-LABEL: wasmssa.func @div_s_i32_zero() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @div_s_i32_zero() -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
// CHECK: %[[VAL_1:.*]] = wasmssa.const 0 : i32
// CHECK: %[[VAL_2:.*]] = wasmssa.div_si %[[VAL_0]] %[[VAL_1]] : i32
// CHECK: wasmssa.return %[[VAL_2]] : i32
-// CHECK-LABEL: wasmssa.func @div_u_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @div_u_i64() -> i64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
// CHECK: %[[VAL_1:.*]] = wasmssa.const 2 : i64
// CHECK: %[[VAL_2:.*]] = wasmssa.div_ui %[[VAL_0]] %[[VAL_1]] : i64
// CHECK: wasmssa.return %[[VAL_2]] : i64
-// CHECK-LABEL: wasmssa.func @div_u_i64_zero() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @div_u_i64_zero() -> i64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
// CHECK: %[[VAL_1:.*]] = wasmssa.const 0 : i64
// CHECK: %[[VAL_2:.*]] = wasmssa.div_ui %[[VAL_0]] %[[VAL_1]] : i64
// CHECK: wasmssa.return %[[VAL_2]] : i64
-// CHECK-LABEL: wasmssa.func @div_s_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @div_s_i64() -> i64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
// CHECK: %[[VAL_1:.*]] = wasmssa.const 2 : i64
// CHECK: %[[VAL_2:.*]] = wasmssa.div_si %[[VAL_0]] %[[VAL_1]] : i64
// CHECK: wasmssa.return %[[VAL_2]] : i64
-// CHECK-LABEL: wasmssa.func @div_s_i64_zero() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @div_s_i64_zero() -> i64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
// CHECK: %[[VAL_1:.*]] = wasmssa.const 0 : i64
// CHECK: %[[VAL_2:.*]] = wasmssa.div_si %[[VAL_0]] %[[VAL_1]] : i64
// CHECK: wasmssa.return %[[VAL_2]] : i64
-// CHECK-LABEL: wasmssa.func @div_f32() -> f32 {
+// CHECK-LABEL: wasmssa.func exported @div_f32() -> f32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32
// CHECK: %[[VAL_1:.*]] = wasmssa.const 2.000000e+00 : f32
// CHECK: %[[VAL_2:.*]] = wasmssa.div %[[VAL_0]] %[[VAL_1]] : f32
// CHECK: wasmssa.return %[[VAL_2]] : f32
-// CHECK-LABEL: wasmssa.func @div_f64() -> f64 {
+// CHECK-LABEL: wasmssa.func exported @div_f64() -> f64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64
// CHECK: %[[VAL_1:.*]] = wasmssa.const 2.000000e+00 : f64
// CHECK: %[[VAL_2:.*]] = wasmssa.div %[[VAL_0]] %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/double_nested_loop.mlir b/mlir/test/Target/Wasm/double_nested_loop.mlir
new file mode 100644
index 0000000..8b3e499
--- /dev/null
+++ b/mlir/test/Target/Wasm/double_nested_loop.mlir
@@ -0,0 +1,63 @@
+// RUN: yaml2obj %S/inputs/double_nested_loop.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/*
+(module
+ (func
+ ;; create a local variable and initialize it to 0
+ (local $i i32)
+ (local $j i32)
+
+ (loop $my_loop
+
+ ;; add one to $i
+ local.get $i
+ i32.const 1
+ i32.add
+ local.set $i
+ (loop $my_second_loop (result i32)
+ i32.const 1
+ local.get $j
+ i32.const 12
+ i32.add
+ local.tee $j
+ local.get $i
+ i32.gt_s
+ br_if $my_second_loop
+ )
+ i32.const 10
+ i32.lt_s
+ br_if $my_loop
+ )
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @func_0() {
+// CHECK: %[[VAL_0:.*]] = wasmssa.local of type i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.local of type i32
+// CHECK: wasmssa.loop : {
+// CHECK: %[[VAL_2:.*]] = wasmssa.local_get %[[VAL_0]] : ref to i32
+// CHECK: %[[VAL_3:.*]] = wasmssa.const 1 : i32
+// CHECK: %[[VAL_4:.*]] = wasmssa.add %[[VAL_2]] %[[VAL_3]] : i32
+// CHECK: wasmssa.local_set %[[VAL_0]] : ref to i32 to %[[VAL_4]] : i32
+// CHECK: wasmssa.loop : {
+// CHECK: %[[VAL_5:.*]] = wasmssa.const 1 : i32
+// CHECK: %[[VAL_6:.*]] = wasmssa.local_get %[[VAL_1]] : ref to i32
+// CHECK: %[[VAL_7:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_8:.*]] = wasmssa.add %[[VAL_6]] %[[VAL_7]] : i32
+// CHECK: %[[VAL_9:.*]] = wasmssa.local_tee %[[VAL_1]] : ref to i32 to %[[VAL_8]] : i32
+// CHECK: %[[VAL_10:.*]] = wasmssa.local_get %[[VAL_0]] : ref to i32
+// CHECK: %[[VAL_11:.*]] = wasmssa.gt_si %[[VAL_9]] %[[VAL_10]] : i32 -> i32
+// CHECK: wasmssa.branch_if %[[VAL_11]] to level 0 else ^bb1
+// CHECK: ^bb1:
+// CHECK: wasmssa.block_return %[[VAL_5]] : i32
+// CHECK: }> ^bb1
+// CHECK: ^bb1(%[[VAL_12:.*]]: i32):
+// CHECK: %[[VAL_13:.*]] = wasmssa.const 10 : i32
+// CHECK: %[[VAL_14:.*]] = wasmssa.lt_si %[[VAL_12]] %[[VAL_13]] : i32 -> i32
+// CHECK: wasmssa.branch_if %[[VAL_14]] to level 0 else ^bb2
+// CHECK: ^bb2:
+// CHECK: wasmssa.block_return
+// CHECK: }> ^bb1
+// CHECK: ^bb1:
+// CHECK: wasmssa.return
diff --git a/mlir/test/Target/Wasm/empty_blocks_list_and_stack.mlir b/mlir/test/Target/Wasm/empty_blocks_list_and_stack.mlir
new file mode 100644
index 0000000..5c98f1a
--- /dev/null
+++ b/mlir/test/Target/Wasm/empty_blocks_list_and_stack.mlir
@@ -0,0 +1,53 @@
+// RUN: yaml2obj %S/inputs/empty_blocks_list_and_stack.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module
+ (func (param $num i32)
+ (block $b1
+ (block $b2
+ (block $b3
+ )
+ )
+ )
+ )
+
+ (func (param $num i32)
+ (block $b1)
+ (block $b2)
+ (block $b3)
+ )
+)
+
+*/
+
+// CHECK-LABEL: wasmssa.func @func_0(
+// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) {
+// CHECK: wasmssa.block : {
+// CHECK: wasmssa.block : {
+// CHECK: wasmssa.block : {
+// CHECK: wasmssa.block_return
+// CHECK: }> ^bb1
+// CHECK: ^bb1:
+// CHECK: wasmssa.block_return
+// CHECK: }> ^bb1
+// CHECK: ^bb1:
+// CHECK: wasmssa.block_return
+// CHECK: }> ^bb1
+// CHECK: ^bb1:
+// CHECK: wasmssa.return
+
+// CHECK-LABEL: wasmssa.func @func_1(
+// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) {
+// CHECK: wasmssa.block : {
+// CHECK: wasmssa.block_return
+// CHECK: }> ^bb1
+// CHECK: ^bb1:
+// CHECK: wasmssa.block : {
+// CHECK: wasmssa.block_return
+// CHECK: }> ^bb2
+// CHECK: ^bb2:
+// CHECK: wasmssa.block : {
+// CHECK: wasmssa.block_return
+// CHECK: }> ^bb3
+// CHECK: ^bb3:
+// CHECK: wasmssa.return
diff --git a/mlir/test/Target/Wasm/eq.mlir b/mlir/test/Target/Wasm/eq.mlir
new file mode 100644
index 0000000..ba3ae2f
--- /dev/null
+++ b/mlir/test/Target/Wasm/eq.mlir
@@ -0,0 +1,56 @@
+// RUN: yaml2obj %S/inputs/eq.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+/* Source code used to create this test:
+(module
+ (func $eq_i32 (result i32)
+ i32.const 12
+ i32.const 50
+ i32.eq
+ )
+
+ (func $eq_i64 (result i32)
+ i64.const 20
+ i64.const 5
+ i64.eq
+ )
+
+ (func $eq_f32 (result i32)
+ f32.const 5
+ f32.const 14
+ f32.eq
+ )
+
+ (func $eq_f64 (result i32)
+ f64.const 17
+ f64.const 0
+ f64.eq
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @func_0() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.eq %[[VAL_0]] %[[VAL_1]] : i32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+// CHECK: }
+
+// CHECK-LABEL: wasmssa.func @func_1() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 20 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 5 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.eq %[[VAL_0]] %[[VAL_1]] : i64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+// CHECK: }
+
+// CHECK-LABEL: wasmssa.func @func_2() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f32
+// CHECK: %[[VAL_2:.*]] = wasmssa.eq %[[VAL_0]] %[[VAL_1]] : f32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+// CHECK: }
+
+// CHECK-LABEL: wasmssa.func @func_3() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.700000e+01 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 0.000000e+00 : f64
+// CHECK: %[[VAL_2:.*]] = wasmssa.eq %[[VAL_0]] %[[VAL_1]] : f64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+// CHECK: }
diff --git a/mlir/test/Target/Wasm/eqz.mlir b/mlir/test/Target/Wasm/eqz.mlir
new file mode 100644
index 0000000..55cf94a
--- /dev/null
+++ b/mlir/test/Target/Wasm/eqz.mlir
@@ -0,0 +1,21 @@
+// RUN: yaml2obj %S/inputs/eqz.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+/* Source code used to create this test:
+(module
+ (func (export "eqz_i32") (result i32)
+ i32.const 13
+ i32.eqz)
+
+ (func (export "eqz_i64") (result i32)
+ i64.const 13
+ i64.eqz)
+)
+*/
+// CHECK-LABEL: wasmssa.func exported @eqz_i32() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 13 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.eqz %[[VAL_0]] : i32 -> i32
+// CHECK: wasmssa.return %[[VAL_1]] : i32
+
+// CHECK-LABEL: wasmssa.func exported @eqz_i64() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 13 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.eqz %[[VAL_0]] : i64 -> i32
+// CHECK: wasmssa.return %[[VAL_1]] : i32
diff --git a/mlir/test/Target/Wasm/extend.mlir b/mlir/test/Target/Wasm/extend.mlir
new file mode 100644
index 0000000..5d4446a
--- /dev/null
+++ b/mlir/test/Target/Wasm/extend.mlir
@@ -0,0 +1,69 @@
+// RUN: yaml2obj %S/inputs/extend.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module
+ (func $i32_s (result i64)
+ i32.const 10
+ i64.extend_i32_s
+ )
+ (func $i32_u (result i64)
+ i32.const 10
+ i64.extend_i32_u
+ )
+ (func $extend8_32 (result i32)
+ i32.const 10
+ i32.extend8_s
+ )
+ (func $extend16_32 (result i32)
+ i32.const 10
+ i32.extend16_s
+ )
+ (func $extend8_64 (result i64)
+ i64.const 10
+ i64.extend8_s
+ )
+ (func $extend16_64 (result i64)
+ i64.const 10
+ i64.extend16_s
+ )
+ (func $extend32_64 (result i64)
+ i64.const 10
+ i64.extend32_s
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @func_0() -> i64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.extend_i32_s %[[VAL_0]] to i64
+// CHECK: wasmssa.return %[[VAL_1]] : i64
+
+// CHECK-LABEL: wasmssa.func @func_1() -> i64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.extend_i32_u %[[VAL_0]] to i64
+// CHECK: wasmssa.return %[[VAL_1]] : i64
+
+// CHECK-LABEL: wasmssa.func @func_2() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.extend 8 : ui32 low bits from %[[VAL_0]] : i32
+// CHECK: wasmssa.return %[[VAL_1]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_3() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.extend 16 : ui32 low bits from %[[VAL_0]] : i32
+// CHECK: wasmssa.return %[[VAL_1]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_4() -> i64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.extend 8 : ui32 low bits from %[[VAL_0]] : i64
+// CHECK: wasmssa.return %[[VAL_1]] : i64
+
+// CHECK-LABEL: wasmssa.func @func_5() -> i64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.extend 16 : ui32 low bits from %[[VAL_0]] : i64
+// CHECK: wasmssa.return %[[VAL_1]] : i64
+
+// CHECK-LABEL: wasmssa.func @func_6() -> i64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.extend 32 : ui32 low bits from %[[VAL_0]] : i64
+// CHECK: wasmssa.return %[[VAL_1]] : i64
diff --git a/mlir/test/Target/Wasm/global.mlir b/mlir/test/Target/Wasm/global.mlir
index e72fe69..1e4fe44 100644
--- a/mlir/test/Target/Wasm/global.mlir
+++ b/mlir/test/Target/Wasm/global.mlir
@@ -29,9 +29,9 @@ i32.add
)
*/
-// CHECK-LABEL: wasmssa.import_global "from_js" from "env" as @global_0 nested : i32
+// CHECK-LABEL: wasmssa.import_global "from_js" from "env" as @global_0 : i32
-// CHECK-LABEL: wasmssa.func nested @func_0() -> i32 {
+// CHECK-LABEL: wasmssa.func @func_0() -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.global_get @global_0 : i32
// CHECK: %[[VAL_1:.*]] = wasmssa.global_get @global_1 : i32
// CHECK: %[[VAL_2:.*]] = wasmssa.add %[[VAL_0]] %[[VAL_1]] : i32
@@ -41,26 +41,26 @@ i32.add
// CHECK: %[[VAL_6:.*]] = wasmssa.add %[[VAL_2]] %[[VAL_5]] : i32
// CHECK: wasmssa.return %[[VAL_6]] : i32
-// CHECK-LABEL: wasmssa.global @global_1 i32 nested : {
+// CHECK-LABEL: wasmssa.global @global_1 i32 : {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
// CHECK: wasmssa.return %[[VAL_0]] : i32
-// CHECK-LABEL: wasmssa.global @global_2 i32 mutable nested : {
+// CHECK-LABEL: wasmssa.global @global_2 i32 mutable : {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
// CHECK: wasmssa.return %[[VAL_0]] : i32
-// CHECK-LABEL: wasmssa.global @global_3 i32 mutable nested : {
+// CHECK-LABEL: wasmssa.global @global_3 i32 mutable : {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
// CHECK: wasmssa.return %[[VAL_0]] : i32
-// CHECK-LABEL: wasmssa.global @global_4 i64 nested : {
+// CHECK-LABEL: wasmssa.global @global_4 i64 : {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 11 : i64
// CHECK: wasmssa.return %[[VAL_0]] : i64
-// CHECK-LABEL: wasmssa.global @global_5 f32 nested : {
+// CHECK-LABEL: wasmssa.global @global_5 f32 : {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.200000e+01 : f32
// CHECK: wasmssa.return %[[VAL_0]] : f32
-// CHECK-LABEL: wasmssa.global @global_6 f64 nested : {
+// CHECK-LABEL: wasmssa.global @global_6 f64 : {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.300000e+01 : f64
// CHECK: wasmssa.return %[[VAL_0]] : f64
diff --git a/mlir/test/Target/Wasm/if.mlir b/mlir/test/Target/Wasm/if.mlir
new file mode 100644
index 0000000..2d7bfbe
--- /dev/null
+++ b/mlir/test/Target/Wasm/if.mlir
@@ -0,0 +1,112 @@
+// RUN: yaml2obj %S/inputs/if.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to create this test:
+(module
+(type $intMapper (func (param $input i32) (result i32)))
+(func $if_else (type $intMapper)
+ local.get 0
+ i32.const 1
+ i32.and
+ if $isOdd (result i32)
+ local.get 0
+ i32.const 3
+ i32.mul
+ i32.const 1
+ i32.add
+ else
+ local.get 0
+ i32.const 1
+ i32.shr_u
+ end
+)
+
+(func $if_only (type $intMapper)
+ local.get 0
+ local.get 0
+ i32.const 1
+ i32.and
+ if $isOdd (type $intMapper)
+ i32.const 1
+ i32.add
+ end
+)
+
+(func $if_if (type $intMapper)
+ local.get 0
+ i32.ctz
+ if $isEven (result i32)
+ i32.const 2
+ local.get 0
+ i32.const 1
+ i32.shr_u
+ i32.ctz
+ if $isMultipleOfFour (type $intMapper)
+ i32.const 2
+ i32.add
+ end
+ else
+ i32.const 1
+ end
+)
+)
+*/
+// CHECK-LABEL: wasmssa.func @func_0(
+// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.and %[[VAL_0]] %[[VAL_1]] : i32
+// CHECK: wasmssa.if %[[VAL_2]] : {
+// CHECK: %[[VAL_3:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32
+// CHECK: %[[VAL_4:.*]] = wasmssa.const 3 : i32
+// CHECK: %[[VAL_5:.*]] = wasmssa.mul %[[VAL_3]] %[[VAL_4]] : i32
+// CHECK: %[[VAL_6:.*]] = wasmssa.const 1 : i32
+// CHECK: %[[VAL_7:.*]] = wasmssa.add %[[VAL_5]] %[[VAL_6]] : i32
+// CHECK: wasmssa.block_return %[[VAL_7]] : i32
+// CHECK: } "else "{
+// CHECK: %[[VAL_8:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32
+// CHECK: %[[VAL_9:.*]] = wasmssa.const 1 : i32
+// CHECK: %[[VAL_10:.*]] = wasmssa.shr_u %[[VAL_8]] by %[[VAL_9]] bits : i32
+// CHECK: wasmssa.block_return %[[VAL_10]] : i32
+// CHECK: }> ^bb1
+// CHECK: ^bb1(%[[VAL_11:.*]]: i32):
+// CHECK: wasmssa.return %[[VAL_11]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_1(
+// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.const 1 : i32
+// CHECK: %[[VAL_3:.*]] = wasmssa.and %[[VAL_1]] %[[VAL_2]] : i32
+// CHECK: wasmssa.if %[[VAL_3]](%[[VAL_0]]) : i32 : {
+// CHECK: ^bb0(%[[VAL_4:.*]]: i32):
+// CHECK: %[[VAL_5:.*]] = wasmssa.const 1 : i32
+// CHECK: %[[VAL_6:.*]] = wasmssa.add %[[VAL_4]] %[[VAL_5]] : i32
+// CHECK: wasmssa.block_return %[[VAL_6]] : i32
+// CHECK: } > ^bb1
+// CHECK: ^bb1(%[[VAL_7:.*]]: i32):
+// CHECK: wasmssa.return %[[VAL_7]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_2(
+// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.ctz %[[VAL_0]] : i32
+// CHECK: wasmssa.if %[[VAL_1]] : {
+// CHECK: %[[VAL_2:.*]] = wasmssa.const 2 : i32
+// CHECK: %[[VAL_3:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32
+// CHECK: %[[VAL_4:.*]] = wasmssa.const 1 : i32
+// CHECK: %[[VAL_5:.*]] = wasmssa.shr_u %[[VAL_3]] by %[[VAL_4]] bits : i32
+// CHECK: %[[VAL_6:.*]] = wasmssa.ctz %[[VAL_5]] : i32
+// CHECK: wasmssa.if %[[VAL_6]](%[[VAL_2]]) : i32 : {
+// CHECK: ^bb0(%[[VAL_7:.*]]: i32):
+// CHECK: %[[VAL_8:.*]] = wasmssa.const 2 : i32
+// CHECK: %[[VAL_9:.*]] = wasmssa.add %[[VAL_7]] %[[VAL_8]] : i32
+// CHECK: wasmssa.block_return %[[VAL_9]] : i32
+// CHECK: } > ^bb1
+// CHECK: ^bb1(%[[VAL_10:.*]]: i32):
+// CHECK: wasmssa.block_return %[[VAL_10]] : i32
+// CHECK: } "else "{
+// CHECK: %[[VAL_11:.*]] = wasmssa.const 1 : i32
+// CHECK: wasmssa.block_return %[[VAL_11]] : i32
+// CHECK: }> ^bb1
+// CHECK: ^bb1(%[[VAL_12:.*]]: i32):
+// CHECK: wasmssa.return %[[VAL_12]] : i32
diff --git a/mlir/test/Target/Wasm/import.mlir b/mlir/test/Target/Wasm/import.mlir
index 541dcf3..dcdfa52 100644
--- a/mlir/test/Target/Wasm/import.mlir
+++ b/mlir/test/Target/Wasm/import.mlir
@@ -11,9 +11,9 @@
)
*/
-// CHECK-LABEL: wasmssa.import_func "foo" from "my_module" as @func_0 {sym_visibility = "nested", type = (i32) -> ()}
-// CHECK: wasmssa.import_func "bar" from "my_module" as @func_1 {sym_visibility = "nested", type = (i32) -> ()}
-// CHECK: wasmssa.import_table "table" from "my_module" as @table_0 {sym_visibility = "nested", type = !wasmssa<tabletype !wasmssa.funcref [2:]>}
-// CHECK: wasmssa.import_mem "mem" from "my_module" as @mem_0 {limits = !wasmssa<limit[2:]>, sym_visibility = "nested"}
-// CHECK: wasmssa.import_global "glob" from "my_module" as @global_0 nested : i32
-// CHECK: wasmssa.import_global "glob_mut" from "my_other_module" as @global_1 mutable nested : i32
+// CHECK-LABEL: wasmssa.import_func "foo" from "my_module" as @func_0 {type = (i32) -> ()}
+// CHECK: wasmssa.import_func "bar" from "my_module" as @func_1 {type = (i32) -> ()}
+// CHECK: wasmssa.import_table "table" from "my_module" as @table_0 {type = !wasmssa<tabletype !wasmssa.funcref [2:]>}
+// CHECK: wasmssa.import_mem "mem" from "my_module" as @mem_0 {limits = !wasmssa<limit[2:]>}
+// CHECK: wasmssa.import_global "glob" from "my_module" as @global_0 : i32
+// CHECK: wasmssa.import_global "glob_mut" from "my_other_module" as @global_1 mutable : i32
diff --git a/mlir/test/Target/Wasm/inputs/add_div.yaml.wasm b/mlir/test/Target/Wasm/inputs/add_div.yaml.wasm
new file mode 100644
index 0000000..865c315
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/add_div.yaml.wasm
@@ -0,0 +1,50 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes:
+ - I32
+ ReturnTypes:
+ - I32
+ - Index: 1
+ ParamTypes:
+ - I32
+ - I32
+ ReturnTypes:
+ - I32
+ - Type: IMPORT
+ Imports:
+ - Module: env
+ Field: twoTimes
+ Kind: FUNCTION
+ SigIndex: 0
+ - Type: FUNCTION
+ FunctionTypes: [ 1 ]
+ - Type: MEMORY
+ Memories:
+ - Minimum: 0x2
+ - Type: GLOBAL
+ Globals:
+ - Index: 0
+ Type: I32
+ Mutable: true
+ InitExpr:
+ Opcode: I32_CONST
+ Value: 66560
+ - Type: EXPORT
+ Exports:
+ - Name: memory
+ Kind: MEMORY
+ Index: 0
+ - Name: add
+ Kind: FUNCTION
+ Index: 1
+ - Type: CODE
+ Functions:
+ - Index: 1
+ Locals: []
+ Body: 20001000200110006A41026D0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/block.yaml.wasm b/mlir/test/Target/Wasm/inputs/block.yaml.wasm
new file mode 100644
index 0000000..dd5118a
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/block.yaml.wasm
@@ -0,0 +1,22 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes: []
+ - Type: FUNCTION
+ FunctionTypes: [ 0 ]
+ - Type: EXPORT
+ Exports:
+ - Name: i_am_a_block
+ Kind: FUNCTION
+ Index: 0
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 02400B0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/block_complete_type.yaml.wasm b/mlir/test/Target/Wasm/inputs/block_complete_type.yaml.wasm
new file mode 100644
index 0000000..7a125bf
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/block_complete_type.yaml.wasm
@@ -0,0 +1,23 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes:
+ - I32
+ ReturnTypes:
+ - I32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 1 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410E020041016A0B0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/block_value_type.yaml.wasm b/mlir/test/Target/Wasm/inputs/block_value_type.yaml.wasm
new file mode 100644
index 0000000..4ba291d
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/block_value_type.yaml.wasm
@@ -0,0 +1,18 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 027F41110B0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/branch_if.yaml.wasm b/mlir/test/Target/Wasm/inputs/branch_if.yaml.wasm
new file mode 100644
index 0000000..40536ed
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/branch_if.yaml.wasm
@@ -0,0 +1,18 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 027F410141020D0041016A0B0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/call.yaml.wasm b/mlir/test/Target/Wasm/inputs/call.yaml.wasm
new file mode 100644
index 0000000..535a623
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/call.yaml.wasm
@@ -0,0 +1,26 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 0 ]
+ - Type: EXPORT
+ Exports:
+ - Name: forty_two
+ Kind: FUNCTION
+ Index: 1
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 412A0B
+ - Index: 1
+ Locals: []
+ Body: 10000B
+...
diff --git a/mlir/test/Target/Wasm/inputs/comparison_ops.yaml.wasm b/mlir/test/Target/Wasm/inputs/comparison_ops.yaml.wasm
new file mode 100644
index 0000000..cde9ee1
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/comparison_ops.yaml.wasm
@@ -0,0 +1,88 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410C4132480B
+ - Index: 1
+ Locals: []
+ Body: 410C41324C0B
+ - Index: 2
+ Locals: []
+ Body: 410C4132490B
+ - Index: 3
+ Locals: []
+ Body: 410C41324D0B
+ - Index: 4
+ Locals: []
+ Body: 410C41324A0B
+ - Index: 5
+ Locals: []
+ Body: 410C41324B0B
+ - Index: 6
+ Locals: []
+ Body: 410C41324E0B
+ - Index: 7
+ Locals: []
+ Body: 410C41324F0B
+ - Index: 8
+ Locals: []
+ Body: 420C4232530B
+ - Index: 9
+ Locals: []
+ Body: 420C4232570B
+ - Index: 10
+ Locals: []
+ Body: 420C4232540B
+ - Index: 11
+ Locals: []
+ Body: 420C4232580B
+ - Index: 12
+ Locals: []
+ Body: 420C4232550B
+ - Index: 13
+ Locals: []
+ Body: 420C4232560B
+ - Index: 14
+ Locals: []
+ Body: 420C4232590B
+ - Index: 15
+ Locals: []
+ Body: 420C42325A0B
+ - Index: 16
+ Locals: []
+ Body: 430000A04043000060415D0B
+ - Index: 17
+ Locals: []
+ Body: 430000A04043000060415F0B
+ - Index: 18
+ Locals: []
+ Body: 430000A04043000060415E0B
+ - Index: 19
+ Locals: []
+ Body: 430000A0404300006041600B
+ - Index: 20
+ Locals: []
+ Body: 440000000000001440440000000000002C40630B
+ - Index: 21
+ Locals: []
+ Body: 440000000000001440440000000000002C40650B
+ - Index: 22
+ Locals: []
+ Body: 440000000000001440440000000000002C40640B
+ - Index: 23
+ Locals: []
+ Body: 440000000000001440440000000000002C40660B
+...
diff --git a/mlir/test/Target/Wasm/inputs/convert.yaml.wasm b/mlir/test/Target/Wasm/inputs/convert.yaml.wasm
new file mode 100644
index 0000000..c346a75
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/convert.yaml.wasm
@@ -0,0 +1,69 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - F32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - F64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 0, 0, 0, 1, 1, 1, 1 ]
+ - Type: EXPORT
+ Exports:
+ - Name: convert_i32_u_to_f32
+ Kind: FUNCTION
+ Index: 0
+ - Name: convert_i32_s_to_f32
+ Kind: FUNCTION
+ Index: 1
+ - Name: convert_i64_u_to_f32
+ Kind: FUNCTION
+ Index: 2
+ - Name: convert_i64s_to_f32
+ Kind: FUNCTION
+ Index: 3
+ - Name: convert_i32_u_to_f64
+ Kind: FUNCTION
+ Index: 4
+ - Name: convert_i32_s_to_f64
+ Kind: FUNCTION
+ Index: 5
+ - Name: convert_i64_u_to_f64
+ Kind: FUNCTION
+ Index: 6
+ - Name: convert_i64s_to_f64
+ Kind: FUNCTION
+ Index: 7
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410AB30B
+ - Index: 1
+ Locals: []
+ Body: 412AB20B
+ - Index: 2
+ Locals: []
+ Body: 4211B50B
+ - Index: 3
+ Locals: []
+ Body: 420AB40B
+ - Index: 4
+ Locals: []
+ Body: 410AB80B
+ - Index: 5
+ Locals: []
+ Body: 412AB70B
+ - Index: 6
+ Locals: []
+ Body: 4211BA0B
+ - Index: 7
+ Locals: []
+ Body: 420AB90B
+...
diff --git a/mlir/test/Target/Wasm/inputs/demote.yaml.wasm b/mlir/test/Target/Wasm/inputs/demote.yaml.wasm
new file mode 100644
index 0000000..3997045
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/demote.yaml.wasm
@@ -0,0 +1,18 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - F32
+ - Type: FUNCTION
+ FunctionTypes: [ 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 44EC51B81E85EB0140B60B
+...
diff --git a/mlir/test/Target/Wasm/inputs/double_nested_loop.yaml.wasm b/mlir/test/Target/Wasm/inputs/double_nested_loop.yaml.wasm
new file mode 100644
index 0000000..41a2944
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/double_nested_loop.yaml.wasm
@@ -0,0 +1,19 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes: []
+ - Type: FUNCTION
+ FunctionTypes: [ 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals:
+ - Type: I32
+ Count: 2
+ Body: 0340200041016A2100037F41012001410C6A220120004A0D000B410A480D000B0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/empty_blocks_list_and_stack.yaml.wasm b/mlir/test/Target/Wasm/inputs/empty_blocks_list_and_stack.yaml.wasm
new file mode 100644
index 0000000..3171409
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/empty_blocks_list_and_stack.yaml.wasm
@@ -0,0 +1,21 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes:
+ - I32
+ ReturnTypes: []
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 0240024002400B0B0B0B
+ - Index: 1
+ Locals: []
+ Body: 02400B02400B02400B0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/eq.yaml.wasm b/mlir/test/Target/Wasm/inputs/eq.yaml.wasm
new file mode 100644
index 0000000..1998369
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/eq.yaml.wasm
@@ -0,0 +1,27 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 0, 0, 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410C4132460B
+ - Index: 1
+ Locals: []
+ Body: 42144205510B
+ - Index: 2
+ Locals: []
+ Body: 430000A04043000060415B0B
+ - Index: 3
+ Locals: []
+ Body: 440000000000003140440000000000000000610B
+...
diff --git a/mlir/test/Target/Wasm/inputs/eqz.yaml.wasm b/mlir/test/Target/Wasm/inputs/eqz.yaml.wasm
new file mode 100644
index 0000000..894ac50
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/eqz.yaml.wasm
@@ -0,0 +1,29 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 0 ]
+ - Type: EXPORT
+ Exports:
+ - Name: eqz_i32
+ Kind: FUNCTION
+ Index: 0
+ - Name: eqz_i64
+ Kind: FUNCTION
+ Index: 1
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410D450B
+ - Index: 1
+ Locals: []
+ Body: 420D500B
+...
diff --git a/mlir/test/Target/Wasm/inputs/extend.yaml.wasm b/mlir/test/Target/Wasm/inputs/extend.yaml.wasm
new file mode 100644
index 0000000..7e872ba
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/extend.yaml.wasm
@@ -0,0 +1,40 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I64
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 0, 1, 1, 0, 0, 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410AAC0B
+ - Index: 1
+ Locals: []
+ Body: 410AAD0B
+ - Index: 2
+ Locals: []
+ Body: 410AC00B
+ - Index: 3
+ Locals: []
+ Body: 410AC10B
+ - Index: 4
+ Locals: []
+ Body: 420AC20B
+ - Index: 5
+ Locals: []
+ Body: 420AC30B
+ - Index: 6
+ Locals: []
+ Body: 420AC40B
+...
diff --git a/mlir/test/Target/Wasm/inputs/if.yaml.wasm b/mlir/test/Target/Wasm/inputs/if.yaml.wasm
new file mode 100644
index 0000000..ccc38f6
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/if.yaml.wasm
@@ -0,0 +1,25 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes:
+ - I32
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 0, 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 2000410171047F200041036C41016A0520004101760B0B
+ - Index: 1
+ Locals: []
+ Body: 20002000410171040041016A0B0B
+ - Index: 2
+ Locals: []
+ Body: 200068047F4102200041017668040041026A0B0541010B0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/loop.yaml.wasm b/mlir/test/Target/Wasm/inputs/loop.yaml.wasm
new file mode 100644
index 0000000..9d33894
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/loop.yaml.wasm
@@ -0,0 +1,17 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes: []
+ - Type: FUNCTION
+ FunctionTypes: [ 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 03400B0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/loop_with_inst.yaml.wasm b/mlir/test/Target/Wasm/inputs/loop_with_inst.yaml.wasm
new file mode 100644
index 0000000..4b8cc54
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/loop_with_inst.yaml.wasm
@@ -0,0 +1,20 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals:
+ - Type: I32
+ Count: 1
+ Body: 037F200041016A21002000410A480B0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/ne.yaml.wasm b/mlir/test/Target/Wasm/inputs/ne.yaml.wasm
new file mode 100644
index 0000000..0167519
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/ne.yaml.wasm
@@ -0,0 +1,27 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 0, 0, 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410C4132470B
+ - Index: 1
+ Locals: []
+ Body: 42144205520B
+ - Index: 2
+ Locals: []
+ Body: 430000A04043000060415C0B
+ - Index: 3
+ Locals: []
+ Body: 440000000000003140440000000000000000620B
+...
diff --git a/mlir/test/Target/Wasm/inputs/promote.yaml.wasm b/mlir/test/Target/Wasm/inputs/promote.yaml.wasm
new file mode 100644
index 0000000..d38603e
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/promote.yaml.wasm
@@ -0,0 +1,18 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - F64
+ - Type: FUNCTION
+ FunctionTypes: [ 0 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 4300002841BB0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/reinterpret.yaml.wasm b/mlir/test/Target/Wasm/inputs/reinterpret.yaml.wasm
new file mode 100644
index 0000000..c01c1b1
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/reinterpret.yaml.wasm
@@ -0,0 +1,53 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - I64
+ - Index: 2
+ ParamTypes: []
+ ReturnTypes:
+ - F32
+ - Index: 3
+ ParamTypes: []
+ ReturnTypes:
+ - F64
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1, 2, 3 ]
+ - Type: EXPORT
+ Exports:
+ - Name: i32.reinterpret_f32
+ Kind: FUNCTION
+ Index: 0
+ - Name: i64.reinterpret_f64
+ Kind: FUNCTION
+ Index: 1
+ - Name: f32.reinterpret_i32
+ Kind: FUNCTION
+ Index: 2
+ - Name: f64.reinterpret_i64
+ Kind: FUNCTION
+ Index: 3
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 43000080BFBC0B
+ - Index: 1
+ Locals: []
+ Body: 44000000000000F0BFBD0B
+ - Index: 2
+ Locals: []
+ Body: 417FBE0B
+ - Index: 3
+ Locals: []
+ Body: 427FBF0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/rounding.yaml.wasm b/mlir/test/Target/Wasm/inputs/rounding.yaml.wasm
new file mode 100644
index 0000000..c6e8bf6
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/rounding.yaml.wasm
@@ -0,0 +1,37 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes: []
+ ReturnTypes:
+ - F64
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - F32
+ - Type: FUNCTION
+ FunctionTypes: [ 0, 1, 0, 1, 0, 1 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 4433333333333328C09B0B
+ - Index: 1
+ Locals: []
+ Body: 43A01ACF3F8D0B
+ - Index: 2
+ Locals: []
+ Body: 4433333333333328C09C0B
+ - Index: 3
+ Locals: []
+ Body: 43A01ACF3F8E0B
+ - Index: 4
+ Locals: []
+ Body: 4433333333333328C09D0B
+ - Index: 5
+ Locals: []
+ Body: 43A01ACF3F8F0B
+...
diff --git a/mlir/test/Target/Wasm/inputs/wrap.yaml.wasm b/mlir/test/Target/Wasm/inputs/wrap.yaml.wasm
new file mode 100644
index 0000000..51c0b02
--- /dev/null
+++ b/mlir/test/Target/Wasm/inputs/wrap.yaml.wasm
@@ -0,0 +1,24 @@
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes:
+ - I64
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 0 ]
+ - Type: EXPORT
+ Exports:
+ - Name: i64_wrap
+ Kind: FUNCTION
+ Index: 0
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 2000A70B
+...
diff --git a/mlir/test/Target/Wasm/invalid_block_type_index.yaml b/mlir/test/Target/Wasm/invalid_block_type_index.yaml
new file mode 100644
index 0000000..5b83e2e
--- /dev/null
+++ b/mlir/test/Target/Wasm/invalid_block_type_index.yaml
@@ -0,0 +1,28 @@
+
+# RUN: yaml2obj %s | not mlir-translate --import-wasm -o - 2>&1 | FileCheck %s
+
+# CHECK: type index references nonexistent type (2)
+
+--- !WASM
+FileHeader:
+ Version: 0x1
+Sections:
+ - Type: TYPE
+ Signatures:
+ - Index: 0
+ ParamTypes:
+ - I32
+ ReturnTypes:
+ - I32
+ - Index: 1
+ ParamTypes: []
+ ReturnTypes:
+ - I32
+ - Type: FUNCTION
+ FunctionTypes: [ 1 ]
+ - Type: CODE
+ Functions:
+ - Index: 0
+ Locals: []
+ Body: 410E020241016A0B0B
+# -----------------------------^^ Invalid type ID
diff --git a/mlir/test/Target/Wasm/local.mlir b/mlir/test/Target/Wasm/local.mlir
index 32f5900..9844f9c 100644
--- a/mlir/test/Target/Wasm/local.mlir
+++ b/mlir/test/Target/Wasm/local.mlir
@@ -29,7 +29,7 @@
)
*/
-// CHECK-LABEL: wasmssa.func nested @func_0() -> f32 {
+// CHECK-LABEL: wasmssa.func @func_0() -> f32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.local of type f32
// CHECK: %[[VAL_1:.*]] = wasmssa.local of type f32
// CHECK: %[[VAL_2:.*]] = wasmssa.const 8.000000e+00 : f32
@@ -40,7 +40,7 @@
// CHECK: %[[VAL_6:.*]] = wasmssa.add %[[VAL_3]] %[[VAL_5]] : f32
// CHECK: wasmssa.return %[[VAL_6]] : f32
-// CHECK-LABEL: wasmssa.func nested @func_1() -> i32 {
+// CHECK-LABEL: wasmssa.func @func_1() -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.local of type i32
// CHECK: %[[VAL_1:.*]] = wasmssa.local of type i32
// CHECK: %[[VAL_2:.*]] = wasmssa.const 8 : i32
@@ -51,7 +51,7 @@
// CHECK: %[[VAL_6:.*]] = wasmssa.add %[[VAL_3]] %[[VAL_5]] : i32
// CHECK: wasmssa.return %[[VAL_6]] : i32
-// CHECK-LABEL: wasmssa.func nested @func_2(
+// CHECK-LABEL: wasmssa.func @func_2(
// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 3 : i32
// CHECK: wasmssa.local_set %[[ARG0]] : ref to i32 to %[[VAL_0]] : i32
diff --git a/mlir/test/Target/Wasm/loop.mlir b/mlir/test/Target/Wasm/loop.mlir
new file mode 100644
index 0000000..29ad502
--- /dev/null
+++ b/mlir/test/Target/Wasm/loop.mlir
@@ -0,0 +1,17 @@
+// RUN: yaml2obj %S/inputs/loop.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* IR generated from:
+(module
+ (func
+ (loop $my_loop
+ )
+ )
+)*/
+
+// CHECK-LABEL: wasmssa.func @func_0() {
+// CHECK: wasmssa.loop : {
+// CHECK: wasmssa.block_return
+// CHECK: }> ^bb1
+// CHECK: ^bb1:
+// CHECK: wasmssa.return
+// CHECK: }
diff --git a/mlir/test/Target/Wasm/loop_with_inst.mlir b/mlir/test/Target/Wasm/loop_with_inst.mlir
new file mode 100644
index 0000000..311d007
--- /dev/null
+++ b/mlir/test/Target/Wasm/loop_with_inst.mlir
@@ -0,0 +1,33 @@
+// RUN: yaml2obj %S/inputs/loop_with_inst.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Code used to create this test:
+
+(module
+ (func (result i32)
+ (local $i i32)
+ (loop $my_loop (result i32)
+ local.get $i
+ i32.const 1
+ i32.add
+ local.set $i
+ local.get $i
+ i32.const 10
+ i32.lt_s
+ )
+ )
+)*/
+
+// CHECK-LABEL: wasmssa.func @func_0() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.local of type i32
+// CHECK: wasmssa.loop : {
+// CHECK: %[[VAL_1:.*]] = wasmssa.local_get %[[VAL_0]] : ref to i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.const 1 : i32
+// CHECK: %[[VAL_3:.*]] = wasmssa.add %[[VAL_1]] %[[VAL_2]] : i32
+// CHECK: wasmssa.local_set %[[VAL_0]] : ref to i32 to %[[VAL_3]] : i32
+// CHECK: %[[VAL_4:.*]] = wasmssa.local_get %[[VAL_0]] : ref to i32
+// CHECK: %[[VAL_5:.*]] = wasmssa.const 10 : i32
+// CHECK: %[[VAL_6:.*]] = wasmssa.lt_si %[[VAL_4]] %[[VAL_5]] : i32 -> i32
+// CHECK: wasmssa.block_return %[[VAL_6]] : i32
+// CHECK: }> ^bb1
+// CHECK: ^bb1(%[[VAL_7:.*]]: i32):
+// CHECK: wasmssa.return %[[VAL_7]] : i32
diff --git a/mlir/test/Target/Wasm/max.mlir b/mlir/test/Target/Wasm/max.mlir
index 4ef2042..9160bde 100644
--- a/mlir/test/Target/Wasm/max.mlir
+++ b/mlir/test/Target/Wasm/max.mlir
@@ -16,14 +16,14 @@
)
*/
-// CHECK-LABEL: wasmssa.func @min_f32() -> f32 {
+// CHECK-LABEL: wasmssa.func exported @min_f32() -> f32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32
// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.000000e+00 : f32
// CHECK: %[[VAL_2:.*]] = wasmssa.max %[[VAL_0]] %[[VAL_1]] : f32
// CHECK: wasmssa.return %[[VAL_2]] : f32
-// CHECK-LABEL: wasmssa.func @min_f64() -> f64 {
+// CHECK-LABEL: wasmssa.func exported @min_f64() -> f64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64
// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.000000e+00 : f64
// CHECK: %[[VAL_2:.*]] = wasmssa.max %[[VAL_0]] %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/memory_min_eq_max.mlir b/mlir/test/Target/Wasm/memory_min_eq_max.mlir
index 2ba5ab5..ea8f719 100644
--- a/mlir/test/Target/Wasm/memory_min_eq_max.mlir
+++ b/mlir/test/Target/Wasm/memory_min_eq_max.mlir
@@ -4,4 +4,4 @@
(module (memory 0 0))
*/
-// CHECK-LABEL: wasmssa.memory @mem_0 nested !wasmssa<limit[0: 0]>
+// CHECK-LABEL: wasmssa.memory @mem_0 !wasmssa<limit[0: 0]>
diff --git a/mlir/test/Target/Wasm/memory_min_max.mlir b/mlir/test/Target/Wasm/memory_min_max.mlir
index ebf6418..88782ec 100644
--- a/mlir/test/Target/Wasm/memory_min_max.mlir
+++ b/mlir/test/Target/Wasm/memory_min_max.mlir
@@ -4,4 +4,4 @@
(module (memory 0 65536))
*/
-// CHECK-LABEL: wasmssa.memory @mem_0 nested !wasmssa<limit[0: 65536]>
+// CHECK-LABEL: wasmssa.memory @mem_0 !wasmssa<limit[0: 65536]>
diff --git a/mlir/test/Target/Wasm/memory_min_no_max.mlir b/mlir/test/Target/Wasm/memory_min_no_max.mlir
index 8d88786..c10c5cc 100644
--- a/mlir/test/Target/Wasm/memory_min_no_max.mlir
+++ b/mlir/test/Target/Wasm/memory_min_no_max.mlir
@@ -4,4 +4,4 @@
(module (memory 1))
*/
-// CHECK-LABEL: wasmssa.memory @mem_0 nested !wasmssa<limit[1:]>
+// CHECK-LABEL: wasmssa.memory @mem_0 !wasmssa<limit[1:]>
diff --git a/mlir/test/Target/Wasm/min.mlir b/mlir/test/Target/Wasm/min.mlir
index 1058c7d..2372bcc 100644
--- a/mlir/test/Target/Wasm/min.mlir
+++ b/mlir/test/Target/Wasm/min.mlir
@@ -16,13 +16,13 @@
)
*/
-// CHECK-LABEL: wasmssa.func @min_f32() -> f32 {
+// CHECK-LABEL: wasmssa.func exported @min_f32() -> f32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32
// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.000000e+00 : f32
// CHECK: %[[VAL_2:.*]] = wasmssa.min %[[VAL_0]] %[[VAL_1]] : f32
// CHECK: wasmssa.return %[[VAL_2]] : f32
-// CHECK-LABEL: wasmssa.func @min_f64() -> f64 {
+// CHECK-LABEL: wasmssa.func exported @min_f64() -> f64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64
// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.000000e+00 : f64
// CHECK: %[[VAL_2:.*]] = wasmssa.min %[[VAL_0]] %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/ne.mlir b/mlir/test/Target/Wasm/ne.mlir
new file mode 100644
index 0000000..331df75
--- /dev/null
+++ b/mlir/test/Target/Wasm/ne.mlir
@@ -0,0 +1,52 @@
+// RUN: yaml2obj %S/inputs/ne.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+/* Source code used to create this test:
+(module
+ (func $ne_i32 (result i32)
+ i32.const 12
+ i32.const 50
+ i32.ne
+ )
+
+ (func $ne_i64 (result i32)
+ i64.const 20
+ i64.const 5
+ i64.ne
+ )
+
+ (func $ne_f32 (result i32)
+ f32.const 5
+ f32.const 14
+ f32.ne
+ )
+
+ (func $ne_f64 (result i32)
+ f64.const 17
+ f64.const 0
+ f64.ne
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func @func_0() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32
+// CHECK: %[[VAL_2:.*]] = wasmssa.ne %[[VAL_0]] %[[VAL_1]] : i32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_1() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 20 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 5 : i64
+// CHECK: %[[VAL_2:.*]] = wasmssa.ne %[[VAL_0]] %[[VAL_1]] : i64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_2() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f32
+// CHECK: %[[VAL_2:.*]] = wasmssa.ne %[[VAL_0]] %[[VAL_1]] : f32 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
+
+// CHECK-LABEL: wasmssa.func @func_3() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.700000e+01 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.const 0.000000e+00 : f64
+// CHECK: %[[VAL_2:.*]] = wasmssa.ne %[[VAL_0]] %[[VAL_1]] : f64 -> i32
+// CHECK: wasmssa.return %[[VAL_2]] : i32
diff --git a/mlir/test/Target/Wasm/neg.mlir b/mlir/test/Target/Wasm/neg.mlir
index 5811ab50..dae8ee5 100644
--- a/mlir/test/Target/Wasm/neg.mlir
+++ b/mlir/test/Target/Wasm/neg.mlir
@@ -12,12 +12,12 @@
)
*/
-// CHECK-LABEL: wasmssa.func @neg_f32() -> f32 {
+// CHECK-LABEL: wasmssa.func exported @neg_f32() -> f32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32
// CHECK: %[[VAL_1:.*]] = wasmssa.neg %[[VAL_0]] : f32
// CHECK: wasmssa.return %[[VAL_1]] : f32
-// CHECK-LABEL: wasmssa.func @neg_f64() -> f64 {
+// CHECK-LABEL: wasmssa.func exported @neg_f64() -> f64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64
// CHECK: %[[VAL_1:.*]] = wasmssa.neg %[[VAL_0]] : f64
// CHECK: wasmssa.return %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/or.mlir b/mlir/test/Target/Wasm/or.mlir
index 521f2ba..be0b3d7 100644
--- a/mlir/test/Target/Wasm/or.mlir
+++ b/mlir/test/Target/Wasm/or.mlir
@@ -14,13 +14,13 @@
)
*/
-// CHECK-LABEL: wasmssa.func @or_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @or_i32() -> i32 {
// CHECK: %0 = wasmssa.const 10 : i32
// CHECK: %1 = wasmssa.const 3 : i32
// CHECK: %2 = wasmssa.or %0 %1 : i32
// CHECK: wasmssa.return %2 : i32
-// CHECK-LABEL: wasmssa.func @or_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @or_i64() -> i64 {
// CHECK: %0 = wasmssa.const 10 : i64
// CHECK: %1 = wasmssa.const 3 : i64
// CHECK: %2 = wasmssa.or %0 %1 : i64
diff --git a/mlir/test/Target/Wasm/popcnt.mlir b/mlir/test/Target/Wasm/popcnt.mlir
index 235333a..bfaa8eb 100644
--- a/mlir/test/Target/Wasm/popcnt.mlir
+++ b/mlir/test/Target/Wasm/popcnt.mlir
@@ -14,12 +14,12 @@
)
*/
-// CHECK-LABEL: wasmssa.func @popcnt_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @popcnt_i32() -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32
// CHECK: %[[VAL_1:.*]] = wasmssa.popcnt %[[VAL_0]] : i32
// CHECK: wasmssa.return %[[VAL_1]] : i32
-// CHECK-LABEL: wasmssa.func @popcnt_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @popcnt_i64() -> i64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64
// CHECK: %[[VAL_1:.*]] = wasmssa.popcnt %[[VAL_0]] : i64
// CHECK: wasmssa.return %[[VAL_1]] : i64
diff --git a/mlir/test/Target/Wasm/promote.mlir b/mlir/test/Target/Wasm/promote.mlir
new file mode 100644
index 0000000..44c31b6
--- /dev/null
+++ b/mlir/test/Target/Wasm/promote.mlir
@@ -0,0 +1,14 @@
+// RUN: yaml2obj %S/inputs/promote.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/* Source code used to generate this test:
+(module
+ (func $main (result f64)
+ f32.const 10.5
+ f64.promote_f32
+ )
+)*/
+
+// CHECK-LABEL: wasmssa.func @func_0() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.050000e+01 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.promote %[[VAL_0]] : f32 to f64
+// CHECK: wasmssa.return %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/reinterpret.mlir b/mlir/test/Target/Wasm/reinterpret.mlir
new file mode 100644
index 0000000..574d13f
--- /dev/null
+++ b/mlir/test/Target/Wasm/reinterpret.mlir
@@ -0,0 +1,46 @@
+// RUN: yaml2obj %S/inputs/reinterpret.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+
+/*
+Test generated from:
+(module
+ (func (export "i32.reinterpret_f32") (result i32)
+ f32.const -1
+ i32.reinterpret_f32
+ )
+
+ (func (export "i64.reinterpret_f64") (result i64)
+ f64.const -1
+ i64.reinterpret_f64
+ )
+
+ (func (export "f32.reinterpret_i32") (result f32)
+ i32.const -1
+ f32.reinterpret_i32
+ )
+
+ (func (export "f64.reinterpret_i64") (result f64)
+ i64.const -1
+ f64.reinterpret_i64
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func exported @i32.reinterpret_f32() -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const -1.000000e+00 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.reinterpret %[[VAL_0]] : f32 as i32
+// CHECK: wasmssa.return %[[VAL_1]] : i32
+
+// CHECK-LABEL: wasmssa.func exported @i64.reinterpret_f64() -> i64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const -1.000000e+00 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.reinterpret %[[VAL_0]] : f64 as i64
+// CHECK: wasmssa.return %[[VAL_1]] : i64
+
+// CHECK-LABEL: wasmssa.func exported @f32.reinterpret_i32() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const -1 : i32
+// CHECK: %[[VAL_1:.*]] = wasmssa.reinterpret %[[VAL_0]] : i32 as f32
+// CHECK: wasmssa.return %[[VAL_1]] : f32
+
+// CHECK-LABEL: wasmssa.func exported @f64.reinterpret_i64() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const -1 : i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.reinterpret %[[VAL_0]] : i64 as f64
+// CHECK: wasmssa.return %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/rem.mlir b/mlir/test/Target/Wasm/rem.mlir
index b19b8d9..16c9c78 100644
--- a/mlir/test/Target/Wasm/rem.mlir
+++ b/mlir/test/Target/Wasm/rem.mlir
@@ -24,28 +24,28 @@
)
*/
-// CHECK-LABEL: wasmssa.func @rem_u_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @rem_u_i32() -> i32 {
// CHECK: %0 = wasmssa.const 10 : i32
// CHECK: %1 = wasmssa.const 3 : i32
// CHECK: %2 = wasmssa.rem_ui %0 %1 : i32
// CHECK: wasmssa.return %2 : i32
// CHECK: }
-// CHECK-LABEL: wasmssa.func @rem_u_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @rem_u_i64() -> i64 {
// CHECK: %0 = wasmssa.const 10 : i64
// CHECK: %1 = wasmssa.const 3 : i64
// CHECK: %2 = wasmssa.rem_ui %0 %1 : i64
// CHECK: wasmssa.return %2 : i64
// CHECK: }
-// CHECK-LABEL: wasmssa.func @rem_s_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @rem_s_i32() -> i32 {
// CHECK: %0 = wasmssa.const 10 : i32
// CHECK: %1 = wasmssa.const 3 : i32
// CHECK: %2 = wasmssa.rem_si %0 %1 : i32
// CHECK: wasmssa.return %2 : i32
// CHECK: }
-// CHECK-LABEL: wasmssa.func @rem_s_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @rem_s_i64() -> i64 {
// CHECK: %0 = wasmssa.const 10 : i64
// CHECK: %1 = wasmssa.const 3 : i64
// CHECK: %2 = wasmssa.rem_si %0 %1 : i64
diff --git a/mlir/test/Target/Wasm/rotl.mlir b/mlir/test/Target/Wasm/rotl.mlir
index ec573554..4c2e5af 100644
--- a/mlir/test/Target/Wasm/rotl.mlir
+++ b/mlir/test/Target/Wasm/rotl.mlir
@@ -14,13 +14,13 @@
)
*/
-// CHECK-LABEL: wasmssa.func @rotl_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @rotl_i32() -> i32 {
// CHECK: %0 = wasmssa.const 10 : i32
// CHECK: %1 = wasmssa.const 3 : i32
// CHECK: %2 = wasmssa.rotl %0 by %1 bits : i32
// CHECK: wasmssa.return %2 : i32
-// CHECK-LABEL: wasmssa.func @rotl_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @rotl_i64() -> i64 {
// CHECK: %0 = wasmssa.const 10 : i64
// CHECK: %1 = wasmssa.const 3 : i64
// CHECK: %2 = wasmssa.rotl %0 by %1 bits : i64
diff --git a/mlir/test/Target/Wasm/rotr.mlir b/mlir/test/Target/Wasm/rotr.mlir
index 5618b43..ec403d0 100644
--- a/mlir/test/Target/Wasm/rotr.mlir
+++ b/mlir/test/Target/Wasm/rotr.mlir
@@ -14,13 +14,13 @@
)
*/
-// CHECK-LABEL: wasmssa.func @rotr_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @rotr_i32() -> i32 {
// CHECK: %0 = wasmssa.const 10 : i32
// CHECK: %1 = wasmssa.const 3 : i32
// CHECK: %2 = wasmssa.rotr %0 by %1 bits : i32
// CHECK: wasmssa.return %2 : i32
-// CHECK-LABEL: wasmssa.func @rotr_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @rotr_i64() -> i64 {
// CHECK: %0 = wasmssa.const 10 : i64
// CHECK: %1 = wasmssa.const 3 : i64
// CHECK: %2 = wasmssa.rotr %0 by %1 bits : i64
diff --git a/mlir/test/Target/Wasm/rounding.mlir b/mlir/test/Target/Wasm/rounding.mlir
new file mode 100644
index 0000000..947637e
--- /dev/null
+++ b/mlir/test/Target/Wasm/rounding.mlir
@@ -0,0 +1,50 @@
+// RUN: yaml2obj %S/inputs/rounding.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+/* Source code used to create this test:
+(module
+ (func $ceil_f64 (result f64)
+ f64.const -12.1
+ f64.ceil
+ )
+ (func $ceil_f32 (result f32)
+ f32.const 1.618
+ f32.ceil
+ )
+ (func $floor_f64 (result f64)
+ f64.const -12.1
+ f64.floor
+ )
+ (func $floor_f32 (result f32)
+ f32.const 1.618
+ f32.floor
+ )
+*/
+
+// CHECK-LABEL: wasmssa.func @func_0() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const -1.210000e+01 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.ceil %[[VAL_0]] : f64
+// CHECK: wasmssa.return %[[VAL_1]] : f64
+
+// CHECK-LABEL: wasmssa.func @func_1() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.618000e+00 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.ceil %[[VAL_0]] : f32
+// CHECK: wasmssa.return %[[VAL_1]] : f32
+
+// CHECK-LABEL: wasmssa.func @func_2() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const -1.210000e+01 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.floor %[[VAL_0]] : f64
+// CHECK: wasmssa.return %[[VAL_1]] : f64
+
+// CHECK-LABEL: wasmssa.func @func_3() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.618000e+00 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.floor %[[VAL_0]] : f32
+// CHECK: wasmssa.return %[[VAL_1]] : f32
+
+// CHECK-LABEL: wasmssa.func @func_4() -> f64 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const -1.210000e+01 : f64
+// CHECK: %[[VAL_1:.*]] = wasmssa.trunc %[[VAL_0]] : f64
+// CHECK: wasmssa.return %[[VAL_1]] : f64
+
+// CHECK-LABEL: wasmssa.func @func_5() -> f32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.618000e+00 : f32
+// CHECK: %[[VAL_1:.*]] = wasmssa.trunc %[[VAL_0]] : f32
+// CHECK: wasmssa.return %[[VAL_1]] : f32
diff --git a/mlir/test/Target/Wasm/shl.mlir b/mlir/test/Target/Wasm/shl.mlir
index f2bdd57..1363112 100644
--- a/mlir/test/Target/Wasm/shl.mlir
+++ b/mlir/test/Target/Wasm/shl.mlir
@@ -14,13 +14,13 @@
)
*/
-// CHECK-LABEL: wasmssa.func @shl_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @shl_i32() -> i32 {
// CHECK: %0 = wasmssa.const 10 : i32
// CHECK: %1 = wasmssa.const 3 : i32
// CHECK: %2 = wasmssa.shl %0 by %1 bits : i32
// CHECK: wasmssa.return %2 : i32
-// CHECK-LABEL: wasmssa.func @shl_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @shl_i64() -> i64 {
// CHECK: %0 = wasmssa.const 10 : i64
// CHECK: %1 = wasmssa.const 3 : i64
// CHECK: %2 = wasmssa.shl %0 by %1 bits : i64
diff --git a/mlir/test/Target/Wasm/shr_s.mlir b/mlir/test/Target/Wasm/shr_s.mlir
index 247d9be..da1a38f 100644
--- a/mlir/test/Target/Wasm/shr_s.mlir
+++ b/mlir/test/Target/Wasm/shr_s.mlir
@@ -14,13 +14,13 @@
)
*/
-// CHECK-LABEL: wasmssa.func @shr_s_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @shr_s_i32() -> i32 {
// CHECK: %0 = wasmssa.const 10 : i32
// CHECK: %1 = wasmssa.const 3 : i32
// CHECK: %2 = wasmssa.shr_s %0 by %1 bits : i32
// CHECK: wasmssa.return %2 : i32
-// CHECK-LABEL: wasmssa.func @shr_s_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @shr_s_i64() -> i64 {
// CHECK: %0 = wasmssa.const 10 : i64
// CHECK: %1 = wasmssa.const 3 : i64
// CHECK: %2 = wasmssa.shr_s %0 by %1 bits : i64
diff --git a/mlir/test/Target/Wasm/shr_u.mlir b/mlir/test/Target/Wasm/shr_u.mlir
index 9a79eed..2991c2a 100644
--- a/mlir/test/Target/Wasm/shr_u.mlir
+++ b/mlir/test/Target/Wasm/shr_u.mlir
@@ -14,13 +14,13 @@
)
*/
-// CHECK-LABEL: wasmssa.func @shr_u_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @shr_u_i32() -> i32 {
// CHECK: %0 = wasmssa.const 10 : i32
// CHECK: %1 = wasmssa.const 3 : i32
// CHECK: %2 = wasmssa.shr_u %0 by %1 bits : i32
// CHECK: wasmssa.return %2 : i32
-// CHECK-LABEL: wasmssa.func @shr_u_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @shr_u_i64() -> i64 {
// CHECK: %0 = wasmssa.const 10 : i64
// CHECK: %1 = wasmssa.const 3 : i64
// CHECK: %2 = wasmssa.shr_u %0 by %1 bits : i64
diff --git a/mlir/test/Target/Wasm/sqrt.mlir b/mlir/test/Target/Wasm/sqrt.mlir
index 77444ad..6b968d6 100644
--- a/mlir/test/Target/Wasm/sqrt.mlir
+++ b/mlir/test/Target/Wasm/sqrt.mlir
@@ -12,12 +12,12 @@
)
*/
-// CHECK-LABEL: wasmssa.func @sqrt_f32() -> f32 {
+// CHECK-LABEL: wasmssa.func exported @sqrt_f32() -> f32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32
// CHECK: %[[VAL_1:.*]] = wasmssa.sqrt %[[VAL_0]] : f32
// CHECK: wasmssa.return %[[VAL_1]] : f32
-// CHECK-LABEL: wasmssa.func @sqrt_f64() -> f64 {
+// CHECK-LABEL: wasmssa.func exported @sqrt_f64() -> f64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64
// CHECK: %[[VAL_1:.*]] = wasmssa.sqrt %[[VAL_0]] : f64
// CHECK: wasmssa.return %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/sub.mlir b/mlir/test/Target/Wasm/sub.mlir
index b9c6caf..5b242f4 100644
--- a/mlir/test/Target/Wasm/sub.mlir
+++ b/mlir/test/Target/Wasm/sub.mlir
@@ -27,25 +27,25 @@
)
*/
-// CHECK-LABEL: wasmssa.func nested @func_0() -> i32 {
+// CHECK-LABEL: wasmssa.func @func_0() -> i32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32
// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32
// CHECK: %[[VAL_2:.*]] = wasmssa.sub %[[VAL_0]] %[[VAL_1]] : i32
// CHECK: wasmssa.return %[[VAL_2]] : i32
-// CHECK-LABEL: wasmssa.func nested @func_1() -> i64 {
+// CHECK-LABEL: wasmssa.func @func_1() -> i64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 20 : i64
// CHECK: %[[VAL_1:.*]] = wasmssa.const 5 : i64
// CHECK: %[[VAL_2:.*]] = wasmssa.sub %[[VAL_0]] %[[VAL_1]] : i64
// CHECK: wasmssa.return %[[VAL_2]] : i64
-// CHECK-LABEL: wasmssa.func nested @func_2() -> f32 {
+// CHECK-LABEL: wasmssa.func @func_2() -> f32 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f32
// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f32
// CHECK: %[[VAL_2:.*]] = wasmssa.sub %[[VAL_0]] %[[VAL_1]] : f32
// CHECK: wasmssa.return %[[VAL_2]] : f32
-// CHECK-LABEL: wasmssa.func nested @func_3() -> f64 {
+// CHECK-LABEL: wasmssa.func @func_3() -> f64 {
// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.700000e+01 : f64
// CHECK: %[[VAL_1:.*]] = wasmssa.const 0.000000e+00 : f64
// CHECK: %[[VAL_2:.*]] = wasmssa.sub %[[VAL_0]] %[[VAL_1]] : f64
diff --git a/mlir/test/Target/Wasm/wrap.mlir b/mlir/test/Target/Wasm/wrap.mlir
new file mode 100644
index 0000000..1266758
--- /dev/null
+++ b/mlir/test/Target/Wasm/wrap.mlir
@@ -0,0 +1,15 @@
+// RUN: yaml2obj %S/inputs/wrap.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s
+/* Source code used to create this test:
+(module
+ (func (export "i64_wrap") (param $in i64) (result i32)
+ local.get $in
+ i32.wrap_i64
+ )
+)
+*/
+
+// CHECK-LABEL: wasmssa.func exported @i64_wrap(
+// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i64>) -> i32 {
+// CHECK: %[[VAL_0:.*]] = wasmssa.local_get %[[ARG0]] : ref to i64
+// CHECK: %[[VAL_1:.*]] = wasmssa.wrap %[[VAL_0]] : i64 to i32
+// CHECK: wasmssa.return %[[VAL_1]] : i32
diff --git a/mlir/test/Target/Wasm/xor.mlir b/mlir/test/Target/Wasm/xor.mlir
index 94691de..56407db 100644
--- a/mlir/test/Target/Wasm/xor.mlir
+++ b/mlir/test/Target/Wasm/xor.mlir
@@ -14,13 +14,13 @@
)
*/
-// CHECK-LABEL: wasmssa.func @xor_i32() -> i32 {
+// CHECK-LABEL: wasmssa.func exported @xor_i32() -> i32 {
// CHECK: %0 = wasmssa.const 10 : i32
// CHECK: %1 = wasmssa.const 3 : i32
// CHECK: %2 = wasmssa.xor %0 %1 : i32
// CHECK: wasmssa.return %2 : i32
-// CHECK-LABEL: wasmssa.func @xor_i64() -> i64 {
+// CHECK-LABEL: wasmssa.func exported @xor_i64() -> i64 {
// CHECK: %0 = wasmssa.const 10 : i64
// CHECK: %1 = wasmssa.const 3 : i64
// CHECK: %2 = wasmssa.xor %0 %1 : i64
diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt
index 9187998..c37671a 100644
--- a/mlir/test/lib/Analysis/CMakeLists.txt
+++ b/mlir/test/lib/Analysis/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_library(MLIRTestAnalysis
DataFlow/TestDenseForwardDataFlowAnalysis.cpp
DataFlow/TestLivenessAnalysis.cpp
DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
+ DataFlow/TestStridedMetadataRangeAnalysis.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Analysis/DataFlow/TestStridedMetadataRangeAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestStridedMetadataRangeAnalysis.cpp
new file mode 100644
index 0000000..6ac09fd
--- /dev/null
+++ b/mlir/test/lib/Analysis/DataFlow/TestStridedMetadataRangeAnalysis.cpp
@@ -0,0 +1,86 @@
+//===- TestStridedMetadataRangeAnalysis.cpp - Test strided md analysis ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
+#include "mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h"
+#include "mlir/Analysis/DataFlowFramework.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+static void printAnalysisResults(DataFlowSolver &solver, Operation *op,
+ raw_ostream &os) {
+ // Collect the strided metadata of the op results.
+ SmallVector<std::pair<unsigned, const StridedMetadataRangeLattice *>> results;
+ for (OpResult result : op->getResults()) {
+ const auto *state = solver.lookupState<StridedMetadataRangeLattice>(result);
+ // Skip the result if it's uninitialized.
+ if (!state || state->getValue().isUninitialized())
+ continue;
+
+ // Skip the result if the range is empty.
+ const mlir::StridedMetadataRange &md = state->getValue();
+ if (md.getOffsets().empty() && md.getSizes().empty() &&
+ md.getStrides().empty())
+ continue;
+ results.push_back({result.getResultNumber(), state});
+ }
+
+ // Early exit if there's no metadata to print.
+ if (results.empty())
+ return;
+
+ // Print the metadata.
+ os << "Op: " << OpWithFlags(op, OpPrintingFlags().skipRegions()) << "\n";
+ for (auto [idx, state] : results)
+ os << " result[" << idx << "]: " << state->getValue() << "\n";
+ os << "\n";
+}
+
+namespace {
+struct TestStridedMetadataRangeAnalysisPass
+ : public PassWrapper<TestStridedMetadataRangeAnalysisPass,
+ OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestStridedMetadataRangeAnalysisPass)
+
+ StringRef getArgument() const override {
+ return "test-strided-metadata-range-analysis";
+ }
+ void runOnOperation() override {
+ Operation *op = getOperation();
+
+ DataFlowSolver solver;
+ solver.load<DeadCodeAnalysis>();
+ solver.load<SparseConstantPropagation>();
+ solver.load<IntegerRangeAnalysis>();
+ solver.load<StridedMetadataRangeAnalysis>();
+ if (failed(solver.initializeAndRun(op)))
+ return signalPassFailure();
+
+ op->walk(
+ [&](Operation *op) { printAnalysisResults(solver, op, llvm::errs()); });
+ }
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestStridedMetadataRangeAnalysisPass() {
+ PassRegistration<TestStridedMetadataRangeAnalysisPass>();
+}
+} // end namespace test
+} // end namespace mlir
diff --git a/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
index 1e2d4a7..4069a74 100644
--- a/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
+++ b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
@@ -11,11 +11,25 @@
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
+#include "TestAttributes.h" // TestTensorEncodingAttr, TestMemRefLayoutAttr
+#include "TestDialect.h"
+
using namespace mlir;
namespace {
+MemRefLayoutAttrInterface
+getMemRefLayoutForTensorEncoding(RankedTensorType tensorType) {
+ if (auto encoding = dyn_cast_if_present<test::TestTensorEncodingAttr>(
+ tensorType.getEncoding())) {
+ return cast<MemRefLayoutAttrInterface>(test::TestMemRefLayoutAttr::get(
+ tensorType.getContext(), encoding.getDummy()));
+ }
+ return {};
+}
+
struct TestOneShotModuleBufferizePass
: public PassWrapper<TestOneShotModuleBufferizePass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneShotModuleBufferizePass)
@@ -25,6 +39,7 @@ struct TestOneShotModuleBufferizePass
: PassWrapper(pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<test::TestDialect>();
registry.insert<bufferization::BufferizationDialect>();
}
StringRef getArgument() const final {
@@ -41,6 +56,17 @@ struct TestOneShotModuleBufferizePass
bufferization::OneShotBufferizationOptions opt;
opt.bufferizeFunctionBoundaries = true;
+ opt.functionArgTypeConverterFn =
+ [&](bufferization::TensorLikeType tensor, Attribute memSpace,
+ func::FuncOp, const bufferization::BufferizationOptions &) {
+ assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
+ auto tensorType = cast<RankedTensorType>(tensor);
+ auto layout = getMemRefLayoutForTensorEncoding(tensorType);
+ return cast<bufferization::BufferLikeType>(
+ MemRefType::get(tensorType.getShape(),
+ tensorType.getElementType(), layout, memSpace));
+ };
+
bufferization::BufferizationState bufferizationState;
if (failed(bufferization::runOneShotModuleBufferize(getOperation(), opt,
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 727c84c..8c5c8e8 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -276,10 +276,8 @@ void TestLinalgTransforms::runOnOperation() {
Operation *consumer = opOperand->getOwner();
// If we have a pack/unpack consumer and a producer that has multiple
// uses, do not apply the folding patterns.
- if (isa<linalg::PackOp, linalg::UnPackOp>(consumer) &&
- isa<TilingInterface>(producer) && !producer->hasOneUse())
- return false;
- return true;
+ return !(isa<linalg::PackOp, linalg::UnPackOp>(consumer) &&
+ isa<TilingInterface>(producer) && !producer->hasOneUse());
};
applyFoldIntoPackAndUnpackPatterns(rootOp, controlFn);
}
diff --git a/mlir/test/lib/Dialect/OpenACC/CMakeLists.txt b/mlir/test/lib/Dialect/OpenACC/CMakeLists.txt
index f84055d..1e59338 100644
--- a/mlir/test/lib/Dialect/OpenACC/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/OpenACC/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_library(MLIROpenACCTestPasses
TestOpenACC.cpp
TestPointerLikeTypeInterface.cpp
+ TestRecipePopulate.cpp
EXCLUDE_FROM_LIBMLIR
)
diff --git a/mlir/test/lib/Dialect/OpenACC/TestOpenACC.cpp b/mlir/test/lib/Dialect/OpenACC/TestOpenACC.cpp
index 9886240..bea21b9 100644
--- a/mlir/test/lib/Dialect/OpenACC/TestOpenACC.cpp
+++ b/mlir/test/lib/Dialect/OpenACC/TestOpenACC.cpp
@@ -15,9 +15,13 @@ namespace test {
// Forward declarations of individual test pass registration functions
void registerTestPointerLikeTypeInterfacePass();
+void registerTestRecipePopulatePass();
// Unified registration function for all OpenACC tests
-void registerTestOpenACC() { registerTestPointerLikeTypeInterfacePass(); }
+void registerTestOpenACC() {
+ registerTestPointerLikeTypeInterfacePass();
+ registerTestRecipePopulatePass();
+}
} // namespace test
} // namespace mlir
diff --git a/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp b/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp
index 85f9283..027b0a1 100644
--- a/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp
+++ b/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp
@@ -196,13 +196,15 @@ void TestPointerLikeTypeInterfacePass::testGenAllocate(
newBuilder.setInsertionPointAfter(op);
// Call the genAllocate API
+ bool needsFree = false;
Value allocRes = pointerType.genAllocate(newBuilder, loc, "test_alloc",
- result.getType(), result);
+ result.getType(), result, needsFree);
if (allocRes) {
llvm::errs() << "Successfully generated alloc for operation: ";
op->print(llvm::errs());
llvm::errs() << "\n";
+ llvm::errs() << "\tneeds free: " << (needsFree ? "true" : "false") << "\n";
// Print all operations that were inserted
for (Operation *insertedOp : tracker.insertedOps) {
@@ -230,8 +232,8 @@ void TestPointerLikeTypeInterfacePass::testGenFree(Operation *op, Value result,
// Call the genFree API
auto typedResult = cast<TypedValue<PointerLikeType>>(result);
- bool success =
- pointerType.genFree(newBuilder, loc, typedResult, result.getType());
+ bool success = pointerType.genFree(newBuilder, loc, typedResult, result,
+ result.getType());
if (success) {
llvm::errs() << "Successfully generated free for operation: ";
diff --git a/mlir/test/lib/Dialect/OpenACC/TestRecipePopulate.cpp b/mlir/test/lib/Dialect/OpenACC/TestRecipePopulate.cpp
new file mode 100644
index 0000000..35f092c
--- /dev/null
+++ b/mlir/test/lib/Dialect/OpenACC/TestRecipePopulate.cpp
@@ -0,0 +1,110 @@
+//===- TestRecipePopulate.cpp - Test Recipe Population -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains test passes for testing the createAndPopulate methods
+// of the recipe operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/CommandLine.h"
+
+using namespace mlir;
+using namespace mlir::acc;
+
+namespace {
+
+struct TestRecipePopulatePass
+ : public PassWrapper<TestRecipePopulatePass, OperationPass<ModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRecipePopulatePass)
+
+ TestRecipePopulatePass() = default;
+ TestRecipePopulatePass(const TestRecipePopulatePass &pass)
+ : PassWrapper(pass) {
+ recipeType = pass.recipeType;
+ }
+
+ Pass::Option<std::string> recipeType{
+ *this, "recipe-type",
+ llvm::cl::desc("Recipe type: private or firstprivate"),
+ llvm::cl::init("private")};
+
+ StringRef getArgument() const override { return "test-acc-recipe-populate"; }
+
+ StringRef getDescription() const override {
+ return "Test OpenACC recipe population";
+ }
+
+ void runOnOperation() override;
+
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<acc::OpenACCDialect>();
+ registry.insert<arith::ArithDialect>();
+ registry.insert<memref::MemRefDialect>();
+ }
+};
+
+void TestRecipePopulatePass::runOnOperation() {
+ auto module = getOperation();
+ OpBuilder builder(&getContext());
+
+ // Collect all test variables
+ SmallVector<std::tuple<Operation *, Value, std::string>> testVars;
+
+ module.walk([&](Operation *op) {
+ if (auto varName = op->getAttrOfType<StringAttr>("test.var")) {
+ for (auto result : op->getResults()) {
+ testVars.push_back({op, result, varName.str()});
+ }
+ }
+ });
+
+ // Generate recipes at module level
+ builder.setInsertionPoint(&module.getBodyRegion().front(),
+ module.getBodyRegion().front().begin());
+
+ for (auto [op, var, varName] : testVars) {
+ Location loc = op->getLoc();
+
+ std::string recipeName = recipeType.getValue() + "_" + varName;
+ ValueRange bounds; // No bounds for memref tests
+
+ if (recipeType == "private") {
+ auto recipe = PrivateRecipeOp::createAndPopulate(
+ builder, loc, recipeName, var.getType(), varName, bounds);
+
+ if (!recipe) {
+ op->emitError("Failed to create private recipe for ") << varName;
+ }
+ } else if (recipeType == "firstprivate") {
+ auto recipe = FirstprivateRecipeOp::createAndPopulate(
+ builder, loc, recipeName, var.getType(), varName, bounds);
+
+ if (!recipe) {
+ op->emitError("Failed to create firstprivate recipe for ") << varName;
+ }
+ }
+ }
+}
+
+} // namespace
+
+namespace mlir {
+namespace test {
+
+void registerTestRecipePopulatePass() {
+ PassRegistration<TestRecipePopulatePass>();
+}
+
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 5685004..9e7e4f8 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -22,6 +22,7 @@ include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/TensorEncoding.td"
// All of the attributes will extend this class.
class Test_Attr<string name, list<Trait> traits = []>
@@ -439,4 +440,20 @@ def TestCustomStorageCtorAttr : Test_Attr<"TestCustomStorageCtorAttr"> {
let hasStorageCustomConstructor = 1;
}
+def TestTensorEncodingAttr : Test_Attr<"TestTensorEncoding",
+ [DeclareAttrInterfaceMethods<VerifiableTensorEncoding>]> {
+ let mnemonic = "tensor_encoding";
+
+ let parameters = (ins "mlir::StringAttr":$dummy);
+ let assemblyFormat = "`<` $dummy `>`";
+}
+
+def TestMemRefLayoutAttr : Test_Attr<"TestMemRefLayout",
+ [DeclareAttrInterfaceMethods<MemRefLayoutAttrInterface>]> {
+ let mnemonic = "memref_layout";
+
+ let parameters = (ins "mlir::StringAttr":$dummy);
+ let assemblyFormat = "`<` $dummy `>`";
+}
+
#endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index fe1e916..9db7b01 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -542,6 +542,24 @@ test::detail::TestCustomStorageCtorAttrAttrStorage::construct(
}
//===----------------------------------------------------------------------===//
+// TestTensorEncodingAttr
+//===----------------------------------------------------------------------===//
+
+::llvm::LogicalResult TestTensorEncodingAttr::verifyEncoding(
+ mlir::ArrayRef<int64_t> shape, mlir::Type elementType,
+ llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) const {
+ return mlir::success();
+}
+
+//===----------------------------------------------------------------------===//
+// TestMemRefLayoutAttr
+//===----------------------------------------------------------------------===//
+
+mlir::AffineMap TestMemRefLayoutAttr::getAffineMap() const {
+ return mlir::AffineMap::getMultiDimIdentityMap(1, getContext());
+}
+
+//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h
index 778d84fa..0ad5ab6 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.h
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.h
@@ -24,6 +24,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectResourceBlobManager.h"
+#include "mlir/IR/TensorEncoding.h"
// generated files require above includes to come first
#include "TestAttrInterfaces.h.inc"
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index f2adca6..bcf3b55d 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -18,6 +18,7 @@
#include "TestInterfaces.h"
#include "TestTypes.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/DLTI/Traits.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td
index 2b5491f..37a263f 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.td
+++ b/mlir/test/lib/Dialect/Test/TestDialect.td
@@ -24,7 +24,10 @@ def Test_Dialect : Dialect {
let useDefaultTypePrinterParser = 0;
let useDefaultAttributePrinterParser = 1;
let isExtensible = 1;
- let dependentDialects = ["::mlir::DLTIDialect"];
+ let dependentDialects = [
+ "::mlir::DLTIDialect",
+ "::mlir::bufferization::BufferizationDialect"
+ ];
let discardableAttrs = (ins
"mlir::IntegerAttr":$discardable_attr_key,
"SimpleAAttr":$other_discardable_attr_key
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 53055fe..b211e24 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1425,6 +1425,39 @@ TestMultiSlotAlloca::handleDestructuringComplete(
return createNewMultiAllocaWithoutSlot(slot, builder, *this);
}
+namespace {
+/// Returns test dialect's memref layout for test dialect's tensor encoding when
+/// applicable.
+MemRefLayoutAttrInterface
+getMemRefLayoutForTensorEncoding(RankedTensorType tensorType) {
+ if (auto encoding =
+ dyn_cast<test::TestTensorEncodingAttr>(tensorType.getEncoding())) {
+ return cast<MemRefLayoutAttrInterface>(test::TestMemRefLayoutAttr::get(
+ tensorType.getContext(), encoding.getDummy()));
+ }
+ return {};
+}
+
+/// Auxiliary bufferization function for test and builtin tensors.
+bufferization::BufferLikeType
+convertTensorToBuffer(mlir::Operation *op,
+ const bufferization::BufferizationOptions &options,
+ bufferization::TensorLikeType tensorLike) {
+ auto buffer =
+ *tensorLike.getBufferType(options, [&]() { return op->emitError(); });
+ if (auto memref = dyn_cast<MemRefType>(buffer)) {
+ // Note: For the sake of testing, we want to ensure that encoding -> layout
+ // bufferization happens. This is currently achieved manually.
+ auto layout =
+ getMemRefLayoutForTensorEncoding(cast<RankedTensorType>(tensorLike));
+ return cast<bufferization::BufferLikeType>(
+ MemRefType::get(memref.getShape(), memref.getElementType(), layout,
+ memref.getMemorySpace()));
+ }
+ return buffer;
+}
+} // namespace
+
::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
::mlir::RewriterBase &rewriter,
const ::mlir::bufferization::BufferizationOptions &options,
@@ -1435,8 +1468,8 @@ TestMultiSlotAlloca::handleDestructuringComplete(
return failure();
const auto outType = getOutput().getType();
- const auto bufferizedOutType = test::TestMemrefType::get(
- getContext(), outType.getShape(), outType.getElementType(), nullptr);
+ const auto bufferizedOutType =
+ convertTensorToBuffer(getOperation(), options, outType);
// replace op with memref analogy
auto dummyMemrefOp = test::TestDummyMemrefOp::create(
rewriter, getLoc(), bufferizedOutType, *buffer);
@@ -1470,13 +1503,12 @@ TestMultiSlotAlloca::handleDestructuringComplete(
mlir::FailureOr<mlir::bufferization::BufferLikeType>
test::TestCreateTensorOp::getBufferType(
- mlir::Value value, const mlir::bufferization::BufferizationOptions &,
+ mlir::Value value, const mlir::bufferization::BufferizationOptions &options,
const mlir::bufferization::BufferizationState &,
llvm::SmallVector<::mlir::Value> &) {
- const auto type = dyn_cast<test::TestTensorType>(value.getType());
+ const auto type = dyn_cast<bufferization::TensorLikeType>(value.getType());
if (type == nullptr)
return failure();
- return cast<mlir::bufferization::BufferLikeType>(test::TestMemrefType::get(
- getContext(), type.getShape(), type.getElementType(), nullptr));
+ return convertTensorToBuffer(getOperation(), options, type);
}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 6329d61..05a33cf 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -32,6 +32,7 @@ include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ValueBoundsOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
+include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
// Include the attribute definitions.
include "TestAttrDefs.td"
@@ -2335,7 +2336,7 @@ def SideEffectWithRegionOp : TEST_Op<"side_effect_with_region_op",
}
//===----------------------------------------------------------------------===//
-// Copy Operation Test
+// Copy Operation Test
//===----------------------------------------------------------------------===//
def CopyOp : TEST_Op<"copy", []> {
@@ -3676,10 +3677,10 @@ def TestDummyTensorOp : TEST_Op<"dummy_tensor_op",
["bufferize", "bufferizesToMemoryRead",
"bufferizesToMemoryWrite", "getAliasingValues"]>]> {
let arguments = (ins
- Arg<TestTensorType>:$input
+ Arg<Bufferization_TensorLikeTypeInterface>:$input
);
let results = (outs
- Arg<TestTensorType>:$output
+ Arg<Bufferization_TensorLikeTypeInterface>:$output
);
let extraClassDefinition = [{
@@ -3701,10 +3702,10 @@ def TestDummyTensorOp : TEST_Op<"dummy_tensor_op",
def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> {
let arguments = (ins
- Arg<TestMemrefType>:$input
+ Arg<Bufferization_BufferLikeTypeInterface>:$input
);
let results = (outs
- Arg<TestMemrefType>:$output
+ Arg<Bufferization_BufferLikeTypeInterface>:$output
);
}
@@ -3714,7 +3715,7 @@ def TestCreateTensorOp : TEST_Op<"create_tensor_op",
"bufferizesToMemoryWrite", "getAliasingValues",
"bufferizesToAllocation"]>]> {
let arguments = (ins);
- let results = (outs Arg<TestTensorType>:$output);
+ let results = (outs Arg<Bufferization_TensorLikeTypeInterface>:$output);
let extraClassDefinition = [{
bool test::TestCreateTensorOp::bufferizesToMemoryRead(::mlir::OpOperand&,
const ::mlir::bufferization::AnalysisState&) {
@@ -3738,7 +3739,7 @@ def TestCreateTensorOp : TEST_Op<"create_tensor_op",
def TestCreateMemrefOp : TEST_Op<"create_memref_op"> {
let arguments = (ins);
- let results = (outs Arg<TestMemrefType>:$output);
+ let results = (outs Arg<Bufferization_BufferLikeTypeInterface>:$output);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 97fc699..496f18b 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -938,10 +938,10 @@ public:
// These are automatically generated by ODS but are not used as the Transform
// dialect uses a different dispatch mechanism to support dialect extensions.
-LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
+[[maybe_unused]] static OptionalParseResult
generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
-LLVM_ATTRIBUTE_UNUSED static LogicalResult
-generatedTypePrinter(Type def, AsmPrinter &printer);
+[[maybe_unused]] static LogicalResult generatedTypePrinter(Type def,
+ AsmPrinter &printer);
#define GET_TYPEDEF_CLASSES
#include "TestTransformDialectExtensionTypes.cpp.inc"
diff --git a/mlir/test/lib/Pass/TestRemarksPass.cpp b/mlir/test/lib/Pass/TestRemarksPass.cpp
index 3b25686..5ca2d1a 100644
--- a/mlir/test/lib/Pass/TestRemarksPass.cpp
+++ b/mlir/test/lib/Pass/TestRemarksPass.cpp
@@ -43,7 +43,12 @@ public:
<< remark::add("This is a test missed remark")
<< remark::reason("because we are testing the remark pipeline")
<< remark::suggest("try using the remark pipeline feature");
-
+ mlir::remark::passed(
+ loc,
+ remark::RemarkOpts::name("test-remark").category("category-1-passed"))
+ << remark::add("This is a test passed remark (should be dropped)")
+ << remark::reason("because we are testing the remark pipeline")
+ << remark::suggest("try using the remark pipeline feature");
mlir::remark::passed(
loc,
remark::RemarkOpts::name("test-remark").category("category-1-passed"))
diff --git a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
index 4e869e5..4be30d8 100644
--- a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
+++ b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
@@ -28,7 +28,7 @@
// CHECK: operation "test.op3"
// CHECK: )mlir", context), std::forward<ConfigsT>(configs)...)
-// CHECK: static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {
+// CHECK{LITERAL}: [[maybe_unused]] static void populateGeneratedPDLLPatterns(::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {
// CHECK-NEXT: patterns.add<GeneratedPDLLPattern0>(patterns.getContext(), configs...);
// CHECK-NEXT: patterns.add<NamedPattern>(patterns.getContext(), configs...);
// CHECK-NEXT: patterns.add<GeneratedPDLLPattern1>(patterns.getContext(), configs...);
diff --git a/mlir/test/mlir-tblgen/cpp-class-comments.td b/mlir/test/mlir-tblgen/cpp-class-comments.td
index a896888..9dcf975 100644
--- a/mlir/test/mlir-tblgen/cpp-class-comments.td
+++ b/mlir/test/mlir-tblgen/cpp-class-comments.td
@@ -96,17 +96,14 @@ def EncodingTrait : AttrInterface<"EncodingTrait"> {
}];
let methods = [
];
-// ATTR-INTERFACE: namespace mlir
-// ATTR-INTERFACE-NEXT: namespace a
-// ATTR-INTERFACE-NEXT: namespace traits
+// ATTR-INTERFACE: namespace mlir::a::traits {
// ATTR-INTERFACE-NEXT: /// Common trait for all layouts.
// ATTR-INTERFACE-NEXT: class EncodingTrait;
}
def SimpleEncodingTrait : AttrInterface<"SimpleEncodingTrait"> {
let cppNamespace = "a::traits";
-// ATTR-INTERFACE: namespace a {
-// ATTR-INTERFACE-NEXT: namespace traits {
+// ATTR-INTERFACE: namespace a::traits {
// ATTR-INTERFACE-NEXT: class SimpleEncodingTrait;
}
@@ -116,8 +113,7 @@ def SimpleOpInterface : OpInterface<"SimpleOpInterface"> {
Simple Op Interface description
}];
-// OP-INTERFACE: namespace a {
-// OP-INTERFACE-NEXT: namespace traits {
+// OP-INTERFACE: namespace a::traits {
// OP-INTERFACE-NEXT: /// Simple Op Interface description
// OP-INTERFACE-NEXT: class SimpleOpInterface;
}
diff --git a/mlir/test/python/dialects/gpu/dialect.py b/mlir/test/python/dialects/gpu/dialect.py
index 26ee9f3..66c4018 100644
--- a/mlir/test/python/dialects/gpu/dialect.py
+++ b/mlir/test/python/dialects/gpu/dialect.py
@@ -1,6 +1,7 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
+import mlir.ir as ir
import mlir.dialects.gpu as gpu
import mlir.dialects.gpu.passes
from mlir.passmanager import *
@@ -64,3 +65,95 @@ def testObjectAttr():
# CHECK: #gpu.object<#nvvm.target, kernels = <[#gpu.kernel_metadata<"kernel", () -> ()>]>, "BC\C0\DE5\14\00\00\05\00\00\00b\0C0$MY\BEf">
print(o)
assert o.kernels == kernelTable
+
+
+# CHECK-LABEL: testGPUFuncOp
+@run
+def testGPUFuncOp():
+ assert gpu.GPUFuncOp.__doc__ is not None
+ module = Module.create()
+ with InsertionPoint(module.body):
+ gpu_module_name = StringAttr.get("gpu_module")
+ gpumodule = gpu.GPUModuleOp(gpu_module_name)
+ block = gpumodule.bodyRegion.blocks.append()
+
+ def builder(func: gpu.GPUFuncOp) -> None:
+ gpu.GlobalIdOp(gpu.Dimension.x)
+ gpu.ReturnOp([])
+
+ with InsertionPoint(block):
+ name = StringAttr.get("kernel0")
+ func_type = ir.FunctionType.get(inputs=[], results=[])
+ type_attr = TypeAttr.get(func_type)
+ func = gpu.GPUFuncOp(type_attr, name)
+ func.attributes["sym_name"] = name
+ func.attributes["gpu.kernel"] = UnitAttr.get()
+
+ try:
+ func.entry_block
+ assert False, "Expected RuntimeError"
+ except RuntimeError as e:
+ assert (
+ str(e)
+ == "Entry block does not exist for kernel0. Do you need to call the add_entry_block() method on this GPUFuncOp?"
+ )
+
+ block = func.add_entry_block()
+ with InsertionPoint(block):
+ builder(func)
+
+ try:
+ func.add_entry_block()
+ assert False, "Expected RuntimeError"
+ except RuntimeError as e:
+ assert str(e) == "Entry block already exists for kernel0"
+
+ func = gpu.GPUFuncOp(
+ func_type,
+ sym_name="kernel1",
+ kernel=True,
+ body_builder=builder,
+ known_block_size=[1, 2, 3],
+ known_grid_size=DenseI32ArrayAttr.get([4, 5, 6]),
+ )
+
+ assert func.name.value == "kernel1"
+ assert func.function_type.value == func_type
+ assert func.arg_attrs == None
+ assert func.res_attrs == None
+ assert func.arguments == []
+ assert func.entry_block == func.body.blocks[0]
+ assert func.is_kernel
+ assert func.known_block_size == DenseI32ArrayAttr.get(
+ [1, 2, 3]
+ ), func.known_block_size
+ assert func.known_grid_size == DenseI32ArrayAttr.get(
+ [4, 5, 6]
+ ), func.known_grid_size
+
+ func = gpu.GPUFuncOp(
+ func_type,
+ sym_name="non_kernel_func",
+ body_builder=builder,
+ )
+ assert not func.is_kernel
+ assert func.known_block_size is None
+ assert func.known_grid_size is None
+
+ print(module)
+
+ # CHECK: gpu.module @gpu_module
+ # CHECK: gpu.func @kernel0() kernel {
+ # CHECK: %[[VAL_0:.*]] = gpu.global_id x
+ # CHECK: gpu.return
+ # CHECK: }
+ # CHECK: gpu.func @kernel1() kernel attributes
+ # CHECK-SAME: known_block_size = array<i32: 1, 2, 3>
+ # CHECK-SAME: known_grid_size = array<i32: 4, 5, 6>
+ # CHECK: %[[VAL_0:.*]] = gpu.global_id x
+ # CHECK: gpu.return
+ # CHECK: }
+ # CHECK: gpu.func @non_kernel_func() {
+ # CHECK: %[[VAL_0:.*]] = gpu.global_id x
+ # CHECK: gpu.return
+ # CHECK: }
diff --git a/mlir/test/python/dialects/openacc.py b/mlir/test/python/dialects/openacc.py
new file mode 100644
index 0000000..8f2142a
--- /dev/null
+++ b/mlir/test/python/dialects/openacc.py
@@ -0,0 +1,171 @@
+# RUN: %PYTHON %s | FileCheck %s
+from unittest import result
+from mlir.ir import (
+ Context,
+ FunctionType,
+ Location,
+ Module,
+ InsertionPoint,
+ IntegerType,
+ IndexType,
+ MemRefType,
+ F32Type,
+ Block,
+ ArrayAttr,
+ Attribute,
+ UnitAttr,
+ StringAttr,
+ DenseI32ArrayAttr,
+ ShapedType,
+)
+from mlir.dialects import openacc, func, arith, memref
+from mlir.extras import types
+
+
+def run(f):
+ print("\n// TEST:", f.__name__)
+ with Context(), Location.unknown():
+ f()
+ return f
+
+
+@run
+def testParallelMemcpy():
+ module = Module.create()
+
+ dynamic = ShapedType.get_dynamic_size()
+ memref_f32_1d_any = MemRefType.get([dynamic], types.f32())
+
+ with InsertionPoint(module.body):
+ function_type = FunctionType.get(
+ [memref_f32_1d_any, memref_f32_1d_any, types.i64()], []
+ )
+ f = func.FuncOp(
+ type=function_type,
+ name="memcpy_idiom",
+ )
+ f.attributes["sym_visibility"] = StringAttr.get("public")
+
+ with InsertionPoint(f.add_entry_block()):
+ c1024 = arith.ConstantOp(types.i32(), 1024)
+ c128 = arith.ConstantOp(types.i32(), 128)
+
+ arg0, arg1, arg2 = f.arguments
+
+ copied = openacc.copyin(
+ acc_var=arg0.type,
+ var=arg0,
+ var_type=types.f32(),
+ bounds=[],
+ async_operands=[],
+ implicit=False,
+ structured=True,
+ )
+ created = openacc.create_(
+ acc_var=arg1.type,
+ var=arg1,
+ var_type=types.f32(),
+ bounds=[],
+ async_operands=[],
+ implicit=False,
+ structured=True,
+ )
+
+ parallel_op = openacc.ParallelOp(
+ asyncOperands=[],
+ waitOperands=[],
+ numGangs=[c1024],
+ numWorkers=[],
+ vectorLength=[c128],
+ reductionOperands=[],
+ privateOperands=[],
+ firstprivateOperands=[],
+ dataClauseOperands=[],
+ )
+
+ # Set required device_type and segment attributes to satisfy verifier
+ acc_device_none = ArrayAttr.get([Attribute.parse("#acc.device_type<none>")])
+ parallel_op.numGangsDeviceType = acc_device_none
+ parallel_op.numGangsSegments = DenseI32ArrayAttr.get([1])
+ parallel_op.vectorLengthDeviceType = acc_device_none
+
+ parallel_block = Block.create_at_start(parent=parallel_op.region, arg_types=[])
+
+ with InsertionPoint(parallel_block):
+ c0 = arith.ConstantOp(types.i64(), 0)
+ c1 = arith.ConstantOp(types.i64(), 1)
+
+ loop_op = openacc.LoopOp(
+ results_=[],
+ lowerbound=[c0],
+ upperbound=[f.arguments[2]],
+ step=[c1],
+ gangOperands=[],
+ workerNumOperands=[],
+ vectorOperands=[],
+ tileOperands=[],
+ cacheOperands=[],
+ privateOperands=[],
+ reductionOperands=[],
+ firstprivateOperands=[],
+ )
+
+ # Set loop attributes: gang and independent on device_type<none>
+ acc_device_none = ArrayAttr.get([Attribute.parse("#acc.device_type<none>")])
+ loop_op.gang = acc_device_none
+ loop_op.independent = acc_device_none
+
+ loop_block = Block.create_at_start(
+ parent=loop_op.region, arg_types=[types.i64()]
+ )
+
+ with InsertionPoint(loop_block):
+ idx = arith.index_cast(out=IndexType.get(), in_=loop_block.arguments[0])
+ val = memref.load(memref=copied, indices=[idx])
+ memref.store(value=val, memref=created, indices=[idx])
+ openacc.YieldOp([])
+
+ openacc.YieldOp([])
+
+ deleted = openacc.delete(
+ acc_var=copied,
+ bounds=[],
+ async_operands=[],
+ implicit=False,
+ structured=True,
+ )
+ copied = openacc.copyout(
+ acc_var=created,
+ var=arg1,
+ var_type=types.f32(),
+ bounds=[],
+ async_operands=[],
+ implicit=False,
+ structured=True,
+ )
+ func.ReturnOp([])
+
+ print(module)
+
+ # CHECK: TEST: testParallelMemcpy
+ # CHECK-LABEL: func.func public @memcpy_idiom(
+ # CHECK-SAME: %[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: i64) {
+ # CHECK: %[[CONSTANT_0:.*]] = arith.constant 1024 : i32
+ # CHECK: %[[CONSTANT_1:.*]] = arith.constant 128 : i32
+ # CHECK: %[[COPYIN_0:.*]] = acc.copyin varPtr(%[[ARG0]] : memref<?xf32>) -> memref<?xf32>
+ # CHECK: %[[CREATE_0:.*]] = acc.create varPtr(%[[ARG1]] : memref<?xf32>) -> memref<?xf32>
+ # CHECK: acc.parallel num_gangs({%[[CONSTANT_0]] : i32}) vector_length(%[[CONSTANT_1]] : i32) {
+ # CHECK: %[[CONSTANT_2:.*]] = arith.constant 0 : i64
+ # CHECK: %[[CONSTANT_3:.*]] = arith.constant 1 : i64
+ # CHECK: acc.loop gang control(%[[VAL_0:.*]] : i64) = (%[[CONSTANT_2]] : i64) to (%[[ARG2]] : i64) step (%[[CONSTANT_3]] : i64) {
+ # CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[VAL_0]] : i64 to index
+ # CHECK: %[[LOAD_0:.*]] = memref.load %[[COPYIN_0]]{{\[}}%[[INDEX_CAST_0]]] : memref<?xf32>
+ # CHECK: memref.store %[[LOAD_0]], %[[CREATE_0]]{{\[}}%[[INDEX_CAST_0]]] : memref<?xf32>
+ # CHECK: acc.yield
+ # CHECK: } attributes {independent = [#acc.device_type<none>]}
+ # CHECK: acc.yield
+ # CHECK: }
+ # CHECK: acc.delete accPtr(%[[COPYIN_0]] : memref<?xf32>)
+ # CHECK: acc.copyout accPtr(%[[CREATE_0]] : memref<?xf32>) to varPtr(%[[ARG1]] : memref<?xf32>)
+ # CHECK: return
+ # CHECK: }
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index cb4cfc8c..1d4ede1 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -569,12 +569,30 @@ def testOperationAttributes():
# CHECK: Attribute value b'text'
print(f"Attribute value {sattr.value_bytes}")
+ # Python dict-style iteration
# We don't know in which order the attributes are stored.
- # CHECK-DAG: NamedAttribute(dependent="text")
- # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
- # CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
- for attr in op.attributes:
- print(str(attr))
+ # CHECK-DAG: dependent
+ # CHECK-DAG: other.attribute
+ # CHECK-DAG: some.attribute
+ for name in op.attributes:
+ print(name)
+
+ # Basic dict-like introspection
+ # CHECK: True
+ print("some.attribute" in op.attributes)
+ # CHECK: False
+ print("missing" in op.attributes)
+ # CHECK: Keys: ['dependent', 'other.attribute', 'some.attribute']
+ print("Keys:", sorted(op.attributes.keys()))
+ # CHECK: Values count 3
+ print("Values count", len(op.attributes.values()))
+ # CHECK: Items count 3
+ print("Items count", len(op.attributes.items()))
+
+ # Dict() conversion test
+ d = {k: v.value for k, v in dict(op.attributes).items()}
+ # CHECK: Dict mapping {'dependent': 'text', 'other.attribute': 3.0, 'some.attribute': 1}
+ print("Dict mapping", d)
# Check that exceptions are raised as expected.
try:
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 6432fae..8842180 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -151,6 +151,7 @@ void registerTestSliceAnalysisPass();
void registerTestSPIRVCPURunnerPipeline();
void registerTestSPIRVFuncSignatureConversion();
void registerTestSPIRVVectorUnrolling();
+void registerTestStridedMetadataRangeAnalysisPass();
void registerTestTensorCopyInsertionPass();
void registerTestTensorLikeAndBufferLikePass();
void registerTestTensorTransforms();
@@ -299,6 +300,7 @@ void registerTestPasses() {
mlir::test::registerTestSPIRVCPURunnerPipeline();
mlir::test::registerTestSPIRVFuncSignatureConversion();
mlir::test::registerTestSPIRVVectorUnrolling();
+ mlir::test::registerTestStridedMetadataRangeAnalysisPass();
mlir::test::registerTestTensorCopyInsertionPass();
mlir::test::registerTestTensorLikeAndBufferLikePass();
mlir::test::registerTestTensorTransforms();
diff --git a/mlir/tools/mlir-pdll/mlir-pdll.cpp b/mlir/tools/mlir-pdll/mlir-pdll.cpp
index f99dcdb..76122a0 100644
--- a/mlir/tools/mlir-pdll/mlir-pdll.cpp
+++ b/mlir/tools/mlir-pdll/mlir-pdll.cpp
@@ -19,6 +19,7 @@
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
+#include "llvm/Support/VirtualFileSystem.h"
#include <set>
using namespace mlir;
@@ -41,6 +42,7 @@ processBuffer(raw_ostream &os, std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
bool dumpODS, std::set<std::string> *includedFiles) {
llvm::SourceMgr sourceMgr;
sourceMgr.setIncludeDirs(includeDirs);
+ sourceMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem());
sourceMgr.AddNewSourceBuffer(std::move(chunkBuffer), SMLoc());
// If we are dumping ODS information, also enable documentation to ensure the
diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index d55ad482..11bf9ce 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -20,6 +20,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
+#include "llvm/TableGen/CodeGenHelpers.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
@@ -701,11 +702,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
StringRef underlyingToSymFnName = enumInfo.getUnderlyingToSymbolFnName();
auto enumerants = enumInfo.getAllCases();
- SmallVector<StringRef, 2> namespaces;
- llvm::SplitString(cppNamespace, namespaces, "::");
-
- for (auto ns : namespaces)
- os << "namespace " << ns << " {\n";
+ llvm::NamespaceEmitter ns(os, cppNamespace);
// Emit the enum class definition
emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os);
@@ -766,8 +763,7 @@ public:
os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName);
}
- for (auto ns : llvm::reverse(namespaces))
- os << "} // namespace " << ns << "\n";
+ ns.close();
// Generate a generic parser and printer for the enum.
std::string qualName =
@@ -790,13 +786,8 @@ static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) {
static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
EnumInfo enumInfo(enumDef);
- StringRef cppNamespace = enumInfo.getCppNamespace();
- SmallVector<StringRef, 2> namespaces;
- llvm::SplitString(cppNamespace, namespaces, "::");
-
- for (auto ns : namespaces)
- os << "namespace " << ns << " {\n";
+ llvm::NamespaceEmitter ns(os, enumInfo.getCppNamespace());
if (enumInfo.isBitEnum()) {
emitSymToStrFnForBitEnum(enumDef, os);
@@ -810,10 +801,6 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
if (enumInfo.genSpecializedAttr())
emitSpecializedAttrDef(enumDef, os);
-
- for (auto ns : llvm::reverse(namespaces))
- os << "} // namespace " << ns << "\n";
- os << "\n";
}
static bool emitEnumDefs(const RecordKeeper &records, raw_ostream &os) {
diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
index 96af14d..11a2db4 100644
--- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
+++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
@@ -416,7 +416,7 @@ static void emitOneEnumToConversion(const Record *record, raw_ostream &os) {
// Emit the function converting the enum attribute to its LLVM counterpart.
os << formatv(
- "static LLVM_ATTRIBUTE_UNUSED {0} convert{1}ToLLVM({2}::{1} value) {{\n",
+ "[[maybe_unused]] static {0} convert{1}ToLLVM({2}::{1} value) {{\n",
llvmClass, cppClassName, cppNamespace);
os << " switch (value) {\n";
@@ -444,7 +444,7 @@ static void emitOneCEnumToConversion(const Record *record, raw_ostream &os) {
StringRef cppNamespace = enumAttr.getCppNamespace();
// Emit the function converting the enum attribute to its LLVM counterpart.
- os << formatv("static LLVM_ATTRIBUTE_UNUSED int64_t "
+ os << formatv("[[maybe_unused]] static int64_t "
"convert{0}ToLLVM({1}::{0} value) {{\n",
cppClassName, cppNamespace);
os << " switch (value) {\n";
@@ -474,7 +474,7 @@ static void emitOneEnumFromConversion(const Record *record, raw_ostream &os) {
StringRef cppNamespace = enumInfo.getCppNamespace();
// Emit the function converting the enum attribute from its LLVM counterpart.
- os << formatv("inline LLVM_ATTRIBUTE_UNUSED {0}::{1} convert{1}FromLLVM({2} "
+ os << formatv("[[maybe_unused]] inline {0}::{1} convert{1}FromLLVM({2} "
"value) {{\n",
cppNamespace, cppClassName, llvmClass);
os << " switch (value) {\n";
@@ -509,10 +509,9 @@ static void emitOneCEnumFromConversion(const Record *record, raw_ostream &os) {
StringRef cppNamespace = enumInfo.getCppNamespace();
// Emit the function converting the enum attribute from its LLVM counterpart.
- os << formatv(
- "inline LLVM_ATTRIBUTE_UNUSED {0}::{1} convert{1}FromLLVM(int64_t "
- "value) {{\n",
- cppNamespace, cppClassName);
+ os << formatv("[[maybe_unused]] inline {0}::{1} convert{1}FromLLVM(int64_t "
+ "value) {{\n",
+ cppNamespace, cppClassName);
os << " switch (value) {\n";
for (const auto &enumerant : enumInfo.getAllCases()) {
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index daae3c7..3718648 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -4896,7 +4896,7 @@ static void emitOpClassDefs(const RecordKeeper &records,
constraintPrefix);
os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
staticVerifierEmitter.collectOpConstraints(defs);
- staticVerifierEmitter.emitOpConstraints(defs);
+ staticVerifierEmitter.emitOpConstraints();
// Emit the classes.
emitOpClasses(records, defs, os, staticVerifierEmitter,
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 730b5b2..ab8d534 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -19,6 +19,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
+#include "llvm/TableGen/CodeGenHelpers.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
@@ -342,11 +343,7 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) {
}
void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
- llvm::SmallVector<StringRef, 2> namespaces;
- llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
- for (StringRef ns : namespaces)
- os << "namespace " << ns << " {\n";
-
+ llvm::NamespaceEmitter ns(os, interface.getCppNamespace());
for (auto &method : interface.getMethods()) {
os << "template<typename " << valueTemplate << ">\n";
emitCPPType(method.getReturnType(), os);
@@ -442,18 +439,11 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
method.isStatic() ? &ctx : &nonStaticMethodFmt);
os << "\n}\n";
}
-
- for (StringRef ns : llvm::reverse(namespaces))
- os << "} // namespace " << ns << "\n";
}
void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) {
- llvm::SmallVector<StringRef, 2> namespaces;
- llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
- for (StringRef ns : namespaces)
- os << "namespace " << ns << " {\n";
-
- os << "namespace detail {\n";
+ auto cppNamespace = (interface.getCppNamespace() + "::detail").str();
+ llvm::NamespaceEmitter ns(os, cppNamespace);
StringRef interfaceName = interface.getName();
auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
@@ -504,10 +494,6 @@ void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) {
os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n";
os << " };\n";
- os << "}// namespace detail\n";
-
- for (StringRef ns : llvm::reverse(namespaces))
- os << "} // namespace " << ns << "\n";
}
static void emitInterfaceDeclMethods(const Interface &interface,
@@ -533,10 +519,7 @@ static void emitInterfaceDeclMethods(const Interface &interface,
}
void InterfaceGenerator::forwardDeclareInterface(const Interface &interface) {
- llvm::SmallVector<StringRef, 2> namespaces;
- llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
- for (StringRef ns : namespaces)
- os << "namespace " << ns << " {\n";
+ llvm::NamespaceEmitter ns(os, interface.getCppNamespace());
// Emit a forward declaration of the interface class so that it becomes usable
// in the signature of its methods.
@@ -545,16 +528,10 @@ void InterfaceGenerator::forwardDeclareInterface(const Interface &interface) {
StringRef interfaceName = interface.getName();
os << "class " << interfaceName << ";\n";
-
- for (StringRef ns : llvm::reverse(namespaces))
- os << "} // namespace " << ns << "\n";
}
void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
- llvm::SmallVector<StringRef, 2> namespaces;
- llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
- for (StringRef ns : namespaces)
- os << "namespace " << ns << " {\n";
+ llvm::NamespaceEmitter ns(os, interface.getCppNamespace());
StringRef interfaceName = interface.getName();
auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
@@ -631,9 +608,6 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
}
os << "};\n";
-
- for (StringRef ns : llvm::reverse(namespaces))
- os << "} // namespace " << ns << "\n";
}
bool InterfaceGenerator::emitInterfaceDecls() {
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 40bc1a9..c3034bb8 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -2120,7 +2120,7 @@ static void emitRewriters(const RecordKeeper &records, raw_ostream &os) {
}
// Emit function to add the generated matchers to the pattern list.
- os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated("
+ os << "[[maybe_unused]] void populateWithGenerated("
"::mlir::RewritePatternSet &patterns) {\n";
for (const auto &name : rewriterNames) {
os << " patterns.add<" << name << ">(patterns.getContext());\n";
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 3ead2f0..ca291b5 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -259,8 +259,8 @@ static void emitInterfaceDecl(const Availability &availability,
std::string interfaceTraitsName =
std::string(formatv("{0}Traits", interfaceName));
- StringRef cppNamespace = availability.getInterfaceClassNamespace();
- llvm::NamespaceEmitter nsEmitter(os, cppNamespace);
+ llvm::NamespaceEmitter nsEmitter(os,
+ availability.getInterfaceClassNamespace());
os << "class " << interfaceName << ";\n\n";
// Emit the traits struct containing the concept and model declarations.
@@ -418,15 +418,9 @@ static void emitAvailabilityQueryForBitEnum(const Record &enumDef,
static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
EnumInfo enumInfo(enumDef);
StringRef enumName = enumInfo.getEnumClassName();
- StringRef cppNamespace = enumInfo.getCppNamespace();
auto enumerants = enumInfo.getAllCases();
- llvm::SmallVector<StringRef, 2> namespaces;
- llvm::SplitString(cppNamespace, namespaces, "::");
-
- for (auto ns : namespaces)
- os << "namespace " << ns << " {\n";
-
+ llvm::NamespaceEmitter ns(os, enumInfo.getCppNamespace());
llvm::StringSet<> handledClasses;
// Place all availability specifications to their corresponding
@@ -441,9 +435,6 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
enumName);
handledClasses.insert(className);
}
-
- for (auto ns : llvm::reverse(namespaces))
- os << "} // namespace " << ns << "\n";
}
static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) {
@@ -459,31 +450,19 @@ static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) {
static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
EnumInfo enumInfo(enumDef);
- StringRef cppNamespace = enumInfo.getCppNamespace();
-
- llvm::SmallVector<StringRef, 2> namespaces;
- llvm::SplitString(cppNamespace, namespaces, "::");
-
- for (auto ns : namespaces)
- os << "namespace " << ns << " {\n";
+ llvm::NamespaceEmitter ns(os, enumInfo.getCppNamespace());
- if (enumInfo.isBitEnum()) {
+ if (enumInfo.isBitEnum())
emitAvailabilityQueryForBitEnum(enumDef, os);
- } else {
+ else
emitAvailabilityQueryForIntEnum(enumDef, os);
- }
-
- for (auto ns : llvm::reverse(namespaces))
- os << "} // namespace " << ns << "\n";
- os << "\n";
}
static bool emitEnumDefs(const RecordKeeper &records, raw_ostream &os) {
llvm::emitSourceFileHeader("SPIR-V Enum Availability Definitions", os,
records);
- auto defs = records.getAllDerivedDefinitions("EnumInfo");
- for (const auto *def : defs)
+ for (const Record *def : records.getAllDerivedDefinitions("EnumInfo"))
emitEnumDef(*def, os);
return false;
diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index fe26fc1..2a58305 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -113,8 +113,7 @@ static Match tensorMatch(TensorId tid) { return Match(tid); }
static Match synZeroMatch() { return Match(); }
#define IMPL_BINOP_PATTERN(OP, KIND) \
- LLVM_ATTRIBUTE_UNUSED static Match OP##Match(const Match &e0, \
- const Match &e1) { \
+ [[maybe_unused]] static Match OP##Match(const Match &e0, const Match &e1) { \
return Match(KIND, e0, e1); \
}
FOREVERY_BINOP(IMPL_BINOP_PATTERN)
diff --git a/mlir/unittests/IR/RemarkTest.cpp b/mlir/unittests/IR/RemarkTest.cpp
index bcbda90..09c576c 100644
--- a/mlir/unittests/IR/RemarkTest.cpp
+++ b/mlir/unittests/IR/RemarkTest.cpp
@@ -53,10 +53,12 @@ TEST(Remark, TestOutputOptimizationRemark) {
/*missed=*/categoryUnroll,
/*analysis=*/categoryRegister,
/*failed=*/categoryInliner};
-
+ std::unique_ptr<remark::RemarkEmittingPolicyAll> policy =
+ std::make_unique<remark::RemarkEmittingPolicyAll>();
LogicalResult isEnabled =
mlir::remark::enableOptimizationRemarksWithLLVMStreamer(
- context, yamlFile, llvm::remarks::Format::YAML, cats);
+ context, yamlFile, llvm::remarks::Format::YAML, std::move(policy),
+ cats);
ASSERT_TRUE(succeeded(isEnabled)) << "Failed to enable remark engine";
// PASS: something succeeded
@@ -202,9 +204,10 @@ TEST(Remark, TestOutputOptimizationRemarkDiagnostic) {
/*missed=*/categoryUnroll,
/*analysis=*/categoryRegister,
/*failed=*/categoryUnroll};
-
- LogicalResult isEnabled =
- remark::enableOptimizationRemarks(context, nullptr, cats, true);
+ std::unique_ptr<remark::RemarkEmittingPolicyAll> policy =
+ std::make_unique<remark::RemarkEmittingPolicyAll>();
+ LogicalResult isEnabled = remark::enableOptimizationRemarks(
+ context, nullptr, std::move(policy), cats, true);
ASSERT_TRUE(succeeded(isEnabled)) << "Failed to enable remark engine";
@@ -282,8 +285,11 @@ TEST(Remark, TestCustomOptimizationRemarkDiagnostic) {
/*analysis=*/std::nullopt,
/*failed=*/categoryLoopunroll};
+ std::unique_ptr<remark::RemarkEmittingPolicyAll> policy =
+ std::make_unique<remark::RemarkEmittingPolicyAll>();
LogicalResult isEnabled = remark::enableOptimizationRemarks(
- context, std::make_unique<MyCustomStreamer>(), cats, true);
+ context, std::make_unique<MyCustomStreamer>(), std::move(policy), cats,
+ true);
ASSERT_TRUE(succeeded(isEnabled)) << "Failed to enable remark engine";
// Remark 1: pass, category LoopUnroll
@@ -311,4 +317,66 @@ TEST(Remark, TestCustomOptimizationRemarkDiagnostic) {
EXPECT_NE(errOut.find(pass2Msg), std::string::npos); // printed
EXPECT_EQ(errOut.find(pass3Msg), std::string::npos); // filtered out
}
+
+TEST(Remark, TestRemarkFinal) {
+ testing::internal::CaptureStderr();
+ const auto *pass1Msg = "I failed";
+ const auto *pass2Msg = "I failed too";
+ const auto *pass3Msg = "I succeeded";
+ const auto *pass4Msg = "I succeeded too";
+
+ std::string categoryLoopunroll("LoopUnroll");
+
+ std::string seenMsg = "";
+
+ {
+ MLIRContext context;
+ Location loc = FileLineColLoc::get(&context, "test.cpp", 1, 5);
+ Location locOther = FileLineColLoc::get(&context, "test.cpp", 55, 5);
+
+ // Setup the remark engine
+ mlir::remark::RemarkCategories cats{/*all=*/"",
+ /*passed=*/categoryLoopunroll,
+ /*missed=*/categoryLoopunroll,
+ /*analysis=*/categoryLoopunroll,
+ /*failed=*/categoryLoopunroll};
+
+ std::unique_ptr<remark::RemarkEmittingPolicyFinal> policy =
+ std::make_unique<remark::RemarkEmittingPolicyFinal>();
+ LogicalResult isEnabled = remark::enableOptimizationRemarks(
+ context, std::make_unique<MyCustomStreamer>(), std::move(policy), cats,
+ true);
+ ASSERT_TRUE(succeeded(isEnabled)) << "Failed to enable remark engine";
+
+ // Remark 1: failure
+ remark::failed(
+ loc, remark::RemarkOpts::name("Unroller").category(categoryLoopunroll))
+ << pass1Msg;
+
+ // Remark 2: failure
+ remark::missed(
+ loc, remark::RemarkOpts::name("Unroller").category(categoryLoopunroll))
+ << remark::reason(pass2Msg);
+
+ // Remark 3: pass
+ remark::passed(
+ loc, remark::RemarkOpts::name("Unroller").category(categoryLoopunroll))
+ << pass3Msg;
+
+ // Remark 4: pass
+ remark::passed(
+ locOther,
+ remark::RemarkOpts::name("Unroller").category(categoryLoopunroll))
+ << pass4Msg;
+ }
+
+ llvm::errs().flush();
+ std::string errOut = ::testing::internal::GetCapturedStderr();
+
+ // Containment checks for messages.
+ EXPECT_EQ(errOut.find(pass1Msg), std::string::npos); // dropped
+ EXPECT_EQ(errOut.find(pass2Msg), std::string::npos); // dropped
+ EXPECT_NE(errOut.find(pass3Msg), std::string::npos); // shown
+ EXPECT_NE(errOut.find(pass4Msg), std::string::npos); // shown
+}
} // namespace
diff --git a/mlir/utils/generate-test-checks.py b/mlir/utils/generate-test-checks.py
index f80a181..3712a6b 100755
--- a/mlir/utils/generate-test-checks.py
+++ b/mlir/utils/generate-test-checks.py
@@ -31,13 +31,16 @@ import argparse
import os # Used to advertise this file's name ("autogenerated_note").
import re
import sys
+from collections import Counter
ADVERT_BEGIN = "// NOTE: Assertions have been autogenerated by "
ADVERT_END = """
-// The script is designed to make adding checks to
-// a test case fast, it is *not* designed to be authoritative
-// about what constitutes a good test! The CHECK should be
-// minimized and named to reflect the test intent.
+// This script is intended to make adding checks to a test case quick and easy.
+// It is *not* authoritative about what constitutes a good test. After using the
+// script, be sure to review and refine the generated checks. For example,
+// CHECK lines should be minimized and named to reflect the test’s intent.
+// For comprehensive guidelines, see:
+// * https://mlir.llvm.org/getting_started/TestingGuide/
"""
@@ -45,6 +48,9 @@ ADVERT_END = """
SSA_RE_STR = "[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*"
SSA_RE = re.compile(SSA_RE_STR)
+# Regex matching `dialect.op_name` (e.g. `vector.transfer_read`).
+SSA_OP_NAME_RE = re.compile(r"\b(?:\s=\s[a-z_]+)[.]([a-z_]+)\b")
+
# Regex matching the left-hand side of an assignment
SSA_RESULTS_STR = r'\s*(%' + SSA_RE_STR + r')(\s*,\s*(%' + SSA_RE_STR + r'))*\s*='
SSA_RESULTS_RE = re.compile(SSA_RESULTS_STR)
@@ -63,7 +69,12 @@ ATTR_DEF_RE = re.compile(ATTR_DEF_RE_STR)
class VariableNamer:
def __init__(self, variable_names):
self.scopes = []
+ # Counter for generic FileCHeck names, e.g. VAL_#N
self.name_counter = 0
+ # Counters for FileCheck names derived from Op names, e.g.
+ # TRANSFER_READ_#N (based on `vector.transfer_read`). Note, there's a
+ # dedicated counter for every Op type present in the input.
+ self.op_name_counter = Counter()
# Number of variable names to still generate in parent scope
self.generate_in_parent_scope_left = 0
@@ -77,17 +88,29 @@ class VariableNamer:
self.generate_in_parent_scope_left = n
# Generate a substitution name for the given ssa value name.
- def generate_name(self, source_variable_name, use_ssa_name):
+ def generate_name(self, source_variable_name, use_ssa_name, op_name=""):
# Compute variable name
- variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else ''
- if variable_name == '':
+ variable_name = (
+ self.variable_names.pop(0) if len(self.variable_names) > 0 else ""
+ )
+ if variable_name == "":
# If `use_ssa_name` is set, use the MLIR SSA value name to generate
# a FileCHeck substation string. As FileCheck requires these
# strings to start with a character, skip MLIR variables starting
# with a digit (e.g. `%0`).
+ #
+ # The next fallback option is to use the op name, if the
+ # corresponding match succeeds.
+ #
+ # If neither worked, use a generic name: `VAL_#N`.
if use_ssa_name and source_variable_name[0].isalpha():
variable_name = source_variable_name.upper()
+ elif op_name != "":
+ variable_name = (
+ op_name.upper() + "_" + str(self.op_name_counter[op_name])
+ )
+ self.op_name_counter[op_name] += 1
else:
variable_name = "VAL_" + str(self.name_counter)
self.name_counter += 1
@@ -123,6 +146,7 @@ class VariableNamer:
def clear_names(self):
self.name_counter = 0
self.used_variable_names = set()
+ self.op_name_counter.clear()
class AttributeNamer:
@@ -170,8 +194,12 @@ def process_line(line_chunks, variable_namer, use_ssa_name=False, strict_name_re
# Process the rest that contained an SSA value name.
for chunk in line_chunks:
- m = SSA_RE.match(chunk)
- ssa_name = m.group(0) if m is not None else ''
+ ssa = SSA_RE.match(chunk)
+ op_name_with_dialect = SSA_OP_NAME_RE.search(chunk)
+ ssa_name = ssa.group(0) if ssa is not None else ""
+ op_name = (
+ op_name_with_dialect.group(1) if op_name_with_dialect is not None else ""
+ )
# Check if an existing variable exists for this name.
variable = None
@@ -185,7 +213,7 @@ def process_line(line_chunks, variable_namer, use_ssa_name=False, strict_name_re
output_line += "%[[" + variable + "]]"
else:
# Otherwise, generate a new variable.
- variable = variable_namer.generate_name(ssa_name, use_ssa_name)
+ variable = variable_namer.generate_name(ssa_name, use_ssa_name, op_name)
if strict_name_re:
# Use stricter regexp for the variable name, if requested.
# Greedy matching may cause issues with the generic '.*'